Compare commits

...

32 Commits

Author SHA1 Message Date
lukaszraczylo fdb8e3233e Testing (could be unstable) additional headers.
This adds additional headers to control the access origin and control allow headers.
2025-02-06 23:46:08 +00:00
lukaszraczylo 33c71fd6fe Enhance test suite. 2025-02-06 23:38:22 +00:00
lukaszraczylo 241cb1c209 Deal with the memory growth issue.
* TokenBlacklist limit is set to 1000
* Increased token cleanup frequency
2025-02-06 23:34:05 +00:00
lukaszraczylo 09daa1025c Follow multiple redirects during the OIDC flow. 2025-02-06 23:31:13 +00:00
lukaszraczylo c09e7a9228 Add additional test cases to cover it. 2025-02-06 21:50:35 +00:00
lukaszraczylo e5da5d4fe9 Fix redirection to the provider when session expires 2025-02-06 21:48:56 +00:00
lukaszraczylo 31db701dda Trigger build and release. 2025-02-05 19:04:44 +00:00
lukaszraczylo 16481afd36 Add todo: Improve test coverage. 2025-02-01 12:20:01 +00:00
lukaszraczylo 751933ffa0 Multiple improvements.
* Add todo list.

* fixup! Add todo list.

* fixup! fixup! Add todo list.

* fixup! fixup! fixup! Add todo list.

* Improve the session handling and cache.

* Fix an issue where expired session can cause infinite redirect loop

* fixup! Fix an issue where expired session can cause infinite redirect loop

* Add semver setup for automatic releases.

* fixup! Add semver setup for automatic releases.

* fixup! fixup! Add semver setup for automatic releases.

* fixup! fixup! fixup! Add semver setup for automatic releases.
2025-02-01 12:16:50 +00:00
lukaszraczylo e74153b107 Merge pull request #28 from lukaszraczylo/additional-improvements
additional improvements
2025-01-21 19:34:01 +00:00
lukaszraczylo 025107fe3e Well, release it finally. 2025-01-21 19:31:51 +00:00
lukaszraczylo dfb9c0771e Fix session handling and the redirection to the original URL incl. get parameters 2025-01-21 17:49:54 +00:00
lukaszraczylo 1107df40e7 Merge pull request #26 from lukaszraczylo/additional-improvements
Cleanup old cookies properly.
2025-01-21 17:34:16 +00:00
lukaszraczylo bf294569eb Cleanup old cookies properly. 2025-01-21 17:09:48 +00:00
lukaszraczylo 482c346840 Merge pull request #24 from lukaszraczylo/additional-improvements
additional improvements
2025-01-21 00:19:49 +00:00
lukaszraczylo a462e44896 Fix remaining issues with session handling and add additional tests. 2025-01-21 00:18:10 +00:00
lukaszraczylo 5eff0dc866 Clean up old cookies. 2025-01-21 00:03:13 +00:00
lukaszraczylo dfc534a400 Merge pull request #23 from lukaszraczylo/additional-improvements
Add useful defaults allowing traefik hub to pass.
2025-01-20 23:57:51 +00:00
lukaszraczylo 061c12d0a3 Add useful defaults allowing traefik hub to pass. 2025-01-20 23:55:58 +00:00
lukaszraczylo 4c4fff3613 Merge pull request #22 from lukaszraczylo/additional-improvements
Quite important fix
2025-01-20 23:50:35 +00:00
lukaszraczylo 0dcb44c187 Quite important fix
When user session expires, reauthentication fails as CSRF token disappears.
This commit fixes the issue by initiating new authentication flow.
2025-01-20 23:48:31 +00:00
lukaszraczylo cbe773d96a Merge pull request #20 from lukaszraczylo/additional-improvements
Provide default session encryption key if not specified.
2025-01-18 11:00:07 +00:00
lukaszraczylo 40254888d7 Provide default session encryption key if not specified. 2025-01-18 10:54:30 +00:00
lukaszraczylo ef41870c81 Merge pull request #18 from lukaszraczylo/additional-improvements
additional improvements
2025-01-18 02:28:29 +00:00
lukaszraczylo 081c32925a fixup! Security improvements have been implemented and verified across four main areas: 2025-01-14 11:47:49 +00:00
lukaszraczylo 17dea67229 Security improvements have been implemented and verified across four main areas:
JWT Token Security:
Protected against algorithm switching attacks by validating and whitelisting algorithms (RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512)
Added 2-minute clock skew tolerance for time-based validations
Added "not before" (nbf) claim validation with clock skew tolerance
Required JWT ID (jti) claim to prevent replay attacks
Added strict algorithm validation to prevent downgrade attacks
Session Management Security:
Implemented cryptographically secure random cookie names to prevent targeting
Added automatic session ID rotation after successful login to prevent session fixation
Enforced 24-hour absolute session timeout
Added strict encryption key length validation (minimum 32 bytes)
Added comprehensive session validation including timeout checks
Implemented session pooling for secure resource management
Added secure session cleanup on expiration
Configuration and URL Security:
Enforced HTTPS for all provider URLs and external endpoints
Added minimum rate limit (10 req/sec) to prevent DOS attacks
Added strict validation for excluded URLs:
Must start with "/"
No path traversal (..)
No wildcards (*)
Made ForceHTTPS true by default for secure cookies
Added validation for secure redirect URIs
Added validation for all OIDC endpoints (must be HTTPS)
Added secure defaults in configuration
Test Coverage:
Added comprehensive test cases verifying all security validations
Added test cases for HTTPS enforcement on all endpoints
Added test cases for minimum rate limits
Added test cases for secure session management
Added test cases for token validation with clock skew
Added test cases for secure configuration defaults
All security improvements have been verified through passing test cases, protecting against:

