Compare commits

..

9 Commits

Author SHA1 Message Date
lukaszraczylo fdb8e3233e Testing (could be unstable) additional headers.
This adds additional headers to control the access origin and control allow headers.
2025-02-06 23:46:08 +00:00
lukaszraczylo 33c71fd6fe Enhance test suite. 2025-02-06 23:38:22 +00:00
lukaszraczylo 241cb1c209 Deal with the memory growth issue.
* TokenBlacklist limit is set to 1000
* Increased token cleanup frequency
2025-02-06 23:34:05 +00:00
lukaszraczylo 09daa1025c Follow multiple redirects during the OIDC flow. 2025-02-06 23:31:13 +00:00
lukaszraczylo c09e7a9228 Add additional test cases to cover it. 2025-02-06 21:50:35 +00:00
lukaszraczylo e5da5d4fe9 Fix redirection to the provider when session expires 2025-02-06 21:48:56 +00:00
lukaszraczylo 31db701dda Trigger build and release. 2025-02-05 19:04:44 +00:00
lukaszraczylo 16481afd36 Add todo: Improve test coverage. 2025-02-01 12:20:01 +00:00
lukaszraczylo 751933ffa0 Multiple improvements.
* Add todo list.

* fixup! Add todo list.

* fixup! fixup! Add todo list.

* fixup! fixup! fixup! Add todo list.

* Improve the session handling and cache.

* Fix an issue where expired session can cause infinite redirect loop

* fixup! Fix an issue where expired session can cause infinite redirect loop

* Add semver setup for automatic releases.

* fixup! Add semver setup for automatic releases.

* fixup! fixup! Add semver setup for automatic releases.

