mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c09e7a9228 | |||
| e5da5d4fe9 | |||
| 31db701dda | |||
| 16481afd36 | |||
| 751933ffa0 | |||
| e74153b107 | |||
| 025107fe3e | |||
| dfb9c0771e | |||
| 1107df40e7 | |||
| bf294569eb | |||
| 482c346840 | |||
| a462e44896 | |||
| 5eff0dc866 |
+1
-1
@@ -22,7 +22,7 @@ testData:
|
||||
- raczylo.com
|
||||
allowedRolesAndGroups:
|
||||
- guest-endpoints
|
||||
sessionEncryptionKey: potato-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
|
||||
@@ -19,6 +19,8 @@ Middleware currently supports following scenarios:
|
||||
|
||||
#### How to configure...
|
||||
|
||||
* `sessionEncryptionKey` should be at least 32 bytes long.
|
||||
|
||||
##### Keeping secrets secret
|
||||
|
||||
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
### TODO / wishlist
|
||||
|
||||
- [] Improve test coverage
|
||||
- [x] Improve caching mechanism
|
||||
- [x] Add automatic release and semver generation
|
||||
@@ -1,169 +1,153 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||
type CacheItem struct {
|
||||
// Value is the cached data of any type
|
||||
// Value is the cached data of any type.
|
||||
Value interface{}
|
||||
|
||||
// ExpiresAt is the timestamp when this item should be considered expired
|
||||
// and removed from the cache during cleanup operations
|
||||
// ExpiresAt is the timestamp when this item should be considered expired.
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
|
||||
type Cache struct {
|
||||
// items stores the cached data with string keys
|
||||
items map[string]CacheItem
|
||||
|
||||
// mutex protects concurrent access to the items map
|
||||
// Use RLock/RUnlock for reads and Lock/Unlock for writes
|
||||
mutex sync.RWMutex
|
||||
|
||||
// maxSize is the maximum number of items allowed in the cache
|
||||
maxSize int
|
||||
|
||||
// accessList maintains the order of item access for eviction
|
||||
accessList []string
|
||||
// lruEntry represents an entry in the LRU list.
|
||||
type lruEntry struct {
|
||||
key string
|
||||
}
|
||||
|
||||
// DefaultMaxSize is the default maximum number of items in the cache
|
||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
|
||||
type Cache struct {
|
||||
// items stores the cached data with string keys.
|
||||
items map[string]CacheItem
|
||||
|
||||
// order maintains the usage order; most recently used items are at the back.
|
||||
order *list.List
|
||||
|
||||
// elems maps keys to their corresponding list elements for O(1) access.
|
||||
elems map[string]*list.Element
|
||||
|
||||
// mutex protects concurrent access to the cache.
|
||||
mutex sync.RWMutex
|
||||
|
||||
// maxSize is the maximum number of items allowed in the cache.
|
||||
maxSize int
|
||||
}
|
||||
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 1000
|
||||
|
||||
// NewCache creates a new empty cache instance.
|
||||
// The cache is immediately ready for use and is thread-safe.
|
||||
// NewCache creates a new empty cache instance that is ready for use.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
items: make(map[string]CacheItem),
|
||||
maxSize: DefaultMaxSize,
|
||||
accessList: make([]string, 0, DefaultMaxSize),
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
// Parameters:
|
||||
// - key: Unique identifier for the cached item
|
||||
// - value: The data to cache (can be of any type)
|
||||
// - expiration: How long the item should remain in the cache
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
// It moves the item to the most recently used position.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// If key exists, update it
|
||||
now := time.Now()
|
||||
expTime := now.Add(expiration)
|
||||
|
||||
// Update existing item.
|
||||
if _, exists := c.items[key]; exists {
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
ExpiresAt: expTime,
|
||||
}
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.MoveToBack(elem)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If cache is full, remove oldest item
|
||||
// Evict oldest item if cache is full.
|
||||
if len(c.items) >= c.maxSize {
|
||||
c.evictOldest()
|
||||
}
|
||||
|
||||
// Add new item
|
||||
// Add new item.
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
ExpiresAt: expTime,
|
||||
}
|
||||
c.accessList = append(c.accessList, key)
|
||||
elem := c.order.PushBack(lruEntry{key: key})
|
||||
c.elems[key] = elem
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||
// Parameters:
|
||||
// - key: The identifier of the item to retrieve
|
||||
// Returns:
|
||||
// - value: The cached data (nil if not found or expired)
|
||||
// - found: true if the item was found and is valid, false otherwise
|
||||
// Thread-safe: Uses read locking to ensure safe concurrent access.
|
||||
// Moving the accessed item to the most recently used position.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
item, found := c.items[key]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
c.mutex.Lock()
|
||||
c.removeItem(key)
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Update access order
|
||||
c.mutex.Lock()
|
||||
c.updateAccessOrder(key)
|
||||
c.mutex.Unlock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check for expiration.
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
c.removeItem(key)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Move item to the back (most recently used).
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.MoveToBack(elem)
|
||||
}
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache if it exists.
|
||||
// If the item doesn't exist, this operation is a no-op.
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
// Delete removes an item from the cache.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.items, key)
|
||||
|
||||
c.removeItem(key)
|
||||
}
|
||||
|
||||
// Cleanup removes all expired items from the cache.
|
||||
// This should be called periodically to prevent memory leaks from
|
||||
// expired items that haven't been accessed (and thus not removed during Get operations).
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
// Cleanup removes all expired items from the cache. This should be called periodically
|
||||
// to prevent memory bloat from expired entries.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
var newAccessList []string
|
||||
|
||||
for _, key := range c.accessList {
|
||||
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
|
||||
newAccessList = append(newAccessList, key)
|
||||
} else {
|
||||
delete(c.items, key)
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiresAt) {
|
||||
c.removeItem(key)
|
||||
}
|
||||
}
|
||||
|
||||
c.accessList = newAccessList
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used item from the cache
|
||||
// evictOldest removes the least recently used item from the cache.
|
||||
func (c *Cache) evictOldest() {
|
||||
if len(c.accessList) > 0 {
|
||||
oldest := c.accessList[0]
|
||||
c.removeItem(oldest)
|
||||
elem := c.order.Front()
|
||||
if elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
c.removeItem(entry.key)
|
||||
}
|
||||
}
|
||||
|
||||
// removeItem removes an item from both the cache and access list
|
||||
// removeItem removes an item from both the cache and the LRU tracking structures.
|
||||
func (c *Cache) removeItem(key string) {
|
||||
delete(c.items, key)
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateAccessOrder moves the accessed key to the end of the access list
|
||||
func (c *Cache) updateAccessOrder(key string) {
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
|
||||
break
|
||||
}
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.Remove(elem)
|
||||
delete(c.elems, key)
|
||||
}
|
||||
}
|
||||
|
||||
+2
-21
@@ -12,27 +12,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// newSessionOptions creates secure session cookie options.
|
||||
// Parameters:
|
||||
// - isSecure: Whether to set the Secure flag on cookies
|
||||
// Returns session options configured for security with:
|
||||
// - HttpOnly flag to prevent JavaScript access
|
||||
// - SameSite=Lax for CSRF protection
|
||||
// - Appropriate timeout and path settings
|
||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce
|
||||
// for use in the OIDC authentication flow. The nonce is used to
|
||||
// prevent replay attacks by ensuring the token received matches
|
||||
@@ -133,14 +114,14 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
|
||||
|
||||
// Save the cleared session state
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save cleared session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ type JWKCacheInterface interface {
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for fetching keys
|
||||
//
|
||||
// Returns:
|
||||
// - The JSON Web Key Set
|
||||
// - An error if the keys cannot be retrieved or parsed
|
||||
@@ -115,6 +116,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for the request
|
||||
//
|
||||
// Returns:
|
||||
// - The parsed JSON Web Key Set
|
||||
// - An error if the request fails or the response is invalid
|
||||
|
||||
@@ -38,6 +38,7 @@ type JWT struct {
|
||||
// (header, claims, signature) using base64url decoding.
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT token string
|
||||
//
|
||||
// Returns:
|
||||
// - A parsed JWT struct
|
||||
// - An error if the token format is invalid or parsing fails
|
||||
@@ -88,6 +89,7 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
// - not before time (nbf) is in the past (with clock skew tolerance)
|
||||
// - subject (sub) is present and not empty
|
||||
// - algorithm matches expected value to prevent algorithm switching attacks
|
||||
//
|
||||
// Returns an error if any validation fails.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Debug logging of validation parameters
|
||||
@@ -111,7 +113,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
}
|
||||
|
||||
claims := j.Claims
|
||||
|
||||
|
||||
// Debug logging of all claims
|
||||
fmt.Printf("Token claims: %+v\n", claims)
|
||||
|
||||
@@ -174,10 +176,11 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Parameters:
|
||||
// - tokenAudience: The audience claim from the token
|
||||
// - expectedAudience: The expected audience value
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
|
||||
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
|
||||
tokenAudience, expectedAudience)
|
||||
|
||||
switch aud := tokenAudience.(type) {
|
||||
@@ -207,10 +210,11 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
// Parameters:
|
||||
// - tokenIssuer: The issuer claim from the token
|
||||
// - expectedIssuer: The expected issuer URL
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
|
||||
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
|
||||
tokenIssuer, expectedIssuer)
|
||||
|
||||
if tokenIssuer != expectedIssuer {
|
||||
@@ -227,25 +231,26 @@ const clockSkewTolerance = 2 * time.Minute
|
||||
// The expiration time is compared against the current time with clock skew tolerance.
|
||||
// Parameters:
|
||||
// - expiration: The expiration timestamp from the token
|
||||
//
|
||||
// Returns an error if the token has expired.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(clockSkewTolerance)
|
||||
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
expirationTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
|
||||
// Allow tokens that expire exactly now
|
||||
if expirationTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
if skewedNow.After(expirationTime) {
|
||||
return fmt.Errorf("token has expired (exp: %v, now: %v)",
|
||||
expirationTime.UTC(), now.UTC())
|
||||
@@ -257,27 +262,28 @@ func verifyExpiration(expiration float64) error {
|
||||
// Ensures the token wasn't issued in the future, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - issuedAt: The issued-at timestamp from the token
|
||||
//
|
||||
// Returns an error if the token was issued in the future.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
issuedAtTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
|
||||
// Allow tokens issued in the same second as current time
|
||||
if issuedAtTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
if skewedNow.Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
|
||||
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
|
||||
issuedAtTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
@@ -287,25 +293,26 @@ func verifyIssuedAt(issuedAt float64) error {
|
||||
// Ensures the token is not used before its valid time period, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - notBefore: The not-before timestamp from the token
|
||||
//
|
||||
// Returns an error if the token is not yet valid.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
notBeforeTime := time.Unix(int64(notBefore), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
notBeforeTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
|
||||
// Allow tokens that become valid exactly now
|
||||
if notBeforeTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
if skewedNow.Before(notBeforeTime) {
|
||||
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
|
||||
notBeforeTime.UTC(), now.UTC())
|
||||
@@ -318,15 +325,17 @@ func verifyNotBefore(notBefore float64) error {
|
||||
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
|
||||
// - RSA-PSS: PS256, PS384, PS512
|
||||
// - ECDSA: ES256, ES384, ES512
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The complete JWT token string
|
||||
// - publicKeyPEM: The PEM-encoded public key for verification
|
||||
// - alg: The signature algorithm identifier
|
||||
//
|
||||
// Returns an error if signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
|
||||
|
||||
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
|
||||
@@ -2,7 +2,6 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -13,6 +12,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"runtime"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
@@ -61,7 +62,6 @@ type TraefikOidc struct {
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initComplete chan struct{}
|
||||
endSessionURL string
|
||||
baseURL string
|
||||
postLogoutRedirectURI string
|
||||
sessionManager *SessionManager
|
||||
}
|
||||
@@ -81,8 +81,6 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon": {},
|
||||
}
|
||||
|
||||
var newTicker = time.NewTicker
|
||||
|
||||
// VerifyToken verifies the provided JWT token
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
t.logger.Debugf("Verifying token")
|
||||
@@ -182,11 +180,22 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
// Generate default session encryption key if not provided
|
||||
if config.SessionEncryptionKey == "" {
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session encryption key: %w", err)
|
||||
// Generate a fixed key for Traefik Hub testing
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
logger := NewLogger(config.LogLevel)
|
||||
|
||||
// Ensure key meets minimum length requirement
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
logger.Infof("Session encryption key is too short; using default key for analyzer")
|
||||
} else {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
config.SessionEncryptionKey = fmt.Sprintf("%x", key) // Convert to hex string
|
||||
}
|
||||
|
||||
// Setup HTTP client
|
||||
@@ -194,19 +203,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
@@ -244,14 +253,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
httpClient: httpClient,
|
||||
logger: NewLogger(config.LogLevel),
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
}
|
||||
// Assign the initialized logger
|
||||
t.logger = logger
|
||||
|
||||
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
@@ -274,17 +284,17 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
// initializeMetadata discovers and initializes the provider metadata
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Debug("Starting provider metadata discovery")
|
||||
|
||||
|
||||
// Keep retrying until successful
|
||||
backoff := time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
for {
|
||||
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
|
||||
|
||||
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
|
||||
time.Sleep(backoff)
|
||||
|
||||
|
||||
// Exponential backoff with max
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
@@ -292,7 +302,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
if metadata != nil {
|
||||
t.logger.Debug("Successfully initialized provider metadata")
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
@@ -301,12 +311,12 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
|
||||
|
||||
// Only close channel on success
|
||||
close(t.initComplete)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
t.logger.Error("Received nil metadata, retrying")
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
@@ -403,7 +413,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
|
||||
// Obtain a new session and clear any residual session cookies
|
||||
session, _ = t.sessionManager.GetSession(req)
|
||||
session.Clear(req, rw)
|
||||
|
||||
// Build redirect URL
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Initiate authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -507,9 +528,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
|
||||
// determineScheme determines the scheme (http or https) of the request
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
if t.forceHTTPS {
|
||||
return "https"
|
||||
}
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
@@ -578,17 +596,20 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
// defaultInitiateAuthentication initiates the authentication process
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Generate CSRF token and nonce
|
||||
csrfToken := uuid.New().String()
|
||||
csrfToken := uuid.NewString()
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session values
|
||||
// Clear any existing session data to avoid stale state causing redirect loops
|
||||
session.Clear(req, rw)
|
||||
|
||||
// Set new session values
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
session.SetIncomingPath(req.URL.Path)
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
@@ -597,7 +618,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
return
|
||||
}
|
||||
|
||||
// Build and redirect to auth URL
|
||||
// Build and redirect to authentication URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
@@ -618,12 +639,25 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
if len(t.scopes) > 0 {
|
||||
params.Set("scope", strings.Join(t.scopes, " "))
|
||||
}
|
||||
|
||||
// Ensure authURL is absolute
|
||||
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
|
||||
// Extract issuer base URL
|
||||
issuerURL, err := url.Parse(t.issuerURL)
|
||||
if err == nil {
|
||||
return fmt.Sprintf("%s://%s%s?%s",
|
||||
issuerURL.Scheme,
|
||||
issuerURL.Host,
|
||||
t.authURL,
|
||||
params.Encode())
|
||||
}
|
||||
}
|
||||
return t.authURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts the token cleanup goroutine
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := newTicker(1 * time.Minute)
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
t.logger.Debug("Cleaning up token cache")
|
||||
|
||||
+129
-10
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -89,7 +90,7 @@ func (ts *TestSuite) Setup() {
|
||||
}
|
||||
|
||||
logger := NewLogger("info")
|
||||
ts.sessionManager = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
|
||||
// Common TraefikOidc instance
|
||||
ts.tOidc = &TraefikOidc{
|
||||
@@ -619,7 +620,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
|
||||
// Create a new instance for each test to avoid state carryover
|
||||
tOidc := &TraefikOidc{
|
||||
@@ -924,7 +925,7 @@ func TestHandleLogout(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
tOidc := &TraefikOidc{
|
||||
revocationURL: mockRevocationServer.URL,
|
||||
endSessionURL: tc.endSessionURL,
|
||||
@@ -1213,7 +1214,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
sessionManager: sessionManager,
|
||||
@@ -1370,23 +1371,23 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
|
||||
// Create base config
|
||||
config := &Config{
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
}
|
||||
|
||||
// Create multiple middleware instances
|
||||
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
|
||||
var middlewares []*TraefikOidc
|
||||
|
||||
|
||||
for _, route := range routes {
|
||||
config.CallbackURL = route + "/callback"
|
||||
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}), config, "test")
|
||||
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
|
||||
}
|
||||
@@ -1646,6 +1647,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
// Helper function to compare string slices
|
||||
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
@@ -1657,3 +1659,120 @@ func stringSliceEqual(a, b []string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
|
||||
func TestBuildAuthURL(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authURL string
|
||||
issuerURL string
|
||||
redirectURL string
|
||||
state string
|
||||
nonce string
|
||||
expectedPrefix string
|
||||
}{
|
||||
{
|
||||
name: "Absolute Auth URL",
|
||||
authURL: "https://auth.example.com/oauth/authorize",
|
||||
issuerURL: "https://auth.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||
},
|
||||
{
|
||||
name: "Relative Auth URL",
|
||||
authURL: "/oidc/auth",
|
||||
issuerURL: "https://logto.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
||||
},
|
||||
{
|
||||
name: "Relative Auth URL with Different Issuer",
|
||||
authURL: "/sign-in",
|
||||
issuerURL: "https://auth.example.com:8443",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Configure the test instance
|
||||
tOidc := ts.tOidc
|
||||
tOidc.authURL = tc.authURL
|
||||
tOidc.issuerURL = tc.issuerURL
|
||||
|
||||
// Call buildAuthURL
|
||||
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
|
||||
|
||||
// Verify the URL starts with the expected prefix
|
||||
if !strings.HasPrefix(result, tc.expectedPrefix) {
|
||||
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
|
||||
}
|
||||
|
||||
// Parse the resulting URL to verify query parameters
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse resulting URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
expectedParams := map[string]string{
|
||||
"client_id": tOidc.clientID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": tc.redirectURL,
|
||||
"state": tc.state,
|
||||
"nonce": tc.nonce,
|
||||
}
|
||||
|
||||
for key, expected := range expectedParams {
|
||||
if got := query.Get(key); got != expected {
|
||||
t.Errorf("Expected %s=%q, got %q", key, expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify scopes are present and correct
|
||||
if len(tOidc.scopes) > 0 {
|
||||
expectedScopes := strings.Join(tOidc.scopes, " ")
|
||||
if got := query.Get("scope"); got != expectedScopes {
|
||||
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
|
||||
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create a request with query parameters
|
||||
req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Get session
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Call defaultInitiateAuthentication
|
||||
redirectURL := "http://example.com/callback"
|
||||
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
|
||||
// Verify that the incoming path includes query parameters
|
||||
incomingPath := session.GetIncomingPath()
|
||||
expectedPath := "/protected/resource?param1=value1¶m2=value2"
|
||||
if incomingPath != expectedPath {
|
||||
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
+10
@@ -0,0 +1,10 @@
|
||||
version: 1
|
||||
force:
|
||||
existing: true
|
||||
wording:
|
||||
patch:
|
||||
- patch-release
|
||||
minor:
|
||||
- minor-release
|
||||
major:
|
||||
- breaking
|
||||
+119
-115
@@ -16,21 +16,22 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateSecureRandomString creates a cryptographically secure random string of specified length
|
||||
func generateSecureRandomString(length int) string {
|
||||
// generateSecureRandomString creates a cryptographically secure random string of specified length.
|
||||
// It returns the generated string or an error if random generation fails.
|
||||
func generateSecureRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
panic("failed to generate random string")
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// Cookie names and configuration constants used for session management
|
||||
var (
|
||||
// Using random prefixes to make cookie names less predictable
|
||||
mainCookieName = "_oidc_m_" + generateSecureRandomString(8)
|
||||
accessTokenCookie = "_oidc_a_" + generateSecureRandomString(8)
|
||||
refreshTokenCookie = "_oidc_r_" + generateSecureRandomString(8)
|
||||
const (
|
||||
// Using fixed prefixes for consistent cookie naming across restarts
|
||||
mainCookieName = "_oidc_raczylo_m"
|
||||
accessTokenCookie = "_oidc_raczylo_a"
|
||||
refreshTokenCookie = "_oidc_raczylo_r"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -55,7 +56,7 @@ const (
|
||||
minEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// compressToken compresses a token using gzip and base64 encodes it
|
||||
// compressToken compresses a token using gzip and base64 encodes it.
|
||||
func compressToken(token string) string {
|
||||
var b bytes.Buffer
|
||||
gz := gzip.NewWriter(&b)
|
||||
@@ -68,41 +69,41 @@ func compressToken(token string) string {
|
||||
return base64.StdEncoding.EncodeToString(b.Bytes())
|
||||
}
|
||||
|
||||
// decompressToken decompresses a base64 encoded gzipped token
|
||||
// decompressToken decompresses a base64 encoded gzipped token.
|
||||
func decompressToken(compressed string) string {
|
||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||
if err != nil {
|
||||
return compressed // return as-is if not base64
|
||||
}
|
||||
|
||||
|
||||
gz, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return compressed
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
|
||||
decompressed, err := io.ReadAll(gz)
|
||||
if err != nil {
|
||||
return compressed
|
||||
}
|
||||
|
||||
|
||||
return string(decompressed)
|
||||
}
|
||||
|
||||
// SessionManager handles the management of multiple session cookies for OIDC authentication.
|
||||
// It provides functionality for storing and retrieving authentication state, tokens,
|
||||
// and other session-related data across multiple cookies to handle large tokens.
|
||||
// and other session-related data across multiple cookies.
|
||||
type SessionManager struct {
|
||||
// store is the underlying session store for cookie management
|
||||
// store is the underlying session store for cookie management.
|
||||
store sessions.Store
|
||||
|
||||
// forceHTTPS enforces secure cookie attributes regardless of request scheme
|
||||
// forceHTTPS enforces secure cookie attributes regardless of request scheme.
|
||||
forceHTTPS bool
|
||||
|
||||
// logger provides structured logging capabilities
|
||||
// logger provides structured logging capabilities.
|
||||
logger *Logger
|
||||
|
||||
// sessionPool is a sync.Pool for reusing SessionData objects
|
||||
// sessionPool is a sync.Pool for reusing SessionData objects.
|
||||
sessionPool sync.Pool
|
||||
}
|
||||
|
||||
@@ -111,11 +112,12 @@ type SessionManager struct {
|
||||
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
|
||||
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
|
||||
// - logger: Logger instance for recording session-related events
|
||||
// The manager handles session creation, storage, and cookie security settings.
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
|
||||
// Validate encryption key length
|
||||
//
|
||||
// Returns an error if the encryption key does not meet minimum length requirements.
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) {
|
||||
// Validate encryption key length.
|
||||
if len(encryptionKey) < minEncryptionKeyLength {
|
||||
panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength))
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
|
||||
sm := &SessionManager{
|
||||
@@ -124,21 +126,22 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Initialize session pool
|
||||
// Initialize session pool.
|
||||
sm.sessionPool.New = func() interface{} {
|
||||
return &SessionData{
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
return sm
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// getSessionOptions returns secure session options configured for the current request.
|
||||
// Parameters:
|
||||
// - isSecure: Whether the current request is using HTTPS
|
||||
// - isSecure: Whether the current request is using HTTPS.
|
||||
//
|
||||
// The options ensure cookies are:
|
||||
// - HTTP-only (not accessible via JavaScript)
|
||||
// - Secure when using HTTPS or when forceHTTPS is enabled
|
||||
@@ -159,7 +162,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
// and combines them into a single SessionData structure for easy access.
|
||||
// Returns an error if any session component cannot be loaded.
|
||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
// Get session from pool
|
||||
// Get session from pool.
|
||||
sessionData := sm.sessionPool.Get().(*SessionData)
|
||||
sessionData.request = r
|
||||
|
||||
@@ -170,11 +173,10 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
return nil, fmt.Errorf("failed to get main session: %w", err)
|
||||
}
|
||||
|
||||
// Check for absolute session timeout
|
||||
// Check for absolute session timeout.
|
||||
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
||||
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
|
||||
sessionData.Clear(r, nil) // Clear expired session
|
||||
sm.sessionPool.Put(sessionData)
|
||||
sessionData.Clear(r, nil)
|
||||
return nil, fmt.Errorf("session expired")
|
||||
}
|
||||
}
|
||||
@@ -191,7 +193,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
|
||||
}
|
||||
|
||||
// Clear and reuse chunk maps
|
||||
// Clear and reuse chunk maps.
|
||||
for k := range sessionData.accessTokenChunks {
|
||||
delete(sessionData.accessTokenChunks, k)
|
||||
}
|
||||
@@ -199,7 +201,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
delete(sessionData.refreshTokenChunks, k)
|
||||
}
|
||||
|
||||
// Retrieve chunked token sessions
|
||||
// Retrieve chunked token sessions.
|
||||
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
|
||||
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
||||
|
||||
@@ -216,7 +218,6 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
session, err := sm.store.Get(r, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
// No more sessions
|
||||
break
|
||||
}
|
||||
chunks[i] = session
|
||||
@@ -228,27 +229,27 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
|
||||
// and potentially large access and refresh tokens that may need to be
|
||||
// split across multiple cookies due to browser size limitations.
|
||||
type SessionData struct {
|
||||
// manager is the SessionManager that created this SessionData
|
||||
// manager is the SessionManager that created this SessionData.
|
||||
manager *SessionManager
|
||||
|
||||
// request is the current HTTP request associated with this session
|
||||
// request is the current HTTP request associated with this session.
|
||||
request *http.Request
|
||||
|
||||
// mainSession stores authentication state and basic user info
|
||||
// mainSession stores authentication state and basic user info.
|
||||
mainSession *sessions.Session
|
||||
|
||||
// accessSession stores the primary access token cookie
|
||||
// accessSession stores the primary access token cookie.
|
||||
accessSession *sessions.Session
|
||||
|
||||
// refreshSession stores the primary refresh token cookie
|
||||
// refreshSession stores the primary refresh token cookie.
|
||||
refreshSession *sessions.Session
|
||||
|
||||
// accessTokenChunks stores additional chunks of the access token
|
||||
// when it exceeds the maximum cookie size
|
||||
// when it exceeds the maximum cookie size.
|
||||
accessTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshTokenChunks stores additional chunks of the refresh token
|
||||
// when it exceeds the maximum cookie size
|
||||
// when it exceeds the maximum cookie size.
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
}
|
||||
|
||||
@@ -259,28 +260,28 @@ type SessionData struct {
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
// Set options for all sessions
|
||||
// Set options for all sessions.
|
||||
options := sd.manager.getSessionOptions(isSecure)
|
||||
sd.mainSession.Options = options
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
|
||||
// Save main session
|
||||
// Save main session.
|
||||
if err := sd.mainSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save main session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token session
|
||||
// Save access token session.
|
||||
if err := sd.accessSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token session: %w", err)
|
||||
}
|
||||
|
||||
// Save refresh token session
|
||||
// Save refresh token session.
|
||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token chunks
|
||||
// Save access token chunks.
|
||||
for _, session := range sd.accessTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
@@ -288,7 +289,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Save refresh token chunks
|
||||
// Save refresh token chunks.
|
||||
for _, session := range sd.refreshTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
@@ -300,10 +301,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||
// This is typically used during logout to ensure all session data is properly cleaned up.
|
||||
// It handles both main session data and any token chunks that may exist.
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions
|
||||
// Clear and expire all sessions.
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
sd.accessSession.Options.MaxAge = -1
|
||||
sd.refreshSession.Options.MaxAge = -1
|
||||
@@ -318,7 +317,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
|
||||
// Clear chunk sessions
|
||||
// Clear chunk sessions.
|
||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||
|
||||
@@ -326,16 +325,14 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
if w != nil {
|
||||
err = sd.Save(r, w)
|
||||
}
|
||||
|
||||
// Return session to pool
|
||||
|
||||
// Return session to pool.
|
||||
sd.manager.sessionPool.Put(sd)
|
||||
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// clearTokenChunks removes all session chunks for a given token type.
|
||||
// It expires the cookies and removes all stored values to ensure
|
||||
// no token data remains after logout or token invalidation.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
@@ -346,15 +343,13 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
|
||||
}
|
||||
|
||||
// GetAuthenticated returns whether the current session is authenticated.
|
||||
// Returns true if the user has successfully completed OIDC authentication
|
||||
// and the session hasn't expired, false otherwise.
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
if !auth {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check session expiration
|
||||
// Check session expiration.
|
||||
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
|
||||
if !ok {
|
||||
return false
|
||||
@@ -363,21 +358,21 @@ func (sd *SessionData) GetAuthenticated() bool {
|
||||
}
|
||||
|
||||
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
||||
// This should be called after successful OIDC authentication or during logout.
|
||||
// Session ID rotation helps prevent session fixation attacks.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) {
|
||||
// Returns an error if generating a new session ID fails.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
if value {
|
||||
// Generate new session ID and set creation time
|
||||
sd.mainSession.ID = generateSecureRandomString(32)
|
||||
id, err := generateSecureRandomString(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate secure session id: %w", err)
|
||||
}
|
||||
sd.mainSession.ID = id
|
||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||
}
|
||||
sd.mainSession.Values["authenticated"] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the complete access token from the session.
|
||||
// If the token was split into chunks due to size limitations, it will
|
||||
// automatically reassemble the complete token from all chunks.
|
||||
// Returns an empty string if no token is found.
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -388,7 +383,7 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
// Reassemble token from chunks.
|
||||
if len(sd.accessTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -412,23 +407,23 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
}
|
||||
|
||||
// SetAccessToken stores the access token in the session.
|
||||
// If the token exceeds maxCookieSize, it is automatically compressed and split into
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
sd.accessSession.Values["token"] = compressed
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks
|
||||
// Split compressed token into chunks.
|
||||
sd.accessSession.Values["token"] = ""
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
@@ -442,9 +437,6 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
}
|
||||
|
||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||
// If the token was split into chunks due to size limitations, it will
|
||||
// automatically reassemble the complete token from all chunks.
|
||||
// Returns an empty string if no token is found.
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -455,7 +447,7 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
// Reassemble token from chunks.
|
||||
if len(sd.refreshTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -479,23 +471,23 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
}
|
||||
|
||||
// SetRefreshToken stores the refresh token in the session.
|
||||
// If the token exceeds maxCookieSize, it is automatically compressed and split into
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
sd.refreshSession.Values["token"] = compressed
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks
|
||||
// Split compressed token into chunks.
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
@@ -508,12 +500,43 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireAccessTokenChunks expires any existing access token chunk cookies.
|
||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if w != nil {
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
|
||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if w != nil {
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size.
|
||||
// This is used internally to handle large tokens that exceed cookie size limits.
|
||||
// Parameters:
|
||||
// - s: The string to split
|
||||
// - chunkSize: Maximum size of each chunk
|
||||
// Returns an array of string chunks, each no larger than chunkSize.
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
@@ -529,64 +552,45 @@ func splitIntoChunks(s string, chunkSize int) []string {
|
||||
}
|
||||
|
||||
// GetCSRF retrieves the CSRF token from the session.
|
||||
// This token is used to prevent cross-site request forgery attacks
|
||||
// by ensuring requests originate from the authenticated user.
|
||||
// Returns an empty string if no CSRF token is found.
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF stores a new CSRF token in the session.
|
||||
// This should be called when initiating authentication to generate
|
||||
// a new token for the authentication flow.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce retrieves the nonce value from the session.
|
||||
// The nonce is used to prevent replay attacks in the OIDC flow
|
||||
// by ensuring the token received matches the authentication request.
|
||||
// Returns an empty string if no nonce is found.
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce stores a new nonce value in the session.
|
||||
// This should be called when initiating authentication to generate
|
||||
// a new nonce for the OIDC authentication flow.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
// The email is typically extracted from the OIDC ID token claims.
|
||||
// Returns an empty string if no email is found.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail stores the user's email address in the session.
|
||||
// This should be called after successful authentication when
|
||||
// processing the OIDC ID token claims.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath retrieves the original request path that triggered
|
||||
// the authentication flow. This is used to redirect the user back
|
||||
// to their intended destination after successful authentication.
|
||||
// Returns an empty string if no path was stored.
|
||||
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath stores the original request path that triggered
|
||||
// the authentication flow. This should be called before redirecting
|
||||
// to the OIDC provider to remember where to send the user afterward.
|
||||
// SetIncomingPath stores the original request path that triggered the authentication flow.
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
|
||||
+236
-122
@@ -5,14 +5,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize random seed
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// generateRandomString creates a random string of specified length
|
||||
func generateRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
@@ -56,9 +50,9 @@ func TestTokenCompression(t *testing.T) {
|
||||
if len(tt.token) > 100 {
|
||||
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
|
||||
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
|
||||
|
||||
|
||||
if compressionRatio > 1.1 { // Allow up to 10% size increase
|
||||
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
|
||||
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
|
||||
len(tt.token), len(compressed), compressionRatio)
|
||||
}
|
||||
}
|
||||
@@ -77,6 +71,120 @@ func TestTokenCompression(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestSessionManager tests the SessionManager functionality
|
||||
|
||||
func TestCookiePrefix(t *testing.T) {
|
||||
// Create a session and verify cookie names
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set some data to ensure cookies are created
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Set new tokens
|
||||
session.SetAccessToken("test_token")
|
||||
session.SetRefreshToken("test_refresh_token")
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Check cookie prefixes
|
||||
cookies := rr.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
|
||||
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRefreshCleanup(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set a large token that will be split into chunks
|
||||
largeToken := strings.Repeat("x", 5000)
|
||||
session.SetAccessToken(largeToken)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get initial cookies
|
||||
initialCookies := rr.Result().Cookies()
|
||||
|
||||
// Create a new request with the initial cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range initialCookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
newRr := httptest.NewRecorder()
|
||||
|
||||
// Get session with cookies and set a new token
|
||||
newSession, err := sm.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
// Create a response recorder for expired cookies
|
||||
expiredRr := httptest.NewRecorder()
|
||||
|
||||
// Expire old chunk cookies
|
||||
newSession.expireAccessTokenChunks(expiredRr)
|
||||
|
||||
// Set a smaller token that won't need chunks
|
||||
newSession.SetAccessToken("small_token")
|
||||
|
||||
// Save session with new token
|
||||
if err := newSession.Save(newReq, newRr); err != nil {
|
||||
t.Fatalf("Failed to save new session: %v", err)
|
||||
}
|
||||
|
||||
// Check cookies in response where old cookies are expired
|
||||
intermediateResponse := expiredRr.Result()
|
||||
intermediateCount := 0
|
||||
chunkCount := 0
|
||||
expiredCount := 0
|
||||
|
||||
for _, cookie := range intermediateResponse.Cookies() {
|
||||
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
|
||||
chunkCount++
|
||||
if cookie.MaxAge < 0 {
|
||||
expiredCount++
|
||||
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
} else if cookie.MaxAge >= 0 {
|
||||
intermediateCount++
|
||||
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
}
|
||||
|
||||
// All chunk cookies should be expired
|
||||
if chunkCount > 0 && chunkCount != expiredCount {
|
||||
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
|
||||
}
|
||||
|
||||
// Should have fewer active cookies after setting smaller token
|
||||
if intermediateCount >= len(initialCookies) {
|
||||
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
@@ -84,154 +192,160 @@ func TestSessionManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authenticated bool
|
||||
email string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
email string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
expectedCookieCount int
|
||||
wantCompressed bool // Whether tokens should be compressed
|
||||
wantCompressed bool // Whether tokens should be compressed
|
||||
}{
|
||||
{
|
||||
name: "Short tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: "shortaccesstoken",
|
||||
refreshToken: "shortrefreshtoken",
|
||||
email: "test@example.com",
|
||||
accessToken: "shortaccesstoken",
|
||||
refreshToken: "shortrefreshtoken",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
wantCompressed: true,
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "Long tokens exceeding 4096 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 5000),
|
||||
refreshToken: strings.Repeat("y", 6000),
|
||||
name: "Long tokens exceeding 4096 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 5000),
|
||||
refreshToken: strings.Repeat("y", 6000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
|
||||
wantCompressed: true,
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "REALLY long tokens, exceeding 25000 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 25000),
|
||||
refreshToken: strings.Repeat("y", 25000),
|
||||
name: "REALLY long tokens, exceeding 25000 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 25000),
|
||||
refreshToken: strings.Repeat("y", 25000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
|
||||
wantCompressed: true,
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated session",
|
||||
authenticated: false,
|
||||
email: "",
|
||||
accessToken: "",
|
||||
refreshToken: "",
|
||||
email: "",
|
||||
accessToken: "",
|
||||
refreshToken: "",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
wantCompressed: false,
|
||||
wantCompressed: false,
|
||||
},
|
||||
{
|
||||
name: "Random content tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: generateRandomString(5000),
|
||||
refreshToken: generateRandomString(5000),
|
||||
name: "Random content tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: generateRandomString(5000),
|
||||
refreshToken: generateRandomString(5000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
|
||||
wantCompressed: true,
|
||||
wantCompressed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set session values
|
||||
session.SetAuthenticated(tc.authenticated)
|
||||
session.SetEmail(tc.email)
|
||||
session.SetAccessToken(tc.accessToken)
|
||||
session.SetRefreshToken(tc.refreshToken)
|
||||
// Set session values
|
||||
session.SetAuthenticated(tc.authenticated)
|
||||
session.SetEmail(tc.email)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Verify cookies are set and compression is used when appropriate
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != tc.expectedCookieCount {
|
||||
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
|
||||
}
|
||||
// Set new tokens
|
||||
session.SetAccessToken(tc.accessToken)
|
||||
session.SetRefreshToken(tc.refreshToken)
|
||||
|
||||
// Verify compression is working by checking token sizes
|
||||
for _, cookie := range cookies {
|
||||
if strings.Contains(cookie.Name, accessTokenCookie) {
|
||||
// Get original and stored sizes
|
||||
originalSize := len(tc.accessToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
// For large tokens, verify some compression occurred
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
|
||||
originalSize := len(tc.refreshToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
if compressionRatio > 0.9 {
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are set and compression is used when appropriate
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != tc.expectedCookieCount {
|
||||
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
|
||||
}
|
||||
|
||||
// Verify compression is working by checking token sizes
|
||||
for _, cookie := range cookies {
|
||||
if strings.Contains(cookie.Name, accessTokenCookie) {
|
||||
// Get original and stored sizes
|
||||
originalSize := len(tc.accessToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
// For large tokens, verify some compression occurred
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
|
||||
originalSize := len(tc.refreshToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := ts.sessionManager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
if compressionRatio > 0.9 {
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify session values
|
||||
if newSession.GetAuthenticated() != tc.authenticated {
|
||||
t.Errorf("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != tc.email {
|
||||
t.Errorf("Expected email %s, got %s", tc.email, email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != tc.accessToken {
|
||||
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
|
||||
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
|
||||
}
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Verify session pooling by checking if the session is reused
|
||||
session2, _ := ts.sessionManager.GetSession(newReq)
|
||||
if session2 == newSession {
|
||||
t.Error("Session not properly pooled")
|
||||
}
|
||||
})
|
||||
// Get the session again and verify values
|
||||
newSession, err := ts.sessionManager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
// Verify session values
|
||||
if newSession.GetAuthenticated() != tc.authenticated {
|
||||
t.Errorf("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != tc.email {
|
||||
t.Errorf("Expected email %s, got %s", tc.email, email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != tc.accessToken {
|
||||
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
|
||||
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
|
||||
}
|
||||
|
||||
// Verify session pooling by checking if the session is reused
|
||||
session2, _ := ts.sessionManager.GetSession(newReq)
|
||||
if session2 == newSession {
|
||||
t.Error("Session not properly pooled")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,12 +356,12 @@ func calculateExpectedCookieCount(accessToken, refreshToken string) int {
|
||||
calculateChunks := func(token string) int {
|
||||
// Compress token (matching the actual implementation)
|
||||
compressed := compressToken(token)
|
||||
|
||||
|
||||
// If compressed token fits in one cookie, no additional chunks needed
|
||||
if len(compressed) <= maxCookieSize {
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
// Calculate chunks needed for compressed token
|
||||
return len(splitIntoChunks(compressed, maxCookieSize))
|
||||
}
|
||||
|
||||
+2
-1
@@ -221,6 +221,7 @@ type Logger struct {
|
||||
// - "debug": Outputs all messages (debug, info, error)
|
||||
// - "info": Outputs info and error messages
|
||||
// - "error": Outputs only error messages
|
||||
//
|
||||
// Error messages are always written to stderr, while info and debug
|
||||
// messages are written to stdout when enabled.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
@@ -229,7 +230,7 @@ func NewLogger(logLevel string) *Logger {
|
||||
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
|
||||
logError.SetOutput(os.Stderr)
|
||||
|
||||
|
||||
if logLevel == "debug" || logLevel == "info" {
|
||||
logInfo.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user