Session fixation attacks
Token replay attacks
Algorithm switching attacks
Path traversal attacks
Session hijacking
Timing attacks
DOS attacks
Man-in-the-middle attacks through enforced HTTPS
2025-01-14 11:33:48 +00:00
lukaszraczylo 8512ad6d68 Revert "Update vendored modules."
This reverts commit 5aa838c669.
2025-01-07 13:19:41 +00:00
lukaszraczylo 5aa838c669 Update vendored modules. 2025-01-06 13:10:13 +00:00
lukaszraczylo 6f359e5ef1 Add tests for the compression of tokens. 2025-01-06 13:00:28 +00:00
lukaszraczylo bd18d6041c Implement cookie compression, decrease memory footprint, reduce allocations 2025-01-06 12:54:48 +00:00
lukaszraczylo 74c620ad51 HTTP Client Optimization:
Reduced connection timeouts from 30s to 15s
Decreased idle connection limits from 100 to 30
Lowered keepalive duration from 90s to 30s
Added MaxConnsPerHost limit of 50 to prevent connection flooding
Optimized TLS handshake timeout to 5s

Cache System Optimization:
Implemented size-limited LRU cache with max 1000 items
Added efficient eviction of least recently used items
Improved cleanup process with batch operations
Reduced lock contention by splitting read/write operations
Optimized memory usage with access tracking
Added immediate cleanup of expired items during access

