Files
traefikoidc/session.go
lukaszraczylo 1b49e133da Complete rebuild of the plugin
* Fix bug affecting Azure OIDC authentication ( and most likely others )

* Fixes issue #51

* Ensure that appended roles are unique. Update the documentation.

* Improvements targetting possible memory usage spikes.

* Additional fixes and cleanup

* Refactoring code to fix the issues identified by the users.

* Modernize run

* Fieldalignment

* Multiple changes to improve performance and reduce complexity.
- Optimise the errors and recovery.
- Deduplicate code in metadata cache.
- Remove unused performance monitoring code.
- Simplify session management and settings handling.

* Fix claims issue.

* Add ability to overwrite the default scopes in the settings file

* Well.. that escalated quickly.

Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ).

* Bugfix #51: Ensures that user provided scopes overrides work.

* fixup! Bugfix #51: Ensures that user provided scopes overrides work.

* fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work.

* Abstract the provider logic into a separate package.

* Additional micro fixes and cleanups.

* Simplify all the things.

* fixup! Simplify all the things.

* fixup! fixup! Simplify all the things.

* fixup! fixup! fixup! Simplify all the things.

* fixup! fixup! fixup! fixup! Simplify all the things.

* ...

* Cleanup tests.

* fixup! Cleanup tests.

* fixup! fixup! fixup! Cleanup tests.

* fixup! fixup! fixup! fixup! Cleanup tests.

* fixup! fixup! fixup! fixup! fixup! Cleanup tests.

* Issue #53: Fix CSRF token handling in reverse proxy

1.  HTTPS Detection Fixed (session.go:723)
- Now uses X-Forwarded-Proto header instead of r.URL.Scheme
- Properly detects HTTPS in reverse proxy environments
2.  SameSite Cookie Attribute Fixed
- Removed automatic SameSiteStrictMode for HTTPS (would break OAuth)
- Keeps SameSiteLaxMode to allow OAuth callbacks from external domains
- Only uses Strict for AJAX requests which don't involve OAuth redirects
3.  Cookie Domain Handling Fixed
- Now respects X-Forwarded-Host header for cookie domain
- Ensures cookies are set for the public domain, not internal proxy domain
4.  EnhanceSessionSecurity Properly Integrated
- Function is now actually called during session save
- Applies security enhancements without breaking OAuth flow

Why Issue #53 Failed Before:

1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back)
2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail
3. Cookie domain might have been wrong (internal vs public domain)

Why It Works Now:

1. Cookies are properly marked Secure for HTTPS
2. Uses SameSite=Lax to allow OAuth provider callbacks
3. Cookie domain uses public domain from X-Forwarded-Host
4. CSRF token persists through the entire OAuth flow

* Next set of enhancements together with memory usage improvements.

* Memory leak fixes and optimisations.

* CSRF and Cookie Domain fixes

* fixup! CSRF and Cookie Domain fixes

* Metadata cache leak fix + profiling

* fixup! Metadata cache leak fix + profiling

* Memory leaks hunting, part 1337.

* Further pursue of perfection.

* fixup! Further pursue of perfection.

* fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* Clear race conditions

* fixup! Clear race conditions

* Weekend fun with memory leaks

* Splitting code into multiple files with reasonable testing coverage.

