mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
c474bbafd6
* Cleanup excessive comments. * Remove leftovers hanging around from previous refactor * Improve test coverage
2239 lines
69 KiB
Go
2239 lines
69 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"github.com/lukaszraczylo/traefikoidc/internal/pool"
|
|
)
|
|
|
|
// constantTimeStringCompare performs a constant-time comparison of two strings
|
|
// to prevent timing attacks. Returns true if the strings are equal.
|
|
func constantTimeStringCompare(a, b string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
|
}
|
|
|
|
// min returns the minimum of two integers.
|
|
// This is a utility function used throughout the session management code.
|
|
// Parameters:
|
|
// - a: The first integer to compare.
|
|
// - b: The second integer to compare.
|
|
//
|
|
// Returns:
|
|
// - The smaller of the two integers.
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// generateSecureRandomString creates a cryptographically secure random string.
|
|
// It generates random bytes using crypto/rand and encodes them as hexadecimal.
|
|
// This is used for session IDs and other security-sensitive random values.
|
|
// Parameters:
|
|
// - length: The number of random bytes to generate (output will be 2x this length in hex).
|
|
//
|
|
// Returns:
|
|
// - The hex-encoded random string.
|
|
// - An error if reading random bytes fails.
|
|
func generateSecureRandomString(length int) (string, error) {
|
|
bytes := make([]byte, length)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|
|
|
|
// Cookie name suffixes used for session management
|
|
// These are appended to the cookiePrefix to create full cookie names
|
|
// #nosec G101 -- These are cookie name suffixes, not hardcoded credentials
|
|
const (
|
|
mainCookieSuffix = "m"
|
|
accessTokenSuffix = "a"
|
|
refreshTokenSuffix = "r"
|
|
idTokenSuffix = "id"
|
|
defaultCookiePrefix = "_oidc_raczylo_"
|
|
)
|
|
|
|
const (
|
|
maxBrowserCookieSize = 3500
|
|
|
|
maxCookieSize = 1200
|
|
|
|
absoluteSessionTimeout = 24 * time.Hour
|
|
|
|
minEncryptionKeyLength = 32
|
|
)
|
|
|
|
// compressToken compresses a JWT token using gzip compression if beneficial.
|
|
// It validates the token format, attempts compression, and verifies the compressed
|
|
// data can be decompressed correctly. Only compresses if it reduces size.
|
|
// Parameters:
|
|
// - token: The JWT token string to potentially compress.
|
|
//
|
|
// Returns:
|
|
// - The base64 encoded, gzipped string, or the original string if compression fails.
|
|
func compressToken(token string) string {
|
|
if token == "" {
|
|
return token
|
|
}
|
|
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
return token
|
|
}
|
|
|
|
if len(token) > 50*1024 {
|
|
return token
|
|
}
|
|
|
|
pm := pool.Get()
|
|
b := pm.GetBuffer(4096)
|
|
defer pm.PutBuffer(b)
|
|
|
|
gz := gzip.NewWriter(b)
|
|
|
|
written, err := gz.Write([]byte(token))
|
|
if err != nil || written != len(token) {
|
|
return token
|
|
}
|
|
|
|
if err := gz.Close(); err != nil {
|
|
return token
|
|
}
|
|
|
|
compressedBytes := b.Bytes()
|
|
if len(compressedBytes) == 0 {
|
|
return token
|
|
}
|
|
|
|
compressed := base64.StdEncoding.EncodeToString(compressedBytes)
|
|
|
|
if len(compressed) >= len(token) {
|
|
return token
|
|
}
|
|
|
|
decompressed := decompressTokenInternal(compressed)
|
|
if decompressed != token {
|
|
return token
|
|
}
|
|
|
|
if strings.Count(decompressed, ".") != 2 {
|
|
return token
|
|
}
|
|
|
|
return compressed
|
|
}
|
|
|
|
// decompressToken decompresses a previously compressed token string.
|
|
// It decodes the base64 data, validates gzip headers, and decompresses safely
|
|
// with size limits to prevent compression bombs.
|
|
// Parameters:
|
|
// - compressed: The base64-encoded compressed token string.
|
|
//
|
|
// Returns:
|
|
// - The decompressed original string, or the input string if decompression fails.
|
|
func decompressToken(compressed string) string {
|
|
return decompressTokenInternal(compressed)
|
|
}
|
|
|
|
// decompressTokenInternal is the internal decompression function.
|
|
// Separated internal function for integrity verification during compression.
|
|
// It performs the actual decompression logic with proper resource management.
|
|
// Parameters:
|
|
// - compressed: The compressed token string to decompress.
|
|
//
|
|
// Returns:
|
|
// - The decompressed token or the original string if decompression fails.
|
|
func decompressTokenInternal(compressed string) string {
|
|
if compressed == "" {
|
|
return compressed
|
|
}
|
|
|
|
if len(compressed) > 100*1024 {
|
|
return compressed
|
|
}
|
|
|
|
data, err := base64.StdEncoding.DecodeString(compressed)
|
|
if err != nil {
|
|
return compressed
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return compressed
|
|
}
|
|
|
|
if len(data) < 2 || data[0] != 0x1f || data[1] != 0x8b {
|
|
return compressed
|
|
}
|
|
|
|
pm := pool.Get()
|
|
readerBuf := pm.GetHTTPResponseBuffer()
|
|
defer pm.PutHTTPResponseBuffer(readerBuf)
|
|
|
|
gz, err := gzip.NewReader(bytes.NewReader(data))
|
|
if err != nil {
|
|
return compressed
|
|
}
|
|
|
|
defer func() {
|
|
if closeErr := gz.Close(); closeErr != nil {
|
|
_ = closeErr
|
|
}
|
|
}()
|
|
|
|
limitedReader := io.LimitReader(gz, 500*1024)
|
|
|
|
if cap(readerBuf) >= 512*1024 {
|
|
readerBuf = readerBuf[:cap(readerBuf)]
|
|
n, err := limitedReader.Read(readerBuf)
|
|
if err != nil && err != io.EOF {
|
|
return compressed
|
|
}
|
|
decompressed := readerBuf[:n]
|
|
return string(decompressed)
|
|
}
|
|
|
|
decompressed, err := io.ReadAll(limitedReader)
|
|
if err != nil {
|
|
return compressed
|
|
}
|
|
|
|
if len(decompressed) == 0 {
|
|
return compressed
|
|
}
|
|
|
|
decompressedStr := string(decompressed)
|
|
|
|
if decompressedStr != "" && strings.Count(decompressedStr, ".") != 2 {
|
|
return compressed
|
|
}
|
|
|
|
return decompressedStr
|
|
}
|
|
|
|
// SessionManager manages OIDC session state and cookie-based storage.
|
|
// It provides secure session management with support for token compression,
|
|
// chunked storage for large tokens, session pooling for performance,
|
|
// session object reuse and supports both HTTP and HTTPS schemes.
|
|
type SessionManager struct {
|
|
sessionPool sync.Pool
|
|
store sessions.Store
|
|
logger *Logger
|
|
chunkManager *ChunkManager
|
|
cookieDomain string
|
|
cookiePrefix string // Prefix for cookie names (default: "_oidc_raczylo_")
|
|
sessionMaxAge time.Duration // Maximum session age (default: 24 hours)
|
|
cleanupMutex sync.RWMutex
|
|
forceHTTPS bool
|
|
cleanupDone bool
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
memoryMonitor *TaskMemoryMonitor
|
|
activeSessions int64
|
|
poolHits int64
|
|
poolMisses int64
|
|
shutdownOnce sync.Once
|
|
}
|
|
|
|
// NewSessionManager creates a new SessionManager instance with secure defaults.
|
|
// It initializes the cookie store with encryption, sets up session pooling,
|
|
// and configures chunk management for large tokens.
|
|
// Parameters:
|
|
// - encryptionKey: The key for encrypting session cookies (minimum 32 bytes).
|
|
// - forceHTTPS: Whether to force HTTPS-only cookies regardless of request scheme.
|
|
// - cookieDomain: The domain for session cookies (empty for auto-detection).
|
|
// - cookiePrefix: Prefix for session cookie names (empty for default "_oidc_raczylo_").
|
|
// - sessionMaxAge: Maximum session age duration (0 for default 24 hours).
|
|
// - logger: Logger instance for debug and error logging.
|
|
//
|
|
// Returns:
|
|
// - The configured SessionManager instance.
|
|
// - An error if the encryption key does not meet minimum length requirements.
|
|
func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain string, cookiePrefix string, sessionMaxAge time.Duration, logger *Logger) (*SessionManager, error) {
|
|
if len(encryptionKey) < minEncryptionKeyLength {
|
|
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
|
}
|
|
|
|
// Set default cookie prefix if not provided
|
|
if cookiePrefix == "" {
|
|
cookiePrefix = defaultCookiePrefix
|
|
}
|
|
|
|
// Set default session max age if not provided (24 hours for backward compatibility)
|
|
if sessionMaxAge == 0 {
|
|
sessionMaxAge = absoluteSessionTimeout
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
sm := &SessionManager{
|
|
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
|
forceHTTPS: forceHTTPS,
|
|
cookieDomain: cookieDomain,
|
|
cookiePrefix: cookiePrefix,
|
|
sessionMaxAge: sessionMaxAge,
|
|
logger: logger,
|
|
chunkManager: NewChunkManager(logger),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
// Initialize global memory monitoring (singleton)
|
|
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
|
|
|
|
// Start memory monitoring every 30 seconds (will skip if already started)
|
|
if err := sm.memoryMonitor.Start(30 * time.Second); err != nil {
|
|
logger.Debugf("Failed to start memory monitoring: %v", err)
|
|
}
|
|
|
|
sm.sessionPool.New = func() interface{} {
|
|
atomic.AddInt64(&sm.poolMisses, 1)
|
|
sd := &SessionData{
|
|
manager: sm,
|
|
accessTokenChunks: make(map[int]*sessions.Session),
|
|
refreshTokenChunks: make(map[int]*sessions.Session),
|
|
idTokenChunks: make(map[int]*sessions.Session),
|
|
refreshMutex: sync.Mutex{},
|
|
sessionMutex: sync.RWMutex{},
|
|
dirty: false,
|
|
inUse: false,
|
|
}
|
|
sd.Reset()
|
|
return sd
|
|
}
|
|
|
|
// Start background cleanup routine
|
|
go sm.backgroundCleanup()
|
|
|
|
return sm, nil
|
|
}
|
|
|
|
// Cookie name helper methods - build cookie names using the configured prefix
|
|
|
|
// mainCookieName returns the main session cookie name with the configured prefix
|
|
func (sm *SessionManager) mainCookieName() string {
|
|
return sm.cookiePrefix + mainCookieSuffix
|
|
}
|
|
|
|
// accessTokenCookieName returns the access token cookie name with the configured prefix
|
|
func (sm *SessionManager) accessTokenCookieName() string {
|
|
return sm.cookiePrefix + accessTokenSuffix
|
|
}
|
|
|
|
// refreshTokenCookieName returns the refresh token cookie name with the configured prefix
|
|
func (sm *SessionManager) refreshTokenCookieName() string {
|
|
return sm.cookiePrefix + refreshTokenSuffix
|
|
}
|
|
|
|
// idTokenCookieName returns the ID token cookie name with the configured prefix
|
|
func (sm *SessionManager) idTokenCookieName() string {
|
|
return sm.cookiePrefix + idTokenSuffix
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the SessionManager and all its background tasks
|
|
func (sm *SessionManager) Shutdown() error {
|
|
var shutdownErr error
|
|
sm.shutdownOnce.Do(func() {
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("SessionManager shutdown initiated")
|
|
}
|
|
|
|
// Cancel context to stop all background operations
|
|
if sm.cancel != nil {
|
|
sm.cancel()
|
|
}
|
|
|
|
// Stop memory monitor
|
|
if sm.memoryMonitor != nil {
|
|
sm.memoryMonitor.Stop()
|
|
}
|
|
|
|
// Stop chunk manager
|
|
if sm.chunkManager != nil {
|
|
sm.chunkManager.Shutdown()
|
|
}
|
|
|
|
// Force garbage collection to help cleanup
|
|
runtime.GC()
|
|
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("SessionManager shutdown completed")
|
|
}
|
|
})
|
|
return shutdownErr
|
|
}
|
|
|
|
// backgroundCleanup runs periodic cleanup tasks for session management
|
|
func (sm *SessionManager) backgroundCleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-sm.ctx.Done():
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("Background cleanup routine terminated")
|
|
}
|
|
return
|
|
case <-ticker.C:
|
|
sm.performCleanupCycle()
|
|
}
|
|
}
|
|
}
|
|
|
|
// performCleanupCycle executes a complete cleanup cycle
|
|
func (sm *SessionManager) performCleanupCycle() {
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("Starting background cleanup cycle")
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Run periodic chunk cleanup
|
|
sm.PeriodicChunkCleanup()
|
|
|
|
// Clean up session pool by forcing GC on old sessions
|
|
sm.cleanupSessionPool()
|
|
|
|
// Force garbage collection if memory usage is high
|
|
var m runtime.MemStats
|
|
runtime.ReadMemStats(&m)
|
|
if m.HeapAlloc > 50*1024*1024 { // 50MB threshold
|
|
runtime.GC()
|
|
if sm.logger != nil {
|
|
sm.logger.Debug("Forced garbage collection due to high memory usage")
|
|
}
|
|
}
|
|
|
|
duration := time.Since(startTime)
|
|
if sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
sm.logger.Debugf("Cleanup cycle completed in %v", duration)
|
|
}
|
|
}
|
|
|
|
// cleanupSessionPool performs cleanup on the session pool
|
|
func (sm *SessionManager) cleanupSessionPool() {
|
|
cleaned := 0
|
|
const maxCleanup = 20 // Limit cleanup per cycle to avoid performance impact
|
|
|
|
for i := 0; i < maxCleanup; i++ {
|
|
select {
|
|
case <-sm.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if poolSession := sm.sessionPool.Get(); poolSession != nil {
|
|
sessionData, ok := poolSession.(*SessionData)
|
|
if ok && sessionData != nil && !sessionData.inUse {
|
|
sessionData.Reset()
|
|
cleaned++
|
|
}
|
|
sm.sessionPool.Put(poolSession)
|
|
} else {
|
|
break // Pool is empty
|
|
}
|
|
}
|
|
|
|
if cleaned > 0 && sm.logger != nil && sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
sm.logger.Debugf("Cleaned %d session pool objects", cleaned)
|
|
}
|
|
}
|
|
|
|
// GetSessionStats returns statistics about session management
|
|
func (sm *SessionManager) GetSessionStats() map[string]interface{} {
|
|
stats := make(map[string]interface{})
|
|
stats["active_sessions"] = atomic.LoadInt64(&sm.activeSessions)
|
|
stats["pool_hits"] = atomic.LoadInt64(&sm.poolHits)
|
|
stats["pool_misses"] = atomic.LoadInt64(&sm.poolMisses)
|
|
|
|
if sm.memoryMonitor != nil {
|
|
if currentStats, err := sm.memoryMonitor.GetCurrentStats(); err == nil {
|
|
stats["goroutines"] = currentStats.Goroutines
|
|
stats["heap_alloc"] = currentStats.HeapAlloc
|
|
stats["num_gc"] = currentStats.NumGC
|
|
}
|
|
}
|
|
|
|
return stats
|
|
}
|
|
|
|
// PeriodicChunkCleanup performs comprehensive session maintenance and cleanup.
|
|
// It cleans up orphaned token chunks, expired sessions, and unused pool objects.
|
|
// This helps maintain performance and prevent cookie accumulation in client browsers.
|
|
func (sm *SessionManager) PeriodicChunkCleanup() {
|
|
if sm == nil || sm.logger == nil {
|
|
return
|
|
}
|
|
|
|
// Check if context is canceled or we're in test mode to prevent logging after test completion
|
|
if sm.ctx == nil || sm.ctx.Err() != nil || isTestMode() {
|
|
return // Skip logging if context is canceled or in test mode
|
|
}
|
|
|
|
sm.logger.Debug("Starting comprehensive session cleanup cycle")
|
|
|
|
cleanupStart := time.Now()
|
|
var orphanedChunks, expiredSessions, cleanupErrors int
|
|
|
|
if sm.store != nil {
|
|
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok {
|
|
// Check context again before logging
|
|
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
sm.logger.Debug("Running session store cleanup")
|
|
}
|
|
_ = cookieStore
|
|
}
|
|
}
|
|
|
|
// Cleanup expired sessions in chunk manager to prevent memory leaks
|
|
if sm.chunkManager != nil {
|
|
sm.chunkManager.CleanupExpiredSessions()
|
|
}
|
|
|
|
poolCleaned := 0
|
|
for i := 0; i < 10; i++ {
|
|
if poolSession := sm.sessionPool.Get(); poolSession != nil {
|
|
sessionData, ok := poolSession.(*SessionData)
|
|
if ok && sessionData != nil && !sessionData.inUse {
|
|
sessionData.Reset()
|
|
poolCleaned++
|
|
}
|
|
sm.sessionPool.Put(poolSession)
|
|
}
|
|
}
|
|
|
|
// Check context before final logging
|
|
if sm.ctx != nil && sm.ctx.Err() == nil && !isTestMode() {
|
|
cleanupDuration := time.Since(cleanupStart)
|
|
sm.logger.Debugf("Session cleanup completed in %v: pool_cleaned=%d, orphaned_chunks=%d, expired_sessions=%d, errors=%d",
|
|
cleanupDuration, poolCleaned, orphanedChunks, expiredSessions, cleanupErrors)
|
|
}
|
|
}
|
|
|
|
// ValidateSessionHealth performs comprehensive validation of session integrity.
|
|
// It checks authentication state, validates token formats, and detects
|
|
// potential tampering or corruption in session data.
|
|
// Parameters:
|
|
// - sessionData: The session data to validate.
|
|
//
|
|
// Returns:
|
|
// - An error describing any validation failures, nil if session is healthy.
|
|
func (sm *SessionManager) ValidateSessionHealth(sessionData *SessionData) error {
|
|
if sessionData == nil {
|
|
return fmt.Errorf("session data is nil")
|
|
}
|
|
|
|
if !sessionData.GetAuthenticated() {
|
|
return fmt.Errorf("session is not authenticated or has expired")
|
|
}
|
|
|
|
accessToken := sessionData.GetAccessToken()
|
|
refreshToken := sessionData.GetRefreshToken()
|
|
idToken := sessionData.GetIDToken()
|
|
|
|
if accessToken != "" {
|
|
if err := sm.validateTokenFormat(accessToken, "access_token"); err != nil {
|
|
return fmt.Errorf("access token validation failed: %w", err)
|
|
}
|
|
}
|
|
|
|
if refreshToken != "" {
|
|
if err := sm.validateTokenFormat(refreshToken, "refresh_token"); err != nil {
|
|
return fmt.Errorf("refresh token validation failed: %w", err)
|
|
}
|
|
}
|
|
|
|
if idToken != "" {
|
|
if err := sm.validateTokenFormat(idToken, "id_token"); err != nil {
|
|
return fmt.Errorf("ID token validation failed: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := sm.detectSessionTampering(sessionData); err != nil {
|
|
return fmt.Errorf("session tampering detected: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateTokenFormat validates the structure and format of authentication tokens.
|
|
// It checks for corruption markers, validates JWT structure if applicable,
|
|
// and ensures tokens meet format requirements.
|
|
// Parameters:
|
|
// - token: The token string to validate.
|
|
// - tokenType: The type of token being validated (for error messages).
|
|
//
|
|
// Returns:
|
|
// - An error if the token has invalid structure or exceeds size limits.
|
|
func (sm *SessionManager) validateTokenFormat(token, tokenType string) error {
|
|
if token == "" {
|
|
return nil
|
|
}
|
|
|
|
if isCorruptionMarker(token) {
|
|
return fmt.Errorf("%s contains corruption marker", tokenType)
|
|
}
|
|
|
|
if strings.Count(token, ".") == 2 {
|
|
parts := strings.Split(token, ".")
|
|
for i, part := range parts {
|
|
if part == "" {
|
|
return fmt.Errorf("%s has empty part %d in JWT format", tokenType, i)
|
|
}
|
|
if strings.ContainsAny(part, "+/=") && !strings.ContainsAny(part, "-_") {
|
|
sm.logger.Debugf("Token %s part %d uses base64 instead of base64url encoding", tokenType, i)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// detectSessionTampering checks for indicators of session tampering.
|
|
// It examines session values for path traversal attempts, XSS payloads,
|
|
// and suspicious data patterns that might indicate malicious modification.
|
|
// Parameters:
|
|
// - sessionData: The session data to examine for tampering.
|
|
//
|
|
// Returns:
|
|
// - An error if tampering is detected, nil if session appears safe.
|
|
func (sm *SessionManager) detectSessionTampering(sessionData *SessionData) error {
|
|
if sessionData.mainSession == nil {
|
|
return fmt.Errorf("main session is missing")
|
|
}
|
|
|
|
for key, value := range sessionData.mainSession.Values {
|
|
if str, ok := value.(string); ok {
|
|
if strings.Contains(str, "../") || strings.Contains(str, "..\\") {
|
|
return fmt.Errorf("potential path traversal attempt in session key %v", key)
|
|
}
|
|
if strings.Contains(str, "<script") || strings.Contains(str, "javascript:") {
|
|
return fmt.Errorf("potential XSS attempt in session key %v", key)
|
|
}
|
|
if len(str) > 10000 {
|
|
return fmt.Errorf("suspiciously long session value for key %v (length: %d)", key, len(str))
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetSessionMetrics returns metrics about session management for monitoring purposes.
|
|
// It provides information about session configuration, security settings,
|
|
// and internal state for debugging and monitoring.
|
|
// Returns:
|
|
// - A map containing session metrics and configuration information.
|
|
func (sm *SessionManager) GetSessionMetrics() map[string]interface{} {
|
|
metrics := make(map[string]interface{})
|
|
metrics["session_manager_type"] = "CookieStore"
|
|
metrics["force_https"] = sm.forceHTTPS
|
|
metrics["absolute_timeout_hours"] = sm.sessionMaxAge.Hours()
|
|
metrics["max_cookie_size"] = maxCookieSize
|
|
metrics["max_browser_cookie_size"] = maxBrowserCookieSize
|
|
|
|
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok && len(cookieStore.Codecs) > 0 {
|
|
metrics["has_encryption"] = true
|
|
metrics["codec_count"] = len(cookieStore.Codecs)
|
|
} else {
|
|
metrics["has_encryption"] = false
|
|
}
|
|
|
|
metrics["pool_implementation"] = "sync.Pool"
|
|
|
|
return metrics
|
|
}
|
|
|
|
// EnhanceSessionSecurity applies additional security measures to session options.
|
|
// It configures secure cookies, domain detection, SameSite policies, and
|
|
// adapts security settings based on request context and client characteristics.
|
|
// Parameters:
|
|
// - options: The base session options to enhance (can be nil).
|
|
// - r: The HTTP request context for security decisions.
|
|
//
|
|
// Returns:
|
|
// - Enhanced sessions.Options with additional security measures.
|
|
func (sm *SessionManager) EnhanceSessionSecurity(options *sessions.Options, r *http.Request) *sessions.Options {
|
|
if options == nil {
|
|
options = &sessions.Options{}
|
|
}
|
|
|
|
if r != nil {
|
|
userAgent := r.Header.Get("User-Agent")
|
|
if userAgent == "" {
|
|
sm.logger.Debugf("Request from %s missing User-Agent header", r.RemoteAddr)
|
|
options.MaxAge = int((sm.sessionMaxAge / 2).Seconds())
|
|
}
|
|
|
|
if r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS {
|
|
options.Secure = true
|
|
}
|
|
|
|
// Keep SameSite=Lax consistently for OAuth flows
|
|
// Removed dynamic switching based on XMLHttpRequest header to prevent redirect loop
|
|
options.SameSite = http.SameSiteLaxMode
|
|
}
|
|
|
|
options.HttpOnly = true
|
|
options.Path = "/" // Ensure cookies are available on all paths for OAuth flow
|
|
|
|
if sm.cookieDomain != "" {
|
|
options.Domain = sm.cookieDomain
|
|
sm.logger.Debugf("Using configured cookie domain: %s", sm.cookieDomain)
|
|
} else if options.Domain == "" && r != nil {
|
|
host := r.Host
|
|
|
|
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
|
host = forwardedHost
|
|
}
|
|
|
|
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
|
|
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
|
|
host = host[:colonIndex]
|
|
}
|
|
options.Domain = host
|
|
sm.logger.Debugf("Auto-detected cookie domain: %s", host)
|
|
}
|
|
}
|
|
|
|
return options
|
|
}
|
|
|
|
// getSessionOptions creates base session options with security settings.
|
|
// It configures cookie security, lifetime, path, and domain settings
|
|
// based on the HTTPS status and manager configuration.
|
|
// Parameters:
|
|
// - isSecure: Whether the request is over HTTPS or should be treated as secure.
|
|
//
|
|
// Returns:
|
|
// - A pointer to a configured sessions.Options struct.
|
|
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
|
baseOptions := &sessions.Options{
|
|
HttpOnly: true,
|
|
Secure: isSecure || sm.forceHTTPS,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: int(sm.sessionMaxAge.Seconds()),
|
|
Path: "/",
|
|
Domain: sm.cookieDomain,
|
|
}
|
|
return baseOptions
|
|
}
|
|
|
|
// CleanupOldCookies removes stale session cookies from the client browser.
|
|
// This method handles cleanup of cookies that may exist with different domain
|
|
// configurations, ensuring clean state when domain settings change.
|
|
// It removes cookies with various domain variations to ensure cleanup after configuration changes.
|
|
// Parameters:
|
|
// - w: The HTTP response writer for setting cookie deletion headers.
|
|
// - r: The HTTP request containing cookies to examine and clean up.
|
|
func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Request) {
|
|
cookies := r.Cookies()
|
|
|
|
currentDomain := sm.cookieDomain
|
|
host := r.Host
|
|
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
|
host = forwardedHost
|
|
}
|
|
if colonIndex := strings.Index(host, ":"); colonIndex != -1 {
|
|
host = host[:colonIndex]
|
|
}
|
|
|
|
// This ensures we clean up cookies from various possible domains
|
|
var domainsToClean []string
|
|
|
|
if host != "" && !strings.Contains(host, "localhost") && !strings.Contains(host, "127.0.0.1") {
|
|
domainsToClean = append(domainsToClean,
|
|
host,
|
|
"."+host,
|
|
)
|
|
|
|
parts := strings.Split(host, ".")
|
|
if len(parts) > 2 {
|
|
parentDomain := strings.Join(parts[len(parts)-2:], ".")
|
|
domainsToClean = append(domainsToClean,
|
|
parentDomain,
|
|
"."+parentDomain,
|
|
)
|
|
}
|
|
}
|
|
|
|
processedCookies := make(map[string]bool)
|
|
|
|
for _, cookie := range cookies {
|
|
// Check if cookie belongs to this middleware instance using the configured prefix
|
|
if strings.HasPrefix(cookie.Name, sm.cookiePrefix) ||
|
|
strings.HasPrefix(cookie.Name, "access_token_chunk_") ||
|
|
strings.HasPrefix(cookie.Name, "refresh_token_chunk_") {
|
|
|
|
processedCookies[cookie.Name] = true
|
|
|
|
sm.cleanupMutex.RLock()
|
|
shouldCleanup := currentDomain != "" && !sm.cleanupDone
|
|
sm.cleanupMutex.RUnlock()
|
|
|
|
if shouldCleanup {
|
|
for _, domain := range domainsToClean {
|
|
if domain == currentDomain || domain == "."+currentDomain || "."+domain == currentDomain {
|
|
continue
|
|
}
|
|
|
|
deleteCookie := &http.Cookie{
|
|
Name: cookie.Name,
|
|
Value: "",
|
|
Path: "/",
|
|
Domain: domain,
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
Secure: r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sm.forceHTTPS,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
http.SetCookie(w, deleteCookie)
|
|
sm.logger.Debugf("Attempting to clean up cookie %s with domain %s", cookie.Name, domain)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(processedCookies) > 0 {
|
|
sm.cleanupMutex.Lock()
|
|
if !sm.cleanupDone {
|
|
sm.cleanupDone = true
|
|
}
|
|
sm.cleanupMutex.Unlock()
|
|
}
|
|
}
|
|
|
|
// GetSession retrieves or creates session data from the HTTP request.
|
|
// It loads the main session and all token chunk sessions, performing validation
|
|
// and timeout checks. The returned session must be explicitly returned to the pool
|
|
// by calling returnToPoolSafely() to prevent memory leaks.
|
|
// MEMORY LEAK FIX: Session is NOT returned to pool here - caller must call ReturnToPool() when done.
|
|
// Parameters:
|
|
// - r: The HTTP request containing session cookies.
|
|
//
|
|
// Returns:
|
|
// - The loaded SessionData instance.
|
|
// - An error if session loading or validation fails.
|
|
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
|
sessionData, _ := sm.sessionPool.Get().(*SessionData) // Safe to ignore: pool return is best-effort
|
|
atomic.AddInt64(&sm.poolHits, 1)
|
|
atomic.AddInt64(&sm.activeSessions, 1)
|
|
|
|
sessionData.inUse = true
|
|
sessionData.request = r
|
|
sessionData.dirty = false
|
|
|
|
handleError := func(err error, message string) (*SessionData, error) {
|
|
if sessionData != nil {
|
|
sessionData.inUse = false
|
|
sessionData.Reset()
|
|
sm.sessionPool.Put(sessionData)
|
|
atomic.AddInt64(&sm.activeSessions, -1)
|
|
}
|
|
return nil, fmt.Errorf("%s: %w", message, err)
|
|
}
|
|
|
|
var err error
|
|
sessionData.mainSession, err = sm.store.Get(r, sm.mainCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get main session")
|
|
}
|
|
|
|
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
|
if time.Since(time.Unix(createdAt, 0)) > sm.sessionMaxAge {
|
|
_ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated
|
|
return handleError(fmt.Errorf("session timeout"), "session expired")
|
|
}
|
|
}
|
|
|
|
sessionData.accessSession, err = sm.store.Get(r, sm.accessTokenCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get access token session")
|
|
}
|
|
|
|
sessionData.refreshSession, err = sm.store.Get(r, sm.refreshTokenCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get refresh token session")
|
|
}
|
|
|
|
sessionData.idTokenSession, err = sm.store.Get(r, sm.idTokenCookieName())
|
|
if err != nil {
|
|
return handleError(err, "failed to get ID token session")
|
|
}
|
|
|
|
for k := range sessionData.accessTokenChunks {
|
|
delete(sessionData.accessTokenChunks, k)
|
|
}
|
|
for k := range sessionData.refreshTokenChunks {
|
|
delete(sessionData.refreshTokenChunks, k)
|
|
}
|
|
for k := range sessionData.idTokenChunks {
|
|
delete(sessionData.idTokenChunks, k)
|
|
}
|
|
|
|
sm.getTokenChunkSessions(r, sm.accessTokenCookieName(), sessionData.accessTokenChunks)
|
|
sm.getTokenChunkSessions(r, sm.refreshTokenCookieName(), sessionData.refreshTokenChunks)
|
|
sm.getTokenChunkSessions(r, sm.idTokenCookieName(), sessionData.idTokenChunks)
|
|
|
|
return sessionData, nil
|
|
}
|
|
|
|
// getTokenChunkSessions loads all available token chunk sessions for a given token type.
|
|
// It iterates through numbered chunk sessions until no more are found,
|
|
// populating the provided chunks map with the loaded sessions.
|
|
// Parameters:
|
|
// - r: The HTTP request containing chunk cookies.
|
|
// - baseName: The base cookie name for the token type (e.g., "_oidc_raczylo_a").
|
|
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks)
|
|
// to populate with the found session chunks.
|
|
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
|
|
for i := 0; ; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
|
session, err := sm.store.Get(r, sessionName)
|
|
if err != nil || session.IsNew {
|
|
break
|
|
}
|
|
chunks[i] = session
|
|
}
|
|
}
|
|
|
|
// SessionData represents a user's authentication session with comprehensive token management.
|
|
// It handles main session data and supports large tokens that need to be
|
|
// split across multiple cookies due to browser size limitations.
|
|
type SessionData struct {
|
|
manager *SessionManager
|
|
|
|
request *http.Request
|
|
|
|
mainSession *sessions.Session
|
|
|
|
accessSession *sessions.Session
|
|
|
|
refreshSession *sessions.Session
|
|
|
|
idTokenSession *sessions.Session
|
|
|
|
accessTokenChunks map[int]*sessions.Session
|
|
|
|
refreshTokenChunks map[int]*sessions.Session
|
|
|
|
idTokenChunks map[int]*sessions.Session
|
|
|
|
refreshMutex sync.Mutex
|
|
|
|
sessionMutex sync.RWMutex
|
|
|
|
dirty bool
|
|
|
|
inUse bool
|
|
}
|
|
|
|
// IsDirty returns true if the session data has been modified since it was last loaded or saved.
|
|
// This is used to optimize session saves by only writing when necessary.
|
|
// Returns:
|
|
// - true if the session has pending changes, false otherwise.
|
|
func (sd *SessionData) IsDirty() bool {
|
|
return sd.dirty
|
|
}
|
|
|
|
// MarkDirty marks the session as having pending changes that need to be saved.
|
|
// This is used when session data hasn't changed in content but should still
|
|
// trigger a session save (e.g., to ensure the cookie is re-issued).
|
|
func (sd *SessionData) MarkDirty() {
|
|
sd.dirty = true
|
|
}
|
|
|
|
// Save persists all session data including main session and token chunks.
|
|
// It applies security options, saves all session components, and handles
|
|
// errors gracefully by continuing to save other components even if one fails.
|
|
// Parameters:
|
|
// - r: The HTTP request context for security option configuration.
|
|
// - w: The HTTP response writer for setting session cookies.
|
|
//
|
|
// Returns:
|
|
// - An error if saving any of the session components fails.
|
|
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
|
isSecure := r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil || sd.manager.forceHTTPS
|
|
|
|
options := sd.manager.getSessionOptions(isSecure)
|
|
options = sd.manager.EnhanceSessionSecurity(options, r)
|
|
|
|
sd.mainSession.Options = options
|
|
sd.accessSession.Options = options
|
|
sd.refreshSession.Options = options
|
|
sd.idTokenSession.Options = options
|
|
|
|
var firstErr error
|
|
saveOrLogError := func(s *sessions.Session, name string) {
|
|
if s == nil {
|
|
sd.manager.logger.Errorf("Attempted to save nil session: %s", name)
|
|
if firstErr == nil {
|
|
firstErr = fmt.Errorf("attempted to save nil session: %s", name)
|
|
}
|
|
return
|
|
}
|
|
if err := s.Save(r, w); err != nil {
|
|
errMsg := fmt.Errorf("failed to save %s session: %w", name, err)
|
|
sd.manager.logger.Error("%s", errMsg.Error())
|
|
if firstErr == nil {
|
|
firstErr = errMsg
|
|
}
|
|
}
|
|
}
|
|
|
|
saveOrLogError(sd.mainSession, "main")
|
|
|
|
saveOrLogError(sd.accessSession, "access token")
|
|
|
|
saveOrLogError(sd.refreshSession, "refresh token")
|
|
|
|
saveOrLogError(sd.idTokenSession, "ID token")
|
|
|
|
for i, sessionChunk := range sd.accessTokenChunks {
|
|
sessionChunk.Options = options
|
|
saveOrLogError(sessionChunk, fmt.Sprintf("access token chunk %d", i))
|
|
}
|
|
|
|
for i, sessionChunk := range sd.refreshTokenChunks {
|
|
sessionChunk.Options = options
|
|
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
|
|
}
|
|
|
|
for i, sessionChunk := range sd.idTokenChunks {
|
|
sessionChunk.Options = options
|
|
saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i))
|
|
}
|
|
|
|
if firstErr == nil {
|
|
sd.dirty = false
|
|
}
|
|
return firstErr
|
|
}
|
|
|
|
// clearSessionValues removes all values from a session and optionally expires it.
|
|
// This is used during session cleanup and logout operations.
|
|
// Parameters:
|
|
// - session: The session to clear values from.
|
|
// - expire: If true, sets MaxAge to -1 to expire the cookie.
|
|
func clearSessionValues(session *sessions.Session, expire bool) {
|
|
if session == nil {
|
|
return
|
|
}
|
|
|
|
for k := range session.Values {
|
|
delete(session.Values, k)
|
|
}
|
|
|
|
if expire {
|
|
session.Options.MaxAge = -1
|
|
}
|
|
}
|
|
|
|
// clearAllSessionData clears all session data including main session and token chunks.
|
|
// It removes all session values and optionally expires all associated cookies.
|
|
// Parameters:
|
|
// - r: The HTTP request context (used for chunk cleanup).
|
|
// - expire: Whether to expire the cookies (set MaxAge to -1).
|
|
func (sd *SessionData) clearAllSessionData(r *http.Request, expire bool) {
|
|
clearSessionValues(sd.mainSession, expire)
|
|
clearSessionValues(sd.accessSession, expire)
|
|
clearSessionValues(sd.refreshSession, expire)
|
|
clearSessionValues(sd.idTokenSession, expire)
|
|
|
|
if expire && r != nil {
|
|
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
|
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
|
sd.clearTokenChunks(r, sd.idTokenChunks)
|
|
} else {
|
|
for k := range sd.accessTokenChunks {
|
|
delete(sd.accessTokenChunks, k)
|
|
}
|
|
for k := range sd.refreshTokenChunks {
|
|
delete(sd.refreshTokenChunks, k)
|
|
}
|
|
for k := range sd.idTokenChunks {
|
|
delete(sd.idTokenChunks, k)
|
|
}
|
|
}
|
|
|
|
if expire {
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// Clear completely clears all session data and safely returns the session to the pool.
|
|
// It removes all authentication data, expires cookies, and handles panic recovery.
|
|
// This method ensures the SessionData object is always returned to the pool.
|
|
// Parameters:
|
|
// - r: The HTTP request context.
|
|
// - w: The HTTP response writer for cookie expiration (can be nil).
|
|
//
|
|
// Returns:
|
|
// - An error if session saving fails during cleanup.
|
|
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
|
defer func() {
|
|
sd.returnToPoolSafely()
|
|
}()
|
|
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
sd.clearAllSessionData(r, true)
|
|
|
|
// This is primarily for testing - in production w will often be nil
|
|
var err error
|
|
if w != nil {
|
|
if r != nil && r.Header.Get("X-Test-Error") == "true" {
|
|
// Return a test error without trying to save problematic data
|
|
err = fmt.Errorf("test error triggered by X-Test-Error header")
|
|
} else {
|
|
err = sd.Save(r, w)
|
|
}
|
|
}
|
|
|
|
sd.request = nil
|
|
|
|
return err
|
|
}
|
|
|
|
// returnToPoolSafely safely returns the session to the object pool.
|
|
// Add thread-safe helper method to return session to pool.
|
|
// It ensures the session is marked as not in use and properly reset before pooling.
|
|
func (sd *SessionData) returnToPoolSafely() {
|
|
if sd != nil && sd.manager != nil {
|
|
if sd.inUse {
|
|
sd.inUse = false
|
|
sd.Reset()
|
|
sd.manager.sessionPool.Put(sd)
|
|
atomic.AddInt64(&sd.manager.activeSessions, -1)
|
|
}
|
|
}
|
|
}
|
|
|
|
// clearTokenChunks clears and expires all token chunk sessions.
|
|
// This is used during logout and session cleanup to ensure
|
|
// all token data is properly removed from the client.
|
|
// Parameters:
|
|
// - r: The HTTP request context.
|
|
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
|
|
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
|
for _, session := range chunks {
|
|
clearSessionValues(session, true)
|
|
}
|
|
}
|
|
|
|
// GetAuthenticated returns whether the user is currently authenticated.
|
|
// It checks both the authentication flag and session timeout.
|
|
// Returns:
|
|
// - true if the user is authenticated and the session is not expired.
|
|
// - false otherwise.
|
|
func (sd *SessionData) GetAuthenticated() bool {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
return sd.getAuthenticatedUnsafe()
|
|
}
|
|
|
|
// getAuthenticatedUnsafe checks authentication status without acquiring locks.
|
|
// Used when the mutex is already held to avoid deadlocks.
|
|
// It validates both the authentication flag and session creation time.
|
|
// Returns:
|
|
// - true if authenticated and not expired, false otherwise.
|
|
func (sd *SessionData) getAuthenticatedUnsafe() bool {
|
|
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
|
if !auth {
|
|
return false
|
|
}
|
|
|
|
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return time.Since(time.Unix(createdAt, 0)) <= sd.manager.sessionMaxAge
|
|
}
|
|
|
|
// SetAuthenticated sets the authentication status and manages session security.
|
|
// When setting to true, it generates a new secure session ID and updates timestamps.
|
|
// This prevents session fixation attacks by regenerating the session identifier.
|
|
// Parameters:
|
|
// - value: The authentication status to set.
|
|
//
|
|
// Returns:
|
|
// - An error if generating a new session ID fails when setting value to true.
|
|
func (sd *SessionData) SetAuthenticated(value bool) error {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
currentAuth := sd.getAuthenticatedUnsafe()
|
|
changed := false
|
|
|
|
if currentAuth != value {
|
|
changed = true
|
|
}
|
|
|
|
if value {
|
|
id, err := generateSecureRandomString(64)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate secure session id: %w", err)
|
|
}
|
|
|
|
maxRetries := 5
|
|
for retry := 0; retry < maxRetries; retry++ {
|
|
if sd.mainSession.ID != id {
|
|
break
|
|
}
|
|
id, err = generateSecureRandomString(64)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate secure session id on retry %d: %w", retry, err)
|
|
}
|
|
}
|
|
|
|
if sd.mainSession.ID != id {
|
|
changed = true
|
|
}
|
|
sd.mainSession.ID = id
|
|
newCreationTime := time.Now().Unix()
|
|
if oldTime, ok := sd.mainSession.Values["created_at"].(int64); !ok || oldTime != newCreationTime {
|
|
changed = true
|
|
}
|
|
sd.mainSession.Values["created_at"] = newCreationTime
|
|
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
|
|
changed = true
|
|
}
|
|
} else {
|
|
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
|
|
changed = true
|
|
}
|
|
}
|
|
|
|
sd.mainSession.Values["authenticated"] = value
|
|
if changed {
|
|
sd.dirty = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// resetSession prepares a session for reuse by clearing its state.
|
|
// This is specifically for pool reuse preparation to ensure
|
|
// no data leaks between different user sessions.
|
|
// Parameters:
|
|
// - session: The session to reset for reuse.
|
|
func resetSession(session *sessions.Session) {
|
|
if session == nil {
|
|
return
|
|
}
|
|
|
|
clearSessionValues(session, false)
|
|
|
|
session.ID = ""
|
|
session.IsNew = true
|
|
}
|
|
|
|
// Reset clears all session data and prepares the SessionData for reuse.
|
|
// It ensures no authentication data persists when the object is reused
|
|
// between different users/sessions.
|
|
func (sd *SessionData) Reset() {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
sd.clearAllSessionData(nil, false)
|
|
|
|
resetSession(sd.mainSession)
|
|
resetSession(sd.accessSession)
|
|
resetSession(sd.refreshSession)
|
|
resetSession(sd.idTokenSession)
|
|
|
|
// Clear redirect count to prevent leaking between sessions
|
|
if sd.mainSession != nil && sd.mainSession.Values != nil {
|
|
delete(sd.mainSession.Values, "redirect_count")
|
|
}
|
|
|
|
sd.dirty = false
|
|
sd.inUse = false
|
|
sd.request = nil
|
|
|
|
// Reset the refresh mutex to ensure clean state
|
|
// Note: We don't need to lock it since sessionMutex is already held
|
|
// and this session is not in use by any request
|
|
}
|
|
|
|
// ReturnToPool manually returns the session to the object pool.
|
|
// This is used in cleanup paths where Clear() is not called, to prevent memory leaks.
|
|
// It only returns the session if it's not currently in use.
|
|
func (sd *SessionData) ReturnToPool() {
|
|
if sd != nil && sd.manager != nil {
|
|
if !sd.inUse {
|
|
sd.Reset()
|
|
sd.manager.sessionPool.Put(sd)
|
|
atomic.AddInt64(&sd.manager.activeSessions, -1)
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetAccessToken retrieves the user's access token from session storage.
|
|
// It handles both single-cookie storage and chunked storage for large tokens,
|
|
// with automatic decompression if the token was compressed.
|
|
// Returns:
|
|
// - The complete, decompressed access token string, or an empty string if not found.
|
|
func (sd *SessionData) GetAccessToken() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
return sd.getAccessTokenUnsafe()
|
|
}
|
|
|
|
// getAccessTokenUnsafe retrieves the access token without acquiring locks.
|
|
// Enhanced token retrieval with comprehensive integrity checks and recovery mechanisms.
|
|
// Used when the session mutex is already held to prevent deadlocks.
|
|
// Returns:
|
|
// - The complete access token string or empty string on error.
|
|
func (sd *SessionData) getAccessTokenUnsafe() string {
|
|
token, _ := sd.accessSession.Values["token"].(string)
|
|
compressed, _ := sd.accessSession.Values["compressed"].(bool)
|
|
|
|
// Debug: Check if manager/chunkManager is nil
|
|
if sd.manager == nil || sd.manager.chunkManager == nil {
|
|
// Direct return if no chunk manager (test scenario)
|
|
return token
|
|
}
|
|
|
|
result := sd.manager.chunkManager.GetToken(
|
|
token,
|
|
compressed,
|
|
sd.accessTokenChunks,
|
|
AccessTokenConfig,
|
|
)
|
|
|
|
if result.Error != nil {
|
|
// Check if we have a raw token available
|
|
// This handles cases where the token exists but doesn't validate as JWT
|
|
if token != "" && !compressed && len(sd.accessTokenChunks) == 0 {
|
|
// We have a non-chunked, non-compressed token that failed validation
|
|
// Check if it's an opaque token (doesn't have JWT structure)
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
// This is likely an opaque token that failed JWT validation
|
|
// Return the raw token as-is since opaque tokens are valid
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Returning opaque access token (dots: %d) despite validation error: %v", dotCount, result.Error)
|
|
}
|
|
return token
|
|
}
|
|
}
|
|
|
|
// For JWT validation errors or other issues, log and return empty
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("ChunkManager.GetToken error: %v", result.Error)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
return result.Token
|
|
}
|
|
|
|
// SetAccessToken stores an access token with automatic compression and chunking.
|
|
// It validates token format, compresses if beneficial, and splits into chunks
|
|
// if the token exceeds cookie size limits. Includes integrity verification.
|
|
// Parameters:
|
|
// - token: The access token string to store.
|
|
func (sd *SessionData) SetAccessToken(token string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
if token != "" {
|
|
dotCount := strings.Count(token, ".")
|
|
// Reject tokens with exactly 1 dot (invalid format - neither JWT nor opaque)
|
|
if dotCount == 1 {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debug("Invalid token format during storage (dots: %d) - rejecting", dotCount)
|
|
}
|
|
return
|
|
}
|
|
// For opaque tokens (no dots), ensure minimum length for security
|
|
if dotCount == 0 && len(token) < 20 {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
|
|
}
|
|
return
|
|
}
|
|
// Tokens with 2 dots are JWTs, tokens with 0 dots are opaque
|
|
// Both are valid formats
|
|
}
|
|
|
|
currentAccessToken := sd.getAccessTokenUnsafe()
|
|
if constantTimeStringCompare(currentAccessToken, token) {
|
|
return
|
|
}
|
|
sd.dirty = true
|
|
|
|
// Debug: Check if accessSession is properly initialized
|
|
if sd.accessSession == nil {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: accessSession is nil when trying to store token")
|
|
}
|
|
return
|
|
}
|
|
|
|
if sd.request != nil {
|
|
sd.expireAccessTokenChunksEnhanced(nil)
|
|
}
|
|
|
|
for k := range sd.accessTokenChunks {
|
|
delete(sd.accessTokenChunks, k)
|
|
}
|
|
|
|
if token == "" {
|
|
if sd.accessSession != nil {
|
|
sd.accessSession.Values["token"] = ""
|
|
sd.accessSession.Values["compressed"] = false
|
|
}
|
|
return
|
|
}
|
|
|
|
compressed := compressToken(token)
|
|
|
|
// Debug for test
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Token compression: original %d bytes, compressed %d bytes", len(token), len(compressed))
|
|
}
|
|
|
|
if len(compressed) > 100*1024 {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Info("Access token too large after compression (%d bytes) - storing uncompressed", len(compressed))
|
|
}
|
|
return
|
|
}
|
|
|
|
if compressed != token {
|
|
testDecompressed := decompressToken(compressed)
|
|
if testDecompressed != token {
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debug("Access token compression verification failed - storing uncompressed")
|
|
}
|
|
compressed = token
|
|
}
|
|
}
|
|
|
|
if len(compressed) <= maxCookieSize {
|
|
if sd.accessSession != nil {
|
|
sd.accessSession.Values["token"] = compressed
|
|
sd.accessSession.Values["compressed"] = (compressed != token)
|
|
// Debug for test
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Stored token in session: compressed=%v, token_len=%d",
|
|
compressed != token, len(compressed))
|
|
}
|
|
}
|
|
} else {
|
|
if sd.accessSession != nil {
|
|
sd.accessSession.Values["token"] = ""
|
|
sd.accessSession.Values["compressed"] = (compressed != token)
|
|
}
|
|
|
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
|
|
|
if len(chunks) == 0 {
|
|
sd.manager.logger.Error("Failed to create chunks for access token")
|
|
return
|
|
}
|
|
|
|
if len(chunks) > 50 {
|
|
sd.manager.logger.Info("Too many chunks (%d) for access token", len(chunks))
|
|
return
|
|
}
|
|
|
|
testReassembled := strings.Join(chunks, "")
|
|
if testReassembled != compressed {
|
|
sd.manager.logger.Debug("Access token chunk reassembly test failed")
|
|
return
|
|
}
|
|
|
|
for i, chunkData := range chunks {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
|
|
|
|
if sd.request == nil {
|
|
sd.manager.logger.Error("SetAccessToken: sd.request is nil, cannot create chunk session %s", sessionName)
|
|
return
|
|
}
|
|
|
|
if chunkData == "" {
|
|
sd.manager.logger.Debug("Empty chunk data at index %d", i)
|
|
return
|
|
}
|
|
|
|
if len(chunkData) > maxCookieSize {
|
|
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
|
|
return
|
|
}
|
|
|
|
if !validateChunkSize(chunkData) {
|
|
sd.manager.logger.Errorf("CRITICAL: Chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
|
|
return
|
|
}
|
|
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
|
|
return
|
|
}
|
|
|
|
session.Values["token_chunk"] = chunkData
|
|
session.Values["compressed"] = (compressed != token)
|
|
session.Values["chunk_created_at"] = time.Now().Unix()
|
|
sd.accessTokenChunks[i] = session
|
|
}
|
|
|
|
sd.manager.logger.Debugf("SUCCESS: Stored access token in %d chunks", len(chunks))
|
|
}
|
|
}
|
|
|
|
// GetRefreshToken retrieves the user's refresh token from session storage.
|
|
// It handles both single-cookie storage and chunked storage for large tokens,
|
|
// with automatic decompression if the token was compressed.
|
|
// Returns:
|
|
// - The complete, decompressed refresh token string, or an empty string if not found.
|
|
func (sd *SessionData) GetRefreshToken() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
token, _ := sd.refreshSession.Values["token"].(string)
|
|
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
|
|
|
result := sd.manager.chunkManager.GetToken(
|
|
token,
|
|
compressed,
|
|
sd.refreshTokenChunks,
|
|
RefreshTokenConfig,
|
|
)
|
|
|
|
if result.Error != nil {
|
|
// Check if we have a raw token available
|
|
// This handles cases where the token exists but doesn't validate as JWT
|
|
if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 {
|
|
// We have a non-chunked, non-compressed token that failed validation
|
|
// Check if it's an opaque token (doesn't have JWT structure)
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
// This is likely an opaque token that failed JWT validation
|
|
// Return the raw token as-is since opaque tokens are valid
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("Returning opaque refresh token (dots: %d) despite validation error: %v", dotCount, result.Error)
|
|
}
|
|
return token
|
|
}
|
|
}
|
|
// For JWT validation errors or other issues, log and return empty
|
|
if sd.manager != nil && sd.manager.logger != nil {
|
|
sd.manager.logger.Debugf("ChunkManager.GetToken error for refresh token: %v", result.Error)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
return result.Token
|
|
}
|
|
|
|
// SetRefreshToken stores a refresh token with automatic compression and chunking.
|
|
// It validates token size, compresses if beneficial, and splits into chunks
|
|
// if needed. Includes comprehensive error checking and integrity verification.
|
|
// Parameters:
|
|
// - token: The refresh token string to store.
|
|
func (sd *SessionData) SetRefreshToken(token string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
if len(token) > 50*1024 {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token too large (%d bytes) - possible corruption, rejecting", len(token))
|
|
return
|
|
}
|
|
|
|
// Get current refresh token without mutex to avoid deadlock since we already hold the lock
|
|
var currentRefreshToken string
|
|
sessionToken, _ := sd.refreshSession.Values["token"].(string)
|
|
if sessionToken != "" {
|
|
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
|
if compressed {
|
|
decompressed := decompressToken(sessionToken)
|
|
currentRefreshToken = decompressed
|
|
} else {
|
|
currentRefreshToken = sessionToken
|
|
}
|
|
} else if len(sd.refreshTokenChunks) > 0 {
|
|
// Simplified chunked token retrieval for deadlock prevention
|
|
var chunks []string
|
|
for i := 0; i < len(sd.refreshTokenChunks); i++ {
|
|
if session, ok := sd.refreshTokenChunks[i]; ok {
|
|
if chunk, chunkOk := session.Values["token_chunk"].(string); chunkOk && chunk != "" {
|
|
chunks = append(chunks, chunk)
|
|
}
|
|
}
|
|
}
|
|
if len(chunks) == len(sd.refreshTokenChunks) {
|
|
reassembled := strings.Join(chunks, "")
|
|
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
|
if compressed {
|
|
currentRefreshToken = decompressToken(reassembled)
|
|
} else {
|
|
currentRefreshToken = reassembled
|
|
}
|
|
}
|
|
}
|
|
if constantTimeStringCompare(currentRefreshToken, token) {
|
|
return
|
|
}
|
|
sd.dirty = true
|
|
|
|
if sd.request != nil {
|
|
sd.expireRefreshTokenChunksEnhanced(nil)
|
|
}
|
|
|
|
for k := range sd.refreshTokenChunks {
|
|
delete(sd.refreshTokenChunks, k)
|
|
}
|
|
|
|
if token == "" {
|
|
sd.refreshSession.Values["token"] = ""
|
|
sd.refreshSession.Values["compressed"] = false
|
|
return
|
|
}
|
|
|
|
compressed := compressToken(token)
|
|
|
|
if compressed != token {
|
|
testDecompressed := decompressToken(compressed)
|
|
if testDecompressed != token {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token compression verification failed - storing uncompressed")
|
|
compressed = token
|
|
}
|
|
}
|
|
|
|
if len(compressed) <= maxCookieSize {
|
|
sd.refreshSession.Values["token"] = compressed
|
|
sd.refreshSession.Values["compressed"] = (compressed != token)
|
|
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
|
|
} else {
|
|
sd.refreshSession.Values["token"] = ""
|
|
sd.refreshSession.Values["compressed"] = (compressed != token)
|
|
sd.refreshSession.Values["issued_at"] = time.Now().Unix()
|
|
|
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
|
|
|
if len(chunks) == 0 {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for refresh token")
|
|
return
|
|
}
|
|
|
|
if len(chunks) > 50 {
|
|
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for refresh token - possible corruption", len(chunks))
|
|
return
|
|
}
|
|
|
|
testReassembled := strings.Join(chunks, "")
|
|
if testReassembled != compressed {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk reassembly test failed")
|
|
return
|
|
}
|
|
|
|
for i, chunkData := range chunks {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
|
|
|
|
if sd.request == nil {
|
|
sd.manager.logger.Errorf("CRITICAL: SetRefreshToken: sd.request is nil, cannot create chunk session %s", sessionName)
|
|
return
|
|
}
|
|
|
|
if chunkData == "" {
|
|
sd.manager.logger.Errorf("CRITICAL: Empty refresh token chunk data at index %d", i)
|
|
return
|
|
}
|
|
|
|
if len(chunkData) > maxCookieSize {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
|
|
return
|
|
}
|
|
|
|
if !validateChunkSize(chunkData) {
|
|
sd.manager.logger.Errorf("CRITICAL: Refresh token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
|
|
return
|
|
}
|
|
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to get refresh token chunk session %s: %v", sessionName, err)
|
|
return
|
|
}
|
|
|
|
session.Values["token_chunk"] = chunkData
|
|
session.Values["compressed"] = (compressed != token)
|
|
session.Values["chunk_created_at"] = time.Now().Unix()
|
|
sd.refreshTokenChunks[i] = session
|
|
}
|
|
|
|
sd.manager.logger.Debugf("SUCCESS: Stored refresh token in %d chunks", len(chunks))
|
|
}
|
|
}
|
|
|
|
// GetRefreshTokenIssuedAt retrieves the timestamp when the refresh token was issued/stored.
|
|
// Returns the time when the current refresh token was obtained, or zero time if not available.
|
|
func (sd *SessionData) GetRefreshTokenIssuedAt() time.Time {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
if issuedAtUnix, ok := sd.refreshSession.Values["issued_at"].(int64); ok {
|
|
return time.Unix(issuedAtUnix, 0)
|
|
}
|
|
|
|
// For chunked tokens, check the first chunk for timestamp
|
|
if len(sd.refreshTokenChunks) > 0 {
|
|
if session, exists := sd.refreshTokenChunks[0]; exists {
|
|
if chunkCreatedAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
return time.Unix(chunkCreatedAt, 0)
|
|
}
|
|
}
|
|
}
|
|
|
|
return time.Time{}
|
|
}
|
|
|
|
// expireAccessTokenChunksEnhanced expires all access token chunks and detects orphaned chunks.
|
|
// It searches for all existing chunks, identifies orphaned or expired chunks,
|
|
// and properly expires them to prevent cookie accumulation.
|
|
// Parameters:
|
|
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
|
func (sd *SessionData) expireAccessTokenChunksEnhanced(w http.ResponseWriter) {
|
|
const maxChunkSearchLimit = 50
|
|
orphanedChunks := 0
|
|
|
|
for i := 0; i < maxChunkSearchLimit; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if session.IsNew {
|
|
break
|
|
}
|
|
|
|
if chunk, exists := session.Values["token_chunk"]; exists {
|
|
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
chunkAge := time.Since(time.Unix(createdAt, 0))
|
|
if chunkAge > 24*time.Hour {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found orphaned access token chunk %d (age: %v)", i, chunkAge)
|
|
}
|
|
} else if chunk != nil {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found access token chunk %d without timestamp, treating as orphaned", i)
|
|
}
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
session.Values = make(map[interface{}]interface{})
|
|
if w != nil {
|
|
if err := session.Save(sd.request, w); err != nil {
|
|
sd.manager.logger.Errorf("failed to save expired access token chunk %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if orphanedChunks > 0 {
|
|
sd.manager.logger.Infof("Cleaned up %d orphaned access token chunks", orphanedChunks)
|
|
}
|
|
}
|
|
|
|
// expireRefreshTokenChunksEnhanced expires all refresh token chunks and detects orphaned chunks.
|
|
// It searches for all existing chunks, identifies orphaned or expired chunks,
|
|
// and properly expires them to prevent cookie accumulation.
|
|
// Parameters:
|
|
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
|
func (sd *SessionData) expireRefreshTokenChunksEnhanced(w http.ResponseWriter) {
|
|
const maxChunkSearchLimit = 50
|
|
orphanedChunks := 0
|
|
|
|
for i := 0; i < maxChunkSearchLimit; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if session.IsNew {
|
|
break
|
|
}
|
|
|
|
if chunk, exists := session.Values["token_chunk"]; exists {
|
|
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
chunkAge := time.Since(time.Unix(createdAt, 0))
|
|
if chunkAge > 24*time.Hour {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found orphaned refresh token chunk %d (age: %v)", i, chunkAge)
|
|
}
|
|
} else if chunk != nil {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found refresh token chunk %d without timestamp, treating as orphaned", i)
|
|
}
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
session.Values = make(map[interface{}]interface{})
|
|
if w != nil {
|
|
if err := session.Save(sd.request, w); err != nil {
|
|
sd.manager.logger.Errorf("failed to save expired refresh token chunk %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if orphanedChunks > 0 {
|
|
sd.manager.logger.Infof("Cleaned up %d orphaned refresh token chunks", orphanedChunks)
|
|
}
|
|
}
|
|
|
|
// expireIDTokenChunksEnhanced expires all ID token chunks and detects orphaned chunks.
|
|
// It searches for all existing chunks, identifies orphaned or expired chunks,
|
|
// and properly expires them to prevent cookie accumulation.
|
|
// Parameters:
|
|
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
|
func (sd *SessionData) expireIDTokenChunksEnhanced(w http.ResponseWriter) {
|
|
const maxChunkSearchLimit = 50
|
|
orphanedChunks := 0
|
|
|
|
for i := 0; i < maxChunkSearchLimit; i++ {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if session.IsNew {
|
|
break
|
|
}
|
|
|
|
if chunk, exists := session.Values["token_chunk"]; exists {
|
|
if createdAt, ok := session.Values["chunk_created_at"].(int64); ok {
|
|
chunkAge := time.Since(time.Unix(createdAt, 0))
|
|
if chunkAge > 24*time.Hour {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found orphaned ID token chunk %d (age: %v)", i, chunkAge)
|
|
}
|
|
} else if chunk != nil {
|
|
orphanedChunks++
|
|
sd.manager.logger.Debugf("Found ID token chunk %d without timestamp, treating as orphaned", i)
|
|
}
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
session.Values = make(map[interface{}]interface{})
|
|
if w != nil {
|
|
if err := session.Save(sd.request, w); err != nil {
|
|
sd.manager.logger.Errorf("failed to save expired ID token chunk %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if orphanedChunks > 0 {
|
|
sd.manager.logger.Infof("Cleaned up %d orphaned ID token chunks", orphanedChunks)
|
|
}
|
|
}
|
|
|
|
// splitIntoChunks divides a string into chunks of specified maximum size.
|
|
// It ensures chunks don't exceed browser cookie limits and handles
|
|
// the string splitting logic for large token storage.
|
|
// Parameters:
|
|
// - s: The string to split into chunks.
|
|
// - chunkSize: The maximum size for each chunk.
|
|
//
|
|
// Returns:
|
|
// - A slice of strings representing the chunks.
|
|
func splitIntoChunks(s string, chunkSize int) []string {
|
|
effectiveChunkSize := min(chunkSize, maxCookieSize)
|
|
|
|
var chunks []string
|
|
for len(s) > 0 {
|
|
if len(s) > effectiveChunkSize {
|
|
chunks = append(chunks, s[:effectiveChunkSize])
|
|
s = s[effectiveChunkSize:]
|
|
} else {
|
|
chunks = append(chunks, s)
|
|
break
|
|
}
|
|
}
|
|
return chunks
|
|
}
|
|
|
|
// validateChunkSize checks if a chunk will fit within browser cookie limits.
|
|
// It estimates the encoded size including cookie overhead and headers
|
|
// to ensure the chunk won't exceed browser-imposed cookie size limits.
|
|
// Parameters:
|
|
// - chunkData: The chunk data to validate.
|
|
//
|
|
// Returns:
|
|
// - true if the chunk is safe to store, false if it may exceed browser limits.
|
|
func validateChunkSize(chunkData string) bool {
|
|
estimatedEncodedSize := len(chunkData) + (len(chunkData) * 50 / 100)
|
|
return estimatedEncodedSize <= maxBrowserCookieSize
|
|
}
|
|
|
|
// isCorruptionMarker detects if data contains known corruption indicators.
|
|
// It checks for specific corruption markers and invalid characters
|
|
// that indicate the data has been tampered with or corrupted.
|
|
// Parameters:
|
|
// - data: The data string to check for corruption markers.
|
|
//
|
|
// Returns:
|
|
// - true if the data contains corruption markers, false otherwise.
|
|
func isCorruptionMarker(data string) bool {
|
|
if data == "" {
|
|
return false
|
|
}
|
|
|
|
corruptionMarkers := []string{
|
|
"__CORRUPTION_MARKER_TEST__",
|
|
"__INVALID_BASE64_DATA__",
|
|
"__CORRUPTED_CHUNK_DATA__",
|
|
"!@#$%^&*()",
|
|
"<<<CORRUPTED>>>",
|
|
}
|
|
|
|
for _, marker := range corruptionMarkers {
|
|
if data == marker {
|
|
return true
|
|
}
|
|
}
|
|
|
|
if len(data) > 10 {
|
|
invalidChars := "!@#$%^&*(){}[]|\\:;\"'<>?,`~"
|
|
for _, char := range invalidChars {
|
|
if strings.ContainsRune(data, char) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// GetCSRF retrieves the CSRF token for state validation.
|
|
// This token is used to prevent cross-site request forgery attacks
|
|
// during the OIDC authentication flow.
|
|
// Returns:
|
|
// - The CSRF token string, or an empty string if not set.
|
|
func (sd *SessionData) GetCSRF() string {
|
|
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
|
return csrf
|
|
}
|
|
|
|
// SetCSRF stores the CSRF token for state validation.
|
|
// The token is used to validate the state parameter in OAuth callbacks.
|
|
// Parameters:
|
|
// - token: The CSRF token to store.
|
|
func (sd *SessionData) SetCSRF(token string) {
|
|
currentVal, _ := sd.mainSession.Values["csrf"].(string)
|
|
if currentVal != token {
|
|
sd.mainSession.Values["csrf"] = token
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetNonce retrieves the nonce for ID token validation.
|
|
// The nonce prevents replay attacks by ensuring ID tokens
|
|
// were issued in response to the specific authentication request.
|
|
// Returns:
|
|
// - The nonce string, or an empty string if not set.
|
|
func (sd *SessionData) GetNonce() string {
|
|
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
|
return nonce
|
|
}
|
|
|
|
// SetNonce stores the nonce for ID token validation.
|
|
// The nonce will be validated against the nonce claim in received ID tokens.
|
|
// Parameters:
|
|
// - nonce: The nonce string to store.
|
|
func (sd *SessionData) SetNonce(nonce string) {
|
|
currentVal, _ := sd.mainSession.Values["nonce"].(string)
|
|
if currentVal != nonce {
|
|
sd.mainSession.Values["nonce"] = nonce
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetCodeVerifier retrieves the PKCE code verifier.
|
|
// This is used in the PKCE (Proof Key for Code Exchange) flow
|
|
// to enhance security for public clients.
|
|
// Returns:
|
|
// - The code verifier string, or an empty string if not set or PKCE is disabled.
|
|
func (sd *SessionData) GetCodeVerifier() string {
|
|
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
|
return codeVerifier
|
|
}
|
|
|
|
// SetCodeVerifier stores the PKCE code verifier.
|
|
// The code verifier is used to generate the code challenge sent to the
|
|
// authorization server and validated during token exchange.
|
|
// Parameters:
|
|
// - codeVerifier: The PKCE code verifier string to store.
|
|
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
|
currentVal, _ := sd.mainSession.Values["code_verifier"].(string)
|
|
if currentVal != codeVerifier {
|
|
sd.mainSession.Values["code_verifier"] = codeVerifier
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetEmail retrieves the authenticated user's email address.
|
|
// The email is extracted from ID token claims and used for
|
|
// authorization decisions and header injection.
|
|
// Returns:
|
|
// - The user's email address string, or an empty string if not set.
|
|
func (sd *SessionData) GetEmail() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
email, _ := sd.mainSession.Values["email"].(string)
|
|
return email
|
|
}
|
|
|
|
// SetEmail stores the authenticated user's email address.
|
|
// The email is typically extracted from the 'email' claim in the ID token.
|
|
// Parameters:
|
|
// - email: The user's email address to store.
|
|
func (sd *SessionData) SetEmail(email string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
currentVal, _ := sd.mainSession.Values["email"].(string)
|
|
if currentVal != email {
|
|
sd.mainSession.Values["email"] = email
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetIncomingPath retrieves the original request URI that triggered authentication.
|
|
// This path is used to redirect the user back to their intended destination
|
|
// after successful authentication.
|
|
// Returns:
|
|
// - The original request URI string, or an empty string if not set.
|
|
func (sd *SessionData) GetIncomingPath() string {
|
|
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
|
return path
|
|
}
|
|
|
|
// SetIncomingPath stores the original request URI for post-authentication redirect.
|
|
// This allows the user to be redirected to their originally requested resource
|
|
// after completing the authentication flow.
|
|
// Parameters:
|
|
// - path: The original request URI string (e.g., "/protected/resource?id=123").
|
|
func (sd *SessionData) SetIncomingPath(path string) {
|
|
currentVal, _ := sd.mainSession.Values["incoming_path"].(string)
|
|
if currentVal != path {
|
|
sd.mainSession.Values["incoming_path"] = path
|
|
sd.dirty = true
|
|
}
|
|
}
|
|
|
|
// GetIDToken retrieves the user's ID token from session storage.
|
|
// The ID token contains user claims and is used for user identification
|
|
// and authorization decisions. Handles compression and chunking automatically.
|
|
// Returns:
|
|
// - The complete, decompressed ID token string, or an empty string if not found.
|
|
func (sd *SessionData) GetIDToken() string {
|
|
sd.sessionMutex.RLock()
|
|
defer sd.sessionMutex.RUnlock()
|
|
|
|
return sd.getIDTokenUnsafe()
|
|
}
|
|
|
|
// getIDTokenUnsafe retrieves the ID token without acquiring locks.
|
|
// Enhanced ID token retrieval with comprehensive integrity checks and chunking support.
|
|
// Used when the session mutex is already held to prevent deadlocks.
|
|
// Returns:
|
|
// - The complete ID token string or empty string on error.
|
|
func (sd *SessionData) getIDTokenUnsafe() string {
|
|
token, _ := sd.idTokenSession.Values["token"].(string)
|
|
compressed, _ := sd.idTokenSession.Values["compressed"].(bool)
|
|
|
|
// Debug: Check if manager/chunkManager is nil
|
|
if sd.manager == nil || sd.manager.chunkManager == nil {
|
|
// Direct return if no chunk manager (test scenario)
|
|
return token
|
|
}
|
|
|
|
result := sd.manager.chunkManager.GetToken(
|
|
token,
|
|
compressed,
|
|
sd.idTokenChunks,
|
|
IDTokenConfig,
|
|
)
|
|
|
|
if result.Error != nil {
|
|
return ""
|
|
}
|
|
|
|
return result.Token
|
|
}
|
|
|
|
// SetIDToken stores an ID token with automatic compression and chunking.
|
|
// It validates the JWT format, compresses if beneficial, and splits into chunks
|
|
// if the token exceeds cookie size limits. Includes comprehensive validation.
|
|
// Parameters:
|
|
// - token: The ID token string to store.
|
|
func (sd *SessionData) SetIDToken(token string) {
|
|
sd.sessionMutex.Lock()
|
|
defer sd.sessionMutex.Unlock()
|
|
|
|
if token != "" {
|
|
dotCount := strings.Count(token, ".")
|
|
if dotCount != 2 {
|
|
sd.manager.logger.Errorf("CRITICAL: Attempt to store invalid JWT ID token format (dots: %d) - rejecting", dotCount)
|
|
return
|
|
}
|
|
}
|
|
|
|
if len(token) > 50*1024 {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token too large (%d bytes) - possible corruption, rejecting", len(token))
|
|
return
|
|
}
|
|
currentIDToken := sd.getIDTokenUnsafe()
|
|
if constantTimeStringCompare(currentIDToken, token) {
|
|
return
|
|
}
|
|
sd.dirty = true
|
|
|
|
if sd.request != nil {
|
|
sd.expireIDTokenChunksEnhanced(nil)
|
|
}
|
|
|
|
for k := range sd.idTokenChunks {
|
|
delete(sd.idTokenChunks, k)
|
|
}
|
|
|
|
if token == "" {
|
|
if sd.idTokenSession != nil {
|
|
sd.idTokenSession.Values["token"] = ""
|
|
sd.idTokenSession.Values["compressed"] = false
|
|
}
|
|
return
|
|
}
|
|
|
|
compressed := compressToken(token)
|
|
|
|
if compressed != token {
|
|
testDecompressed := decompressToken(compressed)
|
|
if testDecompressed != token {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token compression verification failed - storing uncompressed")
|
|
compressed = token
|
|
}
|
|
}
|
|
|
|
if len(compressed) <= maxCookieSize {
|
|
if sd.idTokenSession != nil {
|
|
sd.idTokenSession.Values["token"] = compressed
|
|
sd.idTokenSession.Values["compressed"] = (compressed != token)
|
|
}
|
|
} else {
|
|
if sd.idTokenSession != nil {
|
|
sd.idTokenSession.Values["token"] = ""
|
|
sd.idTokenSession.Values["compressed"] = (compressed != token)
|
|
}
|
|
|
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
|
|
|
if len(chunks) == 0 {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to create chunks for ID token")
|
|
return
|
|
}
|
|
|
|
if len(chunks) > 50 {
|
|
sd.manager.logger.Errorf("CRITICAL: Too many chunks (%d) for ID token - possible corruption", len(chunks))
|
|
return
|
|
}
|
|
|
|
testReassembled := strings.Join(chunks, "")
|
|
if testReassembled != compressed {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token chunk reassembly test failed")
|
|
return
|
|
}
|
|
|
|
for i, chunkData := range chunks {
|
|
sessionName := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
|
|
|
|
if sd.request == nil {
|
|
sd.manager.logger.Errorf("CRITICAL: SetIDToken: sd.request is nil, cannot create chunk session %s", sessionName)
|
|
return
|
|
}
|
|
|
|
if chunkData == "" {
|
|
sd.manager.logger.Debug("Empty chunk data at index %d", i)
|
|
return
|
|
}
|
|
|
|
if len(chunkData) > maxCookieSize {
|
|
sd.manager.logger.Info("Chunk %d size %d exceeds maxCookieSize %d", i, len(chunkData), maxCookieSize)
|
|
return
|
|
}
|
|
|
|
if !validateChunkSize(chunkData) {
|
|
sd.manager.logger.Errorf("CRITICAL: ID token chunk %d will exceed browser cookie limits after encoding (raw size: %d)", i, len(chunkData))
|
|
return
|
|
}
|
|
|
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
if err != nil {
|
|
sd.manager.logger.Errorf("CRITICAL: Failed to get chunk session %s: %v", sessionName, err)
|
|
return
|
|
}
|
|
|
|
session.Values["token_chunk"] = chunkData
|
|
session.Values["compressed"] = (compressed != token)
|
|
session.Values["chunk_created_at"] = time.Now().Unix()
|
|
sd.idTokenChunks[i] = session
|
|
}
|
|
|
|
sd.manager.logger.Debugf("SUCCESS: Stored ID token in %d chunks", len(chunks))
|
|
}
|
|
}
|
|
|
|
// GetRedirectCount returns the number of redirects in the current authentication flow.
|
|
// STABILITY FIX: Prevents infinite redirect loops by tracking redirect attempts.
|
|
// Returns:
|
|
// - The current redirect count, 0 if not set.
|
|
func (sd *SessionData) GetRedirectCount() int {
|
|
if count, ok := sd.mainSession.Values["redirect_count"].(int); ok {
|
|
return count
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// IncrementRedirectCount increases the redirect counter by one.
|
|
// STABILITY FIX: Prevents infinite redirect loops by tracking successive redirects.
|
|
// Used to detect potential redirect loops and abort authentication if too many occur.
|
|
func (sd *SessionData) IncrementRedirectCount() {
|
|
currentCount := sd.GetRedirectCount()
|
|
sd.mainSession.Values["redirect_count"] = currentCount + 1
|
|
sd.dirty = true
|
|
}
|
|
|
|
// ResetRedirectCount resets the redirect counter to zero.
|
|
// STABILITY FIX: Prevents infinite redirect loops by clearing the counter
|
|
// when authentication completes successfully or when starting a new flow.
|
|
func (sd *SessionData) ResetRedirectCount() {
|
|
sd.mainSession.Values["redirect_count"] = 0
|
|
sd.dirty = true
|
|
}
|