Compare commits

...

35 Commits

Author SHA1 Message Date
lukaszraczylo df051e0cfb Improve expiration logic. 2025-02-19 20:33:26 +00:00
lukaszraczylo 5d5ce8ae5e Additional tests for the blacklists 2025-02-19 12:08:37 +00:00
lukaszraczylo d194cd778a gofmt the updated files. 2025-02-19 11:56:31 +00:00
lukaszraczylo 803a1e5e21 Clean the caches properly to avoid memleak 2025-02-19 11:55:32 +00:00
lukaszraczylo 3ad8fb4518 Optimise cache cleanup run to avoid the GC which causes CPU usage to go higher than necessary. 2025-02-10 09:30:56 +00:00
lukaszraczylo 9402f1bca5 Token blacklist, cache and metadata improvements
TokenBlacklist Improvements:
Fixed size limit enforcement to properly maintain max size of 1000 tokens
Improved eviction strategy to remove expired tokens first before removing oldest
Added proper cleanup of tokens during Add operation to prevent size overflow
Fixed oldest token eviction logic to ensure correct token removal
Added proper locking mechanisms to prevent race conditions
Cache Improvements:
Fixed cleanup mechanism to only remove truly expired items
Improved eviction strategy in LRU cache to prioritize expired items
Added smarter eviction in evictOldest to scan for expired items first
Fixed aggressive cleanup that was removing valid items
Maintained proper LRU ordering while handling evictions
MetadataCache:
Verified proper implementation of metadata caching with hourly refresh
Confirmed proper handling of cache extension on fetch failures
Validated thread-safe operations with proper RWMutex usage
2025-02-09 23:53:05 +00:00
lukaszraczylo e6205b3a48 Add metadata caching capability to avoid unnecesary API calls 2025-02-09 23:37:50 +00:00
lukaszraczylo fdb8e3233e Testing (could be unstable) additional headers.
This adds additional headers to control the access origin and control allow headers.
2025-02-06 23:46:08 +00:00
lukaszraczylo 33c71fd6fe Enhance test suite. 2025-02-06 23:38:22 +00:00
lukaszraczylo 241cb1c209 Deal with the memory growth issue.
* TokenBlacklist limit is set to 1000
* Increased token cleanup frequency
2025-02-06 23:34:05 +00:00
lukaszraczylo 09daa1025c Follow multiple redirects during the OIDC flow. 2025-02-06 23:31:13 +00:00
lukaszraczylo c09e7a9228 Add additional test cases to cover it. 2025-02-06 21:50:35 +00:00
lukaszraczylo e5da5d4fe9 Fix redirection to the provider when session expires 2025-02-06 21:48:56 +00:00
lukaszraczylo 31db701dda Trigger build and release. 2025-02-05 19:04:44 +00:00
lukaszraczylo 16481afd36 Add todo: Improve test coverage. 2025-02-01 12:20:01 +00:00
lukaszraczylo 751933ffa0 Multiple improvements.
* Add todo list.

* fixup! Add todo list.

* fixup! fixup! Add todo list.

* fixup! fixup! fixup! Add todo list.

* Improve the session handling and cache.

* Fix an issue where expired session can cause infinite redirect loop

* fixup! Fix an issue where expired session can cause infinite redirect loop

* Add semver setup for automatic releases.

* fixup! Add semver setup for automatic releases.

* fixup! fixup! Add semver setup for automatic releases.

* fixup! fixup! fixup! Add semver setup for automatic releases.
2025-02-01 12:16:50 +00:00
lukaszraczylo e74153b107 Merge pull request #28 from lukaszraczylo/additional-improvements
additional improvements
2025-01-21 19:34:01 +00:00
lukaszraczylo 025107fe3e Well, release it finally. 2025-01-21 19:31:51 +00:00
lukaszraczylo dfb9c0771e Fix session handling and the redirection to the original URL incl. get parameters 2025-01-21 17:49:54 +00:00
lukaszraczylo 1107df40e7 Merge pull request #26 from lukaszraczylo/additional-improvements
Cleanup old cookies properly.
2025-01-21 17:34:16 +00:00
lukaszraczylo bf294569eb Cleanup old cookies properly. 2025-01-21 17:09:48 +00:00
lukaszraczylo 482c346840 Merge pull request #24 from lukaszraczylo/additional-improvements
additional improvements
2025-01-21 00:19:49 +00:00
lukaszraczylo a462e44896 Fix remaining issues with session handling and add additional tests. 2025-01-21 00:18:10 +00:00
lukaszraczylo 5eff0dc866 Clean up old cookies. 2025-01-21 00:03:13 +00:00
lukaszraczylo dfc534a400 Merge pull request #23 from lukaszraczylo/additional-improvements
Add useful defaults allowing traefik hub to pass.
2025-01-20 23:57:51 +00:00
lukaszraczylo 061c12d0a3 Add useful defaults allowing traefik hub to pass. 2025-01-20 23:55:58 +00:00
lukaszraczylo 4c4fff3613 Merge pull request #22 from lukaszraczylo/additional-improvements
Quite important fix
2025-01-20 23:50:35 +00:00
lukaszraczylo 0dcb44c187 Quite important fix
When user session expires, reauthentication fails as CSRF token disappears.
This commit fixes the issue by initiating new authentication flow.
2025-01-20 23:48:31 +00:00
lukaszraczylo cbe773d96a Merge pull request #20 from lukaszraczylo/additional-improvements
Provide default session encryption key if not specified.
2025-01-18 11:00:07 +00:00
lukaszraczylo 40254888d7 Provide default session encryption key if not specified. 2025-01-18 10:54:30 +00:00
lukaszraczylo ef41870c81 Merge pull request #18 from lukaszraczylo/additional-improvements
additional improvements
2025-01-18 02:28:29 +00:00
lukaszraczylo 081c32925a fixup! Security improvements have been implemented and verified across four main areas: 2025-01-14 11:47:49 +00:00
lukaszraczylo 17dea67229 Security improvements have been implemented and verified across four main areas:
JWT Token Security:
Protected against algorithm switching attacks by validating and whitelisting algorithms (RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512)
Added 2-minute clock skew tolerance for time-based validations
Added "not before" (nbf) claim validation with clock skew tolerance
Required JWT ID (jti) claim to prevent replay attacks
Added strict algorithm validation to prevent downgrade attacks
Session Management Security:
Implemented cryptographically secure random cookie names to prevent targeting
Added automatic session ID rotation after successful login to prevent session fixation
Enforced 24-hour absolute session timeout
Added strict encryption key length validation (minimum 32 bytes)
Added comprehensive session validation including timeout checks
Implemented session pooling for secure resource management
Added secure session cleanup on expiration
Configuration and URL Security:
Enforced HTTPS for all provider URLs and external endpoints
Added minimum rate limit (10 req/sec) to prevent DOS attacks
Added strict validation for excluded URLs:
Must start with "/"
No path traversal (..)
No wildcards (*)
Made ForceHTTPS true by default for secure cookies
Added validation for secure redirect URIs
Added validation for all OIDC endpoints (must be HTTPS)
Added secure defaults in configuration
Test Coverage:
Added comprehensive test cases verifying all security validations
Added test cases for HTTPS enforcement on all endpoints
Added test cases for minimum rate limits
Added test cases for secure session management
Added test cases for token validation with clock skew
Added test cases for secure configuration defaults
All security improvements have been verified through passing test cases, protecting against:

