mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bd7eaf6dff | |||
| 3df19e6d90 | |||
| 1910cd6000 | |||
| 46c2f98a15 | |||
| 9e8634bfc0 | |||
| 23e019092a | |||
| 4322407129 |
@@ -62,6 +62,7 @@ testData:
|
||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
|
||||
|
||||
# Configuration documentation
|
||||
configuration:
|
||||
@@ -230,3 +231,15 @@ configuration:
|
||||
|
||||
Example: https://accounts.google.com/logout
|
||||
required: false
|
||||
|
||||
enablePKCE:
|
||||
type: boolean
|
||||
description: |
|
||||
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
|
||||
PKCE adds an extra layer of security to protect against authorization code interception attacks.
|
||||
|
||||
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
|
||||
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
@@ -71,11 +71,13 @@ The middleware supports the following configuration options:
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -233,6 +235,30 @@ spec:
|
||||
- profile
|
||||
```
|
||||
|
||||
### With PKCE Enabled
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-pkce
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
enablePKCE: true # Enables PKCE for added security
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### Keeping Secrets Secret in Kubernetes
|
||||
|
||||
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
|
||||
@@ -378,6 +404,17 @@ http:
|
||||
|
||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||
|
||||
### PKCE Support
|
||||
|
||||
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
|
||||
|
||||
PKCE is recommended when:
|
||||
- Your OIDC provider supports it (most modern providers do)
|
||||
- You need an additional layer of security for the authorization code flow
|
||||
- You're concerned about potential authorization code interception attacks
|
||||
|
||||
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
|
||||
|
||||
### Token Caching and Blacklisting
|
||||
|
||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||
|
||||
+10
-2
@@ -2,8 +2,16 @@ package traefikoidc
|
||||
|
||||
import "time"
|
||||
|
||||
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
|
||||
// It stops when a value is received on the stop channel.
|
||||
// autoCleanupRoutine periodically calls the provided cleanup function.
|
||||
// It starts a ticker with the given interval and executes the cleanup function
|
||||
// on each tick. The routine stops gracefully when a signal is received on the
|
||||
// stop channel. This is typically used for background cleanup tasks like
|
||||
// expiring cache entries.
|
||||
//
|
||||
// Parameters:
|
||||
// - interval: The time duration between cleanup calls.
|
||||
// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop.
|
||||
// - cleanup: The function to call periodically for cleanup tasks.
|
||||
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
-110
@@ -1,110 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
|
||||
type TokenBlacklist struct {
|
||||
tokens map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new token blacklist instance.
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
tokens: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist with an expiration time.
|
||||
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
// Clean up expired tokens if we're at capacity
|
||||
if len(b.tokens) >= 1000 {
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
for t, exp := range b.tokens {
|
||||
if now.After(exp) || futureThreshold.After(exp) {
|
||||
delete(b.tokens, t)
|
||||
}
|
||||
}
|
||||
|
||||
// If still at capacity, remove oldest token
|
||||
if len(b.tokens) >= 1000 {
|
||||
var oldestToken string
|
||||
var oldestTime time.Time
|
||||
first := true
|
||||
for t, exp := range b.tokens {
|
||||
if first || exp.Before(oldestTime) {
|
||||
oldestToken = t
|
||||
oldestTime = exp
|
||||
first = false
|
||||
}
|
||||
}
|
||||
if oldestToken != "" {
|
||||
delete(b.tokens, oldestToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.tokens[token] = expiry
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is in the blacklist and not expired.
|
||||
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
|
||||
expiry, exists := b.tokens[token]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// If token is expired, remove it and return false
|
||||
if time.Now().After(expiry) {
|
||||
// Switch to write lock to remove expired token
|
||||
b.mutex.RUnlock()
|
||||
b.mutex.Lock()
|
||||
delete(b.tokens, token)
|
||||
b.mutex.Unlock()
|
||||
b.mutex.RLock()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist.
|
||||
// Also removes tokens that will expire within the next minute to prevent edge cases.
|
||||
func (b *TokenBlacklist) Cleanup() {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
|
||||
for token, expiry := range b.tokens {
|
||||
// Remove tokens that are expired or will expire soon
|
||||
if now.After(expiry) || futureThreshold.After(expiry) {
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove removes a token from the blacklist regardless of its expiration.
|
||||
func (b *TokenBlacklist) Remove(token string) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
|
||||
// Count returns the current number of tokens in the blacklist.
|
||||
func (b *TokenBlacklist) Count() int {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
return len(b.tokens)
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBlacklist_Add(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
|
||||
if blacklist.IsBlacklisted("nonExistentToken") {
|
||||
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Cleanup(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(-time.Hour) // Expired token
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Cleanup()
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Remove(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Remove(token)
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be removed, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Count(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token1 := "token1"
|
||||
token2 := "token2"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token1, expiry)
|
||||
blacklist.Add(token2, expiry)
|
||||
|
||||
if blacklist.Count() != 2 {
|
||||
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
|
||||
}
|
||||
}
|
||||
@@ -46,7 +46,9 @@ type Cache struct {
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 500
|
||||
|
||||
// NewCache creates a new empty cache instance that is ready for use.
|
||||
// NewCache creates a new empty cache instance with default settings.
|
||||
// It initializes the internal maps and list, sets the default maximum size,
|
||||
// and starts the automatic cleanup goroutine.
|
||||
func NewCache() *Cache {
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
@@ -60,8 +62,12 @@ func NewCache() *Cache {
|
||||
return c
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
// It moves the item to the most recently used position.
|
||||
// Set adds or updates an item in the cache with the specified key, value, and expiration duration.
|
||||
// If the key already exists, its value and expiration time are updated, and it's moved
|
||||
// to the most recently used position in the LRU list.
|
||||
// If the key does not exist and the cache is full, the least recently used item is evicted
|
||||
// before adding the new item.
|
||||
// The expiration duration is relative to the time Set is called.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -95,8 +101,11 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.elems[key] = elem
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||
// Moving the accessed item to the most recently used position.
|
||||
// Get retrieves an item from the cache by its key.
|
||||
// If the item exists and has not expired, its value and true are returned.
|
||||
// Accessing an item moves it to the most recently used position in the LRU list.
|
||||
// If the item does not exist or has expired, nil and false are returned, and the
|
||||
// expired item is removed from the cache.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -120,7 +129,9 @@ func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache.
|
||||
// Delete removes an item from the cache by its key.
|
||||
// If the key exists, the corresponding item is removed from the cache storage
|
||||
// and the LRU list.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -128,8 +139,10 @@ func (c *Cache) Delete(key string) {
|
||||
c.removeItem(key)
|
||||
}
|
||||
|
||||
// Cleanup removes all expired items from the cache. This should be called periodically
|
||||
// to prevent memory bloat from expired entries.
|
||||
// Cleanup iterates through the cache and removes all items that have expired.
|
||||
// An item is considered expired if the current time is after its ExpiresAt timestamp.
|
||||
// This method is called automatically by the auto-cleanup goroutine, but can also
|
||||
// be called manually.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -143,7 +156,11 @@ func (c *Cache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used item from the cache.
|
||||
// evictOldest removes the least recently used (oldest) item from the cache.
|
||||
// It first attempts to find and remove an expired item from the front of the LRU list.
|
||||
// If no expired items are found at the front, it removes the absolute oldest item (front of the list).
|
||||
// This method is called internally by Set when the cache reaches its maximum size.
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) evictOldest() {
|
||||
now := time.Now()
|
||||
elem := c.order.Front()
|
||||
@@ -167,7 +184,9 @@ func (c *Cache) evictOldest() {
|
||||
}
|
||||
}
|
||||
|
||||
// removeItem removes an item from both the cache and the LRU tracking structures.
|
||||
// removeItem removes an item specified by the key from the cache's internal storage (items map)
|
||||
// and its corresponding entry from the LRU list (order list and elems map).
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) removeItem(key string) {
|
||||
delete(c.items, key)
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
@@ -176,12 +195,15 @@ func (c *Cache) removeItem(key string) {
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
|
||||
// startAutoCleanup starts the background goroutine that automatically calls the Cleanup method
|
||||
// at the interval specified by c.autoCleanupInterval.
|
||||
// It uses the autoCleanupRoutine helper function.
|
||||
func (c *Cache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close terminates the auto cleanup goroutine.
|
||||
// Close stops the automatic cleanup goroutine associated with this cache instance.
|
||||
// It should be called when the cache is no longer needed to prevent resource leaks.
|
||||
func (c *Cache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
+166
-174
@@ -3,6 +3,7 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -14,10 +15,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce
|
||||
// for use in the OIDC authentication flow. The nonce is used to
|
||||
// prevent replay attacks by ensuring the token received matches
|
||||
// the authentication request.
|
||||
// generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce.
|
||||
// The nonce is used during the authentication flow to mitigate replay attacks by associating
|
||||
// the ID token with the specific authentication request.
|
||||
// It generates 32 random bytes and encodes them using base64 URL encoding.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (nonce).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -27,6 +32,42 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier.
|
||||
// According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long.
|
||||
// This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (code verifier).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateCodeVerifier() (string, error) {
|
||||
// Using 32 bytes (256 bits) will produce a 43 character base64url string
|
||||
verifierBytes := make([]byte, 32)
|
||||
_, err := rand.Read(verifierBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not generate code verifier: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
|
||||
}
|
||||
|
||||
// deriveCodeChallenge computes the PKCE code challenge from a given code verifier.
|
||||
// It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding)
|
||||
// as defined in RFC 7636.
|
||||
//
|
||||
// Parameters:
|
||||
// - codeVerifier: The high-entropy string generated by generateCodeVerifier.
|
||||
//
|
||||
// Returns:
|
||||
// - The base64 URL encoded SHA-256 hash of the code verifier (code challenge).
|
||||
func deriveCodeChallenge(codeVerifier string) string {
|
||||
// Calculate SHA-256 hash of the code verifier
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(codeVerifier))
|
||||
hash := hasher.Sum(nil)
|
||||
|
||||
// Base64url encode the hash to get the code challenge
|
||||
return base64.RawURLEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from the OIDC token endpoint.
|
||||
// It contains the various tokens and metadata returned after successful
|
||||
// code exchange or token refresh operations.
|
||||
@@ -47,14 +88,23 @@ type TokenResponse struct {
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
|
||||
// It supports both authorization code and refresh token grant types.
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
|
||||
// It handles both the "authorization_code" grant type (exchanging an authorization code for tokens)
|
||||
// and the "refresh_token" grant type (using a refresh token to obtain new tokens).
|
||||
// It includes necessary parameters like client credentials and handles PKCE verification if applicable.
|
||||
// The function follows redirects and handles potential errors during the exchange.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||
// - codeOrToken: Either the authorization code or refresh token
|
||||
// - redirectURL: The callback URL for authorization code grant
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
// - ctx: The context for the outgoing HTTP request.
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token").
|
||||
// - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant).
|
||||
// - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant).
|
||||
// - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens (ID, access, refresh).
|
||||
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant).
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
@@ -64,6 +114,11 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
if grantType == "authorization_code" {
|
||||
data.Set("code", codeOrToken)
|
||||
data.Set("redirect_uri", redirectURL)
|
||||
|
||||
// Add code_verifier if PKCE is being used
|
||||
if codeVerifier != "" {
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
} else if grantType == "refresh_token" {
|
||||
data.Set("refresh_token", codeOrToken)
|
||||
}
|
||||
@@ -108,11 +163,19 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
|
||||
// This is used to refresh access tokens before they expire.
|
||||
// getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh)
|
||||
// from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the
|
||||
// "refresh_token" grant type.
|
||||
//
|
||||
// Parameters:
|
||||
// - refreshToken: The refresh token previously obtained during authentication or a prior refresh.
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the newly obtained tokens.
|
||||
// - An error if the refresh operation fails.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
@@ -121,148 +184,17 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// handleExpiredToken manages token expiration by clearing the session
|
||||
// and initiating a new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear authentication data but preserve CSRF state
|
||||
session.SetAuthenticated(false)
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
|
||||
// Save the cleared session state
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save cleared session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback processes the authentication callback from the OIDC provider.
|
||||
// It validates the callback parameters, exchanges the authorization code for
|
||||
// tokens, verifies the tokens, and establishes the user's session.
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the callback
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF state
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
http.Error(rw, "No code in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify tokens and claims
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce to prevent replay attacks
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate user's email domain
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Update session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to original path or root
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// extractClaims parses a JWT token and extracts its claims.
|
||||
// It handles base64url decoding and JSON parsing of the token payload.
|
||||
// extractClaims decodes the payload (claims set) part of a JWT string.
|
||||
// It splits the JWT into its three parts, base64 URL decodes the second part (payload),
|
||||
// and unmarshals the resulting JSON into a map.
|
||||
// Note: This function does *not* validate the token's signature or claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A map representing the JSON claims extracted from the token payload.
|
||||
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -290,21 +222,36 @@ type TokenCache struct {
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache instance.
|
||||
// NewTokenCache creates and initializes a new TokenCache.
|
||||
// It internally creates a new generic Cache instance for storage.
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a token's claims in the cache with an expiration time.
|
||||
// Set stores the claims associated with a specific token string in the cache.
|
||||
// It prefixes the token string to avoid potential collisions with other cache types
|
||||
// and sets the provided expiration duration.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string (used as the key).
|
||||
// - claims: The map of claims associated with the token.
|
||||
// - expiration: The duration for which the cache entry should be valid.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
|
||||
// Get retrieves a token's claims from the cache.
|
||||
// Returns the claims and a boolean indicating if the token was found.
|
||||
// Get retrieves the cached claims for a given token string.
|
||||
// It prefixes the token string before querying the underlying cache.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to look up.
|
||||
//
|
||||
// Returns:
|
||||
// - The cached claims map if found and valid.
|
||||
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
@@ -315,29 +262,59 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache.
|
||||
// Delete removes the cached entry for a specific token string.
|
||||
// It prefixes the token string before calling the underlying cache's Delete method.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to remove from the cache.
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the cache.
|
||||
// Cleanup triggers the cleanup process for the underlying generic cache,
|
||||
// removing expired token entries.
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
// exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically
|
||||
// for the "authorization_code" grant type. It handles the conditional inclusion of the
|
||||
// PKCE code verifier based on the middleware's configuration (t.enablePKCE).
|
||||
//
|
||||
// Parameters:
|
||||
// - code: The authorization code received from the OIDC provider.
|
||||
// - redirectURL: The redirect URI used in the initial authorization request.
|
||||
// - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens.
|
||||
// - An error if the code exchange fails.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
|
||||
// Only include code verifier if PKCE is enabled
|
||||
effectiveCodeVerifier := ""
|
||||
if t.enablePKCE && codeVerifier != "" {
|
||||
effectiveCodeVerifier = codeVerifier
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings.
|
||||
// Used for efficient lookups in allowed domains and roles.
|
||||
// createStringMap converts a slice of strings into a map[string]struct{} (a set).
|
||||
// This is useful for creating efficient lookups (O(1) average time complexity)
|
||||
// for checking the presence of items like allowed domains, roles, or groups.
|
||||
//
|
||||
// Parameters:
|
||||
// - keys: A slice of strings to be added to the set.
|
||||
//
|
||||
// Returns:
|
||||
// - A map where the keys are the strings from the input slice and the values are empty structs.
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
@@ -346,9 +323,17 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout manages the OIDC logout process.
|
||||
// It clears the session and redirects either to the OIDC provider's
|
||||
// end session endpoint (if available) or to the configured post-logout URL.
|
||||
// handleLogout processes requests to the configured logout path.
|
||||
// It performs the following steps:
|
||||
// 1. Retrieves the current user session.
|
||||
// 2. Gets the access token (ID token hint) from the session.
|
||||
// 3. Clears all authentication-related data from the session cookies.
|
||||
// 4. Determines the final post-logout redirect URI.
|
||||
// 5. If an OIDC end_session_endpoint is configured and an ID token hint is available,
|
||||
// it builds the OIDC logout URL and redirects the user agent to the provider for logout.
|
||||
// 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI.
|
||||
//
|
||||
// It handles potential errors during session retrieval or clearing.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -390,11 +375,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
|
||||
// BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's
|
||||
// end_session_endpoint, including the required id_token_hint and optional
|
||||
// post_logout_redirect_uri parameters as query arguments.
|
||||
//
|
||||
// Parameters:
|
||||
// - endSessionURL: The OIDC provider's end session endpoint
|
||||
// - idToken: The ID token to be invalidated
|
||||
// - postLogoutRedirectURI: Where to redirect after logout completes
|
||||
// - endSessionURL: The URL of the OIDC provider's end session endpoint.
|
||||
// - idToken: The ID token previously issued to the user (used as id_token_hint).
|
||||
// - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout.
|
||||
//
|
||||
// Returns:
|
||||
// - The fully constructed logout URL string.
|
||||
// - An error if the provided endSessionURL is invalid.
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
|
||||
+6
-166
@@ -7,172 +7,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBlacklistSizeLimit(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add tokens up to maxSize
|
||||
for i := 0; i < 1000; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
// Verify size is at max
|
||||
if tb.Count() != 1000 {
|
||||
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Add one more token, should trigger cleanup/eviction
|
||||
tb.Add("newtoken", time.Now().Add(time.Hour))
|
||||
|
||||
// Size should still be at max
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add some expired tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
|
||||
}
|
||||
|
||||
// Add some valid tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
// Force cleanup
|
||||
tb.Cleanup()
|
||||
|
||||
// Only valid tokens should remain
|
||||
if tb.Count() != 500 {
|
||||
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Verify only valid tokens remain
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
for token, expiry := range tb.tokens {
|
||||
if time.Now().After(expiry) {
|
||||
t.Errorf("Found expired token after cleanup: %s", token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistOldestEviction(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add tokens at capacity with different expiration times
|
||||
baseTime := time.Now()
|
||||
oldestToken := "oldest"
|
||||
|
||||
// Add oldest token first
|
||||
tb.Add(oldestToken, baseTime.Add(time.Hour))
|
||||
|
||||
// Fill up to capacity with newer tokens
|
||||
for i := 0; i < 999; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
|
||||
}
|
||||
|
||||
// Add a new token that should evict the oldest
|
||||
newToken := "newest"
|
||||
tb.Add(newToken, baseTime.Add(time.Hour*3))
|
||||
|
||||
// Verify oldest token was evicted
|
||||
if tb.IsBlacklisted(oldestToken) {
|
||||
t.Error("Oldest token should have been evicted")
|
||||
}
|
||||
|
||||
// Verify newest token is present
|
||||
if !tb.IsBlacklisted(newToken) {
|
||||
t.Error("Newest token should be present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistMemoryUsage(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
iterations := 10000
|
||||
|
||||
// Force initial GC
|
||||
runtime.GC()
|
||||
|
||||
// Record initial memory stats
|
||||
var m1, m2 runtime.MemStats
|
||||
runtime.ReadMemStats(&m1)
|
||||
|
||||
// Simulate heavy usage
|
||||
for i := 0; i < iterations; i++ {
|
||||
// Add new token
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
|
||||
// Periodically check blacklisted status
|
||||
if i%100 == 0 {
|
||||
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
|
||||
}
|
||||
|
||||
// Periodically cleanup
|
||||
if i%1000 == 0 {
|
||||
tb.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// Force GC and wait for it to complete
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
runtime.ReadMemStats(&m2)
|
||||
|
||||
// Check memory growth (using HeapAlloc for more accurate measurement)
|
||||
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
|
||||
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
|
||||
|
||||
if memoryGrowth > maxAllowedGrowth {
|
||||
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
|
||||
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
|
||||
}
|
||||
|
||||
// Verify size stayed within limits
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
iterations := 1000
|
||||
concurrency := 10
|
||||
done := make(chan bool)
|
||||
|
||||
// Start multiple goroutines performing operations
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Add tokens
|
||||
token := fmt.Sprintf("token%d-%d", id, j)
|
||||
tb.Add(token, time.Now().Add(time.Hour))
|
||||
|
||||
// Check blacklist status
|
||||
tb.IsBlacklisted(token)
|
||||
|
||||
// Periodic cleanup
|
||||
if j%100 == 0 {
|
||||
tb.Cleanup()
|
||||
}
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify size constraints were maintained
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
// Removed tests related to the old TokenBlacklist implementation:
|
||||
// - TestTokenBlacklistSizeLimit
|
||||
// - TestTokenBlacklistExpiredCleanup
|
||||
// - TestTokenBlacklistOldestEviction
|
||||
// - TestTokenBlacklistMemoryUsage
|
||||
// - TestConcurrentTokenBlacklistOperations
|
||||
|
||||
func TestTokenCacheMemoryUsage(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
|
||||
@@ -45,6 +45,21 @@ type JWKCacheInterface interface {
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// GetJWKS retrieves the JSON Web Key Set (JWKS) from the cache or fetches it from the provider.
|
||||
// It first checks if a valid, non-expired JWKS is present in the cache. If so, it returns the cached version.
|
||||
// Otherwise, it attempts to fetch the JWKS from the specified jwksURL using the provided httpClient.
|
||||
// If the fetch is successful, the JWKS is stored in the cache with an expiration time based on CacheLifetime
|
||||
// (defaulting to 1 hour if not set) and returned.
|
||||
// This method uses double-checked locking to minimize contention when the cache needs refreshing.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request if fetching is required.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for fetching the JWKS.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the JWKSet containing the keys.
|
||||
// - An error if fetching fails or the response cannot be decoded.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
@@ -74,6 +89,8 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// Cleanup removes the cached JWKS if it has expired.
|
||||
// This is intended to be called periodically to ensure stale JWKS data is cleared.
|
||||
func (c *JWKCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -84,6 +101,17 @@ func (c *JWKCache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// fetchJWKS retrieves the JSON Web Key Set (JWKS) from the specified URL.
|
||||
// It uses the provided context and HTTP client to make the request.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for the request.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the fetched JWKSet.
|
||||
// - An error if the request fails, the status code is not OK, or the response body cannot be decoded.
|
||||
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Create a request with context to enforce timeout
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
@@ -109,6 +137,16 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JWK (JSON Web Key) object into PEM (Privacy-Enhanced Mail) format.
|
||||
// It selects the appropriate conversion function based on the JWK's key type ("kty").
|
||||
// Currently supports "RSA" and "EC" key types.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the JWK object to convert.
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the public key in PEM format.
|
||||
// - An error if the key type is unsupported or conversion fails.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -124,6 +162,17 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JWK into PEM format.
|
||||
// It decodes the modulus (n) and exponent (e) from base64 URL encoding,
|
||||
// constructs an rsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the RSA JWK object (must have "kty": "RSA").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the RSA public key in PEM format.
|
||||
// - An error if decoding parameters fails or key marshaling fails.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -155,6 +204,18 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JWK into PEM format.
|
||||
// It decodes the X and Y coordinates from base64 URL encoding, determines the
|
||||
// elliptic curve based on the "crv" parameter (P-256, P-384, P-521),
|
||||
// constructs an ecdsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the EC JWK object (must have "kty": "EC").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the EC public key in PEM format.
|
||||
// - An error if decoding parameters fails, the curve is unsupported, or key marshaling fails.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,9 +15,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var replayCacheMu sync.Mutex
|
||||
var replayCache = make(map[string]time.Time)
|
||||
var (
|
||||
replayCacheMu sync.Mutex
|
||||
replayCache = make(map[string]time.Time)
|
||||
)
|
||||
|
||||
// cleanupReplayCache iterates through the replay cache and removes entries
|
||||
// whose expiration time is before the current time. This function should be
|
||||
// called periodically to prevent the cache from growing indefinitely.
|
||||
// It acquires a mutex to ensure thread safety during cleanup.
|
||||
func cleanupReplayCache() {
|
||||
now := time.Now()
|
||||
for token, expiry := range replayCache {
|
||||
@@ -27,8 +33,16 @@ func cleanupReplayCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// ClockSkewTolerance is configurable to adjust time-based validations.
|
||||
var ClockSkewTolerance = 2 * time.Minute
|
||||
// ClockSkewToleranceFuture defines the tolerance for future-based claims like 'exp'.
|
||||
// Allows for more leniency with expiration checks.
|
||||
var ClockSkewToleranceFuture = 2 * time.Minute
|
||||
|
||||
// ClockSkewTolerancePast defines the tolerance for past-based claims like 'iat' and 'nbf'.
|
||||
// A smaller tolerance is typically used here to prevent accepting tokens issued too far in the future.
|
||||
var (
|
||||
ClockSkewTolerancePast = 10 * time.Second
|
||||
ClockSkewTolerance = 2 * time.Minute
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
type JWT struct {
|
||||
@@ -38,7 +52,18 @@ type JWT struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature.
|
||||
// It splits the token string by '.', decodes each part using base64 URL decoding,
|
||||
// and unmarshals the header and claims JSON into maps. The raw signature bytes are stored.
|
||||
// It performs basic format validation (expecting 3 parts).
|
||||
// Note: This function does *not* validate the signature or the claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a JWT struct containing the decoded parts.
|
||||
// - An error if the token format is invalid or decoding/unmarshaling fails.
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -74,8 +99,24 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// Verify performs standard claim validation on the JWT according to RFC 7519.
|
||||
// It checks the following:
|
||||
// - Algorithm ('alg') is supported.
|
||||
// - Issuer ('iss') matches the expected issuerURL.
|
||||
// - Audience ('aud') contains the expected clientID.
|
||||
// - Expiration time ('exp') is in the future (within tolerance).
|
||||
// - Issued at time ('iat') is in the past (within tolerance).
|
||||
// - Not before time ('nbf'), if present, is in the past (within tolerance).
|
||||
// - Subject ('sub') claim exists and is not empty.
|
||||
// - JWT ID ('jti'), if present, is checked against a replay cache to prevent token reuse.
|
||||
//
|
||||
// Parameters:
|
||||
// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com").
|
||||
// - clientID: The expected audience value (the client ID of this application).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if all standard claims are valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
@@ -164,6 +205,17 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience checks if the expected audience is present in the token's 'aud' claim.
|
||||
// The 'aud' claim can be a single string or an array of strings.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenAudience: The 'aud' claim value extracted from the token (can be string or []interface{}).
|
||||
// - expectedAudience: The audience value expected for this application (client ID).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the expected audience is found.
|
||||
// - An error if the claim type is invalid or the expected audience is not present.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
@@ -187,6 +239,15 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer checks if the token's 'iss' claim matches the expected issuer URL.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenIssuer: The 'iss' claim value from the token.
|
||||
// - expectedIssuer: The expected issuer URL configured for the OIDC provider.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the issuers match.
|
||||
// - An error if the issuers do not match.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
||||
@@ -194,54 +255,76 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyTimeConstraint is a generic function to verify time-based claims
|
||||
// verifyTimeConstraint checks time-based claims ('exp', 'iat', 'nbf') against the current time,
|
||||
// allowing for configurable clock skew. It uses different tolerances for past and future checks.
|
||||
//
|
||||
// Parameters:
|
||||
// - unixTime: The timestamp value from the claim (as a float64 Unix time).
|
||||
// - claimName: The name of the claim being verified ("exp", "iat", "nbf").
|
||||
// - future: A boolean indicating the direction of the check (true for 'exp', false for 'iat'/'nbf').
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the time constraint is met within the allowed tolerance.
|
||||
// - An error describing the failure (e.g., "token has expired", "token used before issued").
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
now := time.Now() // Use current time without truncation
|
||||
|
||||
// For expiration (future=true), we add skew to now (making now later)
|
||||
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
|
||||
skewDirection := 1
|
||||
if !future {
|
||||
skewDirection = -1
|
||||
}
|
||||
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
|
||||
|
||||
if claimTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For expiration: if skewedNow (later) is after expiration, token expired
|
||||
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
|
||||
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
|
||||
var reason string
|
||||
if future {
|
||||
reason = "has expired"
|
||||
} else {
|
||||
var err error
|
||||
if future { // 'exp' check
|
||||
// Token is expired if Now is after (ClaimTime + FutureTolerance)
|
||||
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
|
||||
if now.After(allowedExpiry) {
|
||||
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
|
||||
}
|
||||
} else { // 'iat' or 'nbf' check
|
||||
// Token is invalid if Now is before (ClaimTime - PastTolerance)
|
||||
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
|
||||
if now.Before(allowedStart) {
|
||||
reason := "not yet valid"
|
||||
if claimName == "iat" {
|
||||
reason = "used before issued"
|
||||
} else {
|
||||
reason = "not yet valid"
|
||||
}
|
||||
err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC())
|
||||
}
|
||||
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// verifyExpiration checks the 'exp' (Expiration Time) claim.
|
||||
// It calls verifyTimeConstraint with future=true.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
return verifyTimeConstraint(expiration, "exp", true)
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks the 'iat' (Issued At) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
return verifyTimeConstraint(issuedAt, "iat", false)
|
||||
}
|
||||
|
||||
// verifyNotBefore checks the 'nbf' (Not Before) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
return verifyTimeConstraint(notBefore, "nbf", false)
|
||||
}
|
||||
|
||||
// verifySignature validates the JWT's signature using the provided public key.
|
||||
// It parses the public key from PEM format, selects the appropriate hashing algorithm
|
||||
// based on the 'alg' parameter (SHA256/384/512), hashes the token's signing input
|
||||
// (header + "." + payload), and then verifies the signature against the hash using
|
||||
// the corresponding RSA (PKCS1v15 or PSS) or ECDSA verification method.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw, complete JWT string.
|
||||
// - publicKeyPEM: The public key corresponding to the private key used for signing, in PEM format.
|
||||
// - alg: The algorithm specified in the JWT header (e.g., "RS256", "ES384").
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the signature is valid.
|
||||
// - An error if the token format is invalid, decoding fails, key parsing fails,
|
||||
// the algorithm is unsupported, or the signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
|
||||
+870
-99
File diff suppressed because it is too large
Load Diff
+32
-11
@@ -15,6 +15,8 @@ type MetadataCache struct {
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// NewMetadataCache creates a new MetadataCache instance.
|
||||
// It initializes the cache structure and starts the background cleanup goroutine.
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
@@ -24,7 +26,8 @@ func NewMetadataCache() *MetadataCache {
|
||||
return c
|
||||
}
|
||||
|
||||
// Cleanup removes expired metadata from the cache.
|
||||
// Cleanup removes the cached provider metadata if it has expired.
|
||||
// This is called periodically by the auto-cleanup goroutine.
|
||||
func (c *MetadataCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -34,11 +37,32 @@ func (c *MetadataCache) Cleanup() {
|
||||
c.metadata = nil
|
||||
}
|
||||
}
|
||||
|
||||
// isCacheValid checks if the cached metadata is present and has not expired.
|
||||
// Note: This function assumes the read lock is held or it's called from a context
|
||||
// where the lock is already held (like within GetMetadata after locking).
|
||||
func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the metadata from cache or fetches it if expired
|
||||
// GetMetadata retrieves the OIDC provider metadata.
|
||||
// It first checks the cache for valid, non-expired metadata. If found, it's returned immediately.
|
||||
// If the cache is empty or expired, it attempts to fetch the metadata from the provider's
|
||||
// well-known endpoint using discoverProviderMetadata.
|
||||
// If fetching is successful, the new metadata is cached for 1 hour.
|
||||
// If fetching fails but valid metadata exists in the cache (even if expired), the cache expiry
|
||||
// is extended by 5 minutes, and the cached data is returned to prevent thundering herd issues.
|
||||
// If fetching fails and there's no cached data, an error is returned.
|
||||
// It employs double-checked locking for thread safety and performance.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider.
|
||||
// - httpClient: The HTTP client to use for fetching metadata.
|
||||
// - logger: The logger instance for recording errors or warnings.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the ProviderMetadata struct.
|
||||
// - An error if metadata cannot be retrieved from cache or fetched from the provider.
|
||||
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
if c.isCacheValid() {
|
||||
@@ -67,24 +91,21 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
|
||||
}
|
||||
|
||||
c.metadata = metadata
|
||||
// Calculate expiration time based on usage patterns
|
||||
usageCount := 0 // This should be replaced with actual usage tracking logic
|
||||
if usageCount < 10 {
|
||||
c.expiresAt = time.Now().Add(30 * time.Minute)
|
||||
} else if usageCount < 50 {
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
} else {
|
||||
c.expiresAt = time.Now().Add(2 * time.Hour)
|
||||
}
|
||||
// Set a fixed cache lifetime (e.g., 1 hour)
|
||||
// TODO: Consider making this configurable or respecting HTTP cache headers
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
|
||||
// End of GetMetadata
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background goroutine that periodically calls Cleanup
|
||||
// to remove expired metadata from the cache.
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup goroutine associated with this metadata cache.
|
||||
func (c *MetadataCache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
+200
-40
@@ -16,8 +16,15 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateSecureRandomString creates a cryptographically secure random string of specified length.
|
||||
// It returns the generated string or an error if random generation fails.
|
||||
// generateSecureRandomString creates a cryptographically secure, hex-encoded random string.
|
||||
// It reads the specified number of bytes from crypto/rand and encodes them as a hexadecimal string.
|
||||
//
|
||||
// Parameters:
|
||||
// - length: The number of random bytes to generate (the resulting hex string will be twice this length).
|
||||
//
|
||||
// Returns:
|
||||
// - A hex-encoded random string.
|
||||
// - An error if reading random bytes fails.
|
||||
func generateSecureRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -56,7 +63,14 @@ const (
|
||||
minEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// compressToken compresses a token using gzip and base64 encodes it.
|
||||
// compressToken compresses the input string using gzip and then encodes the result using standard base64 encoding.
|
||||
// If any error occurs during compression, it returns the original uncompressed token as a fallback.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The string to compress.
|
||||
//
|
||||
// Returns:
|
||||
// - The base64 encoded, gzipped string, or the original string if compression fails.
|
||||
func compressToken(token string) string {
|
||||
var b bytes.Buffer
|
||||
gz := gzip.NewWriter(&b)
|
||||
@@ -69,7 +83,15 @@ func compressToken(token string) string {
|
||||
return base64.StdEncoding.EncodeToString(b.Bytes())
|
||||
}
|
||||
|
||||
// decompressToken decompresses a base64 encoded gzipped token.
|
||||
// decompressToken decodes a standard base64 encoded string and then decompresses the result using gzip.
|
||||
// If base64 decoding or gzip decompression fails, it returns the original input string as a fallback,
|
||||
// assuming it might not have been compressed.
|
||||
//
|
||||
// Parameters:
|
||||
// - compressed: The base64 encoded, gzipped string.
|
||||
//
|
||||
// Returns:
|
||||
// - The decompressed original string, or the input string if decompression fails.
|
||||
func decompressToken(compressed string) string {
|
||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||
if err != nil {
|
||||
@@ -128,25 +150,27 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
|
||||
|
||||
// Initialize session pool.
|
||||
sm.sessionPool.New = func() interface{} {
|
||||
// Initialize SessionData with necessary fields and the mutex.
|
||||
return &SessionData{
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshMutex: sync.Mutex{}, // Initialize the mutex
|
||||
}
|
||||
}
|
||||
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// getSessionOptions returns secure session options configured for the current request.
|
||||
// Parameters:
|
||||
// - isSecure: Whether the current request is using HTTPS.
|
||||
// getSessionOptions returns a sessions.Options struct configured with security best practices.
|
||||
// It sets HttpOnly to true, Secure based on the request scheme or forceHTTPS setting,
|
||||
// SameSite to LaxMode, MaxAge to the absoluteSessionTimeout, and Path to "/".
|
||||
//
|
||||
// The options ensure cookies are:
|
||||
// - HTTP-only (not accessible via JavaScript)
|
||||
// - Secure when using HTTPS or when forceHTTPS is enabled
|
||||
// - Using SameSite=Lax for CSRF protection
|
||||
// - Set with appropriate timeout and path settings
|
||||
// Parameters:
|
||||
// - isSecure: A boolean indicating if the current request context is secure (HTTPS).
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a configured sessions.Options struct.
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
@@ -208,11 +232,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// getTokenChunkSessions retrieves all session chunks for a given token type.
|
||||
// getTokenChunkSessions retrieves all cookie chunks associated with a large token (access or refresh).
|
||||
// It iteratively attempts to load cookies named "{baseName}_0", "{baseName}_1", etc., until
|
||||
// a cookie is not found or returns an error. The loaded sessions are stored in the provided chunks map.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request
|
||||
// - baseName: The base name for the token's session cookies
|
||||
// - chunks: Map to store the chunks in
|
||||
// - r: The incoming HTTP request containing the cookies.
|
||||
// - baseName: The base name of the cookie (e.g., accessTokenCookie).
|
||||
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks) to populate with the found session chunks.
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
@@ -251,12 +278,21 @@ type SessionData struct {
|
||||
// refreshTokenChunks stores additional chunks of the refresh token
|
||||
// when it exceeds the maximum cookie size.
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshMutex protects refresh token operations within this session instance.
|
||||
refreshMutex sync.Mutex
|
||||
}
|
||||
|
||||
// Save persists all session data to cookies in the HTTP response.
|
||||
// It saves the main session, token sessions, and any token chunks,
|
||||
// applying appropriate security options to each cookie. All cookies
|
||||
// are saved with consistent security settings based on the request scheme.
|
||||
// Save persists all parts of the session (main, access token, refresh token, and any chunks)
|
||||
// back to the client as cookies in the HTTP response. It applies secure cookie options
|
||||
// obtained via getSessionOptions based on the request's security context.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The original HTTP request (used to determine security context for cookie options).
|
||||
// - w: The HTTP response writer to which the Set-Cookie headers will be added.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if saving any of the session components fails.
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
@@ -300,7 +336,19 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||
// Clear removes all session data associated with this SessionData instance.
|
||||
// It clears the values map of the main, access, and refresh sessions, sets their MaxAge to -1
|
||||
// to expire the cookies immediately, and clears any associated token chunk cookies.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired sessions to send the
|
||||
// expiring Set-Cookie headers. Finally, it clears internal fields and returns the SessionData
|
||||
// object to the pool.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request (required by the underlying session store).
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if saving the expired sessions fails (only if w is not nil).
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions.
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
@@ -335,7 +383,12 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// clearTokenChunks removes all session chunks for a given token type.
|
||||
// clearTokenChunks iterates through a map of session chunks, clears their values,
|
||||
// and sets their MaxAge to -1 to expire them. This is used internally by Clear.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request (required by the underlying session store, though not directly used here).
|
||||
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
@@ -345,7 +398,12 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns whether the current session is authenticated.
|
||||
// GetAuthenticated checks if the session is marked as authenticated and has not exceeded
|
||||
// the absolute session timeout.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the "authenticated" flag is set to true and the session creation time is within the allowed timeout.
|
||||
// - false otherwise.
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
if !auth {
|
||||
@@ -360,8 +418,15 @@ func (sd *SessionData) GetAuthenticated() bool {
|
||||
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
|
||||
}
|
||||
|
||||
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
||||
// Returns an error if generating a new session ID fails.
|
||||
// SetAuthenticated sets the authentication status of the session.
|
||||
// If setting to true, it generates a new secure session ID for the main session
|
||||
// to prevent session fixation attacks and records the current time as the creation time.
|
||||
//
|
||||
// Parameters:
|
||||
// - value: The boolean authentication status (true for authenticated, false otherwise).
|
||||
//
|
||||
// Returns:
|
||||
// - An error if generating a new session ID fails when setting value to true.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
if value {
|
||||
id, err := generateSecureRandomString(32)
|
||||
@@ -375,7 +440,12 @@ func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the complete access token from the session.
|
||||
// GetAccessToken retrieves the access token stored in the session.
|
||||
// It handles reassembling the token from multiple cookie chunks if necessary
|
||||
// and decompresses it if it was stored compressed.
|
||||
//
|
||||
// Returns:
|
||||
// - The complete, decompressed access token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -409,7 +479,14 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// SetAccessToken stores the access token in the session.
|
||||
// SetAccessToken stores the provided access token in the session.
|
||||
// It first expires any existing access token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary access token session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_a_0, _oidc_raczylo_a_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The access token string to store.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
@@ -439,7 +516,12 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||
// GetRefreshToken retrieves the refresh token stored in the session.
|
||||
// It handles reassembling the token from multiple cookie chunks if necessary
|
||||
// and decompresses it if it was stored compressed.
|
||||
//
|
||||
// Returns:
|
||||
// - The complete, decompressed refresh token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -473,7 +555,14 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// SetRefreshToken stores the refresh token in the session.
|
||||
// SetRefreshToken stores the provided refresh token in the session.
|
||||
// It first expires any existing refresh token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary refresh token session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_r_0, _oidc_raczylo_r_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The refresh token string to store.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
@@ -503,7 +592,13 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireAccessTokenChunks expires any existing access token chunk cookies.
|
||||
// expireAccessTokenChunks finds all existing access token chunk cookies (_oidc_raczylo_a_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new access token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
@@ -521,7 +616,13 @@ func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
|
||||
// expireRefreshTokenChunks finds all existing refresh token chunk cookies (_oidc_raczylo_r_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new refresh token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
@@ -539,7 +640,15 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size.
|
||||
// splitIntoChunks divides a string `s` into a slice of strings, where each element
|
||||
// has a maximum length of `chunkSize`.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to split.
|
||||
// - chunkSize: The maximum size of each chunk.
|
||||
//
|
||||
// Returns:
|
||||
// - A slice of strings representing the chunks.
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
@@ -554,46 +663,97 @@ func splitIntoChunks(s string, chunkSize int) []string {
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF retrieves the CSRF token from the session.
|
||||
// GetCSRF retrieves the Cross-Site Request Forgery (CSRF) token stored in the main session.
|
||||
//
|
||||
// Returns:
|
||||
// - The CSRF token string, or an empty string if not set.
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF stores a new CSRF token in the session.
|
||||
// SetCSRF stores the provided CSRF token string in the main session.
|
||||
// This token is typically generated at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The CSRF token to store.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce retrieves the nonce value from the session.
|
||||
// GetNonce retrieves the OIDC nonce value stored in the main session.
|
||||
// The nonce is used to associate an ID token with the specific authentication request.
|
||||
//
|
||||
// Returns:
|
||||
// - The nonce string, or an empty string if not set.
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce stores a new nonce value in the session.
|
||||
// SetNonce stores the provided OIDC nonce string in the main session.
|
||||
// This nonce is typically generated at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - nonce: The nonce string to store.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
// GetCodeVerifier retrieves the PKCE (Proof Key for Code Exchange) code verifier
|
||||
// stored in the main session. This is only relevant if PKCE is enabled.
|
||||
//
|
||||
// Returns:
|
||||
// - The code verifier string, or an empty string if not set or PKCE is disabled.
|
||||
func (sd *SessionData) GetCodeVerifier() string {
|
||||
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
||||
return codeVerifier
|
||||
}
|
||||
|
||||
// SetCodeVerifier stores the provided PKCE code verifier string in the main session.
|
||||
// This is typically called at the start of the authentication flow if PKCE is enabled.
|
||||
//
|
||||
// Parameters:
|
||||
// - codeVerifier: The PKCE code verifier string to store.
|
||||
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
sd.mainSession.Values["code_verifier"] = codeVerifier
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address stored in the main session.
|
||||
// This is typically extracted from the ID token claims after successful authentication.
|
||||
//
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail stores the user's email address in the session.
|
||||
// SetEmail stores the provided user email address string in the main session.
|
||||
// This is typically called after successful authentication and claim extraction.
|
||||
//
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
|
||||
// GetIncomingPath retrieves the original request URI (including query parameters)
|
||||
// that the user was trying to access before being redirected for authentication.
|
||||
// This is stored in the main session to allow redirection back after successful login.
|
||||
//
|
||||
// Returns:
|
||||
// - The original request URI string, or an empty string if not set.
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath stores the original request path that triggered the authentication flow.
|
||||
// SetIncomingPath stores the original request URI (path and query parameters)
|
||||
// in the main session. This is typically called at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - path: The original request URI string (e.g., "/protected/resource?id=123").
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
|
||||
+9
-2
@@ -1,7 +1,9 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -12,7 +14,12 @@ func generateRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
// Handle error appropriately in a real application, maybe panic in test helper
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
+113
-37
@@ -22,6 +22,11 @@ type Config struct {
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
|
||||
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
|
||||
// This enhances security but might not be supported by all OIDC providers
|
||||
// Default: false
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
|
||||
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
||||
// Example: /oauth2/callback
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
@@ -79,6 +84,11 @@ type Config struct {
|
||||
|
||||
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
|
||||
HTTPClient *http.Client
|
||||
|
||||
// RefreshGracePeriodSeconds defines how many seconds before a token expires
|
||||
// the plugin should attempt to refresh it proactively (optional)
|
||||
// Default: 60
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -103,20 +113,36 @@ const (
|
||||
// - RateLimit: 100 requests per second
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
// - EnablePKCE: false (PKCE is opt-in)
|
||||
//
|
||||
// CreateConfig initializes a new Config struct with default values for optional fields.
|
||||
// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the
|
||||
// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret,
|
||||
// CallbackURL, and SessionEncryptionKey must be set explicitly after creation.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a new Config struct with default settings applied.
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate performs validation checks on the Config.
|
||||
// It ensures all required fields are set and have valid values.
|
||||
// Returns an error if any validation check fails.
|
||||
// Validate checks the configuration settings for validity.
|
||||
// It ensures that required fields (ProviderURL, CallbackURL, ClientID, ClientSecret, SessionEncryptionKey)
|
||||
// are present and that URLs are well-formed (HTTPS where required). It also validates
|
||||
// the session key length, log level, rate limit, and refresh grace period.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the configuration is valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (c *Config) Validate() error {
|
||||
// Validate provider URL
|
||||
if c.ProviderURL == "" {
|
||||
@@ -190,16 +216,34 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
|
||||
}
|
||||
|
||||
// Validate refresh grace period
|
||||
if c.RefreshGracePeriodSeconds < 0 {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidSecureURL checks if the provided string is a valid HTTPS URL
|
||||
// isValidSecureURL checks if a given string represents a valid, absolute HTTPS URL.
|
||||
// It uses url.Parse and checks for a nil error, an "https" scheme, and a non-empty host.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The URL string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the string is a valid HTTPS URL, false otherwise.
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level is valid
|
||||
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
|
||||
//
|
||||
// Parameters:
|
||||
// - level: The log level string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the log level is valid, false otherwise.
|
||||
func isValidLogLevel(level string) bool {
|
||||
return level == "debug" || level == "info" || level == "error"
|
||||
}
|
||||
@@ -216,14 +260,20 @@ type Logger struct {
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger with the specified log level.
|
||||
// The log level determines which messages are output:
|
||||
// - "debug": Outputs all messages (debug, info, error)
|
||||
// - "info": Outputs info and error messages
|
||||
// - "error": Outputs only error messages
|
||||
// NewLogger creates and configures a new Logger instance based on the provided log level.
|
||||
// It initializes loggers for ERROR (stderr), INFO (stdout), and DEBUG (stdout) levels,
|
||||
// enabling output based on the specified level:
|
||||
// - "error": Only ERROR messages are output.
|
||||
// - "info": INFO and ERROR messages are output.
|
||||
// - "debug": DEBUG, INFO, and ERROR messages are output.
|
||||
//
|
||||
// Error messages are always written to stderr, while info and debug
|
||||
// messages are written to stdout when enabled.
|
||||
// If an invalid level is provided, it defaults to behavior similar to "error".
|
||||
//
|
||||
// Parameters:
|
||||
// - logLevel: The desired logging level ("debug", "info", or "error").
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the configured Logger instance.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
@@ -245,51 +295,77 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an informational message.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
// Info logs a message at the INFO level using Printf style formatting.
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
// Debug logs a message at the DEBUG level using Printf style formatting.
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
// Error logs a message at the ERROR level using Printf style formatting.
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an informational message using Printf formatting.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
// Infof logs a message at the INFO level using Printf style formatting.
|
||||
// Equivalent to calling l.Info(format, args...).
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message using Printf formatting.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
// Debugf logs a message at the DEBUG level using Printf style formatting.
|
||||
// Equivalent to calling l.Debug(format, args...).
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message using Printf formatting.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
// Errorf logs a message at the ERROR level using Printf style formatting.
|
||||
// Equivalent to calling l.Error(format, args...).
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to both the HTTP response and the error log.
|
||||
// It ensures consistent error handling across the middleware by logging the error
|
||||
// and sending an appropriate HTTP response to the client.
|
||||
// handleError logs an error message using the provided logger and sends an HTTP error
|
||||
// response to the client with the specified message and status code.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The http.ResponseWriter to send the error response to.
|
||||
// - message: The error message string.
|
||||
// - code: The HTTP status code for the response.
|
||||
// - logger: The Logger instance to use for logging the error.
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
|
||||
Reference in New Issue
Block a user