package traefikoidc import ( "bytes" "compress/gzip" "context" "crypto/rand" "crypto/subtle" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "runtime" "strings" "sync" "sync/atomic" "time" "github.com/gorilla/sessions" "github.com/lukaszraczylo/traefikoidc/internal/pool" ) // constantTimeStringCompare performs a constant-time comparison of two strings // to prevent timing attacks. Returns true if the strings are equal. func constantTimeStringCompare(a, b string) bool { if len(a) != len(b) { return false } return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 } // min returns the minimum of two integers. // This is a utility function used throughout the session management code. // Parameters: // - a: The first integer to compare. // - b: The second integer to compare. // // Returns: // - The smaller of the two integers. func min(a, b int) int { if a < b { return a } return b } // generateSecureRandomString creates a cryptographically secure random string. // It generates random bytes using crypto/rand and encodes them as hexadecimal. // This is used for session IDs and other security-sensitive random values. // Parameters: // - length: The number of random bytes to generate (output will be 2x this length in hex). // // Returns: // - The hex-encoded random string. // - An error if reading random bytes fails. func generateSecureRandomString(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("failed to generate random bytes: %w", err) } return hex.EncodeToString(bytes), nil } // Cookie name suffixes used for session management // These are appended to the cookiePrefix to create full cookie names // #nosec G101 -- These are cookie name suffixes, not hardcoded credentials const ( mainCookieSuffix = "m" accessTokenSuffix = "a" refreshTokenSuffix = "r" idTokenSuffix = "id" combinedCookieSuffix = "s" // Combined session cookie suffix defaultCookiePrefix = "_oidc_raczylo_" ) const ( // maxCookieSize is the maximum raw data size per chunk before securecookie encoding. // // Browser cookie limit: 4096 bytes // Securecookie overhead: ~2x (gob + AES encryption + MAC + double base64) // Safety headroom: 1000 bytes for cookie metadata, headers, edge cases // // Math: 1400 raw → ~2800 encoded → safely under (4096 - 1000 = 3096) limit maxCookieSize = 1400 // maxCombinedChunks is the maximum number of chunks allowed for combined session maxCombinedChunks = 10 absoluteSessionTimeout = 24 * time.Hour minEncryptionKeyLength = 32 ) // combinedSessionPayload is the JSON structure for combined cookie storage. // Uses short field names to minimize size. type combinedSessionPayload struct { X map[string]interface{} `json:"x,omitempty"` A string `json:"a,omitempty"` R string `json:"r,omitempty"` I string `json:"i,omitempty"` E string `json:"e,omitempty"` Cs string `json:"cs,omitempty"` N string `json:"n,omitempty"` Cv string `json:"cv,omitempty"` Ip string `json:"ip,omitempty"` Ca int64 `json:"ca,omitempty"` Rc int `json:"rc,omitempty"` Au bool `json:"au,omitempty"` } // knownSessionKeys are the standard keys that are handled explicitly in the combined payload. // All other mainSession.Values keys are stored in the X (extra) field. var knownSessionKeys = map[string]bool{ "access_token": true, "refresh_token": true, "id_token": true, "email": true, "authenticated": true, "csrf": true, "nonce": true, "code_verifier": true, "incoming_path": true, "created_at": true, "redirect_count": true, } // compressCombinedPayload compresses the combined session payload using gzip. // It serializes the payload to JSON, compresses it, and returns base64-encoded data. // Returns the compressed string and any error encountered. func compressCombinedPayload(payload *combinedSessionPayload) (string, error) { jsonData, err := json.Marshal(payload) if err != nil { return "", fmt.Errorf("failed to marshal combined payload: %w", err) } var buf bytes.Buffer gz := gzip.NewWriter(&buf) if _, err := gz.Write(jsonData); err != nil { return "", fmt.Errorf("failed to compress combined payload: %w", err) } if err := gz.Close(); err != nil { return "", fmt.Errorf("failed to close gzip writer: %w", err) } compressed := base64.StdEncoding.EncodeToString(buf.Bytes()) return compressed, nil } // decompressCombinedPayload decompresses a base64+gzip encoded combined session payload. // Returns the deserialized payload and any error encountered. func decompressCombinedPayload(compressed string) (*combinedSessionPayload, error) { if compressed == "" { return nil, fmt.Errorf("empty compressed data") } data, err := base64.StdEncoding.DecodeString(compressed) if err != nil { return nil, fmt.Errorf("failed to decode base64: %w", err) } gr, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return nil, fmt.Errorf("failed to create gzip reader: %w", err) } defer func() { _ = gr.Close() }() // Limit decompressed size to prevent zip bombs limitedReader := io.LimitReader(gr, 512*1024) // 512KB max decompressed, err := io.ReadAll(limitedReader) if err != nil { return nil, fmt.Errorf("failed to decompress: %w", err) } var payload combinedSessionPayload if err := json.Unmarshal(decompressed, &payload); err != nil { return nil, fmt.Errorf("failed to unmarshal combined payload: %w", err) } return &payload, nil } // splitCombinedIntoChunks splits compressed data into chunks of maxCookieSize. // Returns the chunks and the total number of chunks. func splitCombinedIntoChunks(data string, chunkSize int) []string { if len(data) <= chunkSize { return []string{data} } var chunks []string for i := 0; i < len(data); i += chunkSize { end := i + chunkSize if end > len(data) { end = len(data) } chunks = append(chunks, data[i:end]) } return chunks } // assembleCombinedChunks reassembles chunks back into the original compressed data. // It reads chunks from session values in order (chunk_0, chunk_1, etc.). func assembleCombinedChunks(sessions []*sessions.Session) string { if len(sessions) == 0 { return "" } var parts []string for i := 0; i < len(sessions); i++ { session := sessions[i] if session == nil { break } chunk, ok := session.Values["d"].(string) // "d" for data if !ok || chunk == "" { break } parts = append(parts, chunk) } return strings.Join(parts, "") } // compressToken compresses a JWT token using gzip compression if beneficial. // It validates the token format, attempts compression, and verifies the compressed // data can be decompressed correctly. Only compresses if it reduces size. // Parameters: // - token: The JWT token string to potentially compress. // // Returns: // - The base64 encoded, gzipped string, or the original string if compression fails. func compressToken(token string) string { if token == "" { return token } dotCount := strings.Count(token, ".") if dotCount != 2 { return token } if len(token) > 50*1024 { return token } pm := pool.Get() b := pm.GetBuffer(4096) defer pm.PutBuffer(b) gz := gzip.NewWriter(b) written, err := gz.Write([]byte(token)) if err != nil || written != len(token) { return token } if err := gz.Close(); err != nil { return token } compressedBytes := b.Bytes() if len(compressedBytes) == 0 { return token } compressed := base64.StdEncoding.EncodeToString(compressedBytes) if len(compressed) >= len(token) { return token } decompressed := decompressTokenInternal(compressed) if decompressed != token { return token } if strings.Count(decompressed, ".") != 2 { return token } return compressed } // decompressToken decompresses a previously compressed token string. // It decodes the base64 data, validates gzip headers, and decompresses safely // with size limits to prevent compression bombs. // Parameters: // - compressed: The base64-encoded compressed token string. // // Returns: // - The decompressed original string, or the input string if decompression fails. func decompressToken(compressed string) string { return decompressTokenInternal(compressed) } // decompressTokenInternal is the internal decompression function. // Separated internal function for integrity verification during compression. // It performs the actual decompression logic with proper resource management. // Parameters: // - compressed: The compressed token string to decompress. // // Returns: // - The decompressed token or the original string if decompression fails. func decompressTokenInternal(compressed string) string { if compressed == "" { return compressed } if len(compressed) > 100*1024 { return compressed } data, err := base64.StdEncoding.DecodeString(compressed) if err != nil { return compressed } if len(data) == 0 { return compressed } if len(data) < 2 || data[0] != 0x1f || data[1] != 0x8b { return compressed } pm := pool.Get() readerBuf := pm.GetHTTPResponseBuffer() defer pm.PutHTTPResponseBuffer(readerBuf) gz, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return compressed } defer func() { if closeErr := gz.Close(); closeErr != nil { _ = closeErr } }() limitedReader := io.LimitReader(gz, 500*1024) if cap(readerBuf) >= 512*1024 { readerBuf = readerBuf[:cap(readerBuf)] n, err := limitedReader.Read(readerBuf) if err != nil && err != io.EOF { return compressed } decompressed := readerBuf[:n] return string(decompressed) } decompressed, err := io.ReadAll(limitedReader) if err != nil { return compressed } if len(decompressed) == 0 { return compressed } decompressedStr := string(decompressed) if decompressedStr != "" && strings.Count(decompressedStr, ".") != 2 { return compressed } return decompressedStr } // SessionManager manages OIDC session state and cookie-based storage. // It provides secure session management with support for token compression, // chunked storage for large tokens, session pooling for performance, // session object reuse and supports both HTTP and HTTPS schemes. type SessionManager struct { sessionPool sync.Pool ctx context.Context store sessions.Store logger *Logger chunkManager *ChunkManager memoryMonitor *TaskMemoryMonitor cancel context.CancelFunc cookieDomain string cookiePrefix string sessionMaxAge time.Duration activeSessions int64 poolHits int64 poolMisses int64 cleanupMutex sync.RWMutex shutdownOnce sync.Once forceHTTPS bool cleanupDone bool } // NewSessionManager creates a new SessionManager instance with secure defaults. // It initializes the cookie store with encryption, sets up session pooling, // and configures chunk management for large tokens. // Parameters: // - encryptionKey: The key for encrypting session cookies (minimum 32 bytes). // - forceHTTPS: Whether to force HTTPS-only cookies regardless of request scheme. // - cookieDomain: The domain for session cookies (empty for auto-detection). // - cookiePrefix: Prefix for session cookie names (empty for default "_oidc_raczylo_"). // - sessionMaxAge: Maximum session age duration (0 for default 24 hours). // - logger: Logger instance for debug and error logging. // // Returns: // - The configured SessionManager instance. // - An error if the encryption key does not meet minimum length requirements. func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain string, cookiePrefix string, sessionMaxAge time.Duration, logger *Logger) (*SessionManager, error) { if len(encryptionKey) < minEncryptionKeyLength { return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) } // Set default cookie prefix if not provided if cookiePrefix == "" { cookiePrefix = defaultCookiePrefix } // Set default session max age if not provided (24 hours for backward compatibility) if sessionMaxAge == 0 { sessionMaxAge = absoluteSessionTimeout } ctx, cancel := context.WithCancel(context.Background()) sm := &SessionManager{ store: sessions.NewCookieStore([]byte(encryptionKey)), forceHTTPS: forceHTTPS, cookieDomain: cookieDomain, cookiePrefix: cookiePrefix, sessionMaxAge: sessionMaxAge, logger: logger, chunkManager: NewChunkManager(logger), ctx: ctx, cancel: cancel, } // Initialize global memory monitoring (singleton) sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger) // Start memory monitoring every 30 seconds (will skip if already started) if err := sm.memoryMonitor.Start(30 * time.Second); err != nil { logger.Debugf("Failed to start memory monitoring: %v", err) } sm.sessionPool.New = func() interface{} { atomic.AddInt64(&sm.poolMisses, 1) sd := &SessionData{ manager: sm, accessTokenChunks: make(map[int]*sessions.Session), refreshTokenChunks: make(map[int]*sessions.Session), idTokenChunks: make(map[int]*sessions.Session), combinedChunks: make(map[int]*sessions.Session), useCombinedStorage: true, // Use combined storage by default for new sessions refreshMutex: sync.Mutex{}, sessionMutex: sync.RWMutex{}, dirty: false, inUse: false, } sd.Reset() return sd } // Start background cleanup routine go sm.backgroundCleanup() return sm, nil } // Cookie name helper methods - build cookie names using the configured prefix // mainCookieName returns the main session cookie name with the configured prefix func (sm *SessionManager) mainCookieName() string { return sm.cookiePrefix + mainCookieSuffix } // accessTokenCookieName returns the access token cookie name with the configured prefix func (sm *SessionManager) accessTokenCookieName() string { return sm.cookiePrefix + accessTokenSuffix } // refreshTokenCookieName returns the refresh token cookie name with the configured prefix func (sm *SessionManager) refreshTokenCookieName() string { return sm.cookiePrefix + refreshTokenSuffix } // idTokenCookieName returns the ID token cookie name with the configured prefix func (sm *SessionManager) idTokenCookieName() string { return sm.cookiePrefix + idTokenSuffix } // combinedCookieName returns the combined session cookie base name with the configured prefix // Chunk cookies are named: prefix + "s" + "_" + chunkIndex (e.g., "_oidc_raczylo_s_0") func (sm *SessionManager) combinedCookieName() string { return sm.cookiePrefix + combinedCookieSuffix } // combinedChunkCookieName returns the name for a specific combined session chunk func (sm *SessionManager) combinedChunkCookieName(chunkIndex int) string { return fmt.Sprintf("%s_%d", sm.combinedCookieName(), chunkIndex) } // GetCookiePrefix returns the cookie prefix used for all OIDC session cookies. func (sm *SessionManager) GetCookiePrefix() string { return sm.cookiePrefix } // Shutdown gracefully shuts down the SessionManager and all its background tasks func (sm *SessionManager) Shutdown() error { var shutdownErr error sm.shutdownOnce.Do(func() { if sm.logger != nil { sm.logger.Debug("SessionManager shutdown initiated") } // Cancel context to stop all background operations if sm.cancel != nil { sm.cancel() } // Stop memory monitor if sm.memoryMonitor != nil { sm.memoryMonitor.Stop() } // Stop chunk manager if sm.chunkManager != nil { sm.chunkManager.Shutdown() } // Force garbage collection to help cleanup runtime.GC() if sm.logger != nil { sm.logger.Debug("SessionManager shutdown completed") } }) return shutdownErr } // backgroundCleanup runs periodic cleanup tasks for session management func (sm *SessionManager) backgroundCleanup() { ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes defer ticker.Stop() for { select { case <-sm.ctx.Done(): if sm.logger != nil { sm.logger.Debug("Background cleanup routine terminated") } return case <-ticker.C: sm.performCleanupCycle() } } } // performCleanupCycle executes a complete cleanup cycle func (sm *SessionManager) performCleanupCycle() { if sm.logger != nil { sm.logger.Debug("Starting background cleanup cycle") } startTime := time.Now() // Run periodic chunk cleanup sm.PeriodicChunkCleanup() // Clean up session pool by forcing GC on old sessions sm.cleanupSessionPool() // Force garbage collection if memory usage is high var m runtime.MemStats runtime.ReadMemStats(&m) if m.HeapAlloc > 50*1024*1024 { // 50MB threshold runtime.GC() if sm.logger != nil { sm.logger.Debug("Forced garbage collection due to high memory usage") } } duration := time.Since(startTime) if sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() { sm.logger.Debugf("Cleanup cycle completed in %v", duration) } } // cleanupSessionPool performs cleanup on the session pool func (sm *SessionManager) cleanupSessionPool() { cleaned := 0 const maxCleanup = 20 // Limit cleanup per cycle to avoid performance impact for i := 0; i < maxCleanup; i++ { select { case <-sm.ctx.Done(): return default: } if poolSession := sm.sessionPool.Get(); poolSession != nil { sessionData, ok := poolSession.(*SessionData) if ok && sessionData != nil && !sessionData.inUse { sessionData.Reset() cleaned++ } sm.sessionPool.Put(poolSession) } else { break // Pool is empty } } if cleaned > 0 && sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() { sm.logger.Debugf("Cleaned %d session pool objects", cleaned) } } // GetSessionStats returns statistics about session management func (sm *SessionManager) GetSessionStats() map[string]interface{} { stats := make(map[string]interface{}) stats["active_sessions"] = atomic.LoadInt64(&sm.activeSessions) stats["pool_hits"] = atomic.LoadInt64(&sm.poolHits) stats["pool_misses"] = atomic.LoadInt64(&sm.poolMisses) if sm.memoryMonitor != nil { if currentStats, err := sm.memoryMonitor.GetCurrentStats(); err == nil { stats["goroutines"] = currentStats.Goroutines stats["heap_alloc"] = currentStats.HeapAlloc stats["num_gc"] = currentStats.NumGC } } return stats } // PeriodicChunkCleanup performs comprehensive session maintenance and cleanup. // It cleans up orphaned token chunks, expired sessions, and unused pool objects. // This helps maintain performance and prevent cookie accumulation in client browsers. func (sm *SessionManager) PeriodicChunkCleanup() { if sm == nil || sm.logger == nil { return } // Check if context is canceled or we're in test mode to prevent logging after test completion if sm.ctx == nil || sm.ctx.Err() != nil || isTestMode() { return // Skip logging if context is canceled or in test mode } sm.logger.Debug("Starting comprehensive session cleanup cycle") cleanupStart := time.Now() var orphanedChunks, expiredSessions, cleanupErrors int if sm.store != nil { if cookieStore, ok := sm.store.(*sessions.CookieStore); ok { // Check context again before logging if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() { sm.logger.Debug("Running session store cleanup") } _ = cookieStore } } // Cleanup expired sessions in chunk manager to prevent memory leaks if sm.chunkManager != nil { sm.chunkManager.CleanupExpiredSessions() } poolCleaned := 0 for i := 0; i < 10; i++ { if poolSession := sm.sessionPool.Get(); poolSession != nil { sessionData, ok := poolSession.(*SessionData) if ok && sessionData != nil && !sessionData.inUse { sessionData.Reset() poolCleaned++ } sm.sessionPool.Put(poolSession) } } // Check context before final logging if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() { cleanupDuration := time.Since(cleanupStart) sm.logger.Debugf("Session cleanup completed in %v: pool_cleaned=%d, orphaned_chunks=%d, expired_sessions=%d, errors=%d", cleanupDuration, poolCleaned, orphanedChunks, expiredSessions, cleanupErrors) } } // ValidateSessionHealth performs comprehensive validation of session integrity. // It checks authentication state, validates token formats, and detects // potential tampering or corruption in session data. // Parameters: // - sessionData: The session data to validate. // // Returns: // - An error describing any validation failures, nil if session is healthy. func (sm *SessionManager) ValidateSessionHealth(sessionData *SessionData) error { if sessionData == nil { return fmt.Errorf("session data is nil") } if !sessionData.GetAuthenticated() { return fmt.Errorf("session is not authenticated or has expired") } accessToken := sessionData.GetAccessToken() refreshToken := sessionData.GetRefreshToken() idToken := sessionData.GetIDToken() if accessToken != "" { if err := sm.validateTokenFormat(accessToken, "access_token"); err != nil { return fmt.Errorf("access token validation failed: %w", err) } } if refreshToken != "" { if err := sm.validateTokenFormat(refreshToken, "refresh_token"); err != nil { return fmt.Errorf("refresh token validation failed: %w", err) } } if idToken != "" { if err := sm.validateTokenFormat(idToken, "id_token"); err != nil { return fmt.Errorf("ID token validation failed: %w", err) } } if err := sm.detectSessionTampering(sessionData); err != nil { return fmt.Errorf("session tampering detected: %w", err) } return nil } // validateTokenFormat validates the structure and format of authentication tokens. // It checks for corruption markers, validates JWT structure if applicable, // and ensures tokens meet format requirements. // Parameters: // - token: The token string to validate. // - tokenType: The type of token being validated (for error messages). // // Returns: // - An error if the token has invalid structure or exceeds size limits. func (sm *SessionManager) validateTokenFormat(token, tokenType string) error { if token == "" { return nil } if isCorruptionMarker(token) { return fmt.Errorf("%s contains corruption marker", tokenType) } if strings.Count(token, ".") == 2 { parts := strings.Split(token, ".") for i, part := range parts { if part == "" { return fmt.Errorf("%s has empty part %d in JWT format", tokenType, i) } if strings.ContainsAny(part, "+/=") && !strings.ContainsAny(part, "-_") { sm.logger.Debugf("Token %s part %d uses base64 instead of base64url encoding", tokenType, i) } } } return nil } // detectSessionTampering checks for indicators of session tampering. // It examines session values for path traversal attempts, XSS payloads, // and suspicious data patterns that might indicate malicious modification. // Parameters: // - sessionData: The session data to examine for tampering. // // Returns: // - An error if tampering is detected, nil if session appears safe. func (sm *SessionManager) detectSessionTampering(sessionData *SessionData) error { if sessionData.mainSession == nil { return fmt.Errorf("main session is missing") } for key, value := range sessionData.mainSession.Values { if str, ok := value.(string); ok { if strings.Contains(str, "../") || strings.Contains(str, "..\\") { return fmt.Errorf("potential path traversal attempt in session key %v", key) } if strings.Contains(str, " 10000 { return fmt.Errorf("suspiciously long session value for key %v (length: %d)", key, len(str)) } } } return nil } // GetSessionMetrics returns metrics about session management for monitoring purposes. // It provides information about session configuration, security settings, // and internal state for debugging and monitoring. // Returns: // - A map containing session metrics and configuration information. func (sm *SessionManager) GetSessionMetrics() map[string]interface{} { metrics := make(map[string]interface{}) metrics["session_manager_type"] = "CookieStore" metrics["force_https"] = sm.forceHTTPS metrics["absolute_timeout_hours"] = sm.sessionMaxAge.Hours() metrics["max_cookie_size"] = maxCookieSize metrics["max_encoded_cookie_size"] = maxCookieSize * 2 // ~2x encoding overhead if cookieStore, ok := sm.store.(*sessions.CookieStore); ok && len(cookieStore.Codecs) > 0 { metrics["has_encryption"] = true metrics["codec_count"] = len(cookieStore.Codecs) } else { metrics["has_encryption"] = false } metrics["pool_implementation"] = "sync.Pool" return metrics } // EnhanceSessionSecurity applies additional security measures to session options. // It configures secure cookies, domain detection, SameSite policies, and // adapts security settings based on request context and client characteristics. // Parameters: // - options: The base session options to enhance (can be nil). // - r: The HTTP request context for security decisions. // // Returns: // - Enhanced sessions.Options with additional security measures. func (sm *SessionManager) EnhanceSessionSecurity(options *sessions.Options, r *http.Request) *sessions.Options { if options == nil { options = &sessions.Options{} } if r != nil { userAgent := r.Header.Get("User-Agent") if userAgent == "" { sm.logger.Debugf("Request from %s missing User-Agent header", r.RemoteAddr) options.MaxAge = int((sm.sessionMaxAge / 2).Seconds()) } if r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS { options.Secure = true } // Keep SameSite=Lax consistently for OAuth flows // Removed dynamic switching based on XMLHttpRequest header to prevent redirect loop options.SameSite = http.SameSiteLaxMode } options.HttpOnly = true options.Path = "/" // Ensure cookies are available on all paths for OAuth flow if sm.cookieDomain != "" { options.Domain = sm.cookieDomain sm.logger.Debugf("Using configured cookie domain: %s", sm.cookieDomain) } else if options.Domain == "" && r != nil { host := r.Host if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { host = forwardedHost } if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") { if colonIndex := strings.Index(host, ":"); colonIndex != -1 { host = host[:colonIndex] } options.Domain = host sm.logger.Debugf("Auto-detected cookie domain: %s", host) } } return options } // getSessionOptions creates base session options with security settings. // It configures cookie security, lifetime, path, and domain settings // based on the HTTPS status and manager configuration. // Parameters: // - isSecure: Whether the request is over HTTPS or should be treated as secure. // // Returns: // - A pointer to a configured sessions.Options struct. func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { baseOptions := &sessions.Options{ HttpOnly: true, Secure: isSecure || sm.forceHTTPS, SameSite: http.SameSiteLaxMode, MaxAge: int(sm.sessionMaxAge.Seconds()), Path: "/", Domain: sm.cookieDomain, } return baseOptions } // CleanupOldCookies removes stale session cookies from the client browser. // This method handles cleanup of cookies that may exist with different domain // configurations, ensuring clean state when domain settings change. // It removes cookies with various domain variations to ensure cleanup after configuration changes. // Parameters: // - w: The HTTP response writer for setting cookie deletion headers. // - r: The HTTP request containing cookies to examine and clean up. func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Request) { cookies := r.Cookies() currentDomain := sm.cookieDomain host := r.Host if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { host = forwardedHost } if colonIndex := strings.Index(host, ":"); colonIndex != -1 { host = host[:colonIndex] } // This ensures we clean up cookies from various possible domains var domainsToClean []string if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") { domainsToClean = append(domainsToClean, host, "."+host, ) parts := strings.Split(host, ".") if len(parts) > 2 { parentDomain := strings.Join(parts[len(parts)-2:], ".") domainsToClean = append(domainsToClean, parentDomain, "."+parentDomain, ) } } processedCookies := make(map[string]bool) for _, cookie := range cookies { // Check if cookie belongs to this middleware instance using the configured prefix if strings.HasPrefix(cookie.Name, sm.cookiePrefix) || strings.HasPrefix(cookie.Name, "access_token_chunk_") || strings.HasPrefix(cookie.Name, "refresh_token_chunk_") { processedCookies[cookie.Name] = true sm.cleanupMutex.RLock() shouldCleanup := currentDomain != "" && !sm.cleanupDone sm.cleanupMutex.RUnlock() if shouldCleanup { for _, domain := range domainsToClean { if domain == currentDomain || domain == "."+currentDomain || "."+domain == currentDomain { continue } deleteCookie := &http.Cookie{ Name: cookie.Name, Value: "", Path: "/", Domain: domain, MaxAge: -1, HttpOnly: true, Secure: r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS, SameSite: http.SameSiteLaxMode, } http.SetCookie(w, deleteCookie) sm.logger.Debugf("Attempting to clean up cookie %s with domain %s", cookie.Name, domain) } } } } if len(processedCookies) > 0 { sm.cleanupMutex.Lock() if !sm.cleanupDone { sm.cleanupDone = true } sm.cleanupMutex.Unlock() } } // GetSession retrieves or creates session data from the HTTP request. // It first tries to load from combined cookies (new format), falling back to legacy // cookies if combined cookies don't exist. Performs validation and timeout checks. // The returned session must be explicitly returned to the pool by calling // returnToPoolSafely() to prevent memory leaks. // MEMORY LEAK FIX: Session is NOT returned to pool here - caller must call ReturnToPool() when done. // Parameters: // - r: The HTTP request containing session cookies. // // Returns: // - The loaded SessionData instance. // - An error if session loading or validation fails. func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { sessionData, _ := sm.sessionPool.Get().(*SessionData) // Safe to ignore: pool return is best-effort atomic.AddInt64(&sm.poolHits, 1) atomic.AddInt64(&sm.activeSessions, 1) sessionData.inUse = true sessionData.request = r sessionData.dirty = false handleError := func(err error, message string) (*SessionData, error) { if sessionData != nil { sessionData.inUse = false sessionData.Reset() sm.sessionPool.Put(sessionData) atomic.AddInt64(&sm.activeSessions, -1) } return nil, fmt.Errorf("%s: %w", message, err) } // Try to load from combined cookies first (new format) if sm.loadFromCombinedCookies(r, sessionData) { sessionData.useCombinedStorage = true sm.logger.Debug("Loaded session from combined cookies") // Check session timeout if sessionData.getCreatedAtUnsafe() > 0 { if time.Since(time.Unix(sessionData.getCreatedAtUnsafe(), 0)) > sm.sessionMaxAge { _ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated return handleError(fmt.Errorf("session timeout"), "session expired") } } return sessionData, nil } // Fall back to legacy cookies sessionData.useCombinedStorage = false sm.logger.Debug("Loading session from legacy cookies") var err error sessionData.mainSession, err = sm.store.Get(r, sm.mainCookieName()) if err != nil { return handleError(err, "failed to get main session") } if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok { if time.Since(time.Unix(createdAt, 0)) > sm.sessionMaxAge { _ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated return handleError(fmt.Errorf("session timeout"), "session expired") } } sessionData.accessSession, err = sm.store.Get(r, sm.accessTokenCookieName()) if err != nil { return handleError(err, "failed to get access token session") } sessionData.refreshSession, err = sm.store.Get(r, sm.refreshTokenCookieName()) if err != nil { return handleError(err, "failed to get refresh token session") } sessionData.idTokenSession, err = sm.store.Get(r, sm.idTokenCookieName()) if err != nil { return handleError(err, "failed to get ID token session") } for k := range sessionData.accessTokenChunks { delete(sessionData.accessTokenChunks, k) } for k := range sessionData.refreshTokenChunks { delete(sessionData.refreshTokenChunks, k) } for k := range sessionData.idTokenChunks { delete(sessionData.idTokenChunks, k) } sm.getTokenChunkSessions(r, sm.accessTokenCookieName(), sessionData.accessTokenChunks) sm.getTokenChunkSessions(r, sm.refreshTokenCookieName(), sessionData.refreshTokenChunks) sm.getTokenChunkSessions(r, sm.idTokenCookieName(), sessionData.idTokenChunks) // If legacy session has data, migrate to combined storage on next save if !sessionData.mainSession.IsNew { sessionData.useCombinedStorage = true sessionData.dirty = true // Mark dirty to trigger migration on save sm.logger.Debug("Legacy session found, will migrate to combined storage on save") } return sessionData, nil } // loadFromCombinedCookies attempts to load session data from combined cookies. // Returns true if combined cookies were found and successfully loaded. func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *SessionData) bool { // Check if first combined chunk exists firstChunk, err := sm.store.Get(r, sm.combinedChunkCookieName(0)) if err != nil || firstChunk.IsNew { return false } // Get total chunk count from first chunk totalChunks, ok := firstChunk.Values["n"].(int) if !ok || totalChunks < 1 || totalChunks > maxCombinedChunks { sm.logger.Debugf("Invalid combined cookie chunk count: %v", firstChunk.Values["n"]) return false } // Load all chunks chunkSessions := make([]*sessions.Session, totalChunks) chunkSessions[0] = firstChunk sessionData.combinedChunks[0] = firstChunk for i := 1; i < totalChunks; i++ { chunk, err := sm.store.Get(r, sm.combinedChunkCookieName(i)) if err != nil || chunk.IsNew { sm.logger.Debugf("Missing combined cookie chunk %d", i) return false } chunkSessions[i] = chunk sessionData.combinedChunks[i] = chunk } // Assemble and decompress compressed := assembleCombinedChunks(chunkSessions) if compressed == "" { sm.logger.Debug("Failed to assemble combined chunks") return false } payload, err := decompressCombinedPayload(compressed) if err != nil { sm.logger.Debugf("Failed to decompress combined payload: %v", err) return false } // Hydrate the legacy session objects for compatibility with existing getter methods // We need to initialize them even though we're using combined storage sessionData.mainSession, _ = sm.store.Get(r, sm.mainCookieName()) sessionData.accessSession, _ = sm.store.Get(r, sm.accessTokenCookieName()) sessionData.refreshSession, _ = sm.store.Get(r, sm.refreshTokenCookieName()) sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName()) // Populate legacy session values from combined payload sessionData.mainSession.Values["email"] = payload.E sessionData.mainSession.Values["authenticated"] = payload.Au sessionData.mainSession.Values["csrf"] = payload.Cs sessionData.mainSession.Values["nonce"] = payload.N sessionData.mainSession.Values["code_verifier"] = payload.Cv sessionData.mainSession.Values["incoming_path"] = payload.Ip sessionData.mainSession.Values["created_at"] = payload.Ca sessionData.mainSession.Values["redirect_count"] = payload.Rc // Restore extra custom session values for key, val := range payload.X { sessionData.mainSession.Values[key] = val } sessionData.accessSession.Values["token"] = payload.A sessionData.accessSession.Values["compressed"] = false sessionData.refreshSession.Values["token"] = payload.R sessionData.refreshSession.Values["compressed"] = false sessionData.idTokenSession.Values["token"] = payload.I sessionData.idTokenSession.Values["compressed"] = false return true } // getTokenChunkSessions loads all available token chunk sessions for a given token type. // It iterates through numbered chunk sessions until no more are found, // populating the provided chunks map with the loaded sessions. // Parameters: // - r: The HTTP request containing chunk cookies. // - baseName: The base cookie name for the token type (e.g., "_oidc_raczylo_a"). // - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks) // to populate with the found session chunks. func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) { for i := 0; ; i++ { sessionName := fmt.Sprintf("%s_%d", baseName, i) session, err := sm.store.Get(r, sessionName) if err != nil || session.IsNew { break } chunks[i] = session } } // SessionData represents a user's authentication session with comprehensive token management. // It handles main session data and supports large tokens that need to be // split across multiple cookies due to browser size limitations. // Supports both legacy (separate cookies) and combined (single compressed cookie) storage. type SessionData struct { manager *SessionManager request *http.Request // Legacy storage (kept for backward compatibility during migration) mainSession *sessions.Session accessSession *sessions.Session refreshSession *sessions.Session idTokenSession *sessions.Session accessTokenChunks map[int]*sessions.Session refreshTokenChunks map[int]*sessions.Session idTokenChunks map[int]*sessions.Session // Combined storage (new approach - single compressed cookie) combinedChunks map[int]*sessions.Session // useCombinedStorage indicates whether to use the new combined storage format useCombinedStorage bool refreshMutex sync.Mutex sessionMutex sync.RWMutex dirty bool inUse bool // cachedClaimsToken is the ID token string whose claims were last parsed and // cached. A lazy, per-request cache to avoid re-parsing the JWT on every // authenticated request (e.g. for headerTemplates). Protected by sessionMutex. cachedClaimsToken string // cachedClaims holds the parsed claims for cachedClaimsToken. cachedClaims map[string]interface{} // cachedClaimsErr holds the parse error (if any) for cachedClaimsToken so // failures are not retried within the same request. cachedClaimsErr error } // IsDirty returns true if the session data has been modified since it was last loaded or saved. // This is used to optimize session saves by only writing when necessary. // Returns: // - true if the session has pending changes, false otherwise. func (sd *SessionData) IsDirty() bool { return sd.dirty } // MarkDirty marks the session as having pending changes that need to be saved. // This is used when session data hasn't changed in content but should still // trigger a session save (e.g., to ensure the cookie is re-issued). func (sd *SessionData) MarkDirty() { sd.dirty = true } // Save persists all session data including main session and token chunks. // It applies security options, saves all session components, and handles // errors gracefully by continuing to save other components even if one fails. // Uses combined cookie storage for efficiency when useCombinedStorage is true. // Parameters: // - r: The HTTP request context for security option configuration. // - w: The HTTP response writer for setting session cookies. // // Returns: // - An error if saving any of the session components fails. func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { isSecure := r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sd.manager.forceHTTPS options := sd.manager.getSessionOptions(isSecure) options = sd.manager.EnhanceSessionSecurity(options, r) // Use combined storage for new sessions if sd.useCombinedStorage { return sd.saveCombined(r, w, options) } // Legacy storage path (for backward compatibility) return sd.saveLegacy(r, w, options) } // saveCombined saves all session data in a single compressed, chunked cookie. // This reduces cookie count and total size through combined compression. func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, options *sessions.Options) error { // Build the combined payload payload := &combinedSessionPayload{ A: sd.getAccessTokenUnsafe(), R: sd.getRefreshTokenUnsafe(), I: sd.getIDTokenUnsafe(), E: sd.getEmailUnsafe(), Au: sd.getAuthenticatedUnsafe(), Cs: sd.getCSRFUnsafe(), N: sd.getNonceUnsafe(), Cv: sd.getCodeVerifierUnsafe(), Ip: sd.getIncomingPathUnsafe(), Ca: sd.getCreatedAtUnsafe(), Rc: sd.getRedirectCountUnsafe(), } // Collect extra session values not handled by the standard fields sd.sessionMutex.RLock() if sd.mainSession != nil && len(sd.mainSession.Values) > 0 { extra := make(map[string]interface{}) for key, val := range sd.mainSession.Values { keyStr, ok := key.(string) if !ok { continue } // Skip known session keys that are already in the payload if knownSessionKeys[keyStr] { continue } // Store the extra value (must be JSON-serializable) extra[keyStr] = val } if len(extra) > 0 { payload.X = extra } } sd.sessionMutex.RUnlock() // Compress the payload compressed, err := compressCombinedPayload(payload) if err != nil { sd.manager.logger.Errorf("Failed to compress combined payload: %v", err) // Fall back to legacy storage on compression failure return sd.saveLegacy(r, w, options) } sd.manager.logger.Debugf("Combined session: raw payload compressed to %d bytes", len(compressed)) // Split into chunks chunks := splitCombinedIntoChunks(compressed, maxCookieSize) if len(chunks) > maxCombinedChunks { sd.manager.logger.Errorf("Combined session requires %d chunks, exceeds max %d", len(chunks), maxCombinedChunks) return fmt.Errorf("session data too large: requires %d chunks, max is %d", len(chunks), maxCombinedChunks) } sd.manager.logger.Debugf("Combined session split into %d chunks", len(chunks)) var firstErr error // Save each chunk for i, chunkData := range chunks { cookieName := sd.manager.combinedChunkCookieName(i) session, err := sd.manager.store.Get(r, cookieName) if err != nil { sd.manager.logger.Errorf("Failed to get combined chunk session %s: %v", cookieName, err) if firstErr == nil { firstErr = err } continue } session.Values["d"] = chunkData // "d" for data session.Values["n"] = len(chunks) // "n" for total number of chunks session.Values["i"] = i // "i" for index session.Options = options if err := session.Save(r, w); err != nil { sd.manager.logger.Errorf("Failed to save combined chunk %d: %v", i, err) if firstErr == nil { firstErr = err } } sd.combinedChunks[i] = session } // Expire old combined chunks that are no longer needed sd.expireOldCombinedChunks(r, w, options, len(chunks)) // Expire legacy cookies if they exist (migration) sd.expireLegacyCookies(r, w, options) if firstErr == nil { sd.dirty = false } return firstErr } // saveLegacy saves session data using the old separate cookie approach. // Kept for backward compatibility during migration. func (sd *SessionData) saveLegacy(r *http.Request, w http.ResponseWriter, options *sessions.Options) error { sd.mainSession.Options = options sd.accessSession.Options = options sd.refreshSession.Options = options sd.idTokenSession.Options = options var firstErr error saveOrLogError := func(s *sessions.Session, name string) { if s == nil { sd.manager.logger.Errorf("Attempted to save nil session: %s", name) if firstErr == nil { firstErr = fmt.Errorf("attempted to save nil session: %s", name) } return } if err := s.Save(r, w); err != nil { errMsg := fmt.Errorf("failed to save %s session: %w", name, err) sd.manager.logger.Error("%s", errMsg.Error()) if firstErr == nil { firstErr = errMsg } } } saveOrLogError(sd.mainSession, "main") saveOrLogError(sd.accessSession, "access token") saveOrLogError(sd.refreshSession, "refresh token") saveOrLogError(sd.idTokenSession, "ID token") for i, sessionChunk := range sd.accessTokenChunks { sessionChunk.Options = options saveOrLogError(sessionChunk, fmt.Sprintf("access token chunk %d", i)) } for i, sessionChunk := range sd.refreshTokenChunks { sessionChunk.Options = options saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i)) } for i, sessionChunk := range sd.idTokenChunks { sessionChunk.Options = options saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i)) } if firstErr == nil { sd.dirty = false } return firstErr } // expireOldCombinedChunks expires combined cookie chunks that are no longer needed. func (sd *SessionData) expireOldCombinedChunks(r *http.Request, w http.ResponseWriter, options *sessions.Options, currentChunks int) { // Expire chunks beyond the current count for i := currentChunks; i < maxCombinedChunks; i++ { cookieName := sd.manager.combinedChunkCookieName(i) session, err := sd.manager.store.Get(r, cookieName) if err != nil || session.IsNew { // No more old chunks break } // Expire this chunk expireOptions := *options expireOptions.MaxAge = -1 session.Options = &expireOptions for k := range session.Values { delete(session.Values, k) } if err := session.Save(r, w); err != nil { sd.manager.logger.Debugf("Failed to expire old combined chunk %d: %v", i, err) } } } // expireLegacyCookies expires old legacy format cookies during migration. func (sd *SessionData) expireLegacyCookies(r *http.Request, w http.ResponseWriter, options *sessions.Options) { expireOptions := *options expireOptions.MaxAge = -1 // Helper to expire a legacy session cookie without clearing in-memory values // IMPORTANT: We must NOT clear values from sessions that sd is holding, // as store.Get() returns the same cached session object expireLegacyChunk := func(cookieName string) { session, err := sd.manager.store.Get(r, cookieName) if err != nil || session.IsNew { return // Cookie doesn't exist } session.Options = &expireOptions // Clear values from chunk cookies (not the main session objects) for k := range session.Values { delete(session.Values, k) } _ = session.Save(r, w) // Best effort } // For main session cookies, only set expiration WITHOUT clearing values // because sd.mainSession, sd.accessSession, etc. point to these same objects expireLegacyMain := func(cookieName string) { session, err := sd.manager.store.Get(r, cookieName) if err != nil || session.IsNew { return // Cookie doesn't exist } // Just expire the cookie, don't clear values (they're still needed in memory) session.Options = &expireOptions _ = session.Save(r, w) // Best effort } // Expire main legacy cookies (don't clear in-memory values) expireLegacyMain(sd.manager.mainCookieName()) expireLegacyMain(sd.manager.accessTokenCookieName()) expireLegacyMain(sd.manager.refreshTokenCookieName()) expireLegacyMain(sd.manager.idTokenCookieName()) // Expire legacy chunk cookies (safe to clear values, they're separate from main sessions) for i := 0; i < 50; i++ { // Max legacy chunks was 50 accessChunk := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i) refreshChunk := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i) idChunk := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i) session, err := sd.manager.store.Get(r, accessChunk) if err != nil || session.IsNew { break // No more chunks } expireLegacyChunk(accessChunk) expireLegacyChunk(refreshChunk) expireLegacyChunk(idChunk) } } // clearSessionValues removes all values from a session and optionally expires it. // This is used during session cleanup and logout operations. // Parameters: // - session: The session to clear values from. // - expire: If true, sets MaxAge to -1 to expire the cookie. func clearSessionValues(session *sessions.Session, expire bool) { if session == nil { return } for k := range session.Values { delete(session.Values, k) } if expire { session.Options.MaxAge = -1 } } // clearAllSessionData clears all session data including main session and token chunks. // It removes all session values and optionally expires all associated cookies. // Parameters: // - r: The HTTP request context (used for chunk cleanup). // - expire: Whether to expire the cookies (set MaxAge to -1). func (sd *SessionData) clearAllSessionData(r *http.Request, expire bool) { clearSessionValues(sd.mainSession, expire) clearSessionValues(sd.accessSession, expire) clearSessionValues(sd.refreshSession, expire) clearSessionValues(sd.idTokenSession, expire) if expire && r != nil { sd.clearTokenChunks(r, sd.accessTokenChunks) sd.clearTokenChunks(r, sd.refreshTokenChunks) sd.clearTokenChunks(r, sd.idTokenChunks) } else { for k := range sd.accessTokenChunks { delete(sd.accessTokenChunks, k) } for k := range sd.refreshTokenChunks { delete(sd.refreshTokenChunks, k) } for k := range sd.idTokenChunks { delete(sd.idTokenChunks, k) } } if expire { sd.dirty = true } } // Clear completely clears all session data and safely returns the session to the pool. // It removes all authentication data, expires cookies, and handles panic recovery. // This method ensures the SessionData object is always returned to the pool. // Parameters: // - r: The HTTP request context. // - w: The HTTP response writer for cookie expiration (can be nil). // // Returns: // - An error if session saving fails during cleanup. func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { defer func() { sd.returnToPoolSafely() }() sd.sessionMutex.Lock() sd.clearAllSessionData(r, true) // Release the lock before calling Save to prevent deadlock sd.sessionMutex.Unlock() // This is primarily for testing - in production w will often be nil var err error if w != nil { if r != nil && r.Header.Get("X-Test-Error") == "true" { // Return a test error without trying to save problematic data err = fmt.Errorf("test error triggered by X-Test-Error header") } else { err = sd.Save(r, w) } } sd.request = nil return err } // returnToPoolSafely safely returns the session to the object pool. // Add thread-safe helper method to return session to pool. // It ensures the session is marked as not in use and properly reset before pooling. func (sd *SessionData) returnToPoolSafely() { if sd != nil && sd.manager != nil { if sd.inUse { sd.inUse = false sd.Reset() sd.manager.sessionPool.Put(sd) atomic.AddInt64(&sd.manager.activeSessions, -1) } } } // clearTokenChunks clears and expires all token chunk sessions. // This is used during logout and session cleanup to ensure // all token data is properly removed from the client. // Parameters: // - r: The HTTP request context. // - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire. func (sd *SessionData) clearTokenChunks(_ *http.Request, chunks map[int]*sessions.Session) { for _, session := range chunks { clearSessionValues(session, true) } } // GetAuthenticated returns whether the user is currently authenticated. // It checks both the authentication flag and session timeout. // Returns: // - true if the user is authenticated and the session is not expired. // - false otherwise. func (sd *SessionData) GetAuthenticated() bool { sd.sessionMutex.RLock() defer sd.sessionMutex.RUnlock() return sd.getAuthenticatedUnsafe() } // getAuthenticatedUnsafe checks authentication status without acquiring locks. // Used when the mutex is already held to avoid deadlocks. // It validates both the authentication flag and session creation time. // Returns: // - true if authenticated and not expired, false otherwise. func (sd *SessionData) getAuthenticatedUnsafe() bool { auth, _ := sd.mainSession.Values["authenticated"].(bool) if !auth { return false } createdAt, ok := sd.mainSession.Values["created_at"].(int64) if !ok { return false } return time.Since(time.Unix(createdAt, 0)) <= sd.manager.sessionMaxAge } // SetAuthenticated sets the authentication status and manages session security. // When setting to true, it generates a new secure session ID and updates timestamps. // This prevents session fixation attacks by regenerating the session identifier. // Parameters: // - value: The authentication status to set. // // Returns: // - An error if generating a new session ID fails when setting value to true. func (sd *SessionData) SetAuthenticated(value bool) error { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() currentAuth := sd.getAuthenticatedUnsafe() changed := false if currentAuth != value { changed = true } if value { id, err := generateSecureRandomString(64) if err != nil { return fmt.Errorf("failed to generate secure session id: %w", err) } maxRetries := 5 for retry := 0; retry < maxRetries; retry++ { if sd.mainSession.ID != id { break } id, err = generateSecureRandomString(64) if err != nil { return fmt.Errorf("failed to generate secure session id on retry %d: %w", retry, err) } } if sd.mainSession.ID != id { changed = true } sd.mainSession.ID = id newCreationTime := time.Now().Unix() if oldTime, ok := sd.mainSession.Values["created_at"].(int64); !ok || oldTime != newCreationTime { changed = true } sd.mainSession.Values["created_at"] = newCreationTime if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value { changed = true } } else { if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value { changed = true } } sd.mainSession.Values["authenticated"] = value if changed { sd.dirty = true } return nil } // resetSession prepares a session for reuse by clearing its state. // This is specifically for pool reuse preparation to ensure // no data leaks between different user sessions. // Parameters: // - session: The session to reset for reuse. func resetSession(session *sessions.Session) { if session == nil { return } clearSessionValues(session, false) session.ID = "" session.IsNew = true } // Reset clears all session data and prepares the SessionData for reuse. // It ensures no authentication data persists when the object is reused // between different users/sessions. func (sd *SessionData) Reset() { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() sd.clearAllSessionData(nil, false) resetSession(sd.mainSession) resetSession(sd.accessSession) resetSession(sd.refreshSession) resetSession(sd.idTokenSession) // Clear combined chunks for k, session := range sd.combinedChunks { resetSession(session) delete(sd.combinedChunks, k) } // Clear redirect count to prevent leaking between sessions if sd.mainSession != nil && sd.mainSession.Values != nil { delete(sd.mainSession.Values, "redirect_count") } sd.dirty = false sd.inUse = false sd.request = nil sd.useCombinedStorage = true // Reset to use combined storage by default // Drop any cached claims so pooled SessionData does not leak claim data // between requests/users. sd.cachedClaimsToken = "" sd.cachedClaims = nil sd.cachedClaimsErr = nil // Reset the refresh mutex to ensure clean state // Note: We don't need to lock it since sessionMutex is already held // and this session is not in use by any request } // ReturnToPool manually returns the session to the object pool. // This is used in cleanup paths where Clear() is not called, to prevent memory leaks. // It only returns the session if it's not currently in use. func (sd *SessionData) ReturnToPool() { if sd != nil && sd.manager != nil { if !sd.inUse { sd.Reset() sd.manager.sessionPool.Put(sd) atomic.AddInt64(&sd.manager.activeSessions, -1) } } } // GetAccessToken retrieves the user's access token from session storage. // It handles both single-cookie storage and chunked storage for large tokens, // with automatic decompression if the token was compressed. // Returns: // - The complete, decompressed access token string, or an empty string if not found. func (sd *SessionData) GetAccessToken() string { sd.sessionMutex.RLock() defer sd.sessionMutex.RUnlock() return sd.getAccessTokenUnsafe() } // getAccessTokenUnsafe retrieves the access token without acquiring locks. // Enhanced token retrieval with comprehensive integrity checks and recovery mechanisms. // Used when the session mutex is already held to prevent deadlocks. // Returns: // - The complete access token string or empty string on error. func (sd *SessionData) getAccessTokenUnsafe() string { token, _ := sd.accessSession.Values["token"].(string) compressed, _ := sd.accessSession.Values["compressed"].(bool) // Debug: Check if manager/chunkManager is nil if sd.manager == nil || sd.manager.chunkManager == nil { // Direct return if no chunk manager (test scenario) return token } result := sd.manager.chunkManager.GetToken( token, compressed, sd.accessTokenChunks, AccessTokenConfig, ) if result.Error != nil { // Check if we have a raw token available // This handles cases where the token exists but doesn't validate as JWT if token != "" && !compressed && len(sd.accessTokenChunks) == 0 { // We have a non-chunked, non-compressed token that failed validation // Check if it's an opaque token (doesn't have JWT structure) dotCount := strings.Count(token, ".") if dotCount != 2 { // This is likely an opaque token that failed JWT validation // Return the raw token as-is since opaque tokens are valid if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debugf("Returning opaque access token (dots: %d) despite validation error: %v", dotCount, result.Error) } return token } } // For JWT validation errors or other issues, log and return empty if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debugf("ChunkManager.GetToken error: %v", result.Error) } return "" } return result.Token } // SetAccessToken stores an access token with automatic compression and chunking. // It validates token format, compresses if beneficial, and splits into chunks // if the token exceeds cookie size limits. Includes integrity verification. // Parameters: // - token: The access token string to store. func (sd *SessionData) SetAccessToken(token string) { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() if token != "" { if len(token) < 20 { if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token)) } return } } currentAccessToken := sd.getAccessTokenUnsafe() if constantTimeStringCompare(currentAccessToken, token) { return } sd.dirty = true // Debug: Check if accessSession is properly initialized if sd.accessSession == nil { if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Errorf("CRITICAL: accessSession is nil when trying to store token") } return } if sd.request != nil { sd.expireAccessTokenChunksEnhanced(nil) } for k := range sd.accessTokenChunks { delete(sd.accessTokenChunks, k) } if token == "" { if sd.accessSession != nil { sd.accessSession.Values["token"] = "" sd.accessSession.Values["compressed"] = false } return } compressed := compressToken(token) // Debug for test if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debugf("Token compression: original %d bytes, compressed %d bytes", len(token), len(compressed)) } if len(compressed) > 100*1024 { if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Info("Access token too large after compression (%d bytes) - storing uncompressed", len(compressed)) } return } if compressed != token { testDecompressed := decompressToken(compressed) if testDecompressed != token { if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debug("Access token compression verification failed - storing uncompressed") } compressed = token } } if len(compressed) <= maxCookieSize { if sd.accessSession != nil { sd.accessSession.Values["token"] = compressed sd.accessSession.Values["compressed"] = (compressed != token) // Debug for test if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debugf("Stored token in session: compressed=%v, token_len=%d", compressed != token, len(compressed)) } } } else { if sd.accessSession != nil { sd.accessSession.Values["token"] = "" sd.accessSession.Values["compressed"] = (compressed != token) } chunks := splitIntoChunks(compressed, maxCookieSize) if len(chunks) == 0 { sd.manager.logger.Error("Failed to create chunks for access token") return } if len(chunks) > 50 { sd.manager.logger.Info("Too many chunks (%d) for access token", len(chunks)) return } testReassembled := strings.Join(chunks, "") if testReassembled != compressed { sd.manager.logger.Debug("Access token chunk reassembly test failed") return } for i, chunkData := range chunks { sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i) if sd.request == nil { sd.manager.logger.Error("SetAccessToken: sd.request is nil, cannot create chunk session %s", sessionName) return } if chunkData == "" { sd.manager.logger.Debug("Empty chunk data at index %d", i) return } if len(chunkData) > maxCookieSize { sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize) return } if !validateChunkSize(chunkData) { sd.manager.logger.Errorf("CRITICAL: Chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData)) return } session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil { sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err) return } session.Values["token_chunk"] = chunkData session.Values["compressed"] = (compressed != token) session.Values["chunk_created_at"] = time.Now().Unix() sd.accessTokenChunks[i] = session } sd.manager.logger.Debugf("SUCCESS: Stored access token in %d chunks", len(chunks)) } } // GetRefreshToken retrieves the user's refresh token from session storage. // It handles both single-cookie storage and chunked storage for large tokens, // with automatic decompression if the token was compressed. // Returns: // - The complete, decompressed refresh token string, or an empty string if not found. func (sd *SessionData) GetRefreshToken() string { sd.sessionMutex.RLock() defer sd.sessionMutex.RUnlock() token, _ := sd.refreshSession.Values["token"].(string) compressed, _ := sd.refreshSession.Values["compressed"].(bool) result := sd.manager.chunkManager.GetToken( token, compressed, sd.refreshTokenChunks, RefreshTokenConfig, ) if result.Error != nil { // Check if we have a raw token available // This handles cases where the token exists but doesn't validate as JWT if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 { // We have a non-chunked, non-compressed token that failed validation // Check if it's an opaque token (doesn't have JWT structure) dotCount := strings.Count(token, ".") if dotCount != 2 { // This is likely an opaque token that failed JWT validation // Return the raw token as-is since opaque tokens are valid if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debugf("Returning opaque refresh token (dots: %d) despite validation error: %v", dotCount, result.Error) } return token } } // For JWT validation errors or other issues, log and return empty if sd.manager != nil && sd.manager.logger != nil { sd.manager.logger.Debugf("ChunkManager.GetToken error for refresh token: %v", result.Error) } return "" } return result.Token } // SetRefreshToken stores a refresh token with automatic compression and chunking. // It validates token size, compresses if beneficial, and splits into chunks // if needed. Includes comprehensive error checking and integrity verification. // Parameters: // - token: The refresh token string to store. func (sd *SessionData) SetRefreshToken(token string) { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() if len(token) > 50*1024 { sd.manager.logger.Errorf("CRITICAL: Refresh token too large (%d bytes) - possible corruption, rejecting", len(token)) return } // Get current refresh token without mutex to avoid deadlock since we already hold the lock var currentRefreshToken string sessionToken, _ := sd.refreshSession.Values["token"].(string) if sessionToken != "" { compressed, _ := sd.refreshSession.Values["compressed"].(bool) if compressed { decompressed := decompressToken(sessionToken) currentRefreshToken = decompressed } else { currentRefreshToken = sessionToken } } else if len(sd.refreshTokenChunks) > 0 { // Simplified chunked token retrieval for deadlock prevention var chunks []string for i := 0; i < len(sd.refreshTokenChunks); i++ { if session, ok := sd.refreshTokenChunks[i]; ok { if chunk, chunkOk := session.Values["token_chunk"].(string); chunkOk && chunk != "" { chunks = append(chunks, chunk) } } } if len(chunks) == len(sd.refreshTokenChunks) { reassembled := strings.Join(chunks, "") compressed, _ := sd.refreshSession.Values["compressed"].(bool) if compressed { currentRefreshToken = decompressToken(reassembled) } else { currentRefreshToken = reassembled } } } if constantTimeStringCompare(currentRefreshToken, token) { return } sd.dirty = true if sd.request != nil { sd.expireRefreshTokenChunksEnhanced(nil) } for k := range sd.refreshTokenChunks { delete(sd.refreshTokenChunks, k) } if token == "" { sd.refreshSession.Values["token"] = "" sd.refreshSession.Values["compressed"] = false return } compressed := compressToken(token) if compressed != token { testDecompressed := decompressToken(compressed) if testDecompressed != token { sd.manager.logger.Errorf("CRITICAL: Refresh token compression verification failed - storing uncompressed") compressed = token } } if len(compressed) <= maxCookieSize { sd.refreshSession.Values["token"] = compressed sd.refreshSession.Values["compressed"] = (compressed != token) sd.refreshSession.Values["issued_at"] = time.Now().Unix() } else { sd.refreshSession.Values["token"] = "" sd.refreshSession.Values["compressed"] = (compressed != token) sd.refreshSession.Values["issued_at"] = time.Now().Unix() chunks := splitIntoChunks(compressed, maxCookieSize) if len(chunks) == 0 { sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for refresh token") return } if len(chunks) > 50 { sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for refresh token - possible corruption", len(chunks)) return } testReassembled := strings.Join(chunks, "") if testReassembled != compressed { sd.manager.logger.Errorf("CRITICAL: Refresh token chunk reassembly test failed") return } for i, chunkData := range chunks { sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i) if sd.request == nil { sd.manager.logger.Errorf("CRITICAL: SetRefreshToken: sd.request is nil, cannot create chunk session %s", sessionName) return } if chunkData == "" { sd.manager.logger.Errorf("CRITICAL: Empty refresh token chunk data at index %d", i) return } if len(chunkData) > maxCookieSize { sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize) return } if !validateChunkSize(chunkData) { sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData)) return } session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil { sd.manager.logger.Errorf("CRITICAL: Failed to get refresh token chunk session %s: %v", sessionName, err) return } session.Values["token_chunk"] = chunkData session.Values["compressed"] = (compressed != token) session.Values["chunk_created_at"] = time.Now().Unix() sd.refreshTokenChunks[i] = session } sd.manager.logger.Debugf("SUCCESS: Stored refresh token in %d chunks", len(chunks)) } } // GetRefreshTokenIssuedAt retrieves the timestamp when the refresh token was issued/stored. // Returns the time when the current refresh token was obtained, or zero time if not available. func (sd *SessionData) GetRefreshTokenIssuedAt() time.Time { sd.sessionMutex.RLock() defer sd.sessionMutex.RUnlock() if issuedAtUnix, ok := sd.refreshSession.Values["issued_at"].(int64); ok { return time.Unix(issuedAtUnix, 0) } // For chunked tokens, check the first chunk for timestamp if len(sd.refreshTokenChunks) > 0 { if session, exists := sd.refreshTokenChunks[0]; exists { if chunkCreatedAt, ok := session.Values["chunk_created_at"].(int64); ok { return time.Unix(chunkCreatedAt, 0) } } } return time.Time{} } // expireAccessTokenChunksEnhanced expires all access token chunks and detects orphaned chunks. // It searches for all existing chunks, identifies orphaned or expired chunks, // and properly expires them to prevent cookie accumulation. // Parameters: // - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent. func (sd *SessionData) expireAccessTokenChunksEnhanced(w http.ResponseWriter) { const maxChunkSearchLimit = 50 orphanedChunks := 0 for i := 0; i < maxChunkSearchLimit; i++ { sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i) session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil { break } if session.IsNew { break } if chunk, exists := session.Values["token_chunk"]; exists { if createdAt, ok := session.Values["chunk_created_at"].(int64); ok { chunkAge := time.Since(time.Unix(createdAt, 0)) if chunkAge > 24*time.Hour { orphanedChunks++ sd.manager.logger.Debugf("Found orphaned access token chunk %d (age: %v)", i, chunkAge) } } else if chunk != nil { orphanedChunks++ sd.manager.logger.Debugf("Found access token chunk %d without timestamp, treating as orphaned", i) } } session.Options.MaxAge = -1 session.Values = make(map[interface{}]interface{}) if w != nil { if err := session.Save(sd.request, w); err != nil { sd.manager.logger.Errorf("failed to save expired access token chunk %d: %v", i, err) } } } if orphanedChunks > 0 { sd.manager.logger.Infof("Cleaned up %d orphaned access token chunks", orphanedChunks) } } // expireRefreshTokenChunksEnhanced expires all refresh token chunks and detects orphaned chunks. // It searches for all existing chunks, identifies orphaned or expired chunks, // and properly expires them to prevent cookie accumulation. // Parameters: // - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent. func (sd *SessionData) expireRefreshTokenChunksEnhanced(w http.ResponseWriter) { const maxChunkSearchLimit = 50 orphanedChunks := 0 for i := 0; i < maxChunkSearchLimit; i++ { sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i) session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil { break } if session.IsNew { break } if chunk, exists := session.Values["token_chunk"]; exists { if createdAt, ok := session.Values["chunk_created_at"].(int64); ok { chunkAge := time.Since(time.Unix(createdAt, 0)) if chunkAge > 24*time.Hour { orphanedChunks++ sd.manager.logger.Debugf("Found orphaned refresh token chunk %d (age: %v)", i, chunkAge) } } else if chunk != nil { orphanedChunks++ sd.manager.logger.Debugf("Found refresh token chunk %d without timestamp, treating as orphaned", i) } } session.Options.MaxAge = -1 session.Values = make(map[interface{}]interface{}) if w != nil { if err := session.Save(sd.request, w); err != nil { sd.manager.logger.Errorf("failed to save expired refresh token chunk %d: %v", i, err) } } } if orphanedChunks > 0 { sd.manager.logger.Infof("Cleaned up %d orphaned refresh token chunks", orphanedChunks) } } // expireIDTokenChunksEnhanced expires all ID token chunks and detects orphaned chunks. // It searches for all existing chunks, identifies orphaned or expired chunks, // and properly expires them to prevent cookie accumulation. // Parameters: // - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent. func (sd *SessionData) expireIDTokenChunksEnhanced(w http.ResponseWriter) { const maxChunkSearchLimit = 50 orphanedChunks := 0 for i := 0; i < maxChunkSearchLimit; i++ { sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i) session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil { break } if session.IsNew { break } if chunk, exists := session.Values["token_chunk"]; exists { if createdAt, ok := session.Values["chunk_created_at"].(int64); ok { chunkAge := time.Since(time.Unix(createdAt, 0)) if chunkAge > 24*time.Hour { orphanedChunks++ sd.manager.logger.Debugf("Found orphaned ID token chunk %d (age: %v)", i, chunkAge) } } else if chunk != nil { orphanedChunks++ sd.manager.logger.Debugf("Found ID token chunk %d without timestamp, treating as orphaned", i) } } session.Options.MaxAge = -1 session.Values = make(map[interface{}]interface{}) if w != nil { if err := session.Save(sd.request, w); err != nil { sd.manager.logger.Errorf("failed to save expired ID token chunk %d: %v", i, err) } } } if orphanedChunks > 0 { sd.manager.logger.Infof("Cleaned up %d orphaned ID token chunks", orphanedChunks) } } // splitIntoChunks divides a string into chunks of specified maximum size. // It ensures chunks don't exceed browser cookie limits and handles // the string splitting logic for large token storage. // Parameters: // - s: The string to split into chunks. // - chunkSize: The maximum size for each chunk. // // Returns: // - A slice of strings representing the chunks. func splitIntoChunks(s string, chunkSize int) []string { effectiveChunkSize := min(chunkSize, maxCookieSize) var chunks []string for len(s) > 0 { if len(s) > effectiveChunkSize { chunks = append(chunks, s[:effectiveChunkSize]) s = s[effectiveChunkSize:] } else { chunks = append(chunks, s) break } } return chunks } // validateChunkSize checks if a chunk will fit within browser cookie limits. // It estimates the encoded size including cookie overhead and headers // to ensure the chunk won't exceed browser-imposed cookie size limits. // Parameters: // - chunkData: The chunk data to validate. // // Returns: // - true if the chunk is safe to store, false if it may exceed browser limits. func validateChunkSize(chunkData string) bool { // Estimate ~50% overhead for encoding, compare against ~2x maxCookieSize limit estimatedEncodedSize := len(chunkData) + (len(chunkData) * 50 / 100) return estimatedEncodedSize <= maxCookieSize*2 } // isCorruptionMarker detects if data contains known corruption indicators. // It checks for specific corruption markers and invalid characters // that indicate the data has been tampered with or corrupted. // Parameters: // - data: The data string to check for corruption markers. // // Returns: // - true if the data contains corruption markers, false otherwise. func isCorruptionMarker(data string) bool { if data == "" { return false } corruptionMarkers := []string{ "__CORRUPTION_MARKER_TEST__", "__INVALID_BASE64_DATA__", "__CORRUPTED_CHUNK_DATA__", "!@#$%^&*()", "<<>>", } for _, marker := range corruptionMarkers { if data == marker { return true } } if len(data) > 10 { invalidChars := "!@#$%^&*(){}[]|\\:;\"'<>?,`~" for _, char := range invalidChars { if strings.ContainsRune(data, char) { return true } } } return false } // GetCSRF retrieves the CSRF token for state validation. // This token is used to prevent cross-site request forgery attacks // during the OIDC authentication flow. // Returns: // - The CSRF token string, or an empty string if not set. func (sd *SessionData) GetCSRF() string { csrf, _ := sd.mainSession.Values["csrf"].(string) return csrf } // SetCSRF stores the CSRF token for state validation. // The token is used to validate the state parameter in OAuth callbacks. // Parameters: // - token: The CSRF token to store. func (sd *SessionData) SetCSRF(token string) { currentVal, _ := sd.mainSession.Values["csrf"].(string) if currentVal != token { sd.mainSession.Values["csrf"] = token sd.dirty = true } } // GetNonce retrieves the nonce for ID token validation. // The nonce prevents replay attacks by ensuring ID tokens // were issued in response to the specific authentication request. // Returns: // - The nonce string, or an empty string if not set. func (sd *SessionData) GetNonce() string { nonce, _ := sd.mainSession.Values["nonce"].(string) return nonce } // SetNonce stores the nonce for ID token validation. // The nonce will be validated against the nonce claim in received ID tokens. // Parameters: // - nonce: The nonce string to store. func (sd *SessionData) SetNonce(nonce string) { currentVal, _ := sd.mainSession.Values["nonce"].(string) if currentVal != nonce { sd.mainSession.Values["nonce"] = nonce sd.dirty = true } } // GetCodeVerifier retrieves the PKCE code verifier. // This is used in the PKCE (Proof Key for Code Exchange) flow // to enhance security for public clients. // Returns: // - The code verifier string, or an empty string if not set or PKCE is disabled. func (sd *SessionData) GetCodeVerifier() string { codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string) return codeVerifier } // SetCodeVerifier stores the PKCE code verifier. // The code verifier is used to generate the code challenge sent to the // authorization server and validated during token exchange. // Parameters: // - codeVerifier: The PKCE code verifier string to store. func (sd *SessionData) SetCodeVerifier(codeVerifier string) { currentVal, _ := sd.mainSession.Values["code_verifier"].(string) if currentVal != codeVerifier { sd.mainSession.Values["code_verifier"] = codeVerifier sd.dirty = true } } // GetEmail retrieves the authenticated user's email address. // The email is extracted from ID token claims and used for // authorization decisions and header injection. // Returns: // - The user's email address string, or an empty string if not set. func (sd *SessionData) GetEmail() string { sd.sessionMutex.RLock() defer sd.sessionMutex.RUnlock() email, _ := sd.mainSession.Values["email"].(string) return email } // SetEmail stores the authenticated user's email address. // The email is typically extracted from the 'email' claim in the ID token. // Parameters: // - email: The user's email address to store. func (sd *SessionData) SetEmail(email string) { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() currentVal, _ := sd.mainSession.Values["email"].(string) if currentVal != email { sd.mainSession.Values["email"] = email sd.dirty = true } } // GetIncomingPath retrieves the original request URI that triggered authentication. // This path is used to redirect the user back to their intended destination // after successful authentication. // Returns: // - The original request URI string, or an empty string if not set. func (sd *SessionData) GetIncomingPath() string { path, _ := sd.mainSession.Values["incoming_path"].(string) return path } // SetIncomingPath stores the original request URI for post-authentication redirect. // This allows the user to be redirected to their originally requested resource // after completing the authentication flow. // Parameters: // - path: The original request URI string (e.g., "/protected/resource?id=123"). func (sd *SessionData) SetIncomingPath(path string) { currentVal, _ := sd.mainSession.Values["incoming_path"].(string) if currentVal != path { sd.mainSession.Values["incoming_path"] = path sd.dirty = true } } // GetIDToken retrieves the user's ID token from session storage. // The ID token contains user claims and is used for user identification // and authorization decisions. Handles compression and chunking automatically. // Returns: // - The complete, decompressed ID token string, or an empty string if not found. func (sd *SessionData) GetIDToken() string { sd.sessionMutex.RLock() defer sd.sessionMutex.RUnlock() return sd.getIDTokenUnsafe() } // GetIDTokenClaims returns claims parsed from the current ID token, caching // the result on the SessionData so repeated callers within the same request // do not re-parse the JWT. The cache is keyed on the ID token string and is // cleared when the SessionData is reset (see Reset) or when the ID token // changes (e.g. after a refresh). // // The parser parameter is typically the TraefikOidc.extractClaimsFunc, which // lets tests inject mocks just like the direct call it replaces. // // Returns an empty claims map and a nil error when the session has no ID // token, matching the existing "no-op" behavior of the caller sites. func (sd *SessionData) GetIDTokenClaims(parser func(string) (map[string]interface{}, error)) (map[string]interface{}, error) { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() token := sd.getIDTokenUnsafe() if token == "" { // Invalidate any stale cache without running the parser. sd.cachedClaimsToken = "" sd.cachedClaims = nil sd.cachedClaimsErr = nil return nil, nil } if sd.cachedClaimsToken == token && (sd.cachedClaims != nil || sd.cachedClaimsErr != nil) { return sd.cachedClaims, sd.cachedClaimsErr } claims, err := parser(token) sd.cachedClaimsToken = token sd.cachedClaims = claims sd.cachedClaimsErr = err return claims, err } // getIDTokenUnsafe retrieves the ID token without acquiring locks. // Enhanced ID token retrieval with comprehensive integrity checks and chunking support. // Used when the session mutex is already held to prevent deadlocks. // Returns: // - The complete ID token string or empty string on error. func (sd *SessionData) getIDTokenUnsafe() string { token, _ := sd.idTokenSession.Values["token"].(string) compressed, _ := sd.idTokenSession.Values["compressed"].(bool) // Debug: Check if manager/chunkManager is nil if sd.manager == nil || sd.manager.chunkManager == nil { // Direct return if no chunk manager (test scenario) return token } result := sd.manager.chunkManager.GetToken( token, compressed, sd.idTokenChunks, IDTokenConfig, ) if result.Error != nil { return "" } return result.Token } // getRefreshTokenUnsafe retrieves the refresh token without acquiring locks. // Used when the session mutex is already held to prevent deadlocks. func (sd *SessionData) getRefreshTokenUnsafe() string { token, _ := sd.refreshSession.Values["token"].(string) compressed, _ := sd.refreshSession.Values["compressed"].(bool) if sd.manager == nil || sd.manager.chunkManager == nil { return token } result := sd.manager.chunkManager.GetToken( token, compressed, sd.refreshTokenChunks, RefreshTokenConfig, ) if result.Error != nil { // Handle opaque tokens if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 { if strings.Count(token, ".") != 2 { return token } } return "" } return result.Token } // getEmailUnsafe retrieves the email without acquiring locks. func (sd *SessionData) getEmailUnsafe() string { email, _ := sd.mainSession.Values["email"].(string) return email } // getCSRFUnsafe retrieves the CSRF token without acquiring locks. func (sd *SessionData) getCSRFUnsafe() string { csrf, _ := sd.mainSession.Values["csrf"].(string) return csrf } // getNonceUnsafe retrieves the nonce without acquiring locks. func (sd *SessionData) getNonceUnsafe() string { nonce, _ := sd.mainSession.Values["nonce"].(string) return nonce } // getCodeVerifierUnsafe retrieves the code verifier without acquiring locks. func (sd *SessionData) getCodeVerifierUnsafe() string { codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string) return codeVerifier } // getIncomingPathUnsafe retrieves the incoming path without acquiring locks. func (sd *SessionData) getIncomingPathUnsafe() string { path, _ := sd.mainSession.Values["incoming_path"].(string) return path } // getCreatedAtUnsafe retrieves the created_at timestamp without acquiring locks. func (sd *SessionData) getCreatedAtUnsafe() int64 { createdAt, _ := sd.mainSession.Values["created_at"].(int64) return createdAt } // getRedirectCountUnsafe retrieves the redirect count without acquiring locks. func (sd *SessionData) getRedirectCountUnsafe() int { count, _ := sd.mainSession.Values["redirect_count"].(int) return count } // SetIDToken stores an ID token with automatic compression and chunking. // It validates the JWT format, compresses if beneficial, and splits into chunks // if the token exceeds cookie size limits. Includes comprehensive validation. // Parameters: // - token: The ID token string to store. func (sd *SessionData) SetIDToken(token string) { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() if token != "" { dotCount := strings.Count(token, ".") if dotCount != 2 { sd.manager.logger.Errorf("CRITICAL: Attempt to store invalid JWT ID token format (dots: %d) - rejecting", dotCount) return } } if len(token) > 50*1024 { sd.manager.logger.Errorf("CRITICAL: ID token too large (%d bytes) - possible corruption, rejecting", len(token)) return } currentIDToken := sd.getIDTokenUnsafe() if constantTimeStringCompare(currentIDToken, token) { return } sd.dirty = true if sd.request != nil { sd.expireIDTokenChunksEnhanced(nil) } for k := range sd.idTokenChunks { delete(sd.idTokenChunks, k) } if token == "" { if sd.idTokenSession != nil { sd.idTokenSession.Values["token"] = "" sd.idTokenSession.Values["compressed"] = false } return } compressed := compressToken(token) if compressed != token { testDecompressed := decompressToken(compressed) if testDecompressed != token { sd.manager.logger.Errorf("CRITICAL: ID token compression verification failed - storing uncompressed") compressed = token } } if len(compressed) <= maxCookieSize { if sd.idTokenSession != nil { sd.idTokenSession.Values["token"] = compressed sd.idTokenSession.Values["compressed"] = (compressed != token) } } else { if sd.idTokenSession != nil { sd.idTokenSession.Values["token"] = "" sd.idTokenSession.Values["compressed"] = (compressed != token) } chunks := splitIntoChunks(compressed, maxCookieSize) if len(chunks) == 0 { sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for ID token") return } if len(chunks) > 50 { sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for ID token - possible corruption", len(chunks)) return } testReassembled := strings.Join(chunks, "") if testReassembled != compressed { sd.manager.logger.Errorf("CRITICAL: ID token chunk reassembly test failed") return } for i, chunkData := range chunks { sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i) if sd.request == nil { sd.manager.logger.Errorf("CRITICAL: SetIDToken: sd.request is nil, cannot create chunk session %s", sessionName) return } if chunkData == "" { sd.manager.logger.Debug("Empty chunk data at index %d", i) return } if len(chunkData) > maxCookieSize { sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize) return } if !validateChunkSize(chunkData) { sd.manager.logger.Errorf("CRITICAL: ID token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData)) return } session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil { sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err) return } session.Values["token_chunk"] = chunkData session.Values["compressed"] = (compressed != token) session.Values["chunk_created_at"] = time.Now().Unix() sd.idTokenChunks[i] = session } sd.manager.logger.Debugf("SUCCESS: Stored ID token in %d chunks", len(chunks)) } } // GetRedirectCount returns the number of redirects in the current authentication flow. // STABILITY FIX: Prevents infinite redirect loops by tracking redirect attempts. // Returns: // - The current redirect count, 0 if not set. func (sd *SessionData) GetRedirectCount() int { if count, ok := sd.mainSession.Values["redirect_count"].(int); ok { return count } return 0 } // IncrementRedirectCount increases the redirect counter by one. // STABILITY FIX: Prevents infinite redirect loops by tracking successive redirects. // Used to detect potential redirect loops and abort authentication if too many occur. func (sd *SessionData) IncrementRedirectCount() { currentCount := sd.GetRedirectCount() sd.mainSession.Values["redirect_count"] = currentCount + 1 sd.dirty = true } // ResetRedirectCount resets the redirect counter to zero. // STABILITY FIX: Prevents infinite redirect loops by clearing the counter // when authentication completes successfully or when starting a new flow. func (sd *SessionData) ResetRedirectCount() { sd.mainSession.Values["redirect_count"] = 0 sd.dirty = true }