* fixup! fixup! fixup! Add semver setup for automatic releases.
2025-02-01 12:16:50 +00:00
9 changed files with 757 additions and 290 deletions
+5
View File
@@ -0,0 +1,5 @@
### TODO / wishlist
- [] Improve test coverage
- [x] Improve caching mechanism
- [x] Add automatic release and semver generation
+80 -99
View File
@@ -1,172 +1,153 @@
package traefikoidc
import (
"container/list"
"sync"
"time"
)
// CacheItem represents an item stored in the cache with its associated metadata.
type CacheItem struct {
// Value is the cached data of any type
// Value is the cached data of any type.
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
// ExpiresAt is the timestamp when this item should be considered expired.
ExpiresAt time.Time
}
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache
maxSize int
// accessList maintains the order of item access for eviction
accessList []string
// lruEntry represents an entry in the LRU list.
type lruEntry struct {
key string
}
// DefaultMaxSize is the default maximum number of items in the cache
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
type Cache struct {
// items stores the cached data with string keys.
items map[string]CacheItem
// order maintains the usage order; most recently used items are at the back.
order *list.List
// elems maps keys to their corresponding list elements for O(1) access.
elems map[string]*list.Element
// mutex protects concurrent access to the cache.
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache.
maxSize int
}
// DefaultMaxSize is the default maximum number of items in the cache.
const DefaultMaxSize = 1000
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
// NewCache creates a new empty cache instance that is ready for use.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
maxSize: DefaultMaxSize,
accessList: make([]string, 0, DefaultMaxSize),
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
}
}
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
//
// Thread-safe: Uses write locking to ensure safe concurrent access.
// It moves the item to the most recently used position.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// If key exists, update it
now := time.Now()
expTime := now.Add(expiration)
// Update existing item.
if _, exists := c.items[key]; exists {
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: expTime,
}
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return
}
// If cache is full, remove oldest item
// Evict oldest item if cache is full.
if len(c.items) >= c.maxSize {
c.evictOldest()
}
// Add new item
// Add new item.
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: expTime,
}
c.accessList = append(c.accessList, key)
elem := c.order.PushBack(lruEntry{key: key})
c.elems[key] = elem
}
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
//
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
//
// Thread-safe: Uses read locking to ensure safe concurrent access.
// Moving the accessed item to the most recently used position.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
item, found := c.items[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
if time.Now().After(item.ExpiresAt) {
c.mutex.Lock()
c.removeItem(key)
c.mutex.Unlock()
return nil, false
}
// Update access order
c.mutex.Lock()
c.updateAccessOrder(key)
c.mutex.Unlock()
defer c.mutex.Unlock()
item, exists := c.items[key]
if !exists {
return nil, false
}
// Check for expiration.
if time.Now().After(item.ExpiresAt) {
c.removeItem(key)
return nil, false
}
// Move item to the back (most recently used).
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return item.Value, true
}
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Delete removes an item from the cache.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.removeItem(key)
}
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Cleanup removes all expired items from the cache. This should be called periodically
// to prevent memory bloat from expired entries.
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
var newAccessList []string
for _, key := range c.accessList {
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
newAccessList = append(newAccessList, key)
} else {
delete(c.items, key)
for key, item := range c.items {
if now.After(item.ExpiresAt) {
c.removeItem(key)
}
}
c.accessList = newAccessList
}
// evictOldest removes the least recently used item from the cache
// evictOldest removes the least recently used item from the cache.
func (c *Cache) evictOldest() {
if len(c.accessList) > 0 {
oldest := c.accessList[0]
c.removeItem(oldest)
elem := c.order.Front()
if elem != nil {
entry := elem.Value.(lruEntry)
c.removeItem(entry.key)
}
}
// removeItem removes an item from both the cache and access list
// removeItem removes an item from both the cache and the LRU tracking structures.
func (c *Cache) removeItem(key string) {
delete(c.items, key)
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
}
// updateAccessOrder moves the accessed key to the end of the access list
func (c *Cache) updateAccessOrder(key string) {
for i, k := range c.accessList {
if k == key {
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
break
}
if elem, ok := c.elems[key]; ok {
c.order.Remove(elem)
delete(c.elems, key)
}
}
+45 -21
View File
@@ -8,32 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
//
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
@@ -88,13 +69,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
data.Set("refresh_token", codeOrToken)
}
// Create a cookie jar for this request to handle redirects with cookies
jar, _ := cookiejar.New(nil)
client := &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
@@ -296,12 +292,16 @@ type TokenBlacklist struct {
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
// maxSize is the maximum number of tokens in the blacklist
maxSize int
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
maxSize: 1000, // Limit the size to prevent unbounded growth
}
}
@@ -309,6 +309,30 @@ func NewTokenBlacklist() *TokenBlacklist {
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
// Clean up expired tokens if we're at capacity
if len(tb.blacklist) >= tb.maxSize {
now := time.Now()
for token, exp := range tb.blacklist {
if now.After(exp) {
delete(tb.blacklist, token)
}
}
// If still at capacity after cleanup, remove oldest token
if len(tb.blacklist) >= tb.maxSize {
var oldestToken string
var oldestTime time.Time
first := true
for token, exp := range tb.blacklist {
if first || exp.Before(oldestTime) {
oldestToken = token
oldestTime = exp
first = false
}
}
delete(tb.blacklist, oldestToken)
}
}
tb.blacklist[tokenID] = expiration
}
+225
View File
@@ -0,0 +1,225 @@
package traefikoidc
import (
"fmt"
"runtime"
"testing"
"time"
)
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
}
// Verify size is at max
if len(tb.blacklist) != 1000 {
t.Errorf("Expected blacklist size to be 1000, got %d", len(tb.blacklist))
}
// Add one more token, should trigger cleanup/eviction
tb.Add("newtoken", time.Now().Add(time.Hour))
// Size should still be at max
if len(tb.blacklist) > 1000 {
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
}
}
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
}
// Force cleanup
tb.Cleanup()
// Only valid tokens should remain
if len(tb.blacklist) != 500 {
t.Errorf("Expected 500 valid tokens after cleanup, got %d", len(tb.blacklist))
}
// Verify only valid tokens remain
for token, expiry := range tb.blacklist {
if time.Now().After(expiry) {
t.Errorf("Found expired token after cleanup: %s", token)
}
}
}
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
}
// Add a new token that should evict the oldest
newToken := "newest"
tb.Add(newToken, baseTime.Add(time.Hour*3))
// Verify oldest token was evicted
if tb.IsBlacklisted(oldestToken) {
t.Error("Oldest token should have been evicted")
}
// Verify newest token is present
if !tb.IsBlacklisted(newToken) {
t.Error("Newest token should be present")
}
}
func TestTokenBlacklistMemoryUsage(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy usage
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
}
// Verify size stayed within limits
if len(tb.blacklist) > tb.maxSize {
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
}
}
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 1000
concurrency := 10
done := make(chan bool)
// Start multiple goroutines performing operations
for i := 0; i < concurrency; i++ {
go func(id int) {
for j := 0; j < iterations; j++ {
// Add tokens
token := fmt.Sprintf("token%d-%d", id, j)
tb.Add(token, time.Now().Add(time.Hour))
// Check blacklist status
tb.IsBlacklisted(token)
// Periodic cleanup
if j%100 == 0 {
tb.Cleanup()
}
}
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < concurrency; i++ {
<-done
}
// Verify size constraints were maintained
if len(tb.blacklist) > tb.maxSize {
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", len(tb.blacklist))
}
}
func TestTokenCacheMemoryUsage(t *testing.T) {
tc := NewTokenCache()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy cache usage
for i := 0; i < iterations; i++ {
claims := map[string]interface{}{
"sub": fmt.Sprintf("user%d", i),
"exp": time.Now().Add(time.Hour).Unix(),
}
// Add to cache
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
// Periodically retrieve
if i%100 == 0 {
tc.Get(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tc.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive cache memory growth: %d bytes", memoryGrowth)
}
// Verify cache size stayed within limits
if len(tc.cache.items) > tc.cache.maxSize {
t.Errorf("Cache exceeded max size: %d", len(tc.cache.items))
}
}
+91 -16
View File
@@ -62,7 +62,6 @@ type TraefikOidc struct {
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
@@ -82,8 +81,6 @@ var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -228,6 +225,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
}
@@ -264,7 +268,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// Assign the initialized logger
t.logger = logger
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
@@ -512,7 +516,35 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Set OIDC-specific headers
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetAccessToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
// Set security headers
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
rw.Header().Set("X-XSS-Protection", "1; mode=block")
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Set CORS headers
origin := req.Header.Get("Origin")
if origin != "" {
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Handle preflight requests
if req.Method == "OPTIONS" {
rw.WriteHeader(http.StatusOK)
return
}
}
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -531,9 +563,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
@@ -602,14 +631,17 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.New().String()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Set session values
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.RequestURI())
@@ -621,7 +653,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build and redirect to auth URL
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -642,17 +674,60 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
// Ensure authURL is absolute
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
t.authURL,
params.Encode())
}
}
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
ctx, cancel := context.WithCancel(context.Background())
ticker := time.NewTicker(30 * time.Second) // Increased frequency to prevent memory buildup
go func() {
for range ticker.C {
t.logger.Debug("Cleaning up token cache")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
defer ticker.Stop()
defer cancel()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
t.logger.Debug("Starting token cleanup cycle")
// Run cleanup in a separate goroutine with timeout
cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 10*time.Second)
done := make(chan struct{})
go func() {
defer close(done)
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
}()
// Wait for cleanup to complete or timeout
select {
case <-cleanupCtx.Done():
if cleanupCtx.Err() == context.DeadlineExceeded {
t.logger.Error("Token cleanup cycle timed out")
}
case <-done:
t.logger.Debug("Token cleanup cycle completed successfully")
}
cleanupCancel()
}
}
}()
}
+200 -4
View File
@@ -13,6 +13,7 @@ import (
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -89,7 +90,7 @@ func (ts *TestSuite) Setup() {
}
logger := NewLogger("info")
ts.sessionManager = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
@@ -619,7 +620,7 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
@@ -924,7 +925,7 @@ func TestHandleLogout(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
@@ -1213,7 +1214,7 @@ func TestHandleExpiredToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
sessionManager: sessionManager,
@@ -1647,6 +1648,111 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
// Helper function to compare string slices
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupServer func() *httptest.Server
expectError bool
errorContains string
}{
{
name: "Successful token exchange with redirects",
setupServer: func() *httptest.Server {
redirectCount := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 3 {
// Set a cookie before redirecting
http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
Value: "test-value",
})
redirectCount++
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
return
}
// Verify all cookies from previous redirects are present
cookies := r.Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := 0; i < 3; i++ {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
if cookie.Name == expectedName {
found = true
break
}
}
if !found {
t.Errorf("Cookie %s not found", expectedName)
}
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
expectError: false,
},
{
name: "Too many redirects",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
}))
},
expectError: true,
errorContains: "stopped after 50 redirects",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupServer()
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
}
})
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -1659,6 +1765,96 @@ func stringSliceEqual(a, b []string) bool {
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authURL string
issuerURL string
redirectURL string
state string
nonce string
expectedPrefix string
}{
{
name: "Absolute Auth URL",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
},
{
name: "Relative Auth URL",
authURL: "/oidc/auth",
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://logto.example.com/oidc/auth?",
},
{
name: "Relative Auth URL with Different Issuer",
authURL: "/sign-in",
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
}
// Parse the resulting URL to verify query parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
expectedParams := map[string]string{
"client_id": tOidc.clientID,
"response_type": "code",
"redirect_uri": tc.redirectURL,
"state": tc.state,
"nonce": tc.nonce,
}
for key, expected := range expectedParams {
if got := query.Get(key); got != expected {
t.Errorf("Expected %s=%q, got %q", key, expected, got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
if got := query.Get("scope"); got != expectedScopes {
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
}
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
+10
View File
@@ -0,0 +1,10 @@
version: 1
force:
existing: true
wording:
patch:
- patch-release
minor:
- minor-release
major:
- breaking
+99 -142
View File
@@ -16,13 +16,14 @@ import (
"github.com/gorilla/sessions"
)
// generateSecureRandomString creates a cryptographically secure random string of specified length
func generateSecureRandomString(length int) string {
// generateSecureRandomString creates a cryptographically secure random string of specified length.
// It returns the generated string or an error if random generation fails.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
panic("failed to generate random string")
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return hex.EncodeToString(bytes)
return hex.EncodeToString(bytes), nil
}
// Cookie names and configuration constants used for session management
@@ -55,7 +56,7 @@ const (
minEncryptionKeyLength = 32
)
// compressToken compresses a token using gzip and base64 encodes it
// compressToken compresses a token using gzip and base64 encodes it.
func compressToken(token string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
@@ -68,7 +69,7 @@ func compressToken(token string) string {
return base64.StdEncoding.EncodeToString(b.Bytes())
}
// decompressToken decompresses a base64 encoded gzipped token
// decompressToken decompresses a base64 encoded gzipped token.
func decompressToken(compressed string) string {
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
@@ -91,18 +92,18 @@ func decompressToken(compressed string) string {
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
// and other session-related data across multiple cookies.
type SessionManager struct {
// store is the underlying session store for cookie management
// store is the underlying session store for cookie management.
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
// forceHTTPS enforces secure cookie attributes regardless of request scheme.
forceHTTPS bool
// logger provides structured logging capabilities
// logger provides structured logging capabilities.
logger *Logger
// sessionPool is a sync.Pool for reusing SessionData objects
// sessionPool is a sync.Pool for reusing SessionData objects.
sessionPool sync.Pool
}
@@ -112,11 +113,11 @@ type SessionManager struct {
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
//
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
// Validate encryption key length
// Returns an error if the encryption key does not meet minimum length requirements.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) {
// Validate encryption key length.
if len(encryptionKey) < minEncryptionKeyLength {
panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength))
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
sm := &SessionManager{
@@ -125,7 +126,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
logger: logger,
}
// Initialize session pool
// Initialize session pool.
sm.sessionPool.New = func() interface{} {
return &SessionData{
manager: sm,
@@ -134,12 +135,12 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
}
}
return sm
return sm, nil
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
// - isSecure: Whether the current request is using HTTPS.
//
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
@@ -161,7 +162,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
// Get session from pool
// Get session from pool.
sessionData := sm.sessionPool.Get().(*SessionData)
sessionData.request = r
@@ -172,11 +173,10 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
// Check for absolute session timeout
// Check for absolute session timeout.
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
// Session has expired
sm.sessionPool.Put(sessionData)
sessionData.Clear(r, nil)
return nil, fmt.Errorf("session expired")
}
}
@@ -193,7 +193,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
// Clear and reuse chunk maps
// Clear and reuse chunk maps.
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
@@ -201,7 +201,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
delete(sessionData.refreshTokenChunks, k)
}
// Retrieve chunked token sessions
// Retrieve chunked token sessions.
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
@@ -218,7 +218,6 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
@@ -230,27 +229,27 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
// manager is the SessionManager that created this SessionData.
manager *SessionManager
// request is the current HTTP request associated with this session
// request is the current HTTP request associated with this session.
request *http.Request
// mainSession stores authentication state and basic user info
// mainSession stores authentication state and basic user info.
mainSession *sessions.Session
// accessSession stores the primary access token cookie
// accessSession stores the primary access token cookie.
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
// refreshSession stores the primary refresh token cookie.
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
refreshTokenChunks map[int]*sessions.Session
}
@@ -261,28 +260,28 @@ type SessionData struct {
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
// Set options for all sessions.
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
// Save main session.
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
// Save access token session.
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
// Save refresh token session.
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
// Save access token chunks.
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -290,7 +289,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
}
// Save refresh token chunks
// Save refresh token chunks.
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -302,10 +301,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
// Clear and expire all sessions.
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
@@ -320,7 +317,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
// Clear chunk sessions.
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
@@ -329,15 +326,13 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
err = sd.Save(r, w)
}
// Return session to pool
// Return session to pool.
sd.manager.sessionPool.Put(sd)
return err
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
@@ -348,15 +343,13 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication
// and the session hasn't expired, false otherwise.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
if !auth {
return false
}
// Check session expiration
// Check session expiration.
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
@@ -365,21 +358,21 @@ func (sd *SessionData) GetAuthenticated() bool {
}
// SetAuthenticated updates the session's authentication status and rotates session ID.
// This should be called after successful OIDC authentication or during logout.
// Session ID rotation helps prevent session fixation attacks.
func (sd *SessionData) SetAuthenticated(value bool) {
// Returns an error if generating a new session ID fails.
func (sd *SessionData) SetAuthenticated(value bool) error {
if value {
// Generate new session ID and set creation time
sd.mainSession.ID = generateSecureRandomString(32)
id, err := generateSecureRandomString(32)
if err != nil {
return fmt.Errorf("failed to generate secure session id: %w", err)
}
sd.mainSession.ID = id
sd.mainSession.Values["created_at"] = time.Now().Unix()
}
sd.mainSession.Values["authenticated"] = value
return nil
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
@@ -390,7 +383,7 @@ func (sd *SessionData) GetAccessToken() string {
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.accessTokenChunks) == 0 {
return ""
}
@@ -414,45 +407,23 @@ func (sd *SessionData) GetAccessToken() string {
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
// expireAccessTokenChunks expires any existing access token chunk cookies
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
// Expire the cookie
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
// Save expired cookie
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
}
}
}
func (sd *SessionData) SetAccessToken(token string) {
// Expire any existing chunk cookies first
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token
// Clear and prepare chunks map for new token.
sd.accessTokenChunks = make(map[int]*sessions.Session)
// Compress token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
// Split compressed token into chunks.
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
@@ -466,9 +437,6 @@ func (sd *SessionData) SetAccessToken(token string) {
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
@@ -479,7 +447,7 @@ func (sd *SessionData) GetRefreshToken() string {
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.refreshTokenChunks) == 0 {
return ""
}
@@ -503,45 +471,23 @@ func (sd *SessionData) GetRefreshToken() string {
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
// expireRefreshTokenChunks expires any existing refresh token chunk cookies
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
// Expire the cookie
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
// Save expired cookie
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
}
}
}
func (sd *SessionData) SetRefreshToken(token string) {
// Expire any existing chunk cookies first
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token
// Clear and prepare chunks map for new token.
sd.refreshTokenChunks = make(map[int]*sessions.Session)
// Compress token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
// Split compressed token into chunks.
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
@@ -554,13 +500,43 @@ func (sd *SessionData) SetRefreshToken(token string) {
}
}
// expireAccessTokenChunks expires any existing access token chunk cookies.
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err)
}
}
}
}
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err)
}
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
//
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
@@ -576,64 +552,45 @@ func splitIntoChunks(s string, chunkSize int) []string {
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
// SetIncomingPath stores the original request path that triggered the authentication flow.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+2 -8
View File
@@ -5,14 +5,8 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
func init() {
// Initialize random seed
rand.Seed(time.Now().UnixNano())
}
// generateRandomString creates a random string of specified length
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
@@ -83,7 +77,7 @@ func TestCookiePrefix(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
@@ -117,7 +111,7 @@ func TestTokenRefreshCleanup(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)