Files
traefikoidc/session.go
2026-05-27 21:43:20 +01:00

2825 lines
89 KiB
Go

package traefikoidc
import (
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/sessions"
"github.com/lukaszraczylo/traefikoidc/internal/pool"
)
// constantTimeStringCompare performs a constant-time comparison of two strings
// to prevent timing attacks. Returns true if the strings are equal.
func constantTimeStringCompare(a, b string) bool {
if len(a) != len(b) {
return false
}
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// min returns the minimum of two integers.
// This is a utility function used throughout the session management code.
// Parameters:
// - a: The first integer to compare.
// - b: The second integer to compare.
//
// Returns:
// - The smaller of the two integers.
func min(a, b int) int {
if a < b {
return a
}
return b
}
// generateSecureRandomString creates a cryptographically secure random string.
// It generates random bytes using crypto/rand and encodes them as hexadecimal.
// This is used for session IDs and other security-sensitive random values.
// Parameters:
// - length: The number of random bytes to generate (output will be 2x this length in hex).
//
// Returns:
// - The hex-encoded random string.
// - An error if reading random bytes fails.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return hex.EncodeToString(bytes), nil
}
// Cookie name suffixes used for session management
// These are appended to the cookiePrefix to create full cookie names
// #nosec G101 -- These are cookie name suffixes, not hardcoded credentials
const (
mainCookieSuffix = "m"
accessTokenSuffix = "a"
refreshTokenSuffix = "r"
idTokenSuffix = "id"
combinedCookieSuffix = "s" // Combined session cookie suffix
defaultCookiePrefix = "_oidc_raczylo_"
)
const (
// maxCookieSize is the maximum raw data size per chunk before securecookie encoding.
//
// Browser cookie limit: 4096 bytes
// Securecookie overhead: ~2x (gob + AES encryption + MAC + double base64)
// Safety headroom: 1000 bytes for cookie metadata, headers, edge cases
//
// Math: 1400 raw → ~2800 encoded → safely under (4096 - 1000 = 3096) limit
maxCookieSize = 1400
// maxCombinedChunks is the maximum number of chunks allowed for combined session
maxCombinedChunks = 10
absoluteSessionTimeout = 24 * time.Hour
minEncryptionKeyLength = 32
)
// combinedSessionPayload is the JSON structure for combined cookie storage.
// Uses short field names to minimize size.
type combinedSessionPayload struct {
X map[string]interface{} `json:"x,omitempty"`
A string `json:"a,omitempty"`
R string `json:"r,omitempty"`
I string `json:"i,omitempty"`
Ui string `json:"ui,omitempty"`
Cs string `json:"cs,omitempty"`
N string `json:"n,omitempty"`
Cv string `json:"cv,omitempty"`
Ip string `json:"ip,omitempty"`
Ca int64 `json:"ca,omitempty"`
Rc int `json:"rc,omitempty"`
Au bool `json:"au,omitempty"`
}
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
// All other mainSession.Values keys are stored in the X (extra) field.
var knownSessionKeys = map[string]bool{
"access_token": true,
"refresh_token": true,
"id_token": true,
"user_identifier": true,
"authenticated": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
"incoming_path": true,
"created_at": true,
"redirect_count": true,
}
// compressCombinedPayload compresses the combined session payload using gzip.
// It serializes the payload to JSON, compresses it, and returns base64-encoded data.
// Returns the compressed string and any error encountered.
func compressCombinedPayload(payload *combinedSessionPayload) (string, error) {
jsonData, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal combined payload: %w", err)
}
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
if _, err := gz.Write(jsonData); err != nil {
return "", fmt.Errorf("failed to compress combined payload: %w", err)
}
if err := gz.Close(); err != nil {
return "", fmt.Errorf("failed to close gzip writer: %w", err)
}
compressed := base64.StdEncoding.EncodeToString(buf.Bytes())
return compressed, nil
}
// decompressCombinedPayload decompresses a base64+gzip encoded combined session payload.
// Returns the deserialized payload and any error encountered.
func decompressCombinedPayload(compressed string) (*combinedSessionPayload, error) {
if compressed == "" {
return nil, fmt.Errorf("empty compressed data")
}
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return nil, fmt.Errorf("failed to decode base64: %w", err)
}
gr, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer func() { _ = gr.Close() }()
// Limit decompressed size to prevent zip bombs
limitedReader := io.LimitReader(gr, 512*1024) // 512KB max
decompressed, err := io.ReadAll(limitedReader)
if err != nil {
return nil, fmt.Errorf("failed to decompress: %w", err)
}
var payload combinedSessionPayload
if err := json.Unmarshal(decompressed, &payload); err != nil {
return nil, fmt.Errorf("failed to unmarshal combined payload: %w", err)
}
return &payload, nil
}
// splitCombinedIntoChunks splits compressed data into chunks of maxCookieSize.
// Returns the chunks and the total number of chunks.
func splitCombinedIntoChunks(data string, chunkSize int) []string {
if len(data) <= chunkSize {
return []string{data}
}
var chunks []string
for i := 0; i < len(data); i += chunkSize {
end := i + chunkSize
if end > len(data) {
end = len(data)
}
chunks = append(chunks, data[i:end])
}
return chunks
}
// assembleCombinedChunks reassembles chunks back into the original compressed data.
// It reads chunks from session values in order (chunk_0, chunk_1, etc.).
func assembleCombinedChunks(sessions []*sessions.Session) string {
if len(sessions) == 0 {
return ""
}
var parts []string
for i := 0; i < len(sessions); i++ {
session := sessions[i]
if session == nil {
break
}
chunk, ok := session.Values["d"].(string) // "d" for data
if !ok || chunk == "" {
break
}
parts = append(parts, chunk)
}
return strings.Join(parts, "")
}
// compressToken compresses a JWT token using gzip compression if beneficial.
// It validates the token format, attempts compression, and verifies the compressed
// data can be decompressed correctly. Only compresses if it reduces size.
// Parameters:
// - token: The JWT token string to potentially compress.
//
// Returns:
// - The base64 encoded, gzipped string, or the original string if compression fails.
func compressToken(token string) string {
if token == "" {
return token
}
dotCount := strings.Count(token, ".")
if dotCount != 2 {
return token
}
if len(token) > 50*1024 {
return token
}
pm := pool.Get()
b := pm.GetBuffer(4096)
defer pm.PutBuffer(b)
gz := gzip.NewWriter(b)
written, err := gz.Write([]byte(token))
if err != nil || written != len(token) {
return token
}
if err := gz.Close(); err != nil {
return token
}
compressedBytes := b.Bytes()
if len(compressedBytes) == 0 {
return token
}
compressed := base64.StdEncoding.EncodeToString(compressedBytes)
if len(compressed) >= len(token) {
return token
}
decompressed := decompressTokenInternal(compressed)
if decompressed != token {
return token
}
if strings.Count(decompressed, ".") != 2 {
return token
}
return compressed
}
// decompressToken decompresses a previously compressed token string.
// It decodes the base64 data, validates gzip headers, and decompresses safely
// with size limits to prevent compression bombs.
// Parameters:
// - compressed: The base64-encoded compressed token string.
//
// Returns:
// - The decompressed original string, or the input string if decompression fails.
func decompressToken(compressed string) string {
return decompressTokenInternal(compressed)
}
// decompressTokenInternal is the internal decompression function.
// Separated internal function for integrity verification during compression.
// It performs the actual decompression logic with proper resource management.
// Parameters:
// - compressed: The compressed token string to decompress.
//
// Returns:
// - The decompressed token or the original string if decompression fails.
func decompressTokenInternal(compressed string) string {
if compressed == "" {
return compressed
}
if len(compressed) > 100*1024 {
return compressed
}
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return compressed
}
if len(data) == 0 {
return compressed
}
if len(data) < 2 || data[0] != 0x1f || data[1] != 0x8b {
return compressed
}
pm := pool.Get()
readerBuf := pm.GetHTTPResponseBuffer()
defer pm.PutHTTPResponseBuffer(readerBuf)
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return compressed
}
defer func() {
if closeErr := gz.Close(); closeErr != nil {
_ = closeErr
}
}()
limitedReader := io.LimitReader(gz, 500*1024)
if cap(readerBuf) >= 512*1024 {
readerBuf = readerBuf[:cap(readerBuf)]
n, err := limitedReader.Read(readerBuf)
if err != nil && err != io.EOF {
return compressed
}
decompressed := readerBuf[:n]
return string(decompressed)
}
decompressed, err := io.ReadAll(limitedReader)
if err != nil {
return compressed
}
if len(decompressed) == 0 {
return compressed
}
decompressedStr := string(decompressed)
if decompressedStr != "" && strings.Count(decompressedStr, ".") != 2 {
return compressed
}
return decompressedStr
}
// SessionManager manages OIDC session state and cookie-based storage.
// It provides secure session management with support for token compression,
// chunked storage for large tokens, session pooling for performance,
// session object reuse and supports both HTTP and HTTPS schemes.
type SessionManager struct {
sessionPool sync.Pool
ctx context.Context
store sessions.Store
logger *Logger
chunkManager *ChunkManager
memoryMonitor *TaskMemoryMonitor
cancel context.CancelFunc
cookieDomain string
cookiePrefix string
cookiePath string
sessionMaxAge time.Duration
activeSessions int64
poolHits int64
poolMisses int64
cleanupMutex sync.RWMutex
shutdownOnce sync.Once
forceHTTPS bool
cleanupDone bool
}
// NewSessionManager creates a new SessionManager instance with secure defaults.
// It initializes the cookie store with encryption, sets up session pooling,
// and configures chunk management for large tokens.
// Parameters:
// - encryptionKey: The key for encrypting session cookies (minimum 32 bytes).
// - forceHTTPS: Whether to force HTTPS-only cookies regardless of request scheme.
// - cookieDomain: The domain for session cookies (empty for auto-detection).
// - cookiePrefix: Prefix for session cookie names (empty for default "_oidc_raczylo_").
// - sessionMaxAge: Maximum session age duration (0 for default 24 hours).
// - logger: Logger instance for debug and error logging.
//
// Returns:
// - The configured SessionManager instance.
// - An error if the encryption key does not meet minimum length requirements.
func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain string, cookiePrefix string, sessionMaxAge time.Duration, logger *Logger) (*SessionManager, error) {
if len(encryptionKey) < minEncryptionKeyLength {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
// Set default cookie prefix if not provided
if cookiePrefix == "" {
cookiePrefix = defaultCookiePrefix
}
// Set default session max age if not provided (24 hours for backward compatibility)
if sessionMaxAge == 0 {
sessionMaxAge = absoluteSessionTimeout
}
ctx, cancel := context.WithCancel(context.Background())
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
cookieDomain: cookieDomain,
cookiePrefix: cookiePrefix,
sessionMaxAge: sessionMaxAge,
logger: logger,
chunkManager: NewChunkManager(logger),
ctx: ctx,
cancel: cancel,
}
// Initialize global memory monitoring (singleton)
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
// Start memory monitoring every 30 seconds (will skip if already started)
if err := sm.memoryMonitor.Start(30 * time.Second); err != nil {
logger.Debugf("Failed to start memory monitoring: %v", err)
}
sm.sessionPool.New = func() interface{} {
atomic.AddInt64(&sm.poolMisses, 1)
sd := &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
idTokenChunks: make(map[int]*sessions.Session),
combinedChunks: make(map[int]*sessions.Session),
useCombinedStorage: true, // Use combined storage by default for new sessions
refreshMutex: sync.Mutex{},
sessionMutex: sync.RWMutex{},
dirty: false,
inUse: false,
}
sd.Reset()
return sd
}
// Start background cleanup routine
go sm.backgroundCleanup()
return sm, nil
}
// Cookie name helper methods - build cookie names using the configured prefix
// mainCookieName returns the main session cookie name with the configured prefix
func (sm *SessionManager) mainCookieName() string {
return sm.cookiePrefix + mainCookieSuffix
}
// accessTokenCookieName returns the access token cookie name with the configured prefix
func (sm *SessionManager) accessTokenCookieName() string {
return sm.cookiePrefix + accessTokenSuffix
}
// refreshTokenCookieName returns the refresh token cookie name with the configured prefix
func (sm *SessionManager) refreshTokenCookieName() string {
return sm.cookiePrefix + refreshTokenSuffix
}
// idTokenCookieName returns the ID token cookie name with the configured prefix
func (sm *SessionManager) idTokenCookieName() string {
return sm.cookiePrefix + idTokenSuffix
}
// combinedCookieName returns the combined session cookie base name with the configured prefix
// Chunk cookies are named: prefix + "s" + "_" + chunkIndex (e.g., "_oidc_raczylo_s_0")
func (sm *SessionManager) combinedCookieName() string {
return sm.cookiePrefix + combinedCookieSuffix
}
// combinedChunkCookieName returns the name for a specific combined session chunk
func (sm *SessionManager) combinedChunkCookieName(chunkIndex int) string {
return fmt.Sprintf("%s_%d", sm.combinedCookieName(), chunkIndex)
}
// GetCookiePrefix returns the cookie prefix used for all OIDC session cookies.
func (sm *SessionManager) GetCookiePrefix() string {
return sm.cookiePrefix
}
// Shutdown gracefully shuts down the SessionManager and all its background tasks
func (sm *SessionManager) Shutdown() error {
var shutdownErr error
sm.shutdownOnce.Do(func() {
if sm.logger != nil {
sm.logger.Debug("SessionManager shutdown initiated")
}
// Cancel context to stop all background operations
if sm.cancel != nil {
sm.cancel()
}
// Stop memory monitor
if sm.memoryMonitor != nil {
sm.memoryMonitor.Stop()
}
// Stop chunk manager
if sm.chunkManager != nil {
sm.chunkManager.Shutdown()
}
// Force garbage collection to help cleanup
runtime.GC()
if sm.logger != nil {
sm.logger.Debug("SessionManager shutdown completed")
}
})
return shutdownErr
}
// backgroundCleanup runs periodic cleanup tasks for session management
func (sm *SessionManager) backgroundCleanup() {
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
defer ticker.Stop()
for {
select {
case <-sm.ctx.Done():
if sm.logger != nil {
sm.logger.Debug("Background cleanup routine terminated")
}
return
case <-ticker.C:
sm.performCleanupCycle()
}
}
}
// performCleanupCycle executes a complete cleanup cycle
func (sm *SessionManager) performCleanupCycle() {
if sm.logger != nil {
sm.logger.Debug("Starting background cleanup cycle")
}
startTime := time.Now()
// Run periodic chunk cleanup
sm.PeriodicChunkCleanup()
// Clean up session pool by forcing GC on old sessions
sm.cleanupSessionPool()
// Force garbage collection if memory usage is high
var m runtime.MemStats
runtime.ReadMemStats(&m)
if m.HeapAlloc > 50*1024*1024 { // 50MB threshold
runtime.GC()
if sm.logger != nil {
sm.logger.Debug("Forced garbage collection due to high memory usage")
}
}
duration := time.Since(startTime)
if sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
sm.logger.Debugf("Cleanup cycle completed in %v", duration)
}
}
// cleanupSessionPool performs cleanup on the session pool
func (sm *SessionManager) cleanupSessionPool() {
cleaned := 0
const maxCleanup = 20 // Limit cleanup per cycle to avoid performance impact
for i := 0; i < maxCleanup; i++ {
select {
case <-sm.ctx.Done():
return
default:
}
if poolSession := sm.sessionPool.Get(); poolSession != nil {
sessionData, ok := poolSession.(*SessionData)
if ok && sessionData != nil && !sessionData.inUse {
sessionData.Reset()
cleaned++
}
sm.sessionPool.Put(poolSession)
} else {
break // Pool is empty
}
}
if cleaned > 0 && sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
sm.logger.Debugf("Cleaned %d session pool objects", cleaned)
}
}
// GetSessionStats returns statistics about session management
func (sm *SessionManager) GetSessionStats() map[string]interface{} {
stats := make(map[string]interface{})
stats["active_sessions"] = atomic.LoadInt64(&sm.activeSessions)
stats["pool_hits"] = atomic.LoadInt64(&sm.poolHits)
stats["pool_misses"] = atomic.LoadInt64(&sm.poolMisses)
if sm.memoryMonitor != nil {
if currentStats, err := sm.memoryMonitor.GetCurrentStats(); err == nil {
stats["goroutines"] = currentStats.Goroutines
stats["heap_alloc"] = currentStats.HeapAlloc
stats["num_gc"] = currentStats.NumGC
}
}
return stats
}
// PeriodicChunkCleanup performs comprehensive session maintenance and cleanup.
// It cleans up orphaned token chunks, expired sessions, and unused pool objects.
// This helps maintain performance and prevent cookie accumulation in client browsers.
func (sm *SessionManager) PeriodicChunkCleanup() {
if sm == nil || sm.logger == nil {
return
}
// Check if context is canceled or we're in test mode to prevent logging after test completion
if sm.ctx == nil || sm.ctx.Err() != nil || isTestMode() {
return // Skip logging if context is canceled or in test mode
}
sm.logger.Debug("Starting comprehensive session cleanup cycle")
cleanupStart := time.Now()
var orphanedChunks, expiredSessions, cleanupErrors int
if sm.store != nil {
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok {
// Check context again before logging
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
sm.logger.Debug("Running session store cleanup")
}
_ = cookieStore
}
}
// Cleanup expired sessions in chunk manager to prevent memory leaks
if sm.chunkManager != nil {
sm.chunkManager.CleanupExpiredSessions()
}
poolCleaned := 0
for i := 0; i < 10; i++ {
if poolSession := sm.sessionPool.Get(); poolSession != nil {
sessionData, ok := poolSession.(*SessionData)
if ok && sessionData != nil && !sessionData.inUse {
sessionData.Reset()
poolCleaned++
}
sm.sessionPool.Put(poolSession)
}
}
// Check context before final logging
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
cleanupDuration := time.Since(cleanupStart)
sm.logger.Debugf("Session cleanup completed in %v: pool_cleaned=%d, orphaned_chunks=%d, expired_sessions=%d, errors=%d",
cleanupDuration, poolCleaned, orphanedChunks, expiredSessions, cleanupErrors)
}
}
// ValidateSessionHealth performs comprehensive validation of session integrity.
// It checks authentication state, validates token formats, and detects
// potential tampering or corruption in session data.
// Parameters:
// - sessionData: The session data to validate.
//
// Returns:
// - An error describing any validation failures, nil if session is healthy.
func (sm *SessionManager) ValidateSessionHealth(sessionData *SessionData) error {
if sessionData == nil {
return fmt.Errorf("session data is nil")
}
if !sessionData.GetAuthenticated() {
return fmt.Errorf("session is not authenticated or has expired")
}
accessToken := sessionData.GetAccessToken()
refreshToken := sessionData.GetRefreshToken()
idToken := sessionData.GetIDToken()
if accessToken != "" {
if err := sm.validateTokenFormat(accessToken, "access_token"); err != nil {
return fmt.Errorf("access token validation failed: %w", err)
}
}
if refreshToken != "" {
if err := sm.validateTokenFormat(refreshToken, "refresh_token"); err != nil {
return fmt.Errorf("refresh token validation failed: %w", err)
}
}
if idToken != "" {
if err := sm.validateTokenFormat(idToken, "id_token"); err != nil {
return fmt.Errorf("ID token validation failed: %w", err)
}
}
if err := sm.detectSessionTampering(sessionData); err != nil {
return fmt.Errorf("session tampering detected: %w", err)
}
return nil
}
// validateTokenFormat validates the structure and format of authentication tokens.
// It checks for corruption markers, validates JWT structure if applicable,
// and ensures tokens meet format requirements.
// Parameters:
// - token: The token string to validate.
// - tokenType: The type of token being validated (for error messages).
//
// Returns:
// - An error if the token has invalid structure or exceeds size limits.
func (sm *SessionManager) validateTokenFormat(token, tokenType string) error {
if token == "" {
return nil
}
if isCorruptionMarker(token) {
return fmt.Errorf("%s contains corruption marker", tokenType)
}
if strings.Count(token, ".") == 2 {
parts := strings.Split(token, ".")
for i, part := range parts {
if part == "" {
return fmt.Errorf("%s has empty part %d in JWT format", tokenType, i)
}
if strings.ContainsAny(part, "+/=") && !strings.ContainsAny(part, "-_") {
sm.logger.Debugf("Token %s part %d uses base64 instead of base64url encoding", tokenType, i)
}
}
}
return nil
}
// detectSessionTampering checks for indicators of session tampering.
// It examines session values for path traversal attempts, XSS payloads,
// and suspicious data patterns that might indicate malicious modification.
// Parameters:
// - sessionData: The session data to examine for tampering.
//
// Returns:
// - An error if tampering is detected, nil if session appears safe.
func (sm *SessionManager) detectSessionTampering(sessionData *SessionData) error {
if sessionData.mainSession == nil {
return fmt.Errorf("main session is missing")
}
for key, value := range sessionData.mainSession.Values {
if str, ok := value.(string); ok {
if strings.Contains(str, "../") || strings.Contains(str, "..\\") {
return fmt.Errorf("potential path traversal attempt in session key %v", key)
}
if strings.Contains(str, "<script") || strings.Contains(str, "javascript:") {
return fmt.Errorf("potential XSS attempt in session key %v", key)
}
if len(str) > 10000 {
return fmt.Errorf("suspiciously long session value for key %v (length: %d)", key, len(str))
}
}
}
return nil
}
// GetSessionMetrics returns metrics about session management for monitoring purposes.
// It provides information about session configuration, security settings,
// and internal state for debugging and monitoring.
// Returns:
// - A map containing session metrics and configuration information.
func (sm *SessionManager) GetSessionMetrics() map[string]interface{} {
metrics := make(map[string]interface{})
metrics["session_manager_type"] = "CookieStore"
metrics["force_https"] = sm.forceHTTPS
metrics["absolute_timeout_hours"] = sm.sessionMaxAge.Hours()
metrics["max_cookie_size"] = maxCookieSize
metrics["max_encoded_cookie_size"] = maxCookieSize * 2 // ~2x encoding overhead
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok && len(cookieStore.Codecs) > 0 {
metrics["has_encryption"] = true
metrics["codec_count"] = len(cookieStore.Codecs)
} else {
metrics["has_encryption"] = false
}
metrics["pool_implementation"] = "sync.Pool"
return metrics
}
// EnhanceSessionSecurity applies additional security measures to session options.
// It configures secure cookies, domain detection, SameSite policies, and
// adapts security settings based on request context and client characteristics.
// Parameters:
// - options: The base session options to enhance (can be nil).
// - r: The HTTP request context for security decisions.
//
// Returns:
// - Enhanced sessions.Options with additional security measures.
func (sm *SessionManager) EnhanceSessionSecurity(options *sessions.Options, r *http.Request) *sessions.Options {
if options == nil {
options = &sessions.Options{}
}
if r != nil {
userAgent := r.Header.Get("User-Agent")
if userAgent == "" {
sm.logger.Debugf("Request from %s missing User-Agent header", r.RemoteAddr)
options.MaxAge = int((sm.sessionMaxAge / 2).Seconds())
}
if r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS {
options.Secure = true
}
// Keep SameSite=Lax consistently for OAuth flows
// Removed dynamic switching based on XMLHttpRequest header to prevent redirect loop
options.SameSite = http.SameSiteLaxMode
}
options.HttpOnly = true
// Use configured cookie path (default "/" for backward compatibility)
cookiePath := sm.cookiePath
if cookiePath == "" {
cookiePath = "/"
}
options.Path = cookiePath
if sm.cookieDomain != "" {
options.Domain = sm.cookieDomain
sm.logger.Debugf("Using configured cookie domain: %s", sm.cookieDomain)
} else if options.Domain == "" && r != nil {
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
host = host[:colonIndex]
}
options.Domain = host
sm.logger.Debugf("Auto-detected cookie domain: %s", host)
}
}
return options
}
// getSessionOptions creates base session options with security settings.
// It configures cookie security, lifetime, path, and domain settings
// based on the HTTPS status and manager configuration.
// Parameters:
// - isSecure: Whether the request is over HTTPS or should be treated as secure.
//
// Returns:
// - A pointer to a configured sessions.Options struct.
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
baseOptions := &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: int(sm.sessionMaxAge.Seconds()),
Path: "/",
Domain: sm.cookieDomain,
}
return baseOptions
}
// CleanupOldCookies removes stale session cookies from the client browser.
// This method handles cleanup of cookies that may exist with different domain
// configurations, ensuring clean state when domain settings change.
// It removes cookies with various domain variations to ensure cleanup after configuration changes.
// Parameters:
// - w: The HTTP response writer for setting cookie deletion headers.
// - r: The HTTP request containing cookies to examine and clean up.
func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Request) {
cookies := r.Cookies()
currentDomain := sm.cookieDomain
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
host = host[:colonIndex]
}
// This ensures we clean up cookies from various possible domains
var domainsToClean []string
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
domainsToClean = append(domainsToClean,
host,
"."+host,
)
parts := strings.Split(host, ".")
if len(parts) > 2 {
parentDomain := strings.Join(parts[len(parts)-2:], ".")
domainsToClean = append(domainsToClean,
parentDomain,
"."+parentDomain,
)
}
}
processedCookies := make(map[string]bool)
for _, cookie := range cookies {
// Check if cookie belongs to this middleware instance using the configured prefix
if strings.HasPrefix(cookie.Name, sm.cookiePrefix) ||
strings.HasPrefix(cookie.Name, "access_token_chunk_") ||
strings.HasPrefix(cookie.Name, "refresh_token_chunk_") {
processedCookies[cookie.Name] = true
sm.cleanupMutex.RLock()
shouldCleanup := currentDomain != "" && !sm.cleanupDone
sm.cleanupMutex.RUnlock()
if shouldCleanup {
for _, domain := range domainsToClean {
if domain == currentDomain || domain == "."+currentDomain || "."+domain == currentDomain {
continue
}
deleteCookie := &http.Cookie{
Name: cookie.Name,
Value: "",
Path: "/",
Domain: domain,
MaxAge: -1,
HttpOnly: true,
Secure: r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, deleteCookie)
sm.logger.Debugf("Attempting to clean up cookie %s with domain %s", cookie.Name, domain)
}
}
}
}
if len(processedCookies) > 0 {
sm.cleanupMutex.Lock()
if !sm.cleanupDone {
sm.cleanupDone = true
}
sm.cleanupMutex.Unlock()
}
}
// GetSession retrieves or creates session data from the HTTP request.
// It first tries to load from combined cookies (new format), falling back to legacy
// cookies if combined cookies don't exist. Performs validation and timeout checks.
// The returned session must be explicitly returned to the pool by calling
// returnToPoolSafely() to prevent memory leaks.
// MEMORY LEAK FIX: Session is NOT returned to pool here - caller must call ReturnToPool() when done.
// Parameters:
// - r: The HTTP request containing session cookies.
//
// Returns:
// - The loaded SessionData instance.
// - An error if session loading or validation fails.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
sessionData, _ := sm.sessionPool.Get().(*SessionData) // Safe to ignore: pool return is best-effort
atomic.AddInt64(&sm.poolHits, 1)
atomic.AddInt64(&sm.activeSessions, 1)
sessionData.inUse = true
sessionData.request = r
sessionData.dirty = false
handleError := func(err error, message string) (*SessionData, error) {
if sessionData != nil {
sessionData.inUse = false
sessionData.Reset()
sm.sessionPool.Put(sessionData)
atomic.AddInt64(&sm.activeSessions, -1)
}
return nil, fmt.Errorf("%s: %w", message, err)
}
// Try to load from combined cookies first (new format)
if sm.loadFromCombinedCookies(r, sessionData) {
sessionData.useCombinedStorage = true
sm.logger.Debug("Loaded session from combined cookies")
// Check session timeout
if sessionData.getCreatedAtUnsafe() > 0 {
if time.Since(time.Unix(sessionData.getCreatedAtUnsafe(), 0)) > sm.sessionMaxAge {
_ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated
return handleError(fmt.Errorf("session timeout"), "session expired")
}
}
return sessionData, nil
}
// Fall back to legacy cookies
sessionData.useCombinedStorage = false
sm.logger.Debug("Loading session from legacy cookies")
var err error
sessionData.mainSession, err = sm.store.Get(r, sm.mainCookieName())
if err != nil {
return handleError(err, "failed to get main session")
}
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > sm.sessionMaxAge {
_ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated
return handleError(fmt.Errorf("session timeout"), "session expired")
}
}
sessionData.accessSession, err = sm.store.Get(r, sm.accessTokenCookieName())
if err != nil {
return handleError(err, "failed to get access token session")
}
sessionData.refreshSession, err = sm.store.Get(r, sm.refreshTokenCookieName())
if err != nil {
return handleError(err, "failed to get refresh token session")
}
sessionData.idTokenSession, err = sm.store.Get(r, sm.idTokenCookieName())
if err != nil {
return handleError(err, "failed to get ID token session")
}
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
for k := range sessionData.refreshTokenChunks {
delete(sessionData.refreshTokenChunks, k)
}
for k := range sessionData.idTokenChunks {
delete(sessionData.idTokenChunks, k)
}
sm.getTokenChunkSessions(r, sm.accessTokenCookieName(), sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, sm.refreshTokenCookieName(), sessionData.refreshTokenChunks)
sm.getTokenChunkSessions(r, sm.idTokenCookieName(), sessionData.idTokenChunks)
// If legacy session has data, migrate to combined storage on next save
if !sessionData.mainSession.IsNew {
sessionData.useCombinedStorage = true
sessionData.dirty = true // Mark dirty to trigger migration on save
sm.logger.Debug("Legacy session found, will migrate to combined storage on save")
}
return sessionData, nil
}
// loadFromCombinedCookies attempts to load session data from combined cookies.
// Returns true if combined cookies were found and successfully loaded.
func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *SessionData) bool {
// Check if first combined chunk exists
firstChunk, err := sm.store.Get(r, sm.combinedChunkCookieName(0))
if err != nil || firstChunk.IsNew {
return false
}
// Get total chunk count from first chunk
totalChunks, ok := firstChunk.Values["n"].(int)
if !ok || totalChunks < 1 || totalChunks > maxCombinedChunks {
sm.logger.Debugf("Invalid combined cookie chunk count: %v", firstChunk.Values["n"])
return false
}
// Load all chunks
chunkSessions := make([]*sessions.Session, totalChunks)
chunkSessions[0] = firstChunk
sessionData.combinedChunks[0] = firstChunk
for i := 1; i < totalChunks; i++ {
chunk, err := sm.store.Get(r, sm.combinedChunkCookieName(i))
if err != nil || chunk.IsNew {
sm.logger.Debugf("Missing combined cookie chunk %d", i)
return false
}
chunkSessions[i] = chunk
sessionData.combinedChunks[i] = chunk
}
// Assemble and decompress
compressed := assembleCombinedChunks(chunkSessions)
if compressed == "" {
sm.logger.Debug("Failed to assemble combined chunks")
return false
}
payload, err := decompressCombinedPayload(compressed)
if err != nil {
sm.logger.Debugf("Failed to decompress combined payload: %v", err)
return false
}
// Hydrate the legacy session objects for compatibility with existing getter methods
// We need to initialize them even though we're using combined storage
sessionData.mainSession, _ = sm.store.Get(r, sm.mainCookieName())
sessionData.accessSession, _ = sm.store.Get(r, sm.accessTokenCookieName())
sessionData.refreshSession, _ = sm.store.Get(r, sm.refreshTokenCookieName())
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
// Populate legacy session values from combined payload
sessionData.mainSession.Values["user_identifier"] = payload.Ui
sessionData.mainSession.Values["authenticated"] = payload.Au
sessionData.mainSession.Values["csrf"] = payload.Cs
sessionData.mainSession.Values["nonce"] = payload.N
sessionData.mainSession.Values["code_verifier"] = payload.Cv
sessionData.mainSession.Values["incoming_path"] = payload.Ip
sessionData.mainSession.Values["created_at"] = payload.Ca
sessionData.mainSession.Values["redirect_count"] = payload.Rc
// Restore extra custom session values
for key, val := range payload.X {
sessionData.mainSession.Values[key] = val
}
sessionData.accessSession.Values["token"] = payload.A
sessionData.accessSession.Values["compressed"] = false
sessionData.refreshSession.Values["token"] = payload.R
sessionData.refreshSession.Values["compressed"] = false
sessionData.idTokenSession.Values["token"] = payload.I
sessionData.idTokenSession.Values["compressed"] = false
return true
}
// getTokenChunkSessions loads all available token chunk sessions for a given token type.
// It iterates through numbered chunk sessions until no more are found,
// populating the provided chunks map with the loaded sessions.
// Parameters:
// - r: The HTTP request containing chunk cookies.
// - baseName: The base cookie name for the token type (e.g., "_oidc_raczylo_a").
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks)
// to populate with the found session chunks.
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
break
}
chunks[i] = session
}
}
// SessionData represents a user's authentication session with comprehensive token management.
// It handles main session data and supports large tokens that need to be
// split across multiple cookies due to browser size limitations.
// Supports both legacy (separate cookies) and combined (single compressed cookie) storage.
type SessionData struct {
manager *SessionManager
request *http.Request
// Legacy storage (kept for backward compatibility during migration)
mainSession *sessions.Session
accessSession *sessions.Session
refreshSession *sessions.Session
idTokenSession *sessions.Session
accessTokenChunks map[int]*sessions.Session
refreshTokenChunks map[int]*sessions.Session
idTokenChunks map[int]*sessions.Session
// Combined storage (new approach - single compressed cookie)
combinedChunks map[int]*sessions.Session
// useCombinedStorage indicates whether to use the new combined storage format
useCombinedStorage bool
refreshMutex sync.Mutex
sessionMutex sync.RWMutex
dirty bool
inUse bool
// cachedClaimsToken is the ID token string whose claims were last parsed and
// cached. A lazy, per-request cache to avoid re-parsing the JWT on every
// authenticated request (e.g. for headerTemplates). Protected by sessionMutex.
cachedClaimsToken string
// cachedClaims holds the parsed claims for cachedClaimsToken.
cachedClaims map[string]interface{}
// cachedClaimsErr holds the parse error (if any) for cachedClaimsToken so
// failures are not retried within the same request.
cachedClaimsErr error
}
// IsDirty returns true if the session data has been modified since it was last loaded or saved.
// This is used to optimize session saves by only writing when necessary.
// Returns:
// - true if the session has pending changes, false otherwise.
func (sd *SessionData) IsDirty() bool {
return sd.dirty
}
// MarkDirty marks the session as having pending changes that need to be saved.
// This is used when session data hasn't changed in content but should still
// trigger a session save (e.g., to ensure the cookie is re-issued).
func (sd *SessionData) MarkDirty() {
sd.dirty = true
}
// Save persists all session data including main session and token chunks.
// It applies security options, saves all session components, and handles
// errors gracefully by continuing to save other components even if one fails.
// Uses combined cookie storage for efficiency when useCombinedStorage is true.
// Parameters:
// - r: The HTTP request context for security option configuration.
// - w: The HTTP response writer for setting session cookies.
//
// Returns:
// - An error if saving any of the session components fails.
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sd.manager.forceHTTPS
options := sd.manager.getSessionOptions(isSecure)
options = sd.manager.EnhanceSessionSecurity(options, r)
// Use combined storage for new sessions
if sd.useCombinedStorage {
return sd.saveCombined(r, w, options)
}
// Legacy storage path (for backward compatibility)
return sd.saveLegacy(r, w, options)
}
// saveCombined saves all session data in a single compressed, chunked cookie.
// This reduces cookie count and total size through combined compression.
func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, options *sessions.Options) error {
// Build the combined payload
payload := &combinedSessionPayload{
A: sd.getAccessTokenUnsafe(),
R: sd.getRefreshTokenUnsafe(),
I: sd.getIDTokenUnsafe(),
Ui: sd.getUserIdentifierUnsafe(),
Au: sd.getAuthenticatedUnsafe(),
Cs: sd.getCSRFUnsafe(),
N: sd.getNonceUnsafe(),
Cv: sd.getCodeVerifierUnsafe(),
Ip: sd.getIncomingPathUnsafe(),
Ca: sd.getCreatedAtUnsafe(),
Rc: sd.getRedirectCountUnsafe(),
}
// Collect extra session values not handled by the standard fields
sd.sessionMutex.RLock()
if sd.mainSession != nil && len(sd.mainSession.Values) > 0 {
extra := make(map[string]interface{})
for key, val := range sd.mainSession.Values {
keyStr, ok := key.(string)
if !ok {
continue
}
// Skip known session keys that are already in the payload
if knownSessionKeys[keyStr] {
continue
}
// Store the extra value (must be JSON-serializable)
extra[keyStr] = val
}
if len(extra) > 0 {
payload.X = extra
}
}
sd.sessionMutex.RUnlock()
// Compress the payload
compressed, err := compressCombinedPayload(payload)
if err != nil {
sd.manager.logger.Errorf("Failed to compress combined payload: %v", err)
// Fall back to legacy storage on compression failure
return sd.saveLegacy(r, w, options)
}
sd.manager.logger.Debugf("Combined session: raw payload compressed to %d bytes", len(compressed))
// Split into chunks
chunks := splitCombinedIntoChunks(compressed, maxCookieSize)
if len(chunks) > maxCombinedChunks {
sd.manager.logger.Errorf("Combined session requires %d chunks, exceeds max %d", len(chunks), maxCombinedChunks)
return fmt.Errorf("session data too large: requires %d chunks, max is %d", len(chunks), maxCombinedChunks)
}
sd.manager.logger.Debugf("Combined session split into %d chunks", len(chunks))
var firstErr error
// Save each chunk
for i, chunkData := range chunks {
cookieName := sd.manager.combinedChunkCookieName(i)
session, err := sd.manager.store.Get(r, cookieName)
if err != nil {
sd.manager.logger.Errorf("Failed to get combined chunk session %s: %v", cookieName, err)
if firstErr == nil {
firstErr = err
}
continue
}
session.Values["d"] = chunkData // "d" for data
session.Values["n"] = len(chunks) // "n" for total number of chunks
session.Values["i"] = i // "i" for index
session.Options = options
if err := session.Save(r, w); err != nil {
sd.manager.logger.Errorf("Failed to save combined chunk %d: %v", i, err)
if firstErr == nil {
firstErr = err
}
}
sd.combinedChunks[i] = session
}
// Expire old combined chunks that are no longer needed
sd.expireOldCombinedChunks(r, w, options, len(chunks))
// Expire legacy cookies if they exist (migration)
sd.expireLegacyCookies(r, w, options)
if firstErr == nil {
sd.dirty = false
}
return firstErr
}
// saveLegacy saves session data using the old separate cookie approach.
// Kept for backward compatibility during migration.
func (sd *SessionData) saveLegacy(r *http.Request, w http.ResponseWriter, options *sessions.Options) error {
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
sd.idTokenSession.Options = options
var firstErr error
saveOrLogError := func(s *sessions.Session, name string) {
if s == nil {
sd.manager.logger.Errorf("Attempted to save nil session: %s", name)
if firstErr == nil {
firstErr = fmt.Errorf("attempted to save nil session: %s", name)
}
return
}
if err := s.Save(r, w); err != nil {
errMsg := fmt.Errorf("failed to save %s session: %w", name, err)
sd.manager.logger.Error("%s", errMsg.Error())
if firstErr == nil {
firstErr = errMsg
}
}
}
saveOrLogError(sd.mainSession, "main")
saveOrLogError(sd.accessSession, "access token")
saveOrLogError(sd.refreshSession, "refresh token")
saveOrLogError(sd.idTokenSession, "ID token")
for i, sessionChunk := range sd.accessTokenChunks {
sessionChunk.Options = options
saveOrLogError(sessionChunk, fmt.Sprintf("access token chunk %d", i))
}
for i, sessionChunk := range sd.refreshTokenChunks {
sessionChunk.Options = options
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
}
for i, sessionChunk := range sd.idTokenChunks {
sessionChunk.Options = options
saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i))
}
if firstErr == nil {
sd.dirty = false
}
return firstErr
}
// expireOldCombinedChunks expires combined cookie chunks that are no longer needed.
func (sd *SessionData) expireOldCombinedChunks(r *http.Request, w http.ResponseWriter, options *sessions.Options, currentChunks int) {
// Expire chunks beyond the current count
for i := currentChunks; i < maxCombinedChunks; i++ {
cookieName := sd.manager.combinedChunkCookieName(i)
session, err := sd.manager.store.Get(r, cookieName)
if err != nil || session.IsNew {
// No more old chunks
break
}
// Expire this chunk
expireOptions := *options
expireOptions.MaxAge = -1
session.Options = &expireOptions
for k := range session.Values {
delete(session.Values, k)
}
if err := session.Save(r, w); err != nil {
sd.manager.logger.Debugf("Failed to expire old combined chunk %d: %v", i, err)
}
}
}
// expireLegacyCookies expires old legacy format cookies during migration.
func (sd *SessionData) expireLegacyCookies(r *http.Request, w http.ResponseWriter, options *sessions.Options) {
expireOptions := *options
expireOptions.MaxAge = -1
// Helper to expire a legacy session cookie without clearing in-memory values
// IMPORTANT: We must NOT clear values from sessions that sd is holding,
// as store.Get() returns the same cached session object
expireLegacyChunk := func(cookieName string) {
session, err := sd.manager.store.Get(r, cookieName)
if err != nil || session.IsNew {
return // Cookie doesn't exist
}
session.Options = &expireOptions
// Clear values from chunk cookies (not the main session objects)
for k := range session.Values {
delete(session.Values, k)
}
_ = session.Save(r, w) // Best effort
}
// For main session cookies, only set expiration WITHOUT clearing values
// because sd.mainSession, sd.accessSession, etc. point to these same objects
expireLegacyMain := func(cookieName string) {
session, err := sd.manager.store.Get(r, cookieName)
if err != nil || session.IsNew {
return // Cookie doesn't exist
}
// Just expire the cookie, don't clear values (they're still needed in memory)
session.Options = &expireOptions
_ = session.Save(r, w) // Best effort
}
// Expire main legacy cookies (don't clear in-memory values)
expireLegacyMain(sd.manager.mainCookieName())
expireLegacyMain(sd.manager.accessTokenCookieName())
expireLegacyMain(sd.manager.refreshTokenCookieName())
expireLegacyMain(sd.manager.idTokenCookieName())
// Expire legacy chunk cookies (safe to clear values, they're separate from main sessions)
for i := 0; i < 50; i++ { // Max legacy chunks was 50
accessChunk := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
refreshChunk := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
idChunk := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
session, err := sd.manager.store.Get(r, accessChunk)
if err != nil || session.IsNew {
break // No more chunks
}
expireLegacyChunk(accessChunk)
expireLegacyChunk(refreshChunk)
expireLegacyChunk(idChunk)
}
}
// clearSessionValues removes all values from a session and optionally expires it.
// This is used during session cleanup and logout operations.
// Parameters:
// - session: The session to clear values from.
// - expire: If true, sets MaxAge to -1 to expire the cookie.
func clearSessionValues(session *sessions.Session, expire bool) {
if session == nil {
return
}
for k := range session.Values {
delete(session.Values, k)
}
if expire {
session.Options.MaxAge = -1
}
}
// clearAllSessionData clears all session data including main session and token chunks.
// It removes all session values and optionally expires all associated cookies.
// Parameters:
// - r: The HTTP request context (used for chunk cleanup).
// - expire: Whether to expire the cookies (set MaxAge to -1).
func (sd *SessionData) clearAllSessionData(r *http.Request, expire bool) {
clearSessionValues(sd.mainSession, expire)
clearSessionValues(sd.accessSession, expire)
clearSessionValues(sd.refreshSession, expire)
clearSessionValues(sd.idTokenSession, expire)
if expire && r != nil {
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
sd.clearTokenChunks(r, sd.idTokenChunks)
} else {
for k := range sd.accessTokenChunks {
delete(sd.accessTokenChunks, k)
}
for k := range sd.refreshTokenChunks {
delete(sd.refreshTokenChunks, k)
}
for k := range sd.idTokenChunks {
delete(sd.idTokenChunks, k)
}
}
if expire {
sd.dirty = true
}
}
// Clear completely clears all session data and safely returns the session to the pool.
// It removes all authentication data, expires cookies, and handles panic recovery.
// This method ensures the SessionData object is always returned to the pool.
// Parameters:
// - r: The HTTP request context.
// - w: The HTTP response writer for cookie expiration (can be nil).
//
// Returns:
// - An error if session saving fails during cleanup.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
defer func() {
sd.returnToPoolSafely()
}()
sd.sessionMutex.Lock()
sd.clearAllSessionData(r, true)
// Release the lock before calling Save to prevent deadlock
sd.sessionMutex.Unlock()
// This is primarily for testing - in production w will often be nil
var err error
if w != nil {
if r != nil && r.Header.Get("X-Test-Error") == "true" {
// Return a test error without trying to save problematic data
err = fmt.Errorf("test error triggered by X-Test-Error header")
} else {
err = sd.Save(r, w)
}
}
sd.request = nil
return err
}
// returnToPoolSafely safely returns the session to the object pool.
// Add thread-safe helper method to return session to pool.
// It ensures the session is marked as not in use and properly reset before pooling.
func (sd *SessionData) returnToPoolSafely() {
if sd != nil && sd.manager != nil {
if sd.inUse {
sd.inUse = false
sd.Reset()
sd.manager.sessionPool.Put(sd)
atomic.AddInt64(&sd.manager.activeSessions, -1)
}
}
}
// clearTokenChunks clears and expires all token chunk sessions.
// This is used during logout and session cleanup to ensure
// all token data is properly removed from the client.
// Parameters:
// - r: The HTTP request context.
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
func (sd *SessionData) clearTokenChunks(_ *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
clearSessionValues(session, true)
}
}
// GetAuthenticated returns whether the user is currently authenticated.
// It checks both the authentication flag and session timeout.
// Returns:
// - true if the user is authenticated and the session is not expired.
// - false otherwise.
func (sd *SessionData) GetAuthenticated() bool {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
return sd.getAuthenticatedUnsafe()
}
// getAuthenticatedUnsafe checks authentication status without acquiring locks.
// Used when the mutex is already held to avoid deadlocks.
// It validates both the authentication flag and session creation time.
// Returns:
// - true if authenticated and not expired, false otherwise.
func (sd *SessionData) getAuthenticatedUnsafe() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
if !auth {
return false
}
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
}
return time.Since(time.Unix(createdAt, 0)) <= sd.manager.sessionMaxAge
}
// SetAuthenticated sets the authentication status and manages session security.
// When setting to true, it generates a new secure session ID and updates timestamps.
// This prevents session fixation attacks by regenerating the session identifier.
// Parameters:
// - value: The authentication status to set.
//
// Returns:
// - An error if generating a new session ID fails when setting value to true.
func (sd *SessionData) SetAuthenticated(value bool) error {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentAuth := sd.getAuthenticatedUnsafe()
changed := false
if currentAuth != value {
changed = true
}
if value {
id, err := generateSecureRandomString(64)
if err != nil {
return fmt.Errorf("failed to generate secure session id: %w", err)
}
maxRetries := 5
for retry := 0; retry < maxRetries; retry++ {
if sd.mainSession.ID != id {
break
}
id, err = generateSecureRandomString(64)
if err != nil {
return fmt.Errorf("failed to generate secure session id on retry %d: %w", retry, err)
}
}
if sd.mainSession.ID != id {
changed = true
}
sd.mainSession.ID = id
newCreationTime := time.Now().Unix()
if oldTime, ok := sd.mainSession.Values["created_at"].(int64); !ok || oldTime != newCreationTime {
changed = true
}
sd.mainSession.Values["created_at"] = newCreationTime
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
changed = true
}
} else {
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
changed = true
}
}
sd.mainSession.Values["authenticated"] = value
if changed {
sd.dirty = true
}
return nil
}
// resetSession prepares a session for reuse by clearing its state.
// This is specifically for pool reuse preparation to ensure
// no data leaks between different user sessions.
// Parameters:
// - session: The session to reset for reuse.
func resetSession(session *sessions.Session) {
if session == nil {
return
}
clearSessionValues(session, false)
session.ID = ""
session.IsNew = true
}
// Reset clears all session data and prepares the SessionData for reuse.
// It ensures no authentication data persists when the object is reused
// between different users/sessions.
func (sd *SessionData) Reset() {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
sd.clearAllSessionData(nil, false)
resetSession(sd.mainSession)
resetSession(sd.accessSession)
resetSession(sd.refreshSession)
resetSession(sd.idTokenSession)
// Clear combined chunks
for k, session := range sd.combinedChunks {
resetSession(session)
delete(sd.combinedChunks, k)
}
// Clear redirect count to prevent leaking between sessions
if sd.mainSession != nil && sd.mainSession.Values != nil {
delete(sd.mainSession.Values, "redirect_count")
}
sd.dirty = false
sd.inUse = false
sd.request = nil
sd.useCombinedStorage = true // Reset to use combined storage by default
// Drop any cached claims so pooled SessionData does not leak claim data
// between requests/users.
sd.cachedClaimsToken = ""
sd.cachedClaims = nil
sd.cachedClaimsErr = nil
// Reset the refresh mutex to ensure clean state
// Note: We don't need to lock it since sessionMutex is already held
// and this session is not in use by any request
}
// ReturnToPool manually returns the session to the object pool.
// This is used in cleanup paths where Clear() is not called, to prevent memory leaks.
// It only returns the session if it's not currently in use.
func (sd *SessionData) ReturnToPool() {
if sd != nil && sd.manager != nil {
if !sd.inUse {
sd.Reset()
sd.manager.sessionPool.Put(sd)
atomic.AddInt64(&sd.manager.activeSessions, -1)
}
}
}
// GetAccessToken retrieves the user's access token from session storage.
// It handles both single-cookie storage and chunked storage for large tokens,
// with automatic decompression if the token was compressed.
// Returns:
// - The complete, decompressed access token string, or an empty string if not found.
func (sd *SessionData) GetAccessToken() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
return sd.getAccessTokenUnsafe()
}
// getAccessTokenUnsafe retrieves the access token without acquiring locks.
// Enhanced token retrieval with comprehensive integrity checks and recovery mechanisms.
// Used when the session mutex is already held to prevent deadlocks.
// Returns:
// - The complete access token string or empty string on error.
func (sd *SessionData) getAccessTokenUnsafe() string {
token, _ := sd.accessSession.Values["token"].(string)
compressed, _ := sd.accessSession.Values["compressed"].(bool)
// Debug: Check if manager/chunkManager is nil
if sd.manager == nil || sd.manager.chunkManager == nil {
// Direct return if no chunk manager (test scenario)
return token
}
result := sd.manager.chunkManager.GetToken(
token,
compressed,
sd.accessTokenChunks,
AccessTokenConfig,
)
if result.Error != nil {
// Check if we have a raw token available
// This handles cases where the token exists but doesn't validate as JWT
if token != "" && !compressed && len(sd.accessTokenChunks) == 0 {
// We have a non-chunked, non-compressed token that failed validation
// Check if it's an opaque token (doesn't have JWT structure)
dotCount := strings.Count(token, ".")
if dotCount != 2 {
// This is likely an opaque token that failed JWT validation
// Return the raw token as-is since opaque tokens are valid
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("Returning opaque access token (dots: %d) despite validation error: %v", dotCount, result.Error)
}
return token
}
}
// For JWT validation errors or other issues, log and return empty
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("ChunkManager.GetToken error: %v", result.Error)
}
return ""
}
return result.Token
}
// SetAccessToken stores an access token with automatic compression and chunking.
// It validates token format, compresses if beneficial, and splits into chunks
// if the token exceeds cookie size limits. Includes integrity verification.
// Parameters:
// - token: The access token string to store.
func (sd *SessionData) SetAccessToken(token string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
if token != "" {
if len(token) < 20 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
}
return
}
}
currentAccessToken := sd.getAccessTokenUnsafe()
if constantTimeStringCompare(currentAccessToken, token) {
return
}
sd.dirty = true
// Debug: Check if accessSession is properly initialized
if sd.accessSession == nil {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Errorf("CRITICAL: accessSession is nil when trying to store token")
}
return
}
if sd.request != nil {
sd.expireAccessTokenChunksEnhanced(nil)
}
for k := range sd.accessTokenChunks {
delete(sd.accessTokenChunks, k)
}
if token == "" {
if sd.accessSession != nil {
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = false
}
return
}
compressed := compressToken(token)
// Debug for test
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("Token compression: original %d bytes, compressed %d bytes", len(token), len(compressed))
}
if len(compressed) > 100*1024 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Info("Access token too large after compression (%d bytes) - storing uncompressed", len(compressed))
}
return
}
if compressed != token {
testDecompressed := decompressToken(compressed)
if testDecompressed != token {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Access token compression verification failed - storing uncompressed")
}
compressed = token
}
}
if len(compressed) <= maxCookieSize {
if sd.accessSession != nil {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = (compressed != token)
// Debug for test
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("Stored token in session: compressed=%v, token_len=%d",
compressed != token, len(compressed))
}
}
} else {
if sd.accessSession != nil {
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = (compressed != token)
}
chunks := splitIntoChunks(compressed, maxCookieSize)
if len(chunks) == 0 {
sd.manager.logger.Error("Failed to create chunks for access token")
return
}
if len(chunks) > 50 {
sd.manager.logger.Info("Too many chunks (%d) for access token", len(chunks))
return
}
testReassembled := strings.Join(chunks, "")
if testReassembled != compressed {
sd.manager.logger.Debug("Access token chunk reassembly test failed")
return
}
for i, chunkData := range chunks {
sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
if sd.request == nil {
sd.manager.logger.Error("SetAccessToken: sd.request is nil, cannot create chunk session %s", sessionName)
return
}
if chunkData == "" {
sd.manager.logger.Debug("Empty chunk data at index %d", i)
return
}
if len(chunkData) > maxCookieSize {
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
return
}
if !validateChunkSize(chunkData) {
sd.manager.logger.Errorf("CRITICAL: Chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
return
}
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
return
}
session.Values["token_chunk"] = chunkData
session.Values["compressed"] = (compressed != token)
session.Values["chunk_created_at"] = time.Now().Unix()
sd.accessTokenChunks[i] = session
}
sd.manager.logger.Debugf("SUCCESS: Stored access token in %d chunks", len(chunks))
}
}
// GetRefreshToken retrieves the user's refresh token from session storage.
// It handles both single-cookie storage and chunked storage for large tokens,
// with automatic decompression if the token was compressed.
// Returns:
// - The complete, decompressed refresh token string, or an empty string if not found.
func (sd *SessionData) GetRefreshToken() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
token, _ := sd.refreshSession.Values["token"].(string)
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
result := sd.manager.chunkManager.GetToken(
token,
compressed,
sd.refreshTokenChunks,
RefreshTokenConfig,
)
if result.Error != nil {
// Check if we have a raw token available
// This handles cases where the token exists but doesn't validate as JWT
if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 {
// We have a non-chunked, non-compressed token that failed validation
// Check if it's an opaque token (doesn't have JWT structure)
dotCount := strings.Count(token, ".")
if dotCount != 2 {
// This is likely an opaque token that failed JWT validation
// Return the raw token as-is since opaque tokens are valid
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("Returning opaque refresh token (dots: %d) despite validation error: %v", dotCount, result.Error)
}
return token
}
}
// For JWT validation errors or other issues, log and return empty
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debugf("ChunkManager.GetToken error for refresh token: %v", result.Error)
}
return ""
}
return result.Token
}
// SetRefreshToken stores a refresh token with automatic compression and chunking.
// It validates token size, compresses if beneficial, and splits into chunks
// if needed. Includes comprehensive error checking and integrity verification.
// Parameters:
// - token: The refresh token string to store.
func (sd *SessionData) SetRefreshToken(token string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
if len(token) > 50*1024 {
sd.manager.logger.Errorf("CRITICAL: Refresh token too large (%d bytes) - possible corruption, rejecting", len(token))
return
}
// Get current refresh token without mutex to avoid deadlock since we already hold the lock
var currentRefreshToken string
sessionToken, _ := sd.refreshSession.Values["token"].(string)
if sessionToken != "" {
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
decompressed := decompressToken(sessionToken)
currentRefreshToken = decompressed
} else {
currentRefreshToken = sessionToken
}
} else if len(sd.refreshTokenChunks) > 0 {
// Simplified chunked token retrieval for deadlock prevention
var chunks []string
for i := 0; i < len(sd.refreshTokenChunks); i++ {
if session, ok := sd.refreshTokenChunks[i]; ok {
if chunk, chunkOk := session.Values["token_chunk"].(string); chunkOk && chunk != "" {
chunks = append(chunks, chunk)
}
}
}
if len(chunks) == len(sd.refreshTokenChunks) {
reassembled := strings.Join(chunks, "")
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
currentRefreshToken = decompressToken(reassembled)
} else {
currentRefreshToken = reassembled
}
}
}
if constantTimeStringCompare(currentRefreshToken, token) {
return
}
sd.dirty = true
if sd.request != nil {
sd.expireRefreshTokenChunksEnhanced(nil)
}
for k := range sd.refreshTokenChunks {
delete(sd.refreshTokenChunks, k)
}
if token == "" {
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = false
return
}
compressed := compressToken(token)
if compressed != token {
testDecompressed := decompressToken(compressed)
if testDecompressed != token {
sd.manager.logger.Errorf("CRITICAL: Refresh token compression verification failed - storing uncompressed")
compressed = token
}
}
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = (compressed != token)
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
} else {
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = (compressed != token)
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
chunks := splitIntoChunks(compressed, maxCookieSize)
if len(chunks) == 0 {
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for refresh token")
return
}
if len(chunks) > 50 {
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for refresh token - possible corruption", len(chunks))
return
}
testReassembled := strings.Join(chunks, "")
if testReassembled != compressed {
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk reassembly test failed")
return
}
for i, chunkData := range chunks {
sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
if sd.request == nil {
sd.manager.logger.Errorf("CRITICAL: SetRefreshToken: sd.request is nil, cannot create chunk session %s", sessionName)
return
}
if chunkData == "" {
sd.manager.logger.Errorf("CRITICAL: Empty refresh token chunk data at index %d", i)
return
}
if len(chunkData) > maxCookieSize {
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
return
}
if !validateChunkSize(chunkData) {
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
return
}
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
sd.manager.logger.Errorf("CRITICAL: Failed to get refresh token chunk session %s: %v", sessionName, err)
return
}
session.Values["token_chunk"] = chunkData
session.Values["compressed"] = (compressed != token)
session.Values["chunk_created_at"] = time.Now().Unix()
sd.refreshTokenChunks[i] = session
}
sd.manager.logger.Debugf("SUCCESS: Stored refresh token in %d chunks", len(chunks))
}
}
// GetRefreshTokenIssuedAt retrieves the timestamp when the refresh token was issued/stored.
// Returns the time when the current refresh token was obtained, or zero time if not available.
func (sd *SessionData) GetRefreshTokenIssuedAt() time.Time {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
if issuedAtUnix, ok := sd.refreshSession.Values["issued_at"].(int64); ok {
return time.Unix(issuedAtUnix, 0)
}
// For chunked tokens, check the first chunk for timestamp
if len(sd.refreshTokenChunks) > 0 {
if session, exists := sd.refreshTokenChunks[0]; exists {
if chunkCreatedAt, ok := session.Values["chunk_created_at"].(int64); ok {
return time.Unix(chunkCreatedAt, 0)
}
}
}
return time.Time{}
}
// expireAccessTokenChunksEnhanced expires all access token chunks and detects orphaned chunks.
// It searches for all existing chunks, identifies orphaned or expired chunks,
// and properly expires them to prevent cookie accumulation.
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireAccessTokenChunksEnhanced(w http.ResponseWriter) {
const maxChunkSearchLimit = 50
orphanedChunks := 0
for i := 0; i < maxChunkSearchLimit; i++ {
sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
break
}
if session.IsNew {
break
}
if chunk, exists := session.Values["token_chunk"]; exists {
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
chunkAge := time.Since(time.Unix(createdAt, 0))
if chunkAge > 24*time.Hour {
orphanedChunks++
sd.manager.logger.Debugf("Found orphaned access token chunk %d (age: %v)", i, chunkAge)
}
} else if chunk != nil {
orphanedChunks++
sd.manager.logger.Debugf("Found access token chunk %d without timestamp, treating as orphaned", i)
}
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired access token chunk %d: %v", i, err)
}
}
}
if orphanedChunks > 0 {
sd.manager.logger.Infof("Cleaned up %d orphaned access token chunks", orphanedChunks)
}
}
// expireRefreshTokenChunksEnhanced expires all refresh token chunks and detects orphaned chunks.
// It searches for all existing chunks, identifies orphaned or expired chunks,
// and properly expires them to prevent cookie accumulation.
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireRefreshTokenChunksEnhanced(w http.ResponseWriter) {
const maxChunkSearchLimit = 50
orphanedChunks := 0
for i := 0; i < maxChunkSearchLimit; i++ {
sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
break
}
if session.IsNew {
break
}
if chunk, exists := session.Values["token_chunk"]; exists {
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
chunkAge := time.Since(time.Unix(createdAt, 0))
if chunkAge > 24*time.Hour {
orphanedChunks++
sd.manager.logger.Debugf("Found orphaned refresh token chunk %d (age: %v)", i, chunkAge)
}
} else if chunk != nil {
orphanedChunks++
sd.manager.logger.Debugf("Found refresh token chunk %d without timestamp, treating as orphaned", i)
}
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired refresh token chunk %d: %v", i, err)
}
}
}
if orphanedChunks > 0 {
sd.manager.logger.Infof("Cleaned up %d orphaned refresh token chunks", orphanedChunks)
}
}
// expireIDTokenChunksEnhanced expires all ID token chunks and detects orphaned chunks.
// It searches for all existing chunks, identifies orphaned or expired chunks,
// and properly expires them to prevent cookie accumulation.
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireIDTokenChunksEnhanced(w http.ResponseWriter) {
const maxChunkSearchLimit = 50
orphanedChunks := 0
for i := 0; i < maxChunkSearchLimit; i++ {
sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
break
}
if session.IsNew {
break
}
if chunk, exists := session.Values["token_chunk"]; exists {
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
chunkAge := time.Since(time.Unix(createdAt, 0))
if chunkAge > 24*time.Hour {
orphanedChunks++
sd.manager.logger.Debugf("Found orphaned ID token chunk %d (age: %v)", i, chunkAge)
}
} else if chunk != nil {
orphanedChunks++
sd.manager.logger.Debugf("Found ID token chunk %d without timestamp, treating as orphaned", i)
}
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired ID token chunk %d: %v", i, err)
}
}
}
if orphanedChunks > 0 {
sd.manager.logger.Infof("Cleaned up %d orphaned ID token chunks", orphanedChunks)
}
}
// splitIntoChunks divides a string into chunks of specified maximum size.
// It ensures chunks don't exceed browser cookie limits and handles
// the string splitting logic for large token storage.
// Parameters:
// - s: The string to split into chunks.
// - chunkSize: The maximum size for each chunk.
//
// Returns:
// - A slice of strings representing the chunks.
func splitIntoChunks(s string, chunkSize int) []string {
effectiveChunkSize := min(chunkSize, maxCookieSize)
var chunks []string
for len(s) > 0 {
if len(s) > effectiveChunkSize {
chunks = append(chunks, s[:effectiveChunkSize])
s = s[effectiveChunkSize:]
} else {
chunks = append(chunks, s)
break
}
}
return chunks
}
// validateChunkSize checks if a chunk will fit within browser cookie limits.
// It estimates the encoded size including cookie overhead and headers
// to ensure the chunk won't exceed browser-imposed cookie size limits.
// Parameters:
// - chunkData: The chunk data to validate.
//
// Returns:
// - true if the chunk is safe to store, false if it may exceed browser limits.
func validateChunkSize(chunkData string) bool {
// Estimate ~50% overhead for encoding, compare against ~2x maxCookieSize limit
estimatedEncodedSize := len(chunkData) + (len(chunkData) * 50 / 100)
return estimatedEncodedSize <= maxCookieSize*2
}
// isCorruptionMarker detects if data contains known corruption indicators.
// It checks for specific corruption markers and invalid characters
// that indicate the data has been tampered with or corrupted.
// Parameters:
// - data: The data string to check for corruption markers.
//
// Returns:
// - true if the data contains corruption markers, false otherwise.
func isCorruptionMarker(data string) bool {
if data == "" {
return false
}
corruptionMarkers := []string{
"__CORRUPTION_MARKER_TEST__",
"__INVALID_BASE64_DATA__",
"__CORRUPTED_CHUNK_DATA__",
"!@#$%^&*()",
"<<<CORRUPTED>>>",
}
for _, marker := range corruptionMarkers {
if data == marker {
return true
}
}
if len(data) > 10 {
invalidChars := "!@#$%^&*(){}[]|\\:;\"'<>?,`~"
for _, char := range invalidChars {
if strings.ContainsRune(data, char) {
return true
}
}
}
return false
}
// GetCSRF retrieves the CSRF token for state validation.
// This token is used to prevent cross-site request forgery attacks
// during the OIDC authentication flow.
// Returns:
// - The CSRF token string, or an empty string if not set.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores the CSRF token for state validation.
// The token is used to validate the state parameter in OAuth callbacks.
// Parameters:
// - token: The CSRF token to store.
func (sd *SessionData) SetCSRF(token string) {
currentVal, _ := sd.mainSession.Values["csrf"].(string)
if currentVal != token {
sd.mainSession.Values["csrf"] = token
sd.dirty = true
}
}
// GetNonce retrieves the nonce for ID token validation.
// The nonce prevents replay attacks by ensuring ID tokens
// were issued in response to the specific authentication request.
// Returns:
// - The nonce string, or an empty string if not set.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores the nonce for ID token validation.
// The nonce will be validated against the nonce claim in received ID tokens.
// Parameters:
// - nonce: The nonce string to store.
func (sd *SessionData) SetNonce(nonce string) {
currentVal, _ := sd.mainSession.Values["nonce"].(string)
if currentVal != nonce {
sd.mainSession.Values["nonce"] = nonce
sd.dirty = true
}
}
// GetCodeVerifier retrieves the PKCE code verifier.
// This is used in the PKCE (Proof Key for Code Exchange) flow
// to enhance security for public clients.
// Returns:
// - The code verifier string, or an empty string if not set or PKCE is disabled.
func (sd *SessionData) GetCodeVerifier() string {
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
return codeVerifier
}
// SetCodeVerifier stores the PKCE code verifier.
// The code verifier is used to generate the code challenge sent to the
// authorization server and validated during token exchange.
// Parameters:
// - codeVerifier: The PKCE code verifier string to store.
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
currentVal, _ := sd.mainSession.Values["code_verifier"].(string)
if currentVal != codeVerifier {
sd.mainSession.Values["code_verifier"] = codeVerifier
sd.dirty = true
}
}
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
// upn, preferred_username, etc.). The value is used for authorization
// decisions and header injection.
// Returns:
// - The user identifier string, or an empty string if not set.
func (sd *SessionData) GetUserIdentifier() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// SetUserIdentifier stores the authenticated user's identifier value.
// Parameters:
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
if currentVal != userIdentifier {
sd.mainSession.Values["user_identifier"] = userIdentifier
sd.dirty = true
}
}
// GetIncomingPath retrieves the original request URI that triggered authentication.
// This path is used to redirect the user back to their intended destination
// after successful authentication.
// Returns:
// - The original request URI string, or an empty string if not set.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request URI for post-authentication redirect.
// This allows the user to be redirected to their originally requested resource
// after completing the authentication flow.
// Parameters:
// - path: The original request URI string (e.g., "/protected/resource?id=123").
func (sd *SessionData) SetIncomingPath(path string) {
currentVal, _ := sd.mainSession.Values["incoming_path"].(string)
if currentVal != path {
sd.mainSession.Values["incoming_path"] = path
sd.dirty = true
}
}
// GetIDToken retrieves the user's ID token from session storage.
// The ID token contains user claims and is used for user identification
// and authorization decisions. Handles compression and chunking automatically.
// Returns:
// - The complete, decompressed ID token string, or an empty string if not found.
func (sd *SessionData) GetIDToken() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
return sd.getIDTokenUnsafe()
}
// GetIDTokenClaims returns claims parsed from the current ID token, caching
// the result on the SessionData so repeated callers within the same request
// do not re-parse the JWT. The cache is keyed on the ID token string and is
// cleared when the SessionData is reset (see Reset) or when the ID token
// changes (e.g. after a refresh).
//
// The parser parameter is typically the TraefikOidc.extractClaimsFunc, which
// lets tests inject mocks just like the direct call it replaces.
//
// Returns an empty claims map and a nil error when the session has no ID
// token, matching the existing "no-op" behavior of the caller sites.
func (sd *SessionData) GetIDTokenClaims(parser func(string) (map[string]interface{}, error)) (map[string]interface{}, error) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
token := sd.getIDTokenUnsafe()
if token == "" {
// Invalidate any stale cache without running the parser.
sd.cachedClaimsToken = ""
sd.cachedClaims = nil
sd.cachedClaimsErr = nil
return nil, nil
}
if sd.cachedClaimsToken == token && (sd.cachedClaims != nil || sd.cachedClaimsErr != nil) {
return sd.cachedClaims, sd.cachedClaimsErr
}
claims, err := parser(token)
sd.cachedClaimsToken = token
sd.cachedClaims = claims
sd.cachedClaimsErr = err
return claims, err
}
// getIDTokenUnsafe retrieves the ID token without acquiring locks.
// Enhanced ID token retrieval with comprehensive integrity checks and chunking support.
// Used when the session mutex is already held to prevent deadlocks.
// Returns:
// - The complete ID token string or empty string on error.
func (sd *SessionData) getIDTokenUnsafe() string {
token, _ := sd.idTokenSession.Values["token"].(string)
compressed, _ := sd.idTokenSession.Values["compressed"].(bool)
// Debug: Check if manager/chunkManager is nil
if sd.manager == nil || sd.manager.chunkManager == nil {
// Direct return if no chunk manager (test scenario)
return token
}
result := sd.manager.chunkManager.GetToken(
token,
compressed,
sd.idTokenChunks,
IDTokenConfig,
)
if result.Error != nil {
return ""
}
return result.Token
}
// getRefreshTokenUnsafe retrieves the refresh token without acquiring locks.
// Used when the session mutex is already held to prevent deadlocks.
func (sd *SessionData) getRefreshTokenUnsafe() string {
token, _ := sd.refreshSession.Values["token"].(string)
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if sd.manager == nil || sd.manager.chunkManager == nil {
return token
}
result := sd.manager.chunkManager.GetToken(
token,
compressed,
sd.refreshTokenChunks,
RefreshTokenConfig,
)
if result.Error != nil {
// Handle opaque tokens
if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 {
if strings.Count(token, ".") != 2 {
return token
}
}
return ""
}
return result.Token
}
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
func (sd *SessionData) getUserIdentifierUnsafe() string {
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
func (sd *SessionData) getCSRFUnsafe() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// getNonceUnsafe retrieves the nonce without acquiring locks.
func (sd *SessionData) getNonceUnsafe() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// getCodeVerifierUnsafe retrieves the code verifier without acquiring locks.
func (sd *SessionData) getCodeVerifierUnsafe() string {
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
return codeVerifier
}
// getIncomingPathUnsafe retrieves the incoming path without acquiring locks.
func (sd *SessionData) getIncomingPathUnsafe() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// getCreatedAtUnsafe retrieves the created_at timestamp without acquiring locks.
func (sd *SessionData) getCreatedAtUnsafe() int64 {
createdAt, _ := sd.mainSession.Values["created_at"].(int64)
return createdAt
}
// getRedirectCountUnsafe retrieves the redirect count without acquiring locks.
func (sd *SessionData) getRedirectCountUnsafe() int {
count, _ := sd.mainSession.Values["redirect_count"].(int)
return count
}
// SetIDToken stores an ID token with automatic compression and chunking.
// It validates the JWT format, compresses if beneficial, and splits into chunks
// if the token exceeds cookie size limits. Includes comprehensive validation.
// Parameters:
// - token: The ID token string to store.
func (sd *SessionData) SetIDToken(token string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
if token != "" {
dotCount := strings.Count(token, ".")
if dotCount != 2 {
sd.manager.logger.Errorf("CRITICAL: Attempt to store invalid JWT ID token format (dots: %d) - rejecting", dotCount)
return
}
}
if len(token) > 50*1024 {
sd.manager.logger.Errorf("CRITICAL: ID token too large (%d bytes) - possible corruption, rejecting", len(token))
return
}
currentIDToken := sd.getIDTokenUnsafe()
if constantTimeStringCompare(currentIDToken, token) {
return
}
sd.dirty = true
if sd.request != nil {
sd.expireIDTokenChunksEnhanced(nil)
}
for k := range sd.idTokenChunks {
delete(sd.idTokenChunks, k)
}
if token == "" {
if sd.idTokenSession != nil {
sd.idTokenSession.Values["token"] = ""
sd.idTokenSession.Values["compressed"] = false
}
return
}
compressed := compressToken(token)
if compressed != token {
testDecompressed := decompressToken(compressed)
if testDecompressed != token {
sd.manager.logger.Errorf("CRITICAL: ID token compression verification failed - storing uncompressed")
compressed = token
}
}
if len(compressed) <= maxCookieSize {
if sd.idTokenSession != nil {
sd.idTokenSession.Values["token"] = compressed
sd.idTokenSession.Values["compressed"] = (compressed != token)
}
} else {
if sd.idTokenSession != nil {
sd.idTokenSession.Values["token"] = ""
sd.idTokenSession.Values["compressed"] = (compressed != token)
}
chunks := splitIntoChunks(compressed, maxCookieSize)
if len(chunks) == 0 {
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for ID token")
return
}
if len(chunks) > 50 {
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for ID token - possible corruption", len(chunks))
return
}
testReassembled := strings.Join(chunks, "")
if testReassembled != compressed {
sd.manager.logger.Errorf("CRITICAL: ID token chunk reassembly test failed")
return
}
for i, chunkData := range chunks {
sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
if sd.request == nil {
sd.manager.logger.Errorf("CRITICAL: SetIDToken: sd.request is nil, cannot create chunk session %s", sessionName)
return
}
if chunkData == "" {
sd.manager.logger.Debug("Empty chunk data at index %d", i)
return
}
if len(chunkData) > maxCookieSize {
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
return
}
if !validateChunkSize(chunkData) {
sd.manager.logger.Errorf("CRITICAL: ID token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
return
}
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil {
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
return
}
session.Values["token_chunk"] = chunkData
session.Values["compressed"] = (compressed != token)
session.Values["chunk_created_at"] = time.Now().Unix()
sd.idTokenChunks[i] = session
}
sd.manager.logger.Debugf("SUCCESS: Stored ID token in %d chunks", len(chunks))
}
}
// GetRedirectCount returns the number of redirects in the current authentication flow.
// STABILITY FIX: Prevents infinite redirect loops by tracking redirect attempts.
// Returns:
// - The current redirect count, 0 if not set.
func (sd *SessionData) GetRedirectCount() int {
if count, ok := sd.mainSession.Values["redirect_count"].(int); ok {
return count
}
return 0
}
// IncrementRedirectCount increases the redirect counter by one.
// STABILITY FIX: Prevents infinite redirect loops by tracking successive redirects.
// Used to detect potential redirect loops and abort authentication if too many occur.
func (sd *SessionData) IncrementRedirectCount() {
currentCount := sd.GetRedirectCount()
sd.mainSession.Values["redirect_count"] = currentCount + 1
sd.dirty = true
}
// ResetRedirectCount resets the redirect counter to zero.
// STABILITY FIX: Prevents infinite redirect loops by clearing the counter
// when authentication completes successfully or when starting a new flow.
func (sd *SessionData) ResetRedirectCount() {
sd.mainSession.Values["redirect_count"] = 0
sd.dirty = true
}