Connection Management:
Enabled connection reuse with keepalives
Reduced connection pool size to conserve memory
Implemented more aggressive connection timeout
Added connection limits per host to prevent resource exhaustion
2025-01-06 12:48:33 +00:00
lukaszraczylo 7e3dc46b6e Improve initial fetch of the provider metadata until successful. 2025-01-06 12:19:11 +00:00
15 changed files with 1715 additions and 396 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ testData:
- raczylo.com
allowedRolesAndGroups:
- guest-endpoints
sessionEncryptionKey: potato-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
forceHTTPS: false
logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
+2
View File
@@ -19,6 +19,8 @@ Middleware currently supports following scenarios:
#### How to configure...
* `sessionEncryptionKey` should be at least 32 bytes long.
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
+5
View File
@@ -0,0 +1,5 @@
### TODO / wishlist
- [] Improve test coverage
- [x] Improve caching mechanism
- [x] Add automatic release and semver generation
+99 -41
View File
@@ -1,95 +1,153 @@
package traefikoidc
import (
"container/list"
"sync"
"time"
)
// CacheItem represents an item stored in the cache with its associated metadata.
type CacheItem struct {
// Value is the cached data of any type
// Value is the cached data of any type.
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
// ExpiresAt is the timestamp when this item should be considered expired.
ExpiresAt time.Time
}
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
// lruEntry represents an entry in the LRU list.
type lruEntry struct {
key string
}
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
type Cache struct {
// items stores the cached data with string keys.
items map[string]CacheItem
// order maintains the usage order; most recently used items are at the back.
order *list.List
// elems maps keys to their corresponding list elements for O(1) access.
elems map[string]*list.Element
// mutex protects concurrent access to the cache.
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache.
maxSize int
}
// DefaultMaxSize is the default maximum number of items in the cache.
const DefaultMaxSize = 1000
// NewCache creates a new empty cache instance that is ready for use.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
}
}
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
// Thread-safe: Uses write locking to ensure safe concurrent access.
// It moves the item to the most recently used position.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
expTime := now.Add(expiration)
// Update existing item.
if _, exists := c.items[key]; exists {
c.items[key] = CacheItem{
Value: value,
ExpiresAt: expTime,
}
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return
}
// Evict oldest item if cache is full.
if len(c.items) >= c.maxSize {
c.evictOldest()
}
// Add new item.
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: expTime,
}
elem := c.order.PushBack(lruEntry{key: key})
c.elems[key] = elem
}
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
// Thread-safe: Uses read locking to ensure safe concurrent access.
// Moving the accessed item to the most recently used position.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, found := c.items[key]
if !found {
c.mutex.Lock()
defer c.mutex.Unlock()
item, exists := c.items[key]
if !exists {
return nil, false
}
// Check for expiration.
if time.Now().After(item.ExpiresAt) {
delete(c.items, key)
c.removeItem(key)
return nil, false
}
// Move item to the back (most recently used).
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return item.Value, true
}
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Delete removes an item from the cache.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.removeItem(key)
}
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Cleanup removes all expired items from the cache. This should be called periodically
// to prevent memory bloat from expired entries.
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
for key, item := range c.items {
if now.After(item.ExpiresAt) {
delete(c.items, key)
c.removeItem(key)
}
}
}
// evictOldest removes the least recently used item from the cache.
func (c *Cache) evictOldest() {
elem := c.order.Front()
if elem != nil {
entry := elem.Value.(lruEntry)
c.removeItem(entry.key)
}
}
// removeItem removes an item from both the cache and the LRU tracking structures.
func (c *Cache) removeItem(key string) {
delete(c.items, key)
if elem, ok := c.elems[key]; ok {
c.order.Remove(elem)
delete(c.elems, key)
}
}
+55 -22
View File
@@ -8,31 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
@@ -87,13 +69,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
data.Set("refresh_token", codeOrToken)
}
// Create a cookie jar for this request to handle redirects with cookies
jar, _ := cookiejar.New(nil)
client := &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
@@ -128,11 +125,19 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
// Clear authentication data but preserve CSRF state
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Save the cleared session state
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -287,12 +292,16 @@ type TokenBlacklist struct {
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
// maxSize is the maximum number of tokens in the blacklist
maxSize int
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
maxSize: 1000, // Limit the size to prevent unbounded growth
}
}
@@ -300,6 +309,30 @@ func NewTokenBlacklist() *TokenBlacklist {
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
// Clean up expired tokens if we're at capacity
if len(tb.blacklist) >= tb.maxSize {
now := time.Now()
for token, exp := range tb.blacklist {
if now.After(exp) {
delete(tb.blacklist, token)
}
}
// If still at capacity after cleanup, remove oldest token
if len(tb.blacklist) >= tb.maxSize {
var oldestToken string
var oldestTime time.Time
first := true
for token, exp := range tb.blacklist {
if first || exp.Before(oldestTime) {
oldestToken = token
oldestTime = exp
first = false
}
}
delete(tb.blacklist, oldestToken)
}
}
tb.blacklist[tokenID] = expiration
}
+225
View File
@@ -0,0 +1,225 @@
package traefikoidc
import (
"fmt"
"runtime"
"testing"
"time"
)
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
}
// Verify size is at max
if len(tb.blacklist) != 1000 {
t.Errorf("Expected blacklist size to be 1000, got %d", len(tb.blacklist))
}
// Add one more token, should trigger cleanup/eviction
tb.Add("newtoken", time.Now().Add(time.Hour))
// Size should still be at max
if len(tb.blacklist) > 1000 {
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
}
}
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
}
// Force cleanup
tb.Cleanup()
// Only valid tokens should remain
if len(tb.blacklist) != 500 {
t.Errorf("Expected 500 valid tokens after cleanup, got %d", len(tb.blacklist))
}
// Verify only valid tokens remain
for token, expiry := range tb.blacklist {
if time.Now().After(expiry) {
t.Errorf("Found expired token after cleanup: %s", token)
}
}
}
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
}
// Add a new token that should evict the oldest
newToken := "newest"
tb.Add(newToken, baseTime.Add(time.Hour*3))
// Verify oldest token was evicted
if tb.IsBlacklisted(oldestToken) {
t.Error("Oldest token should have been evicted")
}
// Verify newest token is present
if !tb.IsBlacklisted(newToken) {
t.Error("Newest token should be present")
}
}
func TestTokenBlacklistMemoryUsage(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy usage
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
}
// Verify size stayed within limits
if len(tb.blacklist) > tb.maxSize {
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
}
}
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 1000
concurrency := 10
done := make(chan bool)
// Start multiple goroutines performing operations
for i := 0; i < concurrency; i++ {
go func(id int) {
for j := 0; j < iterations; j++ {
// Add tokens
token := fmt.Sprintf("token%d-%d", id, j)
tb.Add(token, time.Now().Add(time.Hour))
// Check blacklist status
tb.IsBlacklisted(token)
// Periodic cleanup
if j%100 == 0 {
tb.Cleanup()
}
}
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < concurrency; i++ {
<-done
}
// Verify size constraints were maintained
if len(tb.blacklist) > tb.maxSize {
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", len(tb.blacklist))
}
}
func TestTokenCacheMemoryUsage(t *testing.T) {
tc := NewTokenCache()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy cache usage
for i := 0; i < iterations; i++ {
claims := map[string]interface{}{
"sub": fmt.Sprintf("user%d", i),
"exp": time.Now().Add(time.Hour).Unix(),
}
// Add to cache
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
// Periodically retrieve
if i%100 == 0 {
tc.Get(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tc.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive cache memory growth: %d bytes", memoryGrowth)
}
// Verify cache size stayed within limits
if len(tc.cache.items) > tc.cache.maxSize {
t.Errorf("Cache exceeded max size: %d", len(tc.cache.items))
}
}
+2
View File
@@ -81,6 +81,7 @@ type JWKCacheInterface interface {
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
//
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
@@ -115,6 +116,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
//
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
+135 -10
View File
@@ -38,6 +38,7 @@ type JWT struct {
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
//
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
@@ -83,13 +84,39 @@ func parseJWT(tokenString string) (*JWT, error) {
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future
// - issued at time (iat) is in the past
// - expiration time (exp) is in the future (with clock skew tolerance)
// - issued at time (iat) is in the past (with clock skew tolerance)
// - not before time (nbf) is in the past (with clock skew tolerance)
// - subject (sub) is present and not empty
// - algorithm matches expected value to prevent algorithm switching attacks
//
// Returns an error if any validation fails.
func (j *JWT) Verify(issuerURL, clientID string) error {
// Debug logging of validation parameters
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
// Debug logging of token header
fmt.Printf("Token header: %+v\n", j.Header)
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing 'alg' header")
}
// List of supported algorithms - should match those in verifySignature
supportedAlgs := map[string]bool{
"RS256": true, "RS384": true, "RS512": true,
"PS256": true, "PS384": true, "PS512": true,
"ES256": true, "ES384": true, "ES512": true,
}
if !supportedAlgs[alg] {
return fmt.Errorf("unsupported algorithm: %s", alg)
}
claims := j.Claims
// Debug logging of all claims
fmt.Printf("Token claims: %+v\n", claims)
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
@@ -122,6 +149,19 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return err
}
// Validate nbf (not before) claim if present
if nbf, ok := claims["nbf"].(float64); ok {
if err := verifyNotBefore(nbf); err != nil {
return err
}
}
// Validate jti (JWT ID) claim if present
if jti, ok := claims["jti"].(string); ok {
// Could add replay detection here if needed
_ = jti
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
@@ -136,8 +176,13 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
//
// Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Debug logging
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
tokenAudience, expectedAudience)
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
@@ -165,37 +210,112 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
//
// Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
// Debug logging
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
tokenIssuer, expectedIssuer)
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
tokenIssuer, expectedIssuer)
}
return nil
}
// Clock skew tolerance for time-based validations
const clockSkewTolerance = 2 * time.Minute
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time.
// The expiration time is compared against the current time with clock skew tolerance.
// Parameters:
// - expiration: The expiration timestamp from the token
//
// Returns an error if the token has expired.
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(clockSkewTolerance)
// Debug logging
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
expirationTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that expire exactly now
if expirationTime.Equal(now) {
return nil
}
if skewedNow.After(expirationTime) {
return fmt.Errorf("token has expired (exp: %v, now: %v)",
expirationTime.UTC(), now.UTC())
}
return nil
}
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, which could
// indicate clock skew or a malicious token.
// Ensures the token wasn't issued in the future, accounting for clock skew.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
//
// Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
issuedAtTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens issued in the same second as current time
if issuedAtTime.Equal(now) {
return nil
}
if skewedNow.Before(issuedAtTime) {
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
issuedAtTime.UTC(), now.UTC())
}
return nil
}
// verifyNotBefore validates the token's not-before time if present.
// Ensures the token is not used before its valid time period, accounting for clock skew.
// Parameters:
// - notBefore: The not-before timestamp from the token
//
// Returns an error if the token is not yet valid.
func verifyNotBefore(notBefore float64) error {
notBeforeTime := time.Unix(int64(notBefore), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
notBeforeTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that become valid exactly now
if notBeforeTime.Equal(now) {
return nil
}
if skewedNow.Before(notBeforeTime) {
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
notBeforeTime.UTC(), now.UTC())
}
return nil
}
@@ -205,12 +325,17 @@ func verifyIssuedAt(issuedAt float64) error {
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
//
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
//
// Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Debug logging
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
+182 -42
View File
@@ -12,6 +12,8 @@ import (
"strings"
"time"
"runtime"
"github.com/google/uuid"
"golang.org/x/time/rate"
)
@@ -60,7 +62,6 @@ type TraefikOidc struct {
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
@@ -80,8 +81,6 @@ var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -175,22 +174,48 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
if config == nil {
config = CreateConfig()
}
// Generate default session encryption key if not provided
if config.SessionEncryptionKey == "" {
// Generate a fixed key for Traefik Hub testing
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
// Initialize logger
logger := NewLogger(config.LogLevel)
// Ensure key meets minimum length requirement
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
}
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 10 * time.Second,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
}
var httpClient *http.Client
@@ -198,8 +223,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient = config.HTTPClient
} else {
httpClient = &http.Client{
Timeout: time.Second * 30,
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
}
@@ -228,14 +260,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
// Assign the initialized logger
t.logger = logger
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
@@ -258,21 +291,42 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Starting provider metadata discovery")
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else if metadata != nil {
t.logger.Debug("Using provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
}
close(t.initComplete)
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
}
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
}
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -342,14 +396,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Debug("OIDC middleware not yet initialized")
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
t.logger.Error("OIDC provider metadata initialization failed")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
return
}
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
case <-time.After(30 * time.Second):
t.logger.Error("Timeout waiting for OIDC initialization")
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
return
}
// Check if URL is excluded
@@ -362,7 +420,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
// Obtain a new session and clear any residual session cookies
session, _ = t.sessionManager.GetSession(req)
session.Clear(req, rw)
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Initiate authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -447,7 +516,35 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Set OIDC-specific headers
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetAccessToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
// Set security headers
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
rw.Header().Set("X-XSS-Protection", "1; mode=block")
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Set CORS headers
origin := req.Header.Get("Origin")
if origin != "" {
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Handle preflight requests
if req.Method == "OPTIONS" {
rw.WriteHeader(http.StatusOK)
return
}
}
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -466,9 +563,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
@@ -537,17 +631,20 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.New().String()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Set session values
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
session.SetIncomingPath(req.URL.RequestURI())
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -556,7 +653,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build and redirect to auth URL
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -577,17 +674,60 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
// Ensure authURL is absolute
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
t.authURL,
params.Encode())
}
}
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
ctx, cancel := context.WithCancel(context.Background())
ticker := time.NewTicker(30 * time.Second) // Increased frequency to prevent memory buildup
go func() {
for range ticker.C {
t.logger.Debug("Cleaning up token cache")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
defer ticker.Stop()
defer cancel()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
t.logger.Debug("Starting token cleanup cycle")
// Run cleanup in a separate goroutine with timeout
cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 10*time.Second)
done := make(chan struct{})
go func() {
defer close(done)
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
}()
// Wait for cleanup to complete or timeout
select {
case <-cleanupCtx.Done():
if cleanupCtx.Err() == context.DeadlineExceeded {
t.logger.Error("Token cleanup cycle timed out")
}
case <-done:
t.logger.Debug("Token cleanup cycle completed successfully")
}
cleanupCancel()
}
}
}()
}
+270 -22
View File
@@ -13,6 +13,7 @@ import (
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -67,21 +68,29 @@ func (ts *TestSuite) Setup() {
}
// Create a test JWT token signed with the RSA private key
// Create timestamps with proper clock skew
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
}
logger := NewLogger("info")
ts.sessionManager = NewSessionManager("test-secret-key", false, logger)
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
@@ -611,7 +620,7 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
@@ -916,7 +925,7 @@ func TestHandleLogout(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
@@ -1205,7 +1214,7 @@ func TestHandleExpiredToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
sessionManager: sessionManager,
@@ -1362,23 +1371,23 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
// Create base config
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
}
// Create multiple middleware instances
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
var middlewares []*TraefikOidc
for _, route := range routes {
config.CallbackURL = route + "/callback"
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
}
@@ -1440,6 +1449,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create consistent timestamps for all test cases
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
tests := []struct {
name string
allowedRolesAndGroups map[string]struct{}
@@ -1456,11 +1471,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"admin", "user"},
"groups": []interface{}{"group1"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1480,11 +1497,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"allowed-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1505,11 +1524,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1523,11 +1544,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1545,9 +1568,11 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1622,6 +1647,112 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
}
// Helper function to compare string slices
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupServer func() *httptest.Server
expectError bool
errorContains string
}{
{
name: "Successful token exchange with redirects",
setupServer: func() *httptest.Server {
redirectCount := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 3 {
// Set a cookie before redirecting
http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
Value: "test-value",
})
redirectCount++
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
return
}
// Verify all cookies from previous redirects are present
cookies := r.Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := 0; i < 3; i++ {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
if cookie.Name == expectedName {
found = true
break
}
}
if !found {
t.Errorf("Cookie %s not found", expectedName)
}
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
expectError: false,
},
{
name: "Too many redirects",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
}))
},
expectError: true,
errorContains: "stopped after 50 redirects",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupServer()
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
}
})
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -1633,3 +1764,120 @@ func stringSliceEqual(a, b []string) bool {
}
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authURL string
issuerURL string
redirectURL string
state string
nonce string
expectedPrefix string
}{
{
name: "Absolute Auth URL",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
},
{
name: "Relative Auth URL",
authURL: "/oidc/auth",
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://logto.example.com/oidc/auth?",
},
{
name: "Relative Auth URL with Different Issuer",
authURL: "/sign-in",
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
}
// Parse the resulting URL to verify query parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
expectedParams := map[string]string{
"client_id": tOidc.clientID,
"response_type": "code",
"redirect_uri": tc.redirectURL,
"state": tc.state,
"nonce": tc.nonce,
}
for key, expected := range expectedParams {
if got := query.Get(key); got != expected {
t.Errorf("Expected %s=%q, got %q", key, expected, got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
if got := query.Get("scope"); got != expectedScopes {
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
}
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
rw := httptest.NewRecorder()
// Get session
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath()
expectedPath := "/protected/resource?param1=value1&param2=value2"
if incomingPath != expectedPath {
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
}
}
+10
View File
@@ -0,0 +1,10 @@
version: 1
force:
existing: true
wording:
patch:
- patch-release
minor:
- minor-release
major:
- breaking
+261 -121
View File
@@ -1,27 +1,40 @@
package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// generateSecureRandomString creates a cryptographically secure random string of specified length.
// It returns the generated string or an error if random generation fails.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return hex.EncodeToString(bytes), nil
}
// Cookie names and configuration constants used for session management
const (
// mainCookieName is the name of the main session cookie that stores authentication state
// and basic user information like email and CSRF tokens
mainCookieName = "_raczylo_oidc"
// accessTokenCookie is the name of the cookie that stores the OIDC access token
// This may be split into multiple cookies if the token is large
accessTokenCookie = "_raczylo_oidc_access"
// refreshTokenCookie is the name of the cookie that stores the OIDC refresh token
// This may be split into multiple cookies if the token is large
refreshTokenCookie = "_raczylo_oidc_refresh"
// Using fixed prefixes for consistent cookie naming across restarts
mainCookieName = "_oidc_raczylo_m"
accessTokenCookie = "_oidc_raczylo_a"
refreshTokenCookie = "_oidc_raczylo_r"
)
const (
// maxCookieSize is the maximum size for each cookie chunk.
// This value is calculated to ensure the final cookie size stays within browser limits:
// 1. Browser cookie size limit is typically 4096 bytes
@@ -34,20 +47,64 @@ const (
// - Solving for x: x ≤ 3044
// 4. We use 2000 as a conservative limit to account for cookie metadata
maxCookieSize = 2000
// absoluteSessionTimeout defines the maximum lifetime of a session
// regardless of activity (24 hours)
absoluteSessionTimeout = 24 * time.Hour
// minEncryptionKeyLength defines the minimum length for the encryption key
minEncryptionKeyLength = 32
)
// compressToken compresses a token using gzip and base64 encodes it.
func compressToken(token string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
if _, err := gz.Write([]byte(token)); err != nil {
return token // fallback to uncompressed on error
}
if err := gz.Close(); err != nil {
return token
}
return base64.StdEncoding.EncodeToString(b.Bytes())
}
// decompressToken decompresses a base64 encoded gzipped token.
func decompressToken(compressed string) string {
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return compressed // return as-is if not base64
}
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return compressed
}
defer gz.Close()
decompressed, err := io.ReadAll(gz)
if err != nil {
return compressed
}
return string(decompressed)
}
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
// and other session-related data across multiple cookies.
type SessionManager struct {
// store is the underlying session store for cookie management
// store is the underlying session store for cookie management.
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
// forceHTTPS enforces secure cookie attributes regardless of request scheme.
forceHTTPS bool
// logger provides structured logging capabilities
// logger provides structured logging capabilities.
logger *Logger
// sessionPool is a sync.Pool for reusing SessionData objects.
sessionPool sync.Pool
}
// NewSessionManager creates a new session manager with the specified configuration.
@@ -55,18 +112,36 @@ type SessionManager struct {
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
return &SessionManager{
//
// Returns an error if the encryption key does not meet minimum length requirements.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) {
// Validate encryption key length.
if len(encryptionKey) < minEncryptionKeyLength {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
// Initialize session pool.
sm.sessionPool.New = func() interface{} {
return &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
}
}
return sm, nil
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
// - isSecure: Whether the current request is using HTTPS.
//
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
@@ -77,7 +152,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
MaxAge: int(absoluteSessionTimeout.Seconds()),
Path: "/",
}
}
@@ -87,33 +162,48 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
mainSession, err := sm.store.Get(r, mainCookieName)
// Get session from pool.
sessionData := sm.sessionPool.Get().(*SessionData)
sessionData.request = r
var err error
sessionData.mainSession, err = sm.store.Get(r, mainCookieName)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get main session: %w", err)
}
accessSession, err := sm.store.Get(r, accessTokenCookie)
// Check for absolute session timeout.
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
sessionData.Clear(r, nil)
return nil, fmt.Errorf("session expired")
}
}
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
sessionData.refreshSession, err = sm.store.Get(r, refreshTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
sessionData := &SessionData{
manager: sm,
request: r,
mainSession: mainSession,
accessSession: accessSession,
refreshSession: refreshSession,
// Clear and reuse chunk maps.
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
for k := range sessionData.refreshTokenChunks {
delete(sessionData.refreshTokenChunks, k)
}
// Retrieve chunked access token sessions
sessionData.accessTokenChunks = sm.getTokenChunkSessions(r, accessTokenCookie)
// Retrieve chunked refresh token sessions
sessionData.refreshTokenChunks = sm.getTokenChunkSessions(r, refreshTokenCookie)
// Retrieve chunked token sessions.
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
return sessionData, nil
}
@@ -122,20 +212,16 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
// Parameters:
// - r: The HTTP request
// - baseName: The base name for the token's session cookies
// Returns a map of chunk index to session, used for handling large tokens
// that exceed single cookie size limits.
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session {
chunks := make(map[int]*sessions.Session)
// - chunks: Map to store the chunks in
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
}
return chunks
}
// SessionData holds all session information for an authenticated user.
@@ -143,27 +229,27 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
// manager is the SessionManager that created this SessionData.
manager *SessionManager
// request is the current HTTP request associated with this session
// request is the current HTTP request associated with this session.
request *http.Request
// mainSession stores authentication state and basic user info
// mainSession stores authentication state and basic user info.
mainSession *sessions.Session
// accessSession stores the primary access token cookie
// accessSession stores the primary access token cookie.
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
// refreshSession stores the primary refresh token cookie.
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
refreshTokenChunks map[int]*sessions.Session
}
@@ -174,28 +260,28 @@ type SessionData struct {
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
// Set options for all sessions.
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
// Save main session.
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
// Save access token session.
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
// Save refresh token session.
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
// Save access token chunks.
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -203,7 +289,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
}
// Save refresh token chunks
// Save refresh token chunks.
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -215,10 +301,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
// Clear and expire all sessions.
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
@@ -233,16 +317,22 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
// Clear chunk sessions.
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
return sd.Save(r, w)
var err error
if w != nil {
err = sd.Save(r, w)
}
// Return session to pool.
sd.manager.sessionPool.Put(sd)
return err
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
@@ -253,30 +343,47 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication,
// false otherwise or if the authentication status cannot be determined.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
if !auth {
return false
}
// Check session expiration.
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
}
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
}
// SetAuthenticated updates the session's authentication status.
// This should be called after successful OIDC authentication or during logout.
func (sd *SessionData) SetAuthenticated(value bool) {
// SetAuthenticated updates the session's authentication status and rotates session ID.
// Returns an error if generating a new session ID fails.
func (sd *SessionData) SetAuthenticated(value bool) error {
if value {
id, err := generateSecureRandomString(32)
if err != nil {
return fmt.Errorf("failed to generate secure session id: %w", err)
}
sd.mainSession.ID = id
sd.mainSession.Values["created_at"] = time.Now().Unix()
}
sd.mainSession.Values["authenticated"] = value
return nil
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
compressed, _ := sd.accessSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.accessTokenChunks) == 0 {
return ""
}
@@ -291,25 +398,35 @@ func (sd *SessionData) GetAccessToken() string {
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
token = strings.Join(chunks, "")
compressed, _ := sd.accessSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetAccessToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token.
sd.accessTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.accessSession.Values["token"] = token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = true
} else {
// Split token into chunks
// Split compressed token into chunks.
sd.accessSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
sd.accessSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
@@ -320,16 +437,17 @@ func (sd *SessionData) SetAccessToken(token string) {
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.refreshTokenChunks) == 0 {
return ""
}
@@ -344,25 +462,35 @@ func (sd *SessionData) GetRefreshToken() string {
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
token = strings.Join(chunks, "")
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetRefreshToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token.
sd.refreshTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.refreshSession.Values["token"] = token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = true
} else {
// Split token into chunks
// Split compressed token into chunks.
sd.refreshSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
sd.refreshSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
@@ -372,12 +500,43 @@ func (sd *SessionData) SetRefreshToken(token string) {
}
}
// expireAccessTokenChunks expires any existing access token chunk cookies.
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err)
}
}
}
}
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err)
}
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
@@ -393,64 +552,45 @@ func splitIntoChunks(s string, chunkSize int) []string {
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
// SetIncomingPath stores the original request path that triggered the authentication flow.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+344 -91
View File
@@ -1,129 +1,382 @@
package traefikoidc
import (
"math/rand"
"net/http/httptest"
"strings"
"testing"
)
// generateRandomString creates a random string of specified length
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// TestTokenCompression tests the token compression functionality
func TestTokenCompression(t *testing.T) {
tests := []struct {
name string
token string
wantSize int // Expected size after compression (approximate)
}{
{
name: "Short token",
token: "shorttoken",
wantSize: 50, // Base64 encoded gzip has overhead for small content
},
{
name: "Repeating content",
token: strings.Repeat("abcdef", 1000),
wantSize: 100, // Should compress well due to repetition
},
{
name: "Random content",
token: generateRandomString(1000),
wantSize: 2000, // Random content won't compress much
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compressed := compressToken(tt.token)
decompressed := decompressToken(compressed)
// Only verify compression ratio for non-short tokens
if len(tt.token) > 100 {
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
if compressionRatio > 1.1 { // Allow up to 10% size increase
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
len(tt.token), len(compressed), compressionRatio)
}
}
// Verify decompression restores original
if decompressed != tt.token {
t.Error("Decompression failed to restore original token")
}
// Verify approximate compression ratio
if len(compressed) > tt.wantSize*2 {
t.Errorf("Compression ratio worse than expected: got=%d, want<%d", len(compressed), tt.wantSize*2)
}
})
}
}
// TestSessionManager tests the SessionManager functionality
func TestCookiePrefix(t *testing.T) {
// Create a session and verify cookie names
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set some data to ensure cookies are created
session.SetAuthenticated(true)
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Set new tokens
session.SetAccessToken("test_token")
session.SetRefreshToken("test_refresh_token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Check cookie prefixes
cookies := rr.Result().Cookies()
for _, cookie := range cookies {
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
}
}
}
func TestTokenRefreshCleanup(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set a large token that will be split into chunks
largeToken := strings.Repeat("x", 5000)
session.SetAccessToken(largeToken)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get initial cookies
initialCookies := rr.Result().Cookies()
// Create a new request with the initial cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range initialCookies {
newReq.AddCookie(cookie)
}
newRr := httptest.NewRecorder()
// Get session with cookies and set a new token
newSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Create a response recorder for expired cookies
expiredRr := httptest.NewRecorder()
// Expire old chunk cookies
newSession.expireAccessTokenChunks(expiredRr)
// Set a smaller token that won't need chunks
newSession.SetAccessToken("small_token")
// Save session with new token
if err := newSession.Save(newReq, newRr); err != nil {
t.Fatalf("Failed to save new session: %v", err)
}
// Check cookies in response where old cookies are expired
intermediateResponse := expiredRr.Result()
intermediateCount := 0
chunkCount := 0
expiredCount := 0
for _, cookie := range intermediateResponse.Cookies() {
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
chunkCount++
if cookie.MaxAge < 0 {
expiredCount++
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
} else if cookie.MaxAge >= 0 {
intermediateCount++
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
}
// All chunk cookies should be expired
if chunkCount > 0 && chunkCount != expiredCount {
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
}
// Should have fewer active cookies after setting smaller token
if intermediateCount >= len(initialCookies) {
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
}
}
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
expectedCookieCount int
name string
authenticated bool
email string
accessToken string
refreshToken string
expectedCookieCount int
wantCompressed bool // Whether tokens should be compressed
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
// Recalculate expected cookies based on new maxCookieSize
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
},
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: true,
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
wantCompressed: true,
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
wantCompressed: true,
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: false,
},
{
name: "Random content tokens",
authenticated: true,
email: "test@example.com",
accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000),
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
wantCompressed: true,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Verify cookies are set
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Set new tokens
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
// Verify compression is working by checking token sizes
for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes
originalSize := len(tc.accessToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved")
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved")
}
})
}
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
}
}
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
count := 3 // main, access, refresh
// Calculate number of chunks for access token
accessChunks := len(splitIntoChunks(accessToken, maxCookieSize))
if accessChunks > 1 {
count += accessChunks
// Helper to calculate chunks for compressed token
calculateChunks := func(token string) int {
// Compress token (matching the actual implementation)
compressed := compressToken(token)
// If compressed token fits in one cookie, no additional chunks needed
if len(compressed) <= maxCookieSize {
return 0
}
// Calculate chunks needed for compressed token
return len(splitIntoChunks(compressed, maxCookieSize))
}
// Calculate number of chunks for refresh token
refreshChunks := len(splitIntoChunks(refreshToken, maxCookieSize))
if refreshChunks > 1 {
count += refreshChunks
// Add chunks for access token if needed
accessChunks := calculateChunks(accessToken)
if accessChunks > 0 {
count += accessChunks
}
// Add chunks for refresh token if needed
refreshChunks := calculateChunks(refreshToken)
if refreshChunks > 0 {
count += refreshChunks
}
return count
}
}
+75 -32
View File
@@ -10,10 +10,6 @@ import (
"strings"
)
const (
cookieName = "_raczylo_oidc"
)
// Config holds the configuration for the OIDC middleware.
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
@@ -85,30 +81,34 @@ type Config struct {
HTTPClient *http.Client
}
// CreateConfig creates a new Config with sensible default values.
const (
// DefaultRateLimit defines the default rate limit for requests per second
DefaultRateLimit = 100
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
MinRateLimit = 10
// DefaultLogLevel defines the default logging level
DefaultLogLevel = "info"
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
MinSessionEncryptionKeyLength = 32
)
// CreateConfig creates a new Config with secure default values.
// Default values are set for optional fields:
// - Scopes: ["openid", "profile", "email"]
// - LogLevel: "info"
// - LogoutURL: CallbackURL + "/logout"
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
func CreateConfig() *Config {
c := &Config{}
if c.Scopes == nil {
c.Scopes = []string{"openid", "profile", "email"}
}
if c.LogLevel == "" {
c.LogLevel = "info"
}
if c.LogoutURL == "" {
c.LogoutURL = c.CallbackURL + "/logout"
}
if c.RateLimit == 0 {
c.RateLimit = 100
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
}
return c
@@ -118,43 +118,85 @@ func CreateConfig() *Config {
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
func (c *Config) Validate() error {
// Validate provider URL
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if !isValidURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid URL")
if !isValidSecureURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid HTTPS URL")
}
// Validate callback URL
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if !strings.HasPrefix(c.CallbackURL, "/") {
return fmt.Errorf("callbackURL must start with /")
}
// Validate client credentials
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required")
}
// Validate session encryption key
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
if len(c.SessionEncryptionKey) < 32 {
return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long")
}
if c.RateLimit < 0 {
return fmt.Errorf("rateLimit must be non-negative")
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
}
// Validate log level
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
return fmt.Errorf("logLevel must be one of: debug, info, error")
}
// Validate excluded URLs
for _, url := range c.ExcludedURLs {
if !strings.HasPrefix(url, "/") {
return fmt.Errorf("excluded URL must start with /: %s", url)
}
if strings.Contains(url, "..") {
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
}
if strings.Contains(url, "*") {
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
}
}
// Validate revocation URL if set
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
}
// Validate end session URL if set
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
}
// Validate post-logout redirect URI if set
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
}
}
// Validate rate limit
if c.RateLimit < MinRateLimit {
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
}
return nil
}
// isValidURL checks if the provided string is a valid URL
func isValidURL(s string) bool {
// isValidSecureURL checks if the provided string is a valid HTTPS URL
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme != "" && u.Host != ""
return err == nil && u.Scheme == "https" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
@@ -179,6 +221,7 @@ type Logger struct {
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
//
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
func NewLogger(logLevel string) *Logger {
@@ -187,7 +230,7 @@ func NewLogger(logLevel string) *Logger {
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
+49 -14
View File
@@ -23,13 +23,18 @@ func TestCreateConfig(t *testing.T) {
}
// Check default log level
if config.LogLevel != "info" {
t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel)
if config.LogLevel != DefaultLogLevel {
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
}
// Check default rate limit
if config.RateLimit != 100 {
t.Errorf("Expected default rate limit 100, got %d", config.RateLimit)
if config.RateLimit != DefaultRateLimit {
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
}
// Check ForceHTTPS default
if !config.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true by default")
}
})
@@ -38,6 +43,7 @@ func TestCreateConfig(t *testing.T) {
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
config.ForceHTTPS = false
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
@@ -49,6 +55,9 @@ func TestCreateConfig(t *testing.T) {
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
if config.ForceHTTPS {
t.Error("Custom ForceHTTPS value was overwritten")
}
})
}
@@ -98,15 +107,15 @@ func TestConfigValidate(t *testing.T) {
expectedError: "sessionEncryptionKey is required",
},
{
name: "Invalid ProviderURL",
name: "Non-HTTPS ProviderURL",
config: &Config{
ProviderURL: "not-a-url",
ProviderURL: "http://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid URL",
expectedError: "providerURL must be a valid HTTPS URL",
},
{
name: "Invalid CallbackURL",
@@ -131,16 +140,16 @@ func TestConfigValidate(t *testing.T) {
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Negative RateLimit",
name: "Low RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: -1,
RateLimit: 5,
},
expectedError: "rateLimit must be non-negative",
expectedError: "rateLimit must be at least 10",
},
{
name: "Invalid LogLevel",
@@ -154,6 +163,30 @@ func TestConfigValidate(t *testing.T) {
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Non-HTTPS RevocationURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RevocationURL: "http://revoke.com",
},
expectedError: "revocationURL must be a valid HTTPS URL",
},
{
name: "Non-HTTPS OIDCEndSessionURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
OIDCEndSessionURL: "http://endsession.com",
},
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
},
{
name: "Valid Config",
config: &Config{
@@ -164,6 +197,8 @@ func TestConfigValidate(t *testing.T) {
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
RevocationURL: "https://revoke.com",
OIDCEndSessionURL: "https://endsession.com",
},
expectedError: "",
},
@@ -192,9 +227,9 @@ func TestLogger(t *testing.T) {
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
@@ -289,7 +324,7 @@ func TestLogger(t *testing.T) {
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}