```
ok      github.com/lukaszraczylo/traefikoidc    117.017s        coverage: 72.6% of statements
ok      github.com/lukaszraczylo/traefikoidc/auth       0.505s  coverage: 87.1% of statements
ok      github.com/lukaszraczylo/traefikoidc/circuit_breaker    0.283s  coverage: 99.0% of statements
        github.com/lukaszraczylo/traefikoidc/config             coverage: 0.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/handlers   0.349s  coverage: 98.2% of statements
ok      github.com/lukaszraczylo/traefikoidc/internal/providers (cached)        coverage: 94.3% of statements
ok      github.com/lukaszraczylo/traefikoidc/middleware 0.808s  coverage: 78.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/recovery   0.653s  coverage: 100.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/chunking   (cached)        coverage: 87.8% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/core       (cached)        coverage: 85.6% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/crypto     (cached)        coverage: 81.8% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/storage    (cached)        coverage: 93.5% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/validators (cached)        coverage: 98.8% of statements
````

* fixup! Splitting code into multiple files with reasonable testing coverage.

* fixup! fixup! Splitting code into multiple files with reasonable testing coverage.

* Weekend fun with further optimisations.

* fixup! Weekend fun with further optimisations.

* fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations.

* Pre-release cleanup.

* Enhance test coverage.

* fixup! Enhance test coverage.

* fixup! fixup! Enhance test coverage.

* fixup! fixup! fixup! Enhance test coverage.
2025-09-18 11:01:30 +01:00

2147 lines
65 KiB
Go

package traefikoidc
import (
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/sessions"
)
// min returns the minimum of two integers.
// This is a utility function used throughout the session management code.
// Parameters:
// - a: The first integer to compare.
// - b: The second integer to compare.
//
// Returns:
// - The smaller of the two integers.
func min(a, b int) int {
if a < b {
return a
}
return b
}
// generateSecureRandomString creates a cryptographically secure random string.
// It generates random bytes using crypto/rand and encodes them as hexadecimal.
// This is used for session IDs and other security-sensitive random values.
// Parameters:
// - length: The number of random bytes to generate (output will be 2x this length in hex).
//
// Returns:
// - The hex-encoded random string.
// - An error if reading random bytes fails.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return hex.EncodeToString(bytes), nil
}
// Cookie names and configuration constants used for session management
const (
mainCookieName = "_oidc_raczylo_m"
accessTokenCookie = "_oidc_raczylo_a"
refreshTokenCookie = "_oidc_raczylo_r"
idTokenCookie = "_oidc_raczylo_id"
)
const (
maxBrowserCookieSize = 3500
maxCookieSize = 1200
absoluteSessionTimeout = 24 * time.Hour
minEncryptionKeyLength = 32
)
// compressToken compresses a JWT token using gzip compression if beneficial.
// It validates the token format, attempts compression, and verifies the compressed
// data can be decompressed correctly. Only compresses if it reduces size.
// Parameters:
// - token: The JWT token string to potentially compress.
//
// Returns:
// - The base64 encoded, gzipped string, or the original string if compression fails.
func compressToken(token string) string {
if token == "" {
return token
}
dotCount := strings.Count(token, ".")
if dotCount != 2 {
return token
}
if len(token) > 50*1024 {
return token
}
pools := GetGlobalMemoryPools()
b := pools.GetCompressionBuffer()
defer pools.PutCompressionBuffer(b)
gz := gzip.NewWriter(b)
written, err := gz.Write([]byte(token))
if err != nil || written != len(token) {
return token
}
if err := gz.Close(); err != nil {
return token
}
compressedBytes := b.Bytes()
if len(compressedBytes) == 0 {
return token
}
compressed := base64.StdEncoding.EncodeToString(compressedBytes)
if len(compressed) >= len(token) {
return token
}
decompressed := decompressTokenInternal(compressed)
if decompressed != token {
return token
}
if strings.Count(decompressed, ".") != 2 {
return token
}
return compressed
}
// decompressToken decompresses a previously compressed token string.
// It decodes the base64 data, validates gzip headers, and decompresses safely
// with size limits to prevent compression bombs.
// Parameters:
// - compressed: The base64-encoded compressed token string.
//
// Returns:
// - The decompressed original string, or the input string if decompression fails.
func decompressToken(compressed string) string {
return decompressTokenInternal(compressed)
}
// decompressTokenInternal is the internal decompression function.
// Separated internal function for integrity verification during compression.
// It performs the actual decompression logic with proper resource management.
// Parameters:
// - compressed: The compressed token string to decompress.
//
// Returns:
// - The decompressed token or the original string if decompression fails.
func decompressTokenInternal(compressed string) string {
if compressed == "" {
return compressed
}
if len(compressed) > 100*1024 {
return compressed
}
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return compressed
}
if len(data) == 0 {
return compressed
}
if len(data) < 2 || data[0] != 0x1f || data[1] != 0x8b {
return compressed
}
pools := GetGlobalMemoryPools()
readerBuf := pools.GetHTTPResponseBuffer()
defer pools.PutHTTPResponseBuffer(readerBuf)
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return compressed
}
defer func() {
if closeErr := gz.Close(); closeErr != nil {
_ = closeErr
}
}()
limitedReader := io.LimitReader(gz, 500*1024)
if cap(readerBuf) >= 512*1024 {
readerBuf = readerBuf[:cap(readerBuf)]
n, err := limitedReader.Read(readerBuf)
if err != nil && err != io.EOF {
return compressed
}
decompressed := readerBuf[:n]
return string(decompressed)
}
decompressed, err := io.ReadAll(limitedReader)
if err != nil {
return compressed
}
if len(decompressed) == 0 {
return compressed
}
decompressedStr := string(decompressed)
if decompressedStr != "" && strings.Count(decompressedStr, ".") != 2 {
return compressed
}
return decompressedStr
}
// SessionManager manages OIDC session state and cookie-based storage.
// It provides secure session management with support for token compression,
// chunked storage for large tokens, session pooling for performance,
// session object reuse and supports both HTTP and HTTPS schemes.
type SessionManager struct {
sessionPool sync.Pool
store sessions.Store
logger *Logger
chunkManager *ChunkManager
cookieDomain string
cleanupMutex sync.RWMutex
forceHTTPS bool
cleanupDone bool
ctx context.Context
cancel context.CancelFunc
memoryMonitor *TaskMemoryMonitor
activeSessions int64
poolHits int64
poolMisses int64
shutdownOnce sync.Once
}
// NewSessionManager creates a new SessionManager instance with secure defaults.
// It initializes the cookie store with encryption, sets up session pooling,
// and configures chunk management for large tokens.
// Parameters:
// - encryptionKey: The key for encrypting session cookies (minimum 32 bytes).
// - forceHTTPS: Whether to force HTTPS-only cookies regardless of request scheme.
// - cookieDomain: The domain for session cookies (empty for auto-detection).
// - logger: Logger instance for debug and error logging.
//
// Returns:
// - The configured SessionManager instance.
// - An error if the encryption key does not meet minimum length requirements.
func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain string, logger *Logger) (*SessionManager, error) {
if len(encryptionKey) < minEncryptionKeyLength {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
ctx, cancel := context.WithCancel(context.Background())
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
cookieDomain: cookieDomain,
logger: logger,
chunkManager: NewChunkManager(logger),
ctx: ctx,
cancel: cancel,
}
// Initialize global memory monitoring (singleton)
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
// Start memory monitoring every 30 seconds (will skip if already started)
if err := sm.memoryMonitor.Start(30 * time.Second); err != nil {
logger.Infof("Failed to start memory monitoring: %v", err)
}
sm.sessionPool.New = func() interface{} {
atomic.AddInt64(&sm.poolMisses, 1)
sd := &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
idTokenChunks: make(map[int]*sessions.Session),
refreshMutex: sync.Mutex{},
sessionMutex: sync.RWMutex{},
dirty: false,
inUse: false,
}
sd.Reset()
return sd
}
// Start background cleanup routine
go sm.backgroundCleanup()
return sm, nil
}
// Shutdown gracefully shuts down the SessionManager and all its background tasks
func (sm *SessionManager) Shutdown() error {
var shutdownErr error
sm.shutdownOnce.Do(func() {
if sm.logger != nil {
sm.logger.Info("SessionManager shutdown initiated")
}
// Cancel context to stop all background operations
if sm.cancel != nil {
sm.cancel()
}
// Stop memory monitor
if sm.memoryMonitor != nil {
sm.memoryMonitor.Stop()
}
// Stop chunk manager
if sm.chunkManager != nil {
sm.chunkManager.Shutdown()
}
// Force garbage collection to help cleanup
runtime.GC()
if sm.logger != nil {
sm.logger.Info("SessionManager shutdown completed")
}
})
return shutdownErr
}
// backgroundCleanup runs periodic cleanup tasks for session management
func (sm *SessionManager) backgroundCleanup() {
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
defer ticker.Stop()
for {
select {
case <-sm.ctx.Done():
if sm.logger != nil {
sm.logger.Debug("Background cleanup routine terminated")
}
return
case <-ticker.C:
sm.performCleanupCycle()
}
}
}
// performCleanupCycle executes a complete cleanup cycle
func (sm *SessionManager) performCleanupCycle() {
if sm.logger != nil {
sm.logger.Debug("Starting background cleanup cycle")
}
startTime := time.Now()
// Run periodic chunk cleanup
sm.PeriodicChunkCleanup()
// Clean up session pool by forcing GC on old sessions
sm.cleanupSessionPool()
// Force garbage collection if memory usage is high
var m runtime.MemStats
runtime.ReadMemStats(&m)
if m.HeapAlloc > 50*1024*1024 { // 50MB threshold
runtime.GC()
if sm.logger != nil {
sm.logger.Debug("Forced garbage collection due to high memory usage")
}
}
duration := time.Since(startTime)
if sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
sm.logger.Debugf("Cleanup cycle completed in %v", duration)
}
}
// cleanupSessionPool performs cleanup on the session pool
func (sm *SessionManager) cleanupSessionPool() {
cleaned := 0
const maxCleanup = 20 // Limit cleanup per cycle to avoid performance impact
for i := 0; i < maxCleanup; i++ {
select {
case <-sm.ctx.Done():
return
default:
}
if poolSession := sm.sessionPool.Get(); poolSession != nil {
sessionData, ok := poolSession.(*SessionData)
if ok && sessionData != nil && !sessionData.inUse {
sessionData.Reset()
cleaned++
}
sm.sessionPool.Put(poolSession)
} else {
break // Pool is empty
}
}
if cleaned > 0 && sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
sm.logger.Debugf("Cleaned %d session pool objects", cleaned)
}
}
// GetSessionStats returns statistics about session management
func (sm *SessionManager) GetSessionStats() map[string]interface{} {
stats := make(map[string]interface{})
stats["active_sessions"] = atomic.LoadInt64(&sm.activeSessions)
stats["pool_hits"] = atomic.LoadInt64(&sm.poolHits)
stats["pool_misses"] = atomic.LoadInt64(&sm.poolMisses)
if sm.memoryMonitor != nil {
if currentStats, err := sm.memoryMonitor.GetCurrentStats(); err == nil {
stats["goroutines"] = currentStats.Goroutines
stats["heap_alloc"] = currentStats.HeapAlloc
stats["num_gc"] = currentStats.NumGC
}
}
return stats
}
// PeriodicChunkCleanup performs comprehensive session maintenance and cleanup.
// It cleans up orphaned token chunks, expired sessions, and unused pool objects.
// This helps maintain performance and prevent cookie accumulation in client browsers.
func (sm *SessionManager) PeriodicChunkCleanup() {
if sm == nil || sm.logger == nil {
return
}
// Check if context is cancelled or we're in test mode to prevent logging after test completion
if sm.ctx == nil || sm.ctx.Err() != nil || isTestMode() {
return // Skip logging if context is cancelled or in test mode
}
sm.logger.Debug("Starting comprehensive session cleanup cycle")
cleanupStart := time.Now()
var orphanedChunks, expiredSessions, cleanupErrors int
if sm.store != nil {
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok {
// Check context again before logging
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
sm.logger.Debug("Running session store cleanup")
}
_ = cookieStore
}
}
// Cleanup expired sessions in chunk manager to prevent memory leaks
if sm.chunkManager != nil {
sm.chunkManager.CleanupExpiredSessions()
}
poolCleaned := 0
for i := 0; i < 10; i++ {
if poolSession := sm.sessionPool.Get(); poolSession != nil {
sessionData, ok := poolSession.(*SessionData)
if ok && sessionData != nil && !sessionData.inUse {
sessionData.Reset()
poolCleaned++
}
sm.sessionPool.Put(poolSession)
}
}
// Check context before final logging
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
cleanupDuration := time.Since(cleanupStart)
sm.logger.Debugf("Session cleanup completed in %v: pool_cleaned=%d, orphaned_chunks=%d, expired_sessions=%d, errors=%d",
cleanupDuration, poolCleaned, orphanedChunks, expiredSessions, cleanupErrors)
}
}
// ValidateSessionHealth performs comprehensive validation of session integrity.
// It checks authentication state, validates token formats, and detects
// potential tampering or corruption in session data.
// Parameters:
// - sessionData: The session data to validate.
//
// Returns:
// - An error describing any validation failures, nil if session is healthy.
func (sm *SessionManager) ValidateSessionHealth(sessionData *SessionData) error {
if sessionData == nil {
return fmt.Errorf("session data is nil")
}
if !sessionData.GetAuthenticated() {
return fmt.Errorf("session is not authenticated or has expired")
}
accessToken := sessionData.GetAccessToken()
refreshToken := sessionData.GetRefreshToken()
idToken := sessionData.GetIDToken()
if accessToken != "" {
if err := sm.validateTokenFormat(accessToken, "access_token"); err != nil {
return fmt.Errorf("access token validation failed: %w", err)
}
}
if refreshToken != "" {
if err := sm.validateTokenFormat(refreshToken, "refresh_token"); err != nil {
return fmt.Errorf("refresh token validation failed: %w", err)
}
}
if idToken != "" {
if err := sm.validateTokenFormat(idToken, "id_token"); err != nil {
return fmt.Errorf("ID token validation failed: %w", err)
}
}
if err := sm.detectSessionTampering(sessionData); err != nil {
return fmt.Errorf("session tampering detected: %w", err)
}
return nil
}
// validateTokenFormat validates the structure and format of authentication tokens.
// It checks for corruption markers, validates JWT structure if applicable,
// and ensures tokens meet format requirements.
// Parameters:
// - token: The token string to validate.
// - tokenType: The type of token being validated (for error messages).
//
// Returns:
// - An error if the token has invalid structure or exceeds size limits.
func (sm *SessionManager) validateTokenFormat(token, tokenType string) error {
if token == "" {
return nil
}
if isCorruptionMarker(token) {
return fmt.Errorf("%s contains corruption marker", tokenType)
}
if strings.Count(token, ".") == 2 {
parts := strings.Split(token, ".")
for i, part := range parts {
if part == "" {
return fmt.Errorf("%s has empty part %d in JWT format", tokenType, i)
}
if strings.ContainsAny(part, "+/=") && !strings.ContainsAny(part, "-_") {
sm.logger.Debugf("Token %s part %d uses base64 instead of base64url encoding", tokenType, i)
}
}
}
return nil
}
// detectSessionTampering checks for indicators of session tampering.
// It examines session values for path traversal attempts, XSS payloads,
// and suspicious data patterns that might indicate malicious modification.
// Parameters:
// - sessionData: The session data to examine for tampering.
//
// Returns:
// - An error if tampering is detected, nil if session appears safe.
func (sm *SessionManager) detectSessionTampering(sessionData *SessionData) error {
if sessionData.mainSession == nil {
return fmt.Errorf("main session is missing")
}
for key, value := range sessionData.mainSession.Values {
if str, ok := value.(string); ok {
if strings.Contains(str, "../") || strings.Contains(str, "..\\") {
return fmt.Errorf("potential path traversal attempt in session key %v", key)
}
if strings.Contains(str, "<script") || strings.Contains(str, "javascript:") {
return fmt.Errorf("potential XSS attempt in session key %v", key)
}
if len(str) > 10000 {
return fmt.Errorf("suspiciously long session value for key %v (length: %d)", key, len(str))
}
}
}
return nil
}
// GetSessionMetrics returns metrics about session management for monitoring purposes.
// It provides information about session configuration, security settings,
// and internal state for debugging and monitoring.
// Returns:
// - A map containing session metrics and configuration information.
func (sm *SessionManager) GetSessionMetrics() map[string]interface{} {
metrics := make(map[string]interface{})
metrics["session_manager_type"] = "CookieStore"
metrics["force_https"] = sm.forceHTTPS
metrics["absolute_timeout_hours"] = absoluteSessionTimeout.Hours()
metrics["max_cookie_size"] = maxCookieSize
metrics["max_browser_cookie_size"] = maxBrowserCookieSize
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok && len(cookieStore.Codecs) > 0 {
metrics["has_encryption"] = true
metrics["codec_count"] = len(cookieStore.Codecs)
} else {
metrics["has_encryption"] = false
}
metrics["pool_implementation"] = "sync.Pool"
return metrics
}
// EnhanceSessionSecurity applies additional security measures to session options.
// It configures secure cookies, domain detection, SameSite policies, and
// adapts security settings based on request context and client characteristics.
// Parameters:
// - options: The base session options to enhance (can be nil).
// - r: The HTTP request context for security decisions.
//
// Returns:
// - Enhanced sessions.Options with additional security measures.
func (sm *SessionManager) EnhanceSessionSecurity(options *sessions.Options, r *http.Request) *sessions.Options {
if options == nil {
options = &sessions.Options{}
}
if r != nil {
userAgent := r.Header.Get("User-Agent")
if userAgent == "" {
sm.logger.Debugf("Request from %s missing User-Agent header", r.RemoteAddr)
options.MaxAge = int((absoluteSessionTimeout / 2).Seconds())
}
if r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS {
options.Secure = true
}
if r.Header.Get("X-Requested-With") == "XMLHttpRequest" {
options.SameSite = http.SameSiteStrictMode
}
}
options.HttpOnly = true
if sm.cookieDomain != "" {
options.Domain = sm.cookieDomain
sm.logger.Debugf("Using configured cookie domain: %s", sm.cookieDomain)
} else if options.Domain == "" && r != nil {
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
host = host[:colonIndex]
}
options.Domain = host
sm.logger.Debugf("Auto-detected cookie domain: %s", host)
}
}
return options
}
// getSessionOptions creates base session options with security settings.
// It configures cookie security, lifetime, path, and domain settings
// based on the HTTPS status and manager configuration.
// Parameters:
// - isSecure: Whether the request is over HTTPS or should be treated as secure.
//
// Returns:
// - A pointer to a configured sessions.Options struct.
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
baseOptions := &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: int(absoluteSessionTimeout.Seconds()),
Path: "/",
Domain: sm.cookieDomain,
}
return baseOptions
}
// CleanupOldCookies removes stale session cookies from the client browser.
// This method handles cleanup of cookies that may exist with different domain
// configurations, ensuring clean state when domain settings change.
// It removes cookies with various domain variations to ensure cleanup after configuration changes.
// Parameters:
// - w: The HTTP response writer for setting cookie deletion headers.
// - r: The HTTP request containing cookies to examine and clean up.
func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Request) {
cookies := r.Cookies()
currentDomain := sm.cookieDomain
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
host = host[:colonIndex]
}
// This ensures we clean up cookies from various possible domains
var domainsToClean []string
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
domainsToClean = append(domainsToClean,
host,
"."+host,
)
parts := strings.Split(host, ".")
if len(parts) > 2 {
parentDomain := strings.Join(parts[len(parts)-2:], ".")
domainsToClean = append(domainsToClean,
parentDomain,
"."+parentDomain,
)
}
}
processedCookies := make(map[string]bool)
for _, cookie := range cookies {
if strings.HasPrefix(cookie.Name, mainCookieName) ||
strings.HasPrefix(cookie.Name, accessTokenCookie) ||
strings.HasPrefix(cookie.Name, refreshTokenCookie) ||
strings.HasPrefix(cookie.Name, "_oidc_raczylo_id") ||
strings.HasPrefix(cookie.Name, "access_token_chunk_") ||
strings.HasPrefix(cookie.Name, "refresh_token_chunk_") {
processedCookies[cookie.Name] = true
sm.cleanupMutex.RLock()
shouldCleanup := currentDomain != "" && !sm.cleanupDone
sm.cleanupMutex.RUnlock()
if shouldCleanup {
for _, domain := range domainsToClean {
if domain == currentDomain || domain == "."+currentDomain || "."+domain == currentDomain {
continue
}
deleteCookie := &http.Cookie{
Name: cookie.Name,
Value: "",
Path: "/",
Domain: domain,
MaxAge: -1,
HttpOnly: true,
Secure: r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, deleteCookie)
sm.logger.Debugf("Attempting to clean up cookie %s with domain %s", cookie.Name, domain)
}
}
}
}
if len(processedCookies) > 0 {
sm.cleanupMutex.Lock()
if !sm.cleanupDone {
sm.cleanupDone = true
}
sm.cleanupMutex.Unlock()
}
}
// GetSession retrieves or creates session data from the HTTP request.
// It loads the main session and all token chunk sessions, performing validation
// and timeout checks. The returned session must be explicitly returned to the pool
// by calling returnToPoolSafely() to prevent memory leaks.
// MEMORY LEAK FIX: Session is NOT returned to pool here - caller must call ReturnToPool() when done.
// Parameters:
// - r: The HTTP request containing session cookies.
//
// Returns:
// - The loaded SessionData instance.
// - An error if session loading or validation fails.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
sessionData := sm.sessionPool.Get().(*SessionData)
atomic.AddInt64(&sm.poolHits, 1)
atomic.AddInt64(&sm.activeSessions, 1)
sessionData.inUse = true
sessionData.request = r
sessionData.dirty = false
handleError := func(err error, message string) (*SessionData, error) {
if sessionData != nil {
sessionData.inUse = false
sessionData.Reset()
sm.sessionPool.Put(sessionData)
atomic.AddInt64(&sm.activeSessions, -1)
}
return nil, fmt.Errorf("%s: %w", message, err)
}
var err error
sessionData.mainSession, err = sm.store.Get(r, mainCookieName)
if err != nil {
return handleError(err, "failed to get main session")
}
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
sessionData.Clear(r, nil)
return handleError(fmt.Errorf("session timeout"), "session expired")
}
}
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
if err != nil {
return handleError(err, "failed to get access token session")
}
sessionData.refreshSession, err = sm.store.Get(r, refreshTokenCookie)
if err != nil {
return handleError(err, "failed to get refresh token session")
}
sessionData.idTokenSession, err = sm.store.Get(r, idTokenCookie)
if err != nil {
return handleError(err, "failed to get ID token session")
}
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
for k := range sessionData.refreshTokenChunks {
delete(sessionData.refreshTokenChunks, k)
}
for k := range sessionData.idTokenChunks {
delete(sessionData.idTokenChunks, k)
}
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
sm.getTokenChunkSessions(r, idTokenCookie, sessionData.idTokenChunks)
return sessionData, nil
}
// getTokenChunkSessions loads all available token chunk sessions for a given token type.
// It iterates through numbered chunk sessions until no more are found,
// populating the provided chunks map with the loaded sessions.
// Parameters:
// - r: The HTTP request containing chunk cookies.
// - baseName: The base cookie name for the token type (e.g., "_oidc_raczylo_a").
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks)
// to populate with the found session chunks.
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
break
}
chunks[i] = session
}
}
// SessionData represents a user's authentication session with comprehensive token management.
// It handles main session data and supports large tokens that need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
manager *SessionManager
request *http.Request
mainSession *sessions.Session
accessSession *sessions.Session
refreshSession *sessions.Session
idTokenSession *sessions.Session
accessTokenChunks map[int]*sessions.Session
refreshTokenChunks map[int]*sessions.Session
idTokenChunks map[int]*sessions.Session
refreshMutex sync.Mutex
sessionMutex sync.RWMutex
dirty bool
inUse bool
}
// IsDirty returns true if the session data has been modified since it was last loaded or saved.
// This is used to optimize session saves by only writing when necessary.
// Returns:
// - true if the session has pending changes, false otherwise.
func (sd *SessionData) IsDirty() bool {
return sd.dirty
}
// MarkDirty marks the session as having pending changes that need to be saved.
// This is used when session data hasn't changed in content but should still
// trigger a session save (e.g., to ensure the cookie is re-issued).
func (sd *SessionData) MarkDirty() {
sd.dirty = true
}
// Save persists all session data including main session and token chunks.
// It applies security options, saves all session components, and handles
// errors gracefully by continuing to save other components even if one fails.
// Parameters:
// - r: The HTTP request context for security option configuration.
// - w: The HTTP response writer for setting session cookies.
//
// Returns:
// - An error if saving any of the session components fails.
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sd.manager.forceHTTPS
options := sd.manager.getSessionOptions(isSecure)
options = sd.manager.EnhanceSessionSecurity(options, r)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
var firstErr error
saveOrLogError := func(s *sessions.Session, name string) {
if s == nil {
sd.manager.logger.Errorf("Attempted to save nil session: %s", name)
if firstErr == nil {
firstErr = fmt.Errorf("attempted to save nil session: %s", name)
}
return
}
if err := s.Save(r, w); err != nil {
errMsg := fmt.Errorf("failed to save %s session: %w", name, err)
sd.manager.logger.Error("%s", errMsg.Error())
if firstErr == nil {
firstErr = errMsg
}
}
}
saveOrLogError(sd.mainSession, "main")
saveOrLogError(sd.accessSession, "access token")
saveOrLogError(sd.refreshSession, "refresh token")
saveOrLogError(sd.idTokenSession, "ID token")
for i, sessionChunk := range sd.accessTokenChunks {
sessionChunk.Options = options
saveOrLogError(sessionChunk, fmt.Sprintf("access token chunk %d", i))
}
for i, sessionChunk := range sd.refreshTokenChunks {
sessionChunk.Options = options
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
}
for i, sessionChunk := range sd.idTokenChunks {
sessionChunk.Options = options
saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i))
}
if firstErr == nil {
sd.dirty = false
}
return firstErr
}
// clearSessionValues removes all values from a session and optionally expires it.
// This is used during session cleanup and logout operations.
// Parameters:
// - session: The session to clear values from.
// - expire: If true, sets MaxAge to -1 to expire the cookie.
func clearSessionValues(session *sessions.Session, expire bool) {
if session == nil {
return
}
for k := range session.Values {
delete(session.Values, k)
}
if expire {
session.Options.MaxAge = -1
}
}
// clearAllSessionData clears all session data including main session and token chunks.
// It removes all session values and optionally expires all associated cookies.
// Parameters:
// - r: The HTTP request context (used for chunk cleanup).
// - expire: Whether to expire the cookies (set MaxAge to -1).
func (sd *SessionData) clearAllSessionData(r *http.Request, expire bool) {
clearSessionValues(sd.mainSession, expire)
clearSessionValues(sd.accessSession, expire)
clearSessionValues(sd.refreshSession, expire)
clearSessionValues(sd.idTokenSession, expire)
if expire && r != nil {
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
sd.clearTokenChunks(r, sd.idTokenChunks)
} else {
for k := range sd.accessTokenChunks {
delete(sd.accessTokenChunks, k)
}
for k := range sd.refreshTokenChunks {
delete(sd.refreshTokenChunks, k)
}
for k := range sd.idTokenChunks {
delete(sd.idTokenChunks, k)
}
}
if expire {
sd.dirty = true
}
}
// Clear completely clears all session data and safely returns the session to the pool.
// It removes all authentication data, expires cookies, and handles panic recovery.
// This method ensures the SessionData object is always returned to the pool.
// Parameters:
// - r: The HTTP request context.
// - w: The HTTP response writer for cookie expiration (can be nil).
//
// Returns:
// - An error if session saving fails during cleanup.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
defer func() {
sd.returnToPoolSafely()
}()
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
sd.clearAllSessionData(r, true)
// This is primarily for testing - in production w will often be nil
var err error
if w != nil {
if r != nil && r.Header.Get("X-Test-Error") == "true" {
// Return a test error without trying to save problematic data
err = fmt.Errorf("test error triggered by X-Test-Error header")
} else {
err = sd.Save(r, w)
}
}
sd.request = nil
return err
}
// returnToPoolSafely safely returns the session to the object pool.
// Add thread-safe helper method to return session to pool.
// It ensures the session is marked as not in use and properly reset before pooling.
func (sd *SessionData) returnToPoolSafely() {
if sd != nil && sd.manager != nil {
if sd.inUse {
sd.inUse = false
sd.Reset()
sd.manager.sessionPool.Put(sd)
atomic.AddInt64(&sd.manager.activeSessions, -1)
}
}
}
// clearTokenChunks clears and expires all token chunk sessions.
// This is used during logout and session cleanup to ensure
// all token data is properly removed from the client.
// Parameters:
// - r: The HTTP request context.
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
clearSessionValues(session, true)
}
}
// GetAuthenticated returns whether the user is currently authenticated.
// It checks both the authentication flag and session timeout.
// Returns:
// - true if the user is authenticated and the session is not expired.
// - false otherwise.
func (sd *SessionData) GetAuthenticated() bool {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
return sd.getAuthenticatedUnsafe()
}
// getAuthenticatedUnsafe checks authentication status without acquiring locks.
// Used when the mutex is already held to avoid deadlocks.
// It validates both the authentication flag and session creation time.
// Returns:
// - true if authenticated and not expired, false otherwise.
func (sd *SessionData) getAuthenticatedUnsafe() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
if !auth {
return false
}
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
}
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
}
// SetAuthenticated sets the authentication status and manages session security.
// When setting to true, it generates a new secure session ID and updates timestamps.
// This prevents session fixation attacks by regenerating the session identifier.
// Parameters:
// - value: The authentication status to set.
//
// Returns:
// - An error if generating a new session ID fails when setting value to true.
func (sd *SessionData) SetAuthenticated(value bool) error {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentAuth := sd.getAuthenticatedUnsafe()
changed := false
if currentAuth != value {
changed = true
}
if value {
id, err := generateSecureRandomString(64)
if err != nil {
return fmt.Errorf("failed to generate secure session id: %w", err)
}
maxRetries := 5
for retry := 0; retry < maxRetries; retry++ {
if sd.mainSession.ID != id {
break
}
id, err = generateSecureRandomString(64)
if err != nil {
return fmt.Errorf("failed to generate secure session id on retry %d: %w", retry, err)
}
}
if sd.mainSession.ID != id {
changed = true
}
sd.mainSession.ID = id
newCreationTime := time.Now().Unix()
if oldTime, ok := sd.mainSession.Values["created_at"].(int64); !ok || oldTime != newCreationTime {
changed = true
}
sd.mainSession.Values["created_at"] = newCreationTime
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
changed = true
}
} else {
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
changed = true
}
}
sd.mainSession.Values["authenticated"] = value
if changed {
sd.dirty = true
}
return nil
}
// resetSession prepares a session for reuse by clearing its state.
// This is specifically for pool reuse preparation to ensure
// no data leaks between different user sessions.
// Parameters:
// - session: The session to reset for reuse.
func resetSession(session *sessions.Session) {
if session == nil {
return
}
clearSessionValues(session, false)
session.ID = ""
session.IsNew = true
}
// Reset clears all session data and prepares the SessionData for reuse.
// It ensures no authentication data persists when the object is reused
// between different users/sessions.
func (sd *SessionData) Reset() {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
sd.clearAllSessionData(nil, false)
resetSession(sd.mainSession)
resetSession(sd.accessSession)
resetSession(sd.refreshSession)
resetSession(sd.idTokenSession)
// Clear redirect count to prevent leaking between sessions
if sd.mainSession != nil && sd.mainSession.Values != nil {
delete(sd.mainSession.Values, "redirect_count")
}
sd.dirty = false
sd.inUse = false
sd.request = nil
// Reset the refresh mutex to ensure clean state
// Note: We don't need to lock it since sessionMutex is already held
// and this session is not in use by any request
}
// ReturnToPool manually returns the session to the object pool.
// This is used in cleanup paths where Clear() is not called, to prevent memory leaks.
// It only returns the session if it's not currently in use.
func (sd *SessionData) ReturnToPool() {
if sd != nil && sd.manager != nil {
if !sd.inUse {
sd.Reset()
sd.manager.sessionPool.Put(sd)
atomic.AddInt64(&sd.manager.activeSessions, -1)
}
}
}
// GetAccessToken retrieves the user's access token from session storage.
// It handles both single-cookie storage and chunked storage for large tokens,
// with automatic decompression if the token was compressed.
// Returns:
// - The complete, decompressed access token string, or an empty string if not found.
func (sd *SessionData) GetAccessToken() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
return sd.getAccessTokenUnsafe()
}
// getAccessTokenUnsafe retrieves the access token without acquiring locks.
// Enhanced token retrieval with comprehensive integrity checks and recovery mechanisms.
// Used when the session mutex is already held to prevent deadlocks.
// Returns:
// - The complete access token string or empty string on error.
func (sd *SessionData) getAccessTokenUnsafe() string {
token, _ := sd.accessSession.Values["token"].(string)
compressed, _ := sd.accessSession.Values["compressed"].(bool)
// Debug: Check if manager/chunkManager is nil
if sd.manager == nil || sd.manager.chunkManager == nil {
// Direct return if no chunk manager (test scenario)
return token
}
result := sd.manager.chunkManager.GetToken(
token,
compressed,
sd.accessTokenChunks,
AccessTokenConfig,
)
if result.Error != nil {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("ChunkManager.GetToken error: %v", result.Error)
}
return ""
}
return result.Token
}
// SetAccessToken stores an access token with automatic compression and chunking.
// It validates token format, compresses if beneficial, and splits into chunks
// if the token exceeds cookie size limits. Includes integrity verification.
// Parameters:
// - token: The access token string to store.
func (sd *SessionData) SetAccessToken(token string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
if token != "" {
dotCount := strings.Count(token, ".")
if dotCount == 1 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Invalid token format during storage (dots: %d) - rejecting", dotCount)
}
return
}
if dotCount == 0 && len(token) < 20 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
}
return
}
}
currentAccessToken := sd.getAccessTokenUnsafe()
if currentAccessToken == token {
return
}
sd.dirty = true
// Debug: Check if accessSession is properly initialized
if sd.accessSession == nil {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Errorf("CRITICAL: accessSession is nil when trying to store token")
}
return
}
if sd.request != nil {
sd.expireAccessTokenChunksEnhanced(nil)
}
for k := range sd.accessTokenChunks {
delete(sd.accessTokenChunks, k)
}
if token == "" {
if sd.accessSession != nil {
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = false
}
return
}
compressed := compressToken(token)
// Debug for test
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("Token compression: original %d bytes, compressed %d bytes", len(token), len(compressed))
}
if len(compressed) > 100*1024 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Info("Access token too large after compression (%d bytes) - storing uncompressed", len(compressed))
}
return
}
if compressed != token {
testDecompressed := decompressToken(compressed)
if testDecompressed != token {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Access token compression verification failed - storing uncompressed")
}
compressed = token
}
}
if len(compressed) <= maxCookieSize {
if sd.accessSession != nil {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = (compressed != token)
// Debug for test
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("Stored token in session: compressed=%v, token_len=%d",
compressed != token, len(compressed))
}
}
} else {
if sd.accessSession != nil {
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = (compressed != token)
}
chunks := splitIntoChunks(compressed, maxCookieSize)
if len(chunks) == 0 {
sd.manager.logger.Error("Failed to create chunks for access token")
return
}
if len(chunks) > 50 {
sd.manager.logger.Info("Too many chunks (%d) for access token", len(chunks))
return
}
testReassembled := strings.Join(chunks, "")
if testReassembled != compressed {
sd.manager.logger.Debug("Access token chunk reassembly test failed")
return
}
for i, chunkData := range chunks {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
if sd.request == nil {
sd.manager.logger.Error("SetAccessToken: sd.request is nil, cannot create chunk session %s", sessionName)
return
}
if chunkData == "" {
sd.manager.logger.Debug("Empty chunk data at index %d", i)
return
}
if len(chunkData) > maxCookieSize {
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
return
}
if !validateChunkSize(chunkData) {
sd.manager.logger.Errorf("CRITICAL: Chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
return
}
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
return
}
session.Values["token_chunk"] = chunkData
session.Values["compressed"] = (compressed != token)
session.Values["chunk_created_at"] = time.Now().Unix()
sd.accessTokenChunks[i] = session
}
sd.manager.logger.Debugf("SUCCESS: Stored access token in %d chunks", len(chunks))
}
}
// GetRefreshToken retrieves the user's refresh token from session storage.
// It handles both single-cookie storage and chunked storage for large tokens,
// with automatic decompression if the token was compressed.
// Returns:
// - The complete, decompressed refresh token string, or an empty string if not found.
func (sd *SessionData) GetRefreshToken() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
token, _ := sd.refreshSession.Values["token"].(string)
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
result := sd.manager.chunkManager.GetToken(
token,
compressed,
sd.refreshTokenChunks,
RefreshTokenConfig,
)
if result.Error != nil {
return ""
}
return result.Token
}
// SetRefreshToken stores a refresh token with automatic compression and chunking.
// It validates token size, compresses if beneficial, and splits into chunks
// if needed. Includes comprehensive error checking and integrity verification.
// Parameters:
// - token: The refresh token string to store.
func (sd *SessionData) SetRefreshToken(token string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
if len(token) > 50*1024 {
sd.manager.logger.Errorf("CRITICAL: Refresh token too large (%d bytes) - possible corruption, rejecting", len(token))
return
}
// Get current refresh token without mutex to avoid deadlock since we already hold the lock
var currentRefreshToken string
sessionToken, _ := sd.refreshSession.Values["token"].(string)
if sessionToken != "" {
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
decompressed := decompressToken(sessionToken)
currentRefreshToken = decompressed
} else {
currentRefreshToken = sessionToken
}
} else if len(sd.refreshTokenChunks) > 0 {
// Simplified chunked token retrieval for deadlock prevention
var chunks []string
for i := 0; i < len(sd.refreshTokenChunks); i++ {
if session, ok := sd.refreshTokenChunks[i]; ok {
if chunk, chunkOk := session.Values["token_chunk"].(string); chunkOk && chunk != "" {
chunks = append(chunks, chunk)
}
}
}
if len(chunks) == len(sd.refreshTokenChunks) {
reassembled := strings.Join(chunks, "")
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
currentRefreshToken = decompressToken(reassembled)
} else {
currentRefreshToken = reassembled
}
}
}
if currentRefreshToken == token {
return
}
sd.dirty = true
if sd.request != nil {
sd.expireRefreshTokenChunksEnhanced(nil)
}
for k := range sd.refreshTokenChunks {
delete(sd.refreshTokenChunks, k)
}
if token == "" {
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = false
return
}
compressed := compressToken(token)
if compressed != token {
testDecompressed := decompressToken(compressed)
if testDecompressed != token {
sd.manager.logger.Errorf("CRITICAL: Refresh token compression verification failed - storing uncompressed")
compressed = token
}
}
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = (compressed != token)
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
} else {
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = (compressed != token)
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
chunks := splitIntoChunks(compressed, maxCookieSize)
if len(chunks) == 0 {
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for refresh token")
return
}
if len(chunks) > 50 {
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for refresh token - possible corruption", len(chunks))
return
}
testReassembled := strings.Join(chunks, "")
if testReassembled != compressed {
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk reassembly test failed")
return
}
for i, chunkData := range chunks {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
if sd.request == nil {
sd.manager.logger.Errorf("CRITICAL: SetRefreshToken: sd.request is nil, cannot create chunk session %s", sessionName)
return
}
if chunkData == "" {
sd.manager.logger.Errorf("CRITICAL: Empty refresh token chunk data at index %d", i)
return
}
if len(chunkData) > maxCookieSize {
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
return
}
if !validateChunkSize(chunkData) {
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
return
}
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
sd.manager.logger.Errorf("CRITICAL: Failed to get refresh token chunk session %s: %v", sessionName, err)
return
}
session.Values["token_chunk"] = chunkData
session.Values["compressed"] = (compressed != token)
session.Values["chunk_created_at"] = time.Now().Unix()
sd.refreshTokenChunks[i] = session
}
sd.manager.logger.Debugf("SUCCESS: Stored refresh token in %d chunks", len(chunks))
}
}
// GetRefreshTokenIssuedAt retrieves the timestamp when the refresh token was issued/stored.
// Returns the time when the current refresh token was obtained, or zero time if not available.
func (sd *SessionData) GetRefreshTokenIssuedAt() time.Time {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
if issuedAtUnix, ok := sd.refreshSession.Values["issued_at"].(int64); ok {
return time.Unix(issuedAtUnix, 0)
}
// For chunked tokens, check the first chunk for timestamp
if len(sd.refreshTokenChunks) > 0 {
if session, exists := sd.refreshTokenChunks[0]; exists {
if chunkCreatedAt, ok := session.Values["chunk_created_at"].(int64); ok {
return time.Unix(chunkCreatedAt, 0)
}
}
}
return time.Time{}
}
// expireAccessTokenChunksEnhanced expires all access token chunks and detects orphaned chunks.
// It searches for all existing chunks, identifies orphaned or expired chunks,
// and properly expires them to prevent cookie accumulation.
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireAccessTokenChunksEnhanced(w http.ResponseWriter) {
const maxChunkSearchLimit = 50
orphanedChunks := 0
for i := 0; i < maxChunkSearchLimit; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
break
}
if session.IsNew {
break
}
if chunk, exists := session.Values["token_chunk"]; exists {
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
chunkAge := time.Since(time.Unix(createdAt, 0))
if chunkAge > 24*time.Hour {
orphanedChunks++
sd.manager.logger.Debugf("Found orphaned access token chunk %d (age: %v)", i, chunkAge)
}
} else if chunk != nil {
orphanedChunks++
sd.manager.logger.Debugf("Found access token chunk %d without timestamp, treating as orphaned", i)
}
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired access token chunk %d: %v", i, err)
}
}
}
if orphanedChunks > 0 {
sd.manager.logger.Infof("Cleaned up %d orphaned access token chunks", orphanedChunks)
}
}
// expireRefreshTokenChunksEnhanced expires all refresh token chunks and detects orphaned chunks.
// It searches for all existing chunks, identifies orphaned or expired chunks,
// and properly expires them to prevent cookie accumulation.
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireRefreshTokenChunksEnhanced(w http.ResponseWriter) {
const maxChunkSearchLimit = 50
orphanedChunks := 0
for i := 0; i < maxChunkSearchLimit; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
break
}
if session.IsNew {
break
}
if chunk, exists := session.Values["token_chunk"]; exists {
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
chunkAge := time.Since(time.Unix(createdAt, 0))
if chunkAge > 24*time.Hour {
orphanedChunks++
sd.manager.logger.Debugf("Found orphaned refresh token chunk %d (age: %v)", i, chunkAge)
}
} else if chunk != nil {
orphanedChunks++
sd.manager.logger.Debugf("Found refresh token chunk %d without timestamp, treating as orphaned", i)
}
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired refresh token chunk %d: %v", i, err)
}
}
}
if orphanedChunks > 0 {
sd.manager.logger.Infof("Cleaned up %d orphaned refresh token chunks", orphanedChunks)
}
}
// expireIDTokenChunksEnhanced expires all ID token chunks and detects orphaned chunks.
// It searches for all existing chunks, identifies orphaned or expired chunks,
// and properly expires them to prevent cookie accumulation.
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireIDTokenChunksEnhanced(w http.ResponseWriter) {
const maxChunkSearchLimit = 50
orphanedChunks := 0
for i := 0; i < maxChunkSearchLimit; i++ {
sessionName := fmt.Sprintf("%s_%d", idTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
break
}
if session.IsNew {
break
}
if chunk, exists := session.Values["token_chunk"]; exists {
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
chunkAge := time.Since(time.Unix(createdAt, 0))
if chunkAge > 24*time.Hour {
orphanedChunks++
sd.manager.logger.Debugf("Found orphaned ID token chunk %d (age: %v)", i, chunkAge)
}
} else if chunk != nil {
orphanedChunks++
sd.manager.logger.Debugf("Found ID token chunk %d without timestamp, treating as orphaned", i)
}
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired ID token chunk %d: %v", i, err)
}
}
}
if orphanedChunks > 0 {
sd.manager.logger.Infof("Cleaned up %d orphaned ID token chunks", orphanedChunks)
}
}
// splitIntoChunks divides a string into chunks of specified maximum size.
// It ensures chunks don't exceed browser cookie limits and handles
// the string splitting logic for large token storage.
// Parameters:
// - s: The string to split into chunks.
// - chunkSize: The maximum size for each chunk.
//
// Returns:
// - A slice of strings representing the chunks.
func splitIntoChunks(s string, chunkSize int) []string {
effectiveChunkSize := min(chunkSize, maxCookieSize)
var chunks []string
for len(s) > 0 {
if len(s) > effectiveChunkSize {
chunks = append(chunks, s[:effectiveChunkSize])
s = s[effectiveChunkSize:]
} else {
chunks = append(chunks, s)
break
}
}
return chunks
}
// validateChunkSize checks if a chunk will fit within browser cookie limits.
// It estimates the encoded size including cookie overhead and headers
// to ensure the chunk won't exceed browser-imposed cookie size limits.
// Parameters:
// - chunkData: The chunk data to validate.
//
// Returns:
// - true if the chunk is safe to store, false if it may exceed browser limits.
func validateChunkSize(chunkData string) bool {
estimatedEncodedSize := len(chunkData) + (len(chunkData) * 50 / 100)
return estimatedEncodedSize <= maxBrowserCookieSize
}
// isCorruptionMarker detects if data contains known corruption indicators.
// It checks for specific corruption markers and invalid characters
// that indicate the data has been tampered with or corrupted.
// Parameters:
// - data: The data string to check for corruption markers.
//
// Returns:
// - true if the data contains corruption markers, false otherwise.
func isCorruptionMarker(data string) bool {
if data == "" {
return false
}
corruptionMarkers := []string{
"__CORRUPTION_MARKER_TEST__",
"__INVALID_BASE64_DATA__",
"__CORRUPTED_CHUNK_DATA__",
"!@#$%^&*()",
"<<<CORRUPTED>>>",
}
for _, marker := range corruptionMarkers {
if data == marker {
return true
}
}
if len(data) > 10 {
invalidChars := "!@#$%^&*(){}[]|\\:;\"'<>?,`~"
for _, char := range invalidChars {
if strings.ContainsRune(data, char) {
return true
}
}
}
return false
}
// GetCSRF retrieves the CSRF token for state validation.
// This token is used to prevent cross-site request forgery attacks
// during the OIDC authentication flow.
// Returns:
// - The CSRF token string, or an empty string if not set.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores the CSRF token for state validation.
// The token is used to validate the state parameter in OAuth callbacks.
// Parameters:
// - token: The CSRF token to store.
func (sd *SessionData) SetCSRF(token string) {
currentVal, _ := sd.mainSession.Values["csrf"].(string)
if currentVal != token {
sd.mainSession.Values["csrf"] = token
sd.dirty = true
}
}
// GetNonce retrieves the nonce for ID token validation.
// The nonce prevents replay attacks by ensuring ID tokens
// were issued in response to the specific authentication request.
// Returns:
// - The nonce string, or an empty string if not set.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores the nonce for ID token validation.
// The nonce will be validated against the nonce claim in received ID tokens.
// Parameters:
// - nonce: The nonce string to store.
func (sd *SessionData) SetNonce(nonce string) {
currentVal, _ := sd.mainSession.Values["nonce"].(string)
if currentVal != nonce {
sd.mainSession.Values["nonce"] = nonce
sd.dirty = true
}
}
// GetCodeVerifier retrieves the PKCE code verifier.
// This is used in the PKCE (Proof Key for Code Exchange) flow
// to enhance security for public clients.
// Returns:
// - The code verifier string, or an empty string if not set or PKCE is disabled.
func (sd *SessionData) GetCodeVerifier() string {
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
return codeVerifier
}
// SetCodeVerifier stores the PKCE code verifier.
// The code verifier is used to generate the code challenge sent to the
// authorization server and validated during token exchange.
// Parameters:
// - codeVerifier: The PKCE code verifier string to store.
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
currentVal, _ := sd.mainSession.Values["code_verifier"].(string)
if currentVal != codeVerifier {
sd.mainSession.Values["code_verifier"] = codeVerifier
sd.dirty = true
}
}
// GetEmail retrieves the authenticated user's email address.
// The email is extracted from ID token claims and used for
// authorization decisions and header injection.
// Returns:
// - The user's email address string, or an empty string if not set.
func (sd *SessionData) GetEmail() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the authenticated user's email address.
// The email is typically extracted from the 'email' claim in the ID token.
// Parameters:
// - email: The user's email address to store.
func (sd *SessionData) SetEmail(email string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentVal, _ := sd.mainSession.Values["email"].(string)
if currentVal != email {
sd.mainSession.Values["email"] = email
sd.dirty = true
}
}
// GetIncomingPath retrieves the original request URI that triggered authentication.
// This path is used to redirect the user back to their intended destination
// after successful authentication.
// Returns:
// - The original request URI string, or an empty string if not set.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request URI for post-authentication redirect.
// This allows the user to be redirected to their originally requested resource
// after completing the authentication flow.
// Parameters:
// - path: The original request URI string (e.g., "/protected/resource?id=123").
func (sd *SessionData) SetIncomingPath(path string) {
currentVal, _ := sd.mainSession.Values["incoming_path"].(string)
if currentVal != path {
sd.mainSession.Values["incoming_path"] = path
sd.dirty = true
}
}
// GetIDToken retrieves the user's ID token from session storage.
// The ID token contains user claims and is used for user identification
// and authorization decisions. Handles compression and chunking automatically.
// Returns:
// - The complete, decompressed ID token string, or an empty string if not found.
func (sd *SessionData) GetIDToken() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
return sd.getIDTokenUnsafe()
}
// getIDTokenUnsafe retrieves the ID token without acquiring locks.
// Enhanced ID token retrieval with comprehensive integrity checks and chunking support.
// Used when the session mutex is already held to prevent deadlocks.
// Returns:
// - The complete ID token string or empty string on error.
func (sd *SessionData) getIDTokenUnsafe() string {
token, _ := sd.idTokenSession.Values["token"].(string)
compressed, _ := sd.idTokenSession.Values["compressed"].(bool)
// Debug: Check if manager/chunkManager is nil
if sd.manager == nil || sd.manager.chunkManager == nil {
// Direct return if no chunk manager (test scenario)
return token
}
result := sd.manager.chunkManager.GetToken(
token,
compressed,
sd.idTokenChunks,
IDTokenConfig,
)
if result.Error != nil {
return ""
}
return result.Token
}
// SetIDToken stores an ID token with automatic compression and chunking.
// It validates the JWT format, compresses if beneficial, and splits into chunks
// if the token exceeds cookie size limits. Includes comprehensive validation.
// Parameters:
// - token: The ID token string to store.
func (sd *SessionData) SetIDToken(token string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
if token != "" {
dotCount := strings.Count(token, ".")
if dotCount != 2 {
sd.manager.logger.Errorf("CRITICAL: Attempt to store invalid JWT ID token format (dots: %d) - rejecting", dotCount)
return
}
}
if len(token) > 50*1024 {
sd.manager.logger.Errorf("CRITICAL: ID token too large (%d bytes) - possible corruption, rejecting", len(token))
return
}
currentIDToken := sd.getIDTokenUnsafe()
if currentIDToken == token {
return
}
sd.dirty = true
if sd.request != nil {
sd.expireIDTokenChunksEnhanced(nil)
}
for k := range sd.idTokenChunks {
delete(sd.idTokenChunks, k)
}
if token == "" {
if sd.idTokenSession != nil {
sd.idTokenSession.Values["token"] = ""
sd.idTokenSession.Values["compressed"] = false
}
return
}
compressed := compressToken(token)
if compressed != token {
testDecompressed := decompressToken(compressed)
if testDecompressed != token {
sd.manager.logger.Errorf("CRITICAL: ID token compression verification failed - storing uncompressed")
compressed = token
}
}
if len(compressed) <= maxCookieSize {
if sd.idTokenSession != nil {
sd.idTokenSession.Values["token"] = compressed
sd.idTokenSession.Values["compressed"] = (compressed != token)
}
} else {
if sd.idTokenSession != nil {
sd.idTokenSession.Values["token"] = ""
sd.idTokenSession.Values["compressed"] = (compressed != token)
}
chunks := splitIntoChunks(compressed, maxCookieSize)
if len(chunks) == 0 {
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for ID token")
return
}
if len(chunks) > 50 {
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for ID token - possible corruption", len(chunks))
return
}
testReassembled := strings.Join(chunks, "")
if testReassembled != compressed {
sd.manager.logger.Errorf("CRITICAL: ID token chunk reassembly test failed")
return
}
for i, chunkData := range chunks {
sessionName := fmt.Sprintf("%s_%d", idTokenCookie, i)
if sd.request == nil {
sd.manager.logger.Errorf("CRITICAL: SetIDToken: sd.request is nil, cannot create chunk session %s", sessionName)
return
}
if chunkData == "" {
sd.manager.logger.Debug("Empty chunk data at index %d", i)
return
}
if len(chunkData) > maxCookieSize {
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
return
}
if !validateChunkSize(chunkData) {
sd.manager.logger.Errorf("CRITICAL: ID token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
return
}
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
return
}
session.Values["token_chunk"] = chunkData
session.Values["compressed"] = (compressed != token)
session.Values["chunk_created_at"] = time.Now().Unix()
sd.idTokenChunks[i] = session
}
sd.manager.logger.Debugf("SUCCESS: Stored ID token in %d chunks", len(chunks))
}
}
// GetRedirectCount returns the number of redirects in the current authentication flow.
// STABILITY FIX: Prevents infinite redirect loops by tracking redirect attempts.
// Returns:
// - The current redirect count, 0 if not set.
func (sd *SessionData) GetRedirectCount() int {
if count, ok := sd.mainSession.Values["redirect_count"].(int); ok {
return count
}
return 0
}
// IncrementRedirectCount increases the redirect counter by one.
// STABILITY FIX: Prevents infinite redirect loops by tracking successive redirects.
// Used to detect potential redirect loops and abort authentication if too many occur.
func (sd *SessionData) IncrementRedirectCount() {
currentCount := sd.GetRedirectCount()
sd.mainSession.Values["redirect_count"] = currentCount + 1
sd.dirty = true
}
// ResetRedirectCount resets the redirect counter to zero.
// STABILITY FIX: Prevents infinite redirect loops by clearing the counter
// when authentication completes successfully or when starting a new flow.
func (sd *SessionData) ResetRedirectCount() {
sd.mainSession.Values["redirect_count"] = 0
sd.dirty = true
}