Session fixation attacks
Token replay attacks
Algorithm switching attacks
Path traversal attacks
Session hijacking
Timing attacks
DOS attacks
Man-in-the-middle attacks through enforced HTTPS
2025-01-14 11:33:48 +00:00
lukaszraczylo 8512ad6d68 Revert "Update vendored modules."
This reverts commit 5aa838c669.
2025-01-07 13:19:41 +00:00
lukaszraczylo 5aa838c669 Update vendored modules. 2025-01-06 13:10:13 +00:00
18 changed files with 1723 additions and 525 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ testData:
- raczylo.com
allowedRolesAndGroups:
- guest-endpoints
sessionEncryptionKey: potato-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
forceHTTPS: false
logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
+2
View File
@@ -19,6 +19,8 @@ Middleware currently supports following scenarios:
#### How to configure...
* `sessionEncryptionKey` should be at least 32 bytes long.
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
+5
View File
@@ -0,0 +1,5 @@
### TODO / wishlist
- [] Improve test coverage
- [x] Improve caching mechanism
- [x] Add automatic release and semver generation
+110
View File
@@ -0,0 +1,110 @@
package traefikoidc
import (
"sync"
"time"
)
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
type TokenBlacklist struct {
tokens map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new token blacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
tokens: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
b.mutex.Lock()
defer b.mutex.Unlock()
// Clean up expired tokens if we're at capacity
if len(b.tokens) >= 1000 {
now := time.Now()
futureThreshold := now.Add(time.Minute)
for t, exp := range b.tokens {
if now.After(exp) || futureThreshold.After(exp) {
delete(b.tokens, t)
}
}
// If still at capacity, remove oldest token
if len(b.tokens) >= 1000 {
var oldestToken string
var oldestTime time.Time
first := true
for t, exp := range b.tokens {
if first || exp.Before(oldestTime) {
oldestToken = t
oldestTime = exp
first = false
}
}
if oldestToken != "" {
delete(b.tokens, oldestToken)
}
}
}
b.tokens[token] = expiry
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
b.mutex.RLock()
defer b.mutex.RUnlock()
expiry, exists := b.tokens[token]
if !exists {
return false
}
// If token is expired, remove it and return false
if time.Now().After(expiry) {
// Switch to write lock to remove expired token
b.mutex.RUnlock()
b.mutex.Lock()
delete(b.tokens, token)
b.mutex.Unlock()
b.mutex.RLock()
return false
}
return true
}
// Cleanup removes expired tokens from the blacklist.
// Also removes tokens that will expire within the next minute to prevent edge cases.
func (b *TokenBlacklist) Cleanup() {
b.mutex.Lock()
defer b.mutex.Unlock()
now := time.Now()
futureThreshold := now.Add(time.Minute)
for token, expiry := range b.tokens {
// Remove tokens that are expired or will expire soon
if now.After(expiry) || futureThreshold.After(expiry) {
delete(b.tokens, token)
}
}
}
// Remove removes a token from the blacklist regardless of its expiration.
func (b *TokenBlacklist) Remove(token string) {
b.mutex.Lock()
defer b.mutex.Unlock()
delete(b.tokens, token)
}
// Count returns the current number of tokens in the blacklist.
func (b *TokenBlacklist) Count() int {
b.mutex.RLock()
defer b.mutex.RUnlock()
return len(b.tokens)
}
+74
View File
@@ -0,0 +1,74 @@
package traefikoidc
import (
"testing"
"time"
)
func TestTokenBlacklist_Add(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
}
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
if blacklist.IsBlacklisted("nonExistentToken") {
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
}
}
func TestTokenBlacklist_Cleanup(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(-time.Hour) // Expired token
blacklist.Add(token, expiry)
blacklist.Cleanup()
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
}
}
func TestTokenBlacklist_Remove(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
blacklist.Remove(token)
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be removed, but it was not")
}
}
func TestTokenBlacklist_Count(t *testing.T) {
blacklist := NewTokenBlacklist()
token1 := "token1"
token2 := "token2"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token1, expiry)
blacklist.Add(token2, expiry)
if blacklist.Count() != 2 {
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
}
}
+97 -97
View File
@@ -1,169 +1,169 @@
package traefikoidc
import (
"container/list"
"sync"
"time"
)
// CacheItem represents an item stored in the cache with its associated metadata.
type CacheItem struct {
// Value is the cached data of any type
// Value is the cached data of any type.
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
// ExpiresAt is the timestamp when this item should be considered expired.
ExpiresAt time.Time
}
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache
maxSize int
// accessList maintains the order of item access for eviction
accessList []string
// lruEntry represents an entry in the LRU list.
type lruEntry struct {
key string
}
// DefaultMaxSize is the default maximum number of items in the cache
const DefaultMaxSize = 1000
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
type Cache struct {
// items stores the cached data with string keys.
items map[string]CacheItem
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
// order maintains the usage order; most recently used items are at the back.
order *list.List
// elems maps keys to their corresponding list elements for O(1) access.
elems map[string]*list.Element
// mutex protects concurrent access to the cache.
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache.
maxSize int
}
// DefaultMaxSize is the default maximum number of items in the cache.
const DefaultMaxSize = 500
// NewCache creates a new empty cache instance that is ready for use.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
maxSize: DefaultMaxSize,
accessList: make([]string, 0, DefaultMaxSize),
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
}
}
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
// Thread-safe: Uses write locking to ensure safe concurrent access.
// It moves the item to the most recently used position.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// If key exists, update it
now := time.Now()
expTime := now.Add(expiration)
// Update existing item.
if _, exists := c.items[key]; exists {
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: expTime,
}
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return
}
// If cache is full, remove oldest item
// Evict oldest item if cache is full.
if len(c.items) >= c.maxSize {
c.evictOldest()
}
// Add new item
// Add new item.
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: expTime,
}
c.accessList = append(c.accessList, key)
elem := c.order.PushBack(lruEntry{key: key})
c.elems[key] = elem
}
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
// Thread-safe: Uses read locking to ensure safe concurrent access.
// Moving the accessed item to the most recently used position.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
item, found := c.items[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
if time.Now().After(item.ExpiresAt) {
c.mutex.Lock()
c.removeItem(key)
c.mutex.Unlock()
return nil, false
}
// Update access order
c.mutex.Lock()
c.updateAccessOrder(key)
c.mutex.Unlock()
defer c.mutex.Unlock()
item, exists := c.items[key]
if !exists {
return nil, false
}
// Check for expiration.
if time.Now().After(item.ExpiresAt) {
c.removeItem(key)
return nil, false
}
// Move item to the back (most recently used).
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return item.Value, true
}
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Delete removes an item from the cache.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.removeItem(key)
}
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Cleanup removes all expired items from the cache. This should be called periodically
// to prevent memory bloat from expired entries.
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
var newAccessList []string
for _, key := range c.accessList {
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
newAccessList = append(newAccessList, key)
} else {
delete(c.items, key)
for key, item := range c.items {
// Remove items that are expired or within 10% of expiration
if now.After(item.ExpiresAt) || now.Add(time.Duration(float64(item.ExpiresAt.Sub(now))*0.1)).After(item.ExpiresAt) {
c.removeItem(key)
}
}
c.accessList = newAccessList
}
// evictOldest removes the least recently used item from the cache
// evictOldest removes the least recently used item from the cache.
func (c *Cache) evictOldest() {
if len(c.accessList) > 0 {
oldest := c.accessList[0]
c.removeItem(oldest)
now := time.Now()
elem := c.order.Front()
// First try to find an expired item from the front
for elem != nil {
entry := elem.Value.(lruEntry)
if item, exists := c.items[entry.key]; exists {
if now.After(item.ExpiresAt) {
c.removeItem(entry.key)
return
}
}
elem = elem.Next()
}
// If no expired items found, remove the oldest item
if elem = c.order.Front(); elem != nil {
entry := elem.Value.(lruEntry)
c.removeItem(entry.key)
}
}
// removeItem removes an item from both the cache and access list
// removeItem removes an item from both the cache and the LRU tracking structures.
func (c *Cache) removeItem(key string) {
delete(c.items, key)
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
}
// updateAccessOrder moves the accessed key to the end of the access list
func (c *Cache) updateAccessOrder(key string) {
for i, k := range c.accessList {
if k == key {
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
break
}
if elem, ok := c.elems[key]; ok {
c.order.Remove(elem)
delete(c.elems, key)
}
}
+27 -68
View File
@@ -8,31 +8,12 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
@@ -87,13 +68,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
data.Set("refresh_token", codeOrToken)
}
// Create a cookie jar for this request to handle redirects with cookies
jar, _ := cookiejar.New(nil)
client := &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
@@ -128,11 +124,19 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
// Clear authentication data but preserve CSRF state
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Save the cleared session state
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -278,51 +282,6 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a thread-safe list of revoked tokens.
// It stores tokens with their expiration times and automatically
// removes expired entries during cleanup operations.
type TokenBlacklist struct {
// blacklist maps token IDs to their expiration times
blacklist map[string]time.Time
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist.
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
now := time.Now()
for tokenID, expiration := range tb.blacklist {
if now.After(expiration) {
delete(tb.blacklist, tokenID)
}
}
}
// TokenCache provides a caching mechanism for validated tokens.
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
+227
View File
@@ -0,0 +1,227 @@
package traefikoidc
import (
"fmt"
"runtime"
"testing"
"time"
)
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
}
// Verify size is at max
if tb.Count() != 1000 {
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
}
// Add one more token, should trigger cleanup/eviction
tb.Add("newtoken", time.Now().Add(time.Hour))
// Size should still be at max
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
}
// Force cleanup
tb.Cleanup()
// Only valid tokens should remain
if tb.Count() != 500 {
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
}
// Verify only valid tokens remain
tb.mutex.RLock()
defer tb.mutex.RUnlock()
for token, expiry := range tb.tokens {
if time.Now().After(expiry) {
t.Errorf("Found expired token after cleanup: %s", token)
}
}
}
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
}
// Add a new token that should evict the oldest
newToken := "newest"
tb.Add(newToken, baseTime.Add(time.Hour*3))
// Verify oldest token was evicted
if tb.IsBlacklisted(oldestToken) {
t.Error("Oldest token should have been evicted")
}
// Verify newest token is present
if !tb.IsBlacklisted(newToken) {
t.Error("Newest token should be present")
}
}
func TestTokenBlacklistMemoryUsage(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy usage
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
}
// Verify size stayed within limits
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 1000
concurrency := 10
done := make(chan bool)
// Start multiple goroutines performing operations
for i := 0; i < concurrency; i++ {
go func(id int) {
for j := 0; j < iterations; j++ {
// Add tokens
token := fmt.Sprintf("token%d-%d", id, j)
tb.Add(token, time.Now().Add(time.Hour))
// Check blacklist status
tb.IsBlacklisted(token)
// Periodic cleanup
if j%100 == 0 {
tb.Cleanup()
}
}
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < concurrency; i++ {
<-done
}
// Verify size constraints were maintained
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
}
}
func TestTokenCacheMemoryUsage(t *testing.T) {
tc := NewTokenCache()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy cache usage
for i := 0; i < iterations; i++ {
claims := map[string]interface{}{
"sub": fmt.Sprintf("user%d", i),
"exp": time.Now().Add(time.Hour).Unix(),
}
// Add to cache
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
// Periodically retrieve
if i%100 == 0 {
tc.Get(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tc.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive cache memory growth: %d bytes", memoryGrowth)
}
// Verify cache size stayed within limits
if len(tc.cache.items) > tc.cache.maxSize {
t.Errorf("Cache exceeded max size: %d", len(tc.cache.items))
}
}
+14
View File
@@ -73,6 +73,7 @@ type JWKCache struct {
// maintaining consistent behavior in the token verification process.
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
Cleanup() // Add Cleanup method to the interface
}
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
@@ -81,6 +82,7 @@ type JWKCacheInterface interface {
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
//
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
@@ -110,11 +112,23 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// Cleanup removes expired JWKs from the cache.
func (c *JWKCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
if c.jwks != nil && now.After(c.expiresAt) {
c.jwks = nil
}
}
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
//
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
+135 -10
View File
@@ -38,6 +38,7 @@ type JWT struct {
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
//
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
@@ -83,13 +84,39 @@ func parseJWT(tokenString string) (*JWT, error) {
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future
// - issued at time (iat) is in the past
// - expiration time (exp) is in the future (with clock skew tolerance)
// - issued at time (iat) is in the past (with clock skew tolerance)
// - not before time (nbf) is in the past (with clock skew tolerance)
// - subject (sub) is present and not empty
// - algorithm matches expected value to prevent algorithm switching attacks
//
// Returns an error if any validation fails.
func (j *JWT) Verify(issuerURL, clientID string) error {
// Debug logging of validation parameters
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
// Debug logging of token header
fmt.Printf("Token header: %+v\n", j.Header)
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing 'alg' header")
}
// List of supported algorithms - should match those in verifySignature
supportedAlgs := map[string]bool{
"RS256": true, "RS384": true, "RS512": true,
"PS256": true, "PS384": true, "PS512": true,
"ES256": true, "ES384": true, "ES512": true,
}
if !supportedAlgs[alg] {
return fmt.Errorf("unsupported algorithm: %s", alg)
}
claims := j.Claims
// Debug logging of all claims
fmt.Printf("Token claims: %+v\n", claims)
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
@@ -122,6 +149,19 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return err
}
// Validate nbf (not before) claim if present
if nbf, ok := claims["nbf"].(float64); ok {
if err := verifyNotBefore(nbf); err != nil {
return err
}
}
// Validate jti (JWT ID) claim if present
if jti, ok := claims["jti"].(string); ok {
// Could add replay detection here if needed
_ = jti
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
@@ -136,8 +176,13 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
//
// Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Debug logging
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
tokenAudience, expectedAudience)
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
@@ -165,37 +210,112 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
//
// Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
// Debug logging
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
tokenIssuer, expectedIssuer)
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
tokenIssuer, expectedIssuer)
}
return nil
}
// Clock skew tolerance for time-based validations
const clockSkewTolerance = 2 * time.Minute
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time.
// The expiration time is compared against the current time with clock skew tolerance.
// Parameters:
// - expiration: The expiration timestamp from the token
//
// Returns an error if the token has expired.
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(clockSkewTolerance)
// Debug logging
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
expirationTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that expire exactly now
if expirationTime.Equal(now) {
return nil
}
if skewedNow.After(expirationTime) {
return fmt.Errorf("token has expired (exp: %v, now: %v)",
expirationTime.UTC(), now.UTC())
}
return nil
}
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, which could
// indicate clock skew or a malicious token.
// Ensures the token wasn't issued in the future, accounting for clock skew.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
//
// Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
issuedAtTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens issued in the same second as current time
if issuedAtTime.Equal(now) {
return nil
}
if skewedNow.Before(issuedAtTime) {
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
issuedAtTime.UTC(), now.UTC())
}
return nil
}
// verifyNotBefore validates the token's not-before time if present.
// Ensures the token is not used before its valid time period, accounting for clock skew.
// Parameters:
// - notBefore: The not-before timestamp from the token
//
// Returns an error if the token is not yet valid.
func verifyNotBefore(notBefore float64) error {
notBeforeTime := time.Unix(int64(notBefore), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
notBeforeTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that become valid exactly now
if notBeforeTime.Equal(now) {
return nil
}
if skewedNow.Before(notBeforeTime) {
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
notBeforeTime.UTC(), now.UTC())
}
return nil
}
@@ -205,12 +325,17 @@ func verifyIssuedAt(issuedAt float64) error {
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
//
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
//
// Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Debug logging
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
+150 -47
View File
@@ -12,6 +12,8 @@ import (
"strings"
"time"
"runtime"
"github.com/google/uuid"
"golang.org/x/time/rate"
)
@@ -37,6 +39,7 @@ type TraefikOidc struct {
issuerURL string
revocationURL string
jwkCache JWKCacheInterface
metadataCache *MetadataCache
tokenBlacklist *TokenBlacklist
jwksURL string
clientID string
@@ -60,7 +63,6 @@ type TraefikOidc struct {
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
@@ -80,8 +82,6 @@ var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -175,24 +175,48 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
if config == nil {
config = CreateConfig()
}
// Generate default session encryption key if not provided
if config.SessionEncryptionKey == "" {
// Generate a fixed key for Traefik Hub testing
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
// Initialize logger
logger := NewLogger(config.LogLevel)
// Ensure key meets minimum length requirement
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
}
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
}
var httpClient *http.Client
@@ -202,6 +226,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
}
@@ -223,6 +254,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
metadataCache: NewMetadataCache(),
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
@@ -230,14 +262,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
// Assign the initialized logger
t.logger = logger
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
@@ -260,41 +293,56 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Starting provider metadata discovery")
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
// Get metadata from cache or fetch it
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to get provider metadata: %v", err)
return
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Start metadata refresh goroutine
go t.startMetadataRefresh(providerURL)
// Only close channel on success
close(t.initComplete)
return
}
t.logger.Error("Received nil metadata")
}
// startMetadataRefresh periodically refreshes the OIDC metadata
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
t.logger.Debug("Refreshing OIDC metadata")
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
t.logger.Errorf("Failed to refresh metadata: %v", err)
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
t.logger.Debug("Successfully refreshed metadata")
}
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
}
}
@@ -389,7 +437,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
// Obtain a new session and clear any residual session cookies
session, _ = t.sessionManager.GetSession(req)
session.Clear(req, rw)
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Initiate authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -475,6 +534,34 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Set OIDC-specific headers
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetAccessToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
// Set security headers
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
rw.Header().Set("X-XSS-Protection", "1; mode=block")
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Set CORS headers
origin := req.Header.Get("Origin")
if origin != "" {
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Handle preflight requests
if req.Method == "OPTIONS" {
rw.WriteHeader(http.StatusOK)
return
}
}
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -493,9 +580,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
@@ -564,17 +648,20 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.New().String()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Set session values
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
session.SetIncomingPath(req.URL.RequestURI())
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -583,7 +670,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build and redirect to auth URL
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -604,17 +691,33 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
// Ensure authURL is absolute
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
t.authURL,
params.Encode())
}
}
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
go func() {
defer ticker.Stop()
for range ticker.C {
t.logger.Debug("Cleaning up token cache")
t.logger.Debug("Starting token cleanup cycle")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go
// Removed runtime.GC() call
}
}()
}
+276 -22
View File
@@ -13,6 +13,7 @@ import (
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -67,21 +68,29 @@ func (ts *TestSuite) Setup() {
}
// Create a test JWT token signed with the RSA private key
// Create timestamps with proper clock skew
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
}
logger := NewLogger("info")
ts.sessionManager = NewSessionManager("test-secret-key", false, logger)
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
@@ -126,6 +135,12 @@ func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet
return m.JWKS, m.Err
}
func (m *MockJWKCache) Cleanup() {
// Mock cleanup implementation
m.JWKS = nil
m.Err = nil
}
// Helper function to create a JWT token
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
header := map[string]interface{}{
@@ -611,7 +626,7 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
@@ -916,7 +931,7 @@ func TestHandleLogout(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
@@ -1205,7 +1220,7 @@ func TestHandleExpiredToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
sessionManager: sessionManager,
@@ -1362,23 +1377,23 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
// Create base config
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
}
// Create multiple middleware instances
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
var middlewares []*TraefikOidc
for _, route := range routes {
config.CallbackURL = route + "/callback"
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
}
@@ -1440,6 +1455,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create consistent timestamps for all test cases
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
tests := []struct {
name string
allowedRolesAndGroups map[string]struct{}
@@ -1456,11 +1477,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"admin", "user"},
"groups": []interface{}{"group1"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1480,11 +1503,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"allowed-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1505,11 +1530,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1523,11 +1550,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1545,9 +1574,11 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1622,6 +1653,112 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
}
// Helper function to compare string slices
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupServer func() *httptest.Server
expectError bool
errorContains string
}{
{
name: "Successful token exchange with redirects",
setupServer: func() *httptest.Server {
redirectCount := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 3 {
// Set a cookie before redirecting
http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
Value: "test-value",
})
redirectCount++
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
return
}
// Verify all cookies from previous redirects are present
cookies := r.Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := 0; i < 3; i++ {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
if cookie.Name == expectedName {
found = true
break
}
}
if !found {
t.Errorf("Cookie %s not found", expectedName)
}
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
expectError: false,
},
{
name: "Too many redirects",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
}))
},
expectError: true,
errorContains: "stopped after 50 redirects",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupServer()
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
}
})
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -1633,3 +1770,120 @@ func stringSliceEqual(a, b []string) bool {
}
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authURL string
issuerURL string
redirectURL string
state string
nonce string
expectedPrefix string
}{
{
name: "Absolute Auth URL",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
},
{
name: "Relative Auth URL",
authURL: "/oidc/auth",
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://logto.example.com/oidc/auth?",
},
{
name: "Relative Auth URL with Different Issuer",
authURL: "/sign-in",
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
}
// Parse the resulting URL to verify query parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
expectedParams := map[string]string{
"client_id": tOidc.clientID,
"response_type": "code",
"redirect_uri": tc.redirectURL,
"state": tc.state,
"nonce": tc.nonce,
}
for key, expected := range expectedParams {
if got := query.Get(key); got != expected {
t.Errorf("Expected %s=%q, got %q", key, expected, got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
if got := query.Get("scope"); got != expectedScopes {
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
}
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
rw := httptest.NewRecorder()
// Get session
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath()
expectedPath := "/protected/resource?param1=value1&param2=value2"
if incomingPath != expectedPath {
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
}
}
+73
View File
@@ -0,0 +1,73 @@
package traefikoidc
import (
"fmt"
"net/http"
"sync"
"time"
)
// MetadataCache provides thread-safe caching for OIDC provider metadata
type MetadataCache struct {
metadata *ProviderMetadata
expiresAt time.Time
mutex sync.RWMutex
}
// NewMetadataCache creates a new metadata cache instance
func NewMetadataCache() *MetadataCache {
return &MetadataCache{}
}
// Cleanup removes expired metadata from the cache.
func (c *MetadataCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
if c.metadata != nil && now.After(c.expiresAt) {
c.metadata = nil
}
}
// GetMetadata retrieves the metadata from cache or fetches it if expired
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
c.mutex.RLock()
if c.metadata != nil && time.Now().Before(c.expiresAt) {
defer c.mutex.RUnlock()
return c.metadata, nil
}
c.mutex.RUnlock()
c.mutex.Lock()
defer c.mutex.Unlock()
// Double-check after acquiring write lock
if c.metadata != nil && time.Now().Before(c.expiresAt) {
return c.metadata, nil
}
metadata, err := discoverProviderMetadata(providerURL, httpClient, logger)
if err != nil {
if c.metadata != nil {
// On error, extend current cache by 5 minutes to prevent thundering herd
c.expiresAt = time.Now().Add(5 * time.Minute)
logger.Errorf("Failed to refresh metadata, using cached version for 5 more minutes: %v", err)
return c.metadata, nil
}
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
}
c.metadata = metadata
// Calculate expiration time based on usage patterns
usageCount := 0 // This should be replaced with actual usage tracking logic
if usageCount < 10 {
c.expiresAt = time.Now().Add(30 * time.Minute)
} else if usageCount < 50 {
c.expiresAt = time.Now().Add(1 * time.Hour)
} else {
c.expiresAt = time.Now().Add(2 * time.Hour)
}
return metadata, nil
}
+10
View File
@@ -0,0 +1,10 @@
version: 1
force:
existing: true
wording:
patch:
- patch-release
minor:
- minor-release
major:
- breaking
+162 -112
View File
@@ -3,30 +3,38 @@ package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// generateSecureRandomString creates a cryptographically secure random string of specified length.
// It returns the generated string or an error if random generation fails.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return hex.EncodeToString(bytes), nil
}
// Cookie names and configuration constants used for session management
const (
// mainCookieName is the name of the main session cookie that stores authentication state
// and basic user information like email and CSRF tokens
mainCookieName = "_raczylo_oidc"
// accessTokenCookie is the name of the cookie that stores the OIDC access token
// This may be split into multiple cookies if the token is large
accessTokenCookie = "_raczylo_oidc_access"
// refreshTokenCookie is the name of the cookie that stores the OIDC refresh token
// This may be split into multiple cookies if the token is large
refreshTokenCookie = "_raczylo_oidc_refresh"
// Using fixed prefixes for consistent cookie naming across restarts
mainCookieName = "_oidc_raczylo_m"
accessTokenCookie = "_oidc_raczylo_a"
refreshTokenCookie = "_oidc_raczylo_r"
)
const (
// maxCookieSize is the maximum size for each cookie chunk.
// This value is calculated to ensure the final cookie size stays within browser limits:
// 1. Browser cookie size limit is typically 4096 bytes
@@ -39,9 +47,16 @@ const (
// - Solving for x: x ≤ 3044
// 4. We use 2000 as a conservative limit to account for cookie metadata
maxCookieSize = 2000
// absoluteSessionTimeout defines the maximum lifetime of a session
// regardless of activity (24 hours)
absoluteSessionTimeout = 24 * time.Hour
// minEncryptionKeyLength defines the minimum length for the encryption key
minEncryptionKeyLength = 32
)
// compressToken compresses a token using gzip and base64 encodes it
// compressToken compresses a token using gzip and base64 encodes it.
func compressToken(token string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
@@ -54,41 +69,41 @@ func compressToken(token string) string {
return base64.StdEncoding.EncodeToString(b.Bytes())
}
// decompressToken decompresses a base64 encoded gzipped token
// decompressToken decompresses a base64 encoded gzipped token.
func decompressToken(compressed string) string {
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return compressed // return as-is if not base64
}
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return compressed
}
defer gz.Close()
decompressed, err := io.ReadAll(gz)
if err != nil {
return compressed
}
return string(decompressed)
}
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
// and other session-related data across multiple cookies.
type SessionManager struct {
// store is the underlying session store for cookie management
// store is the underlying session store for cookie management.
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
// forceHTTPS enforces secure cookie attributes regardless of request scheme.
forceHTTPS bool
// logger provides structured logging capabilities
// logger provides structured logging capabilities.
logger *Logger
// sessionPool is a sync.Pool for reusing SessionData objects
// sessionPool is a sync.Pool for reusing SessionData objects.
sessionPool sync.Pool
}
@@ -97,29 +112,36 @@ type SessionManager struct {
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
//
// Returns an error if the encryption key does not meet minimum length requirements.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) {
// Validate encryption key length.
if len(encryptionKey) < minEncryptionKeyLength {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
// Initialize session pool
// Initialize session pool.
sm.sessionPool.New = func() interface{} {
return &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
}
}
return sm
return sm, nil
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
// - isSecure: Whether the current request is using HTTPS.
//
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
@@ -130,7 +152,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
MaxAge: int(absoluteSessionTimeout.Seconds()),
Path: "/",
}
}
@@ -140,7 +162,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
// Get session from pool
// Get session from pool.
sessionData := sm.sessionPool.Get().(*SessionData)
sessionData.request = r
@@ -151,6 +173,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
// Check for absolute session timeout.
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
sessionData.Clear(r, nil)
return nil, fmt.Errorf("session expired")
}
}
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
@@ -163,7 +193,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
// Clear and reuse chunk maps
// Clear and reuse chunk maps.
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
@@ -171,7 +201,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
delete(sessionData.refreshTokenChunks, k)
}
// Retrieve chunked token sessions
// Retrieve chunked token sessions.
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
@@ -188,7 +218,6 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
@@ -200,27 +229,27 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
// manager is the SessionManager that created this SessionData.
manager *SessionManager
// request is the current HTTP request associated with this session
// request is the current HTTP request associated with this session.
request *http.Request
// mainSession stores authentication state and basic user info
// mainSession stores authentication state and basic user info.
mainSession *sessions.Session
// accessSession stores the primary access token cookie
// accessSession stores the primary access token cookie.
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
// refreshSession stores the primary refresh token cookie.
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
refreshTokenChunks map[int]*sessions.Session
}
@@ -231,28 +260,28 @@ type SessionData struct {
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
// Set options for all sessions.
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
// Save main session.
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
// Save access token session.
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
// Save refresh token session.
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
// Save access token chunks.
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -260,7 +289,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
}
// Save refresh token chunks
// Save refresh token chunks.
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -272,10 +301,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
// Clear and expire all sessions.
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
@@ -290,21 +317,22 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
// Clear chunk sessions.
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
err := sd.Save(r, w)
// Return session to pool
var err error
if w != nil {
err = sd.Save(r, w)
}
// Return session to pool.
sd.manager.sessionPool.Put(sd)
return err
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
@@ -315,23 +343,36 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication,
// false otherwise or if the authentication status cannot be determined.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
if !auth {
return false
}
// Check session expiration.
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
}
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
}
// SetAuthenticated updates the session's authentication status.
// This should be called after successful OIDC authentication or during logout.
func (sd *SessionData) SetAuthenticated(value bool) {
// SetAuthenticated updates the session's authentication status and rotates session ID.
// Returns an error if generating a new session ID fails.
func (sd *SessionData) SetAuthenticated(value bool) error {
if value {
id, err := generateSecureRandomString(32)
if err != nil {
return fmt.Errorf("failed to generate secure session id: %w", err)
}
sd.mainSession.ID = id
sd.mainSession.Values["created_at"] = time.Now().Unix()
}
sd.mainSession.Values["authenticated"] = value
return nil
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
@@ -342,7 +383,7 @@ func (sd *SessionData) GetAccessToken() string {
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.accessTokenChunks) == 0 {
return ""
}
@@ -366,23 +407,23 @@ func (sd *SessionData) GetAccessToken() string {
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetAccessToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token.
sd.accessTokenChunks = make(map[int]*sessions.Session)
// Compress token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
// Split compressed token into chunks.
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
@@ -396,9 +437,6 @@ func (sd *SessionData) SetAccessToken(token string) {
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
@@ -409,7 +447,7 @@ func (sd *SessionData) GetRefreshToken() string {
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.refreshTokenChunks) == 0 {
return ""
}
@@ -433,23 +471,23 @@ func (sd *SessionData) GetRefreshToken() string {
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetRefreshToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token.
sd.refreshTokenChunks = make(map[int]*sessions.Session)
// Compress token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
// Split compressed token into chunks.
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
@@ -462,12 +500,43 @@ func (sd *SessionData) SetRefreshToken(token string) {
}
}
// expireAccessTokenChunks expires any existing access token chunk cookies.
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err)
}
}
}
}
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err)
}
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
@@ -483,64 +552,45 @@ func splitIntoChunks(s string, chunkSize int) []string {
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
// SetIncomingPath stores the original request path that triggered the authentication flow.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+236 -122
View File
@@ -5,14 +5,8 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
func init() {
// Initialize random seed
rand.Seed(time.Now().UnixNano())
}
// generateRandomString creates a random string of specified length
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
@@ -56,9 +50,9 @@ func TestTokenCompression(t *testing.T) {
if len(tt.token) > 100 {
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
if compressionRatio > 1.1 { // Allow up to 10% size increase
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
len(tt.token), len(compressed), compressionRatio)
}
}
@@ -77,6 +71,120 @@ func TestTokenCompression(t *testing.T) {
}
// TestSessionManager tests the SessionManager functionality
func TestCookiePrefix(t *testing.T) {
// Create a session and verify cookie names
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set some data to ensure cookies are created
session.SetAuthenticated(true)
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Set new tokens
session.SetAccessToken("test_token")
session.SetRefreshToken("test_refresh_token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Check cookie prefixes
cookies := rr.Result().Cookies()
for _, cookie := range cookies {
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
}
}
}
func TestTokenRefreshCleanup(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set a large token that will be split into chunks
largeToken := strings.Repeat("x", 5000)
session.SetAccessToken(largeToken)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get initial cookies
initialCookies := rr.Result().Cookies()
// Create a new request with the initial cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range initialCookies {
newReq.AddCookie(cookie)
}
newRr := httptest.NewRecorder()
// Get session with cookies and set a new token
newSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Create a response recorder for expired cookies
expiredRr := httptest.NewRecorder()
// Expire old chunk cookies
newSession.expireAccessTokenChunks(expiredRr)
// Set a smaller token that won't need chunks
newSession.SetAccessToken("small_token")
// Save session with new token
if err := newSession.Save(newReq, newRr); err != nil {
t.Fatalf("Failed to save new session: %v", err)
}
// Check cookies in response where old cookies are expired
intermediateResponse := expiredRr.Result()
intermediateCount := 0
chunkCount := 0
expiredCount := 0
for _, cookie := range intermediateResponse.Cookies() {
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
chunkCount++
if cookie.MaxAge < 0 {
expiredCount++
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
} else if cookie.MaxAge >= 0 {
intermediateCount++
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
}
// All chunk cookies should be expired
if chunkCount > 0 && chunkCount != expiredCount {
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
}
// Should have fewer active cookies after setting smaller token
if intermediateCount >= len(initialCookies) {
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
}
}
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
@@ -84,154 +192,160 @@ func TestSessionManager(t *testing.T) {
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
email string
accessToken string
refreshToken string
expectedCookieCount int
wantCompressed bool // Whether tokens should be compressed
wantCompressed bool // Whether tokens should be compressed
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: true,
wantCompressed: true,
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
wantCompressed: true,
wantCompressed: true,
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
wantCompressed: true,
wantCompressed: true,
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: false,
wantCompressed: false,
},
{
name: "Random content tokens",
authenticated: true,
email: "test@example.com",
accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000),
name: "Random content tokens",
authenticated: true,
email: "test@example.com",
accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000),
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
wantCompressed: true,
wantCompressed: true,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Set new tokens
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Verify compression is working by checking token sizes
for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes
originalSize := len(tc.accessToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Verify compression is working by checking token sizes
for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes
originalSize := len(tc.accessToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
}
}
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
}
}
@@ -242,12 +356,12 @@ func calculateExpectedCookieCount(accessToken, refreshToken string) int {
calculateChunks := func(token string) int {
// Compress token (matching the actual implementation)
compressed := compressToken(token)
// If compressed token fits in one cookie, no additional chunks needed
if len(compressed) <= maxCookieSize {
return 0
}
// Calculate chunks needed for compressed token
return len(splitIntoChunks(compressed, maxCookieSize))
}
+75 -32
View File
@@ -10,10 +10,6 @@ import (
"strings"
)
const (
cookieName = "_raczylo_oidc"
)
// Config holds the configuration for the OIDC middleware.
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
@@ -85,30 +81,34 @@ type Config struct {
HTTPClient *http.Client
}
// CreateConfig creates a new Config with sensible default values.
const (
// DefaultRateLimit defines the default rate limit for requests per second
DefaultRateLimit = 100
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
MinRateLimit = 10
// DefaultLogLevel defines the default logging level
DefaultLogLevel = "info"
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
MinSessionEncryptionKeyLength = 32
)
// CreateConfig creates a new Config with secure default values.
// Default values are set for optional fields:
// - Scopes: ["openid", "profile", "email"]
// - LogLevel: "info"
// - LogoutURL: CallbackURL + "/logout"
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
func CreateConfig() *Config {
c := &Config{}
if c.Scopes == nil {
c.Scopes = []string{"openid", "profile", "email"}
}
if c.LogLevel == "" {
c.LogLevel = "info"
}
if c.LogoutURL == "" {
c.LogoutURL = c.CallbackURL + "/logout"
}
if c.RateLimit == 0 {
c.RateLimit = 100
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
}
return c
@@ -118,43 +118,85 @@ func CreateConfig() *Config {
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
func (c *Config) Validate() error {
// Validate provider URL
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if !isValidURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid URL")
if !isValidSecureURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid HTTPS URL")
}
// Validate callback URL
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if !strings.HasPrefix(c.CallbackURL, "/") {
return fmt.Errorf("callbackURL must start with /")
}
// Validate client credentials
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required")
}
// Validate session encryption key
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
if len(c.SessionEncryptionKey) < 32 {
return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long")
}
if c.RateLimit < 0 {
return fmt.Errorf("rateLimit must be non-negative")
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
}
// Validate log level
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
return fmt.Errorf("logLevel must be one of: debug, info, error")
}
// Validate excluded URLs
for _, url := range c.ExcludedURLs {
if !strings.HasPrefix(url, "/") {
return fmt.Errorf("excluded URL must start with /: %s", url)
}
if strings.Contains(url, "..") {
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
}
if strings.Contains(url, "*") {
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
}
}
// Validate revocation URL if set
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
}
// Validate end session URL if set
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
}
// Validate post-logout redirect URI if set
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
}
}
// Validate rate limit
if c.RateLimit < MinRateLimit {
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
}
return nil
}
// isValidURL checks if the provided string is a valid URL
func isValidURL(s string) bool {
// isValidSecureURL checks if the provided string is a valid HTTPS URL
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme != "" && u.Host != ""
return err == nil && u.Scheme == "https" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
@@ -179,6 +221,7 @@ type Logger struct {
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
//
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
func NewLogger(logLevel string) *Logger {
@@ -187,7 +230,7 @@ func NewLogger(logLevel string) *Logger {
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
+49 -14
View File
@@ -23,13 +23,18 @@ func TestCreateConfig(t *testing.T) {
}
// Check default log level
if config.LogLevel != "info" {
t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel)
if config.LogLevel != DefaultLogLevel {
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
}
// Check default rate limit
if config.RateLimit != 100 {
t.Errorf("Expected default rate limit 100, got %d", config.RateLimit)
if config.RateLimit != DefaultRateLimit {
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
}
// Check ForceHTTPS default
if !config.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true by default")
}
})
@@ -38,6 +43,7 @@ func TestCreateConfig(t *testing.T) {
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
config.ForceHTTPS = false
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
@@ -49,6 +55,9 @@ func TestCreateConfig(t *testing.T) {
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
if config.ForceHTTPS {
t.Error("Custom ForceHTTPS value was overwritten")
}
})
}
@@ -98,15 +107,15 @@ func TestConfigValidate(t *testing.T) {
expectedError: "sessionEncryptionKey is required",
},
{
name: "Invalid ProviderURL",
name: "Non-HTTPS ProviderURL",
config: &Config{
ProviderURL: "not-a-url",
ProviderURL: "http://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid URL",
expectedError: "providerURL must be a valid HTTPS URL",
},
{
name: "Invalid CallbackURL",
@@ -131,16 +140,16 @@ func TestConfigValidate(t *testing.T) {
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Negative RateLimit",
name: "Low RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: -1,
RateLimit: 5,
},
expectedError: "rateLimit must be non-negative",
expectedError: "rateLimit must be at least 10",
},
{
name: "Invalid LogLevel",
@@ -154,6 +163,30 @@ func TestConfigValidate(t *testing.T) {
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Non-HTTPS RevocationURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RevocationURL: "http://revoke.com",
},
expectedError: "revocationURL must be a valid HTTPS URL",
},
{
name: "Non-HTTPS OIDCEndSessionURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
OIDCEndSessionURL: "http://endsession.com",
},
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
},
{
name: "Valid Config",
config: &Config{
@@ -164,6 +197,8 @@ func TestConfigValidate(t *testing.T) {
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
RevocationURL: "https://revoke.com",
OIDCEndSessionURL: "https://endsession.com",
},
expectedError: "",
},
@@ -192,9 +227,9 @@ func TestLogger(t *testing.T) {
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
@@ -289,7 +324,7 @@ func TestLogger(t *testing.T) {
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}