From 751933ffa0cdc19fc7b1acaaf6943c5fd9c5270e Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sat, 1 Feb 2025 12:16:50 +0000 Subject: [PATCH] Multiple improvements. * Add todo list. * fixup! Add todo list. * fixup! fixup! Add todo list. * fixup! fixup! fixup! Add todo list. * Improve the session handling and cache. * Fix an issue where expired session can cause infinite redirect loop * fixup! Fix an issue where expired session can cause infinite redirect loop * Add semver setup for automatic releases. * fixup! Add semver setup for automatic releases. * fixup! fixup! Add semver setup for automatic releases. * fixup! fixup! fixup! Add semver setup for automatic releases. --- TODO.txt | 4 + cache.go | 179 ++++++++++++++++------------------- helpers.go | 20 ---- main.go | 19 ++-- main_test.go | 8 +- semver.yaml | 10 ++ session.go | 241 ++++++++++++++++++++---------------------------- session_test.go | 10 +- 8 files changed, 207 insertions(+), 284 deletions(-) create mode 100644 semver.yaml diff --git a/TODO.txt b/TODO.txt index e69de29..07038e0 100644 --- a/TODO.txt +++ b/TODO.txt @@ -0,0 +1,4 @@ +## TODO / wishlist + +- [x] Improve caching mechanism +- [x] Add automatic release and semver generation \ No newline at end of file diff --git a/cache.go b/cache.go index 73794eb..2ae81e3 100644 --- a/cache.go +++ b/cache.go @@ -1,172 +1,153 @@ package traefikoidc import ( + "container/list" "sync" "time" ) // CacheItem represents an item stored in the cache with its associated metadata. type CacheItem struct { - // Value is the cached data of any type + // Value is the cached data of any type. Value interface{} - // ExpiresAt is the timestamp when this item should be considered expired - // and removed from the cache during cleanup operations + // ExpiresAt is the timestamp when this item should be considered expired. ExpiresAt time.Time } -// Cache provides a thread-safe in-memory caching mechanism with expiration support. -// It uses a read-write mutex to ensure safe concurrent access to the cached items. -type Cache struct { - // items stores the cached data with string keys - items map[string]CacheItem - - // mutex protects concurrent access to the items map - // Use RLock/RUnlock for reads and Lock/Unlock for writes - mutex sync.RWMutex - - // maxSize is the maximum number of items allowed in the cache - maxSize int - - // accessList maintains the order of item access for eviction - accessList []string +// lruEntry represents an entry in the LRU list. +type lruEntry struct { + key string } -// DefaultMaxSize is the default maximum number of items in the cache +// Cache provides a thread-safe in-memory caching mechanism with expiration support. +// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency. +type Cache struct { + // items stores the cached data with string keys. + items map[string]CacheItem + + // order maintains the usage order; most recently used items are at the back. + order *list.List + + // elems maps keys to their corresponding list elements for O(1) access. + elems map[string]*list.Element + + // mutex protects concurrent access to the cache. + mutex sync.RWMutex + + // maxSize is the maximum number of items allowed in the cache. + maxSize int +} + +// DefaultMaxSize is the default maximum number of items in the cache. const DefaultMaxSize = 1000 -// NewCache creates a new empty cache instance. -// The cache is immediately ready for use and is thread-safe. +// NewCache creates a new empty cache instance that is ready for use. func NewCache() *Cache { return &Cache{ - items: make(map[string]CacheItem), - maxSize: DefaultMaxSize, - accessList: make([]string, 0, DefaultMaxSize), + items: make(map[string]CacheItem, DefaultMaxSize), + order: list.New(), + elems: make(map[string]*list.Element, DefaultMaxSize), + maxSize: DefaultMaxSize, } } // Set adds or updates an item in the cache with the specified expiration duration. -// Parameters: -// - key: Unique identifier for the cached item -// - value: The data to cache (can be of any type) -// - expiration: How long the item should remain in the cache -// -// Thread-safe: Uses write locking to ensure safe concurrent access. +// It moves the item to the most recently used position. func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() - // If key exists, update it + now := time.Now() + expTime := now.Add(expiration) + + // Update existing item. if _, exists := c.items[key]; exists { c.items[key] = CacheItem{ Value: value, - ExpiresAt: time.Now().Add(expiration), + ExpiresAt: expTime, + } + if elem, ok := c.elems[key]; ok { + c.order.MoveToBack(elem) } return } - // If cache is full, remove oldest item + // Evict oldest item if cache is full. if len(c.items) >= c.maxSize { c.evictOldest() } - // Add new item + // Add new item. c.items[key] = CacheItem{ Value: value, - ExpiresAt: time.Now().Add(expiration), + ExpiresAt: expTime, } - c.accessList = append(c.accessList, key) + elem := c.order.PushBack(lruEntry{key: key}) + c.elems[key] = elem } // Get retrieves an item from the cache if it exists and hasn't expired. -// Parameters: -// - key: The identifier of the item to retrieve -// -// Returns: -// - value: The cached data (nil if not found or expired) -// - found: true if the item was found and is valid, false otherwise -// -// Thread-safe: Uses read locking to ensure safe concurrent access. +// Moving the accessed item to the most recently used position. func (c *Cache) Get(key string) (interface{}, bool) { - c.mutex.RLock() - item, found := c.items[key] - c.mutex.RUnlock() - - if !found { - return nil, false - } - - if time.Now().After(item.ExpiresAt) { - c.mutex.Lock() - c.removeItem(key) - c.mutex.Unlock() - return nil, false - } - - // Update access order c.mutex.Lock() - c.updateAccessOrder(key) - c.mutex.Unlock() + defer c.mutex.Unlock() + + item, exists := c.items[key] + if !exists { + return nil, false + } + + // Check for expiration. + if time.Now().After(item.ExpiresAt) { + c.removeItem(key) + return nil, false + } + + // Move item to the back (most recently used). + if elem, ok := c.elems[key]; ok { + c.order.MoveToBack(elem) + } return item.Value, true } -// Delete removes an item from the cache if it exists. -// If the item doesn't exist, this operation is a no-op. -// Thread-safe: Uses write locking to ensure safe concurrent access. +// Delete removes an item from the cache. func (c *Cache) Delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() - delete(c.items, key) + + c.removeItem(key) } -// Cleanup removes all expired items from the cache. -// This should be called periodically to prevent memory leaks from -// expired items that haven't been accessed (and thus not removed during Get operations). -// Thread-safe: Uses write locking to ensure safe concurrent access. +// Cleanup removes all expired items from the cache. This should be called periodically +// to prevent memory bloat from expired entries. func (c *Cache) Cleanup() { c.mutex.Lock() defer c.mutex.Unlock() now := time.Now() - var newAccessList []string - - for _, key := range c.accessList { - if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) { - newAccessList = append(newAccessList, key) - } else { - delete(c.items, key) + for key, item := range c.items { + if now.After(item.ExpiresAt) { + c.removeItem(key) } } - - c.accessList = newAccessList } -// evictOldest removes the least recently used item from the cache +// evictOldest removes the least recently used item from the cache. func (c *Cache) evictOldest() { - if len(c.accessList) > 0 { - oldest := c.accessList[0] - c.removeItem(oldest) + elem := c.order.Front() + if elem != nil { + entry := elem.Value.(lruEntry) + c.removeItem(entry.key) } } -// removeItem removes an item from both the cache and access list +// removeItem removes an item from both the cache and the LRU tracking structures. func (c *Cache) removeItem(key string) { delete(c.items, key) - for i, k := range c.accessList { - if k == key { - c.accessList = append(c.accessList[:i], c.accessList[i+1:]...) - break - } - } -} - -// updateAccessOrder moves the accessed key to the end of the access list -func (c *Cache) updateAccessOrder(key string) { - for i, k := range c.accessList { - if k == key { - c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key) - break - } + if elem, ok := c.elems[key]; ok { + c.order.Remove(elem) + delete(c.elems, key) } } diff --git a/helpers.go b/helpers.go index d66294a..4475d3e 100644 --- a/helpers.go +++ b/helpers.go @@ -12,28 +12,8 @@ import ( "strings" "sync" "time" - - "github.com/gorilla/sessions" ) -// newSessionOptions creates secure session cookie options. -// Parameters: -// - isSecure: Whether to set the Secure flag on cookies -// -// Returns session options configured for security with: -// - HttpOnly flag to prevent JavaScript access -// - SameSite=Lax for CSRF protection -// - Appropriate timeout and path settings -func newSessionOptions(isSecure bool) *sessions.Options { - return &sessions.Options{ - HttpOnly: true, - Secure: isSecure, - SameSite: http.SameSiteLaxMode, - MaxAge: ConstSessionTimeout, - Path: "/", - } -} - // generateNonce creates a cryptographically secure random nonce // for use in the OIDC authentication flow. The nonce is used to // prevent replay attacks by ensuring the token received matches diff --git a/main.go b/main.go index 9d5fbff..8ee3a25 100644 --- a/main.go +++ b/main.go @@ -62,7 +62,6 @@ type TraefikOidc struct { extractClaimsFunc func(tokenString string) (map[string]interface{}, error) initComplete chan struct{} endSessionURL string - baseURL string postLogoutRedirectURI string sessionManager *SessionManager } @@ -82,8 +81,6 @@ var defaultExcludedURLs = map[string]struct{}{ "/favicon": {}, } -var newTicker = time.NewTicker - // VerifyToken verifies the provided JWT token func (t *TraefikOidc) VerifyToken(token string) error { t.logger.Debugf("Verifying token") @@ -264,7 +261,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h // Assign the initialized logger t.logger = logger - t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) + t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.extractClaimsFunc = extractClaims t.exchangeCodeForTokenFunc = t.exchangeCodeForToken t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { @@ -531,9 +528,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { // determineScheme determines the scheme (http or https) of the request func (t *TraefikOidc) determineScheme(req *http.Request) string { - if t.forceHTTPS { - return "https" - } if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" { return scheme } @@ -602,14 +596,17 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo // defaultInitiateAuthentication initiates the authentication process func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { // Generate CSRF token and nonce - csrfToken := uuid.New().String() + csrfToken := uuid.NewString() nonce, err := generateNonce() if err != nil { http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError) return } - // Set session values + // Clear any existing session data to avoid stale state causing redirect loops + session.Clear(req, rw) + + // Set new session values session.SetCSRF(csrfToken) session.SetNonce(nonce) session.SetIncomingPath(req.URL.RequestURI()) @@ -621,7 +618,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req return } - // Build and redirect to auth URL + // Build and redirect to authentication URL authURL := t.buildAuthURL(redirectURL, csrfToken, nonce) http.Redirect(rw, req, authURL, http.StatusFound) } @@ -647,7 +644,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { // startTokenCleanup starts the token cleanup goroutine func (t *TraefikOidc) startTokenCleanup() { - ticker := newTicker(1 * time.Minute) + ticker := time.NewTicker(1 * time.Minute) go func() { for range ticker.C { t.logger.Debug("Cleaning up token cache") diff --git a/main_test.go b/main_test.go index 03be49b..9427e91 100644 --- a/main_test.go +++ b/main_test.go @@ -89,7 +89,7 @@ func (ts *TestSuite) Setup() { } logger := NewLogger("info") - ts.sessionManager = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) + ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) // Common TraefikOidc instance ts.tOidc = &TraefikOidc{ @@ -619,7 +619,7 @@ func TestHandleCallback(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") - sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) + sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) // Create a new instance for each test to avoid state carryover tOidc := &TraefikOidc{ @@ -924,7 +924,7 @@ func TestHandleLogout(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") - sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) + sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) tOidc := &TraefikOidc{ revocationURL: mockRevocationServer.URL, endSessionURL: tc.endSessionURL, @@ -1213,7 +1213,7 @@ func TestHandleExpiredToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") - sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) + sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) tOidc := &TraefikOidc{ sessionManager: sessionManager, diff --git a/semver.yaml b/semver.yaml new file mode 100644 index 0000000..0e21b52 --- /dev/null +++ b/semver.yaml @@ -0,0 +1,10 @@ +version: 1 +force: + existing: true +wording: + patch: + - patch-release + minor: + - minor-release + major: + - breaking diff --git a/session.go b/session.go index e1abc61..1c4b728 100644 --- a/session.go +++ b/session.go @@ -16,13 +16,14 @@ import ( "github.com/gorilla/sessions" ) -// generateSecureRandomString creates a cryptographically secure random string of specified length -func generateSecureRandomString(length int) string { +// generateSecureRandomString creates a cryptographically secure random string of specified length. +// It returns the generated string or an error if random generation fails. +func generateSecureRandomString(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { - panic("failed to generate random string") + return "", fmt.Errorf("failed to generate random bytes: %w", err) } - return hex.EncodeToString(bytes) + return hex.EncodeToString(bytes), nil } // Cookie names and configuration constants used for session management @@ -55,7 +56,7 @@ const ( minEncryptionKeyLength = 32 ) -// compressToken compresses a token using gzip and base64 encodes it +// compressToken compresses a token using gzip and base64 encodes it. func compressToken(token string) string { var b bytes.Buffer gz := gzip.NewWriter(&b) @@ -68,7 +69,7 @@ func compressToken(token string) string { return base64.StdEncoding.EncodeToString(b.Bytes()) } -// decompressToken decompresses a base64 encoded gzipped token +// decompressToken decompresses a base64 encoded gzipped token. func decompressToken(compressed string) string { data, err := base64.StdEncoding.DecodeString(compressed) if err != nil { @@ -91,18 +92,18 @@ func decompressToken(compressed string) string { // SessionManager handles the management of multiple session cookies for OIDC authentication. // It provides functionality for storing and retrieving authentication state, tokens, -// and other session-related data across multiple cookies to handle large tokens. +// and other session-related data across multiple cookies. type SessionManager struct { - // store is the underlying session store for cookie management + // store is the underlying session store for cookie management. store sessions.Store - // forceHTTPS enforces secure cookie attributes regardless of request scheme + // forceHTTPS enforces secure cookie attributes regardless of request scheme. forceHTTPS bool - // logger provides structured logging capabilities + // logger provides structured logging capabilities. logger *Logger - // sessionPool is a sync.Pool for reusing SessionData objects + // sessionPool is a sync.Pool for reusing SessionData objects. sessionPool sync.Pool } @@ -112,11 +113,11 @@ type SessionManager struct { // - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme // - logger: Logger instance for recording session-related events // -// The manager handles session creation, storage, and cookie security settings. -func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager { - // Validate encryption key length +// Returns an error if the encryption key does not meet minimum length requirements. +func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) { + // Validate encryption key length. if len(encryptionKey) < minEncryptionKeyLength { - panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength)) + return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) } sm := &SessionManager{ @@ -125,7 +126,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S logger: logger, } - // Initialize session pool + // Initialize session pool. sm.sessionPool.New = func() interface{} { return &SessionData{ manager: sm, @@ -134,12 +135,12 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S } } - return sm + return sm, nil } // getSessionOptions returns secure session options configured for the current request. // Parameters: -// - isSecure: Whether the current request is using HTTPS +// - isSecure: Whether the current request is using HTTPS. // // The options ensure cookies are: // - HTTP-only (not accessible via JavaScript) @@ -161,7 +162,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { // and combines them into a single SessionData structure for easy access. // Returns an error if any session component cannot be loaded. func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { - // Get session from pool + // Get session from pool. sessionData := sm.sessionPool.Get().(*SessionData) sessionData.request = r @@ -172,11 +173,10 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { return nil, fmt.Errorf("failed to get main session: %w", err) } - // Check for absolute session timeout + // Check for absolute session timeout. if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok { if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout { - // Session has expired - sm.sessionPool.Put(sessionData) + sessionData.Clear(r, nil) return nil, fmt.Errorf("session expired") } } @@ -193,7 +193,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { return nil, fmt.Errorf("failed to get refresh token session: %w", err) } - // Clear and reuse chunk maps + // Clear and reuse chunk maps. for k := range sessionData.accessTokenChunks { delete(sessionData.accessTokenChunks, k) } @@ -201,7 +201,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { delete(sessionData.refreshTokenChunks, k) } - // Retrieve chunked token sessions + // Retrieve chunked token sessions. sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks) sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks) @@ -218,7 +218,6 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string sessionName := fmt.Sprintf("%s_%d", baseName, i) session, err := sm.store.Get(r, sessionName) if err != nil || session.IsNew { - // No more sessions break } chunks[i] = session @@ -230,27 +229,27 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string // and potentially large access and refresh tokens that may need to be // split across multiple cookies due to browser size limitations. type SessionData struct { - // manager is the SessionManager that created this SessionData + // manager is the SessionManager that created this SessionData. manager *SessionManager - // request is the current HTTP request associated with this session + // request is the current HTTP request associated with this session. request *http.Request - // mainSession stores authentication state and basic user info + // mainSession stores authentication state and basic user info. mainSession *sessions.Session - // accessSession stores the primary access token cookie + // accessSession stores the primary access token cookie. accessSession *sessions.Session - // refreshSession stores the primary refresh token cookie + // refreshSession stores the primary refresh token cookie. refreshSession *sessions.Session // accessTokenChunks stores additional chunks of the access token - // when it exceeds the maximum cookie size + // when it exceeds the maximum cookie size. accessTokenChunks map[int]*sessions.Session // refreshTokenChunks stores additional chunks of the refresh token - // when it exceeds the maximum cookie size + // when it exceeds the maximum cookie size. refreshTokenChunks map[int]*sessions.Session } @@ -261,28 +260,28 @@ type SessionData struct { func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS - // Set options for all sessions + // Set options for all sessions. options := sd.manager.getSessionOptions(isSecure) sd.mainSession.Options = options sd.accessSession.Options = options sd.refreshSession.Options = options - // Save main session + // Save main session. if err := sd.mainSession.Save(r, w); err != nil { return fmt.Errorf("failed to save main session: %w", err) } - // Save access token session + // Save access token session. if err := sd.accessSession.Save(r, w); err != nil { return fmt.Errorf("failed to save access token session: %w", err) } - // Save refresh token session + // Save refresh token session. if err := sd.refreshSession.Save(r, w); err != nil { return fmt.Errorf("failed to save refresh token session: %w", err) } - // Save access token chunks + // Save access token chunks. for _, session := range sd.accessTokenChunks { session.Options = options if err := session.Save(r, w); err != nil { @@ -290,7 +289,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { } } - // Save refresh token chunks + // Save refresh token chunks. for _, session := range sd.refreshTokenChunks { session.Options = options if err := session.Save(r, w); err != nil { @@ -302,10 +301,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { } // Clear removes all session data by expiring all cookies and clearing their values. -// This is typically used during logout to ensure all session data is properly cleaned up. -// It handles both main session data and any token chunks that may exist. func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { - // Clear and expire all sessions + // Clear and expire all sessions. sd.mainSession.Options.MaxAge = -1 sd.accessSession.Options.MaxAge = -1 sd.refreshSession.Options.MaxAge = -1 @@ -320,7 +317,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { delete(sd.refreshSession.Values, k) } - // Clear chunk sessions + // Clear chunk sessions. sd.clearTokenChunks(r, sd.accessTokenChunks) sd.clearTokenChunks(r, sd.refreshTokenChunks) @@ -329,15 +326,13 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { err = sd.Save(r, w) } - // Return session to pool + // Return session to pool. sd.manager.sessionPool.Put(sd) return err } // clearTokenChunks removes all session chunks for a given token type. -// It expires the cookies and removes all stored values to ensure -// no token data remains after logout or token invalidation. func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) { for _, session := range chunks { session.Options.MaxAge = -1 @@ -348,15 +343,13 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session } // GetAuthenticated returns whether the current session is authenticated. -// Returns true if the user has successfully completed OIDC authentication -// and the session hasn't expired, false otherwise. func (sd *SessionData) GetAuthenticated() bool { auth, _ := sd.mainSession.Values["authenticated"].(bool) if !auth { return false } - // Check session expiration + // Check session expiration. createdAt, ok := sd.mainSession.Values["created_at"].(int64) if !ok { return false @@ -365,21 +358,21 @@ func (sd *SessionData) GetAuthenticated() bool { } // SetAuthenticated updates the session's authentication status and rotates session ID. -// This should be called after successful OIDC authentication or during logout. -// Session ID rotation helps prevent session fixation attacks. -func (sd *SessionData) SetAuthenticated(value bool) { +// Returns an error if generating a new session ID fails. +func (sd *SessionData) SetAuthenticated(value bool) error { if value { - // Generate new session ID and set creation time - sd.mainSession.ID = generateSecureRandomString(32) + id, err := generateSecureRandomString(32) + if err != nil { + return fmt.Errorf("failed to generate secure session id: %w", err) + } + sd.mainSession.ID = id sd.mainSession.Values["created_at"] = time.Now().Unix() } sd.mainSession.Values["authenticated"] = value + return nil } // GetAccessToken retrieves the complete access token from the session. -// If the token was split into chunks due to size limitations, it will -// automatically reassemble the complete token from all chunks. -// Returns an empty string if no token is found. func (sd *SessionData) GetAccessToken() string { token, _ := sd.accessSession.Values["token"].(string) if token != "" { @@ -390,7 +383,7 @@ func (sd *SessionData) GetAccessToken() string { return token } - // Reassemble token from chunks + // Reassemble token from chunks. if len(sd.accessTokenChunks) == 0 { return "" } @@ -414,45 +407,23 @@ func (sd *SessionData) GetAccessToken() string { } // SetAccessToken stores the access token in the session. -// If the token exceeds maxCookieSize, it is automatically compressed and split into -// multiple cookie chunks to handle large tokens while staying within -// browser cookie size limits. Any existing token or chunks are cleared -// before setting the new token. -// expireAccessTokenChunks expires any existing access token chunk cookies -func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) { - for i := 0; ; i++ { - sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i) - session, err := sd.manager.store.Get(sd.request, sessionName) - if err != nil || session.IsNew { - break - } - // Expire the cookie - session.Options.MaxAge = -1 - session.Values = make(map[interface{}]interface{}) - // Save expired cookie - if err := session.Save(sd.request, w); err != nil { - sd.manager.logger.Errorf("Failed to save expired cookie: %v", err) - } - } -} - func (sd *SessionData) SetAccessToken(token string) { - // Expire any existing chunk cookies first + // Expire any existing chunk cookies first. if sd.request != nil { - sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called + sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called. } - // Clear and prepare chunks map for new token + // Clear and prepare chunks map for new token. sd.accessTokenChunks = make(map[int]*sessions.Session) - // Compress token + // Compress token. compressed := compressToken(token) if len(compressed) <= maxCookieSize { sd.accessSession.Values["token"] = compressed sd.accessSession.Values["compressed"] = true } else { - // Split compressed token into chunks + // Split compressed token into chunks. sd.accessSession.Values["token"] = "" sd.accessSession.Values["compressed"] = true chunks := splitIntoChunks(compressed, maxCookieSize) @@ -466,9 +437,6 @@ func (sd *SessionData) SetAccessToken(token string) { } // GetRefreshToken retrieves the complete refresh token from the session. -// If the token was split into chunks due to size limitations, it will -// automatically reassemble the complete token from all chunks. -// Returns an empty string if no token is found. func (sd *SessionData) GetRefreshToken() string { token, _ := sd.refreshSession.Values["token"].(string) if token != "" { @@ -479,7 +447,7 @@ func (sd *SessionData) GetRefreshToken() string { return token } - // Reassemble token from chunks + // Reassemble token from chunks. if len(sd.refreshTokenChunks) == 0 { return "" } @@ -503,45 +471,23 @@ func (sd *SessionData) GetRefreshToken() string { } // SetRefreshToken stores the refresh token in the session. -// If the token exceeds maxCookieSize, it is automatically compressed and split into -// multiple cookie chunks to handle large tokens while staying within -// browser cookie size limits. Any existing token or chunks are cleared -// before setting the new token. -// expireRefreshTokenChunks expires any existing refresh token chunk cookies -func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) { - for i := 0; ; i++ { - sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i) - session, err := sd.manager.store.Get(sd.request, sessionName) - if err != nil || session.IsNew { - break - } - // Expire the cookie - session.Options.MaxAge = -1 - session.Values = make(map[interface{}]interface{}) - // Save expired cookie - if err := session.Save(sd.request, w); err != nil { - sd.manager.logger.Errorf("Failed to save expired cookie: %v", err) - } - } -} - func (sd *SessionData) SetRefreshToken(token string) { - // Expire any existing chunk cookies first + // Expire any existing chunk cookies first. if sd.request != nil { - sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called + sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called. } - // Clear and prepare chunks map for new token + // Clear and prepare chunks map for new token. sd.refreshTokenChunks = make(map[int]*sessions.Session) - // Compress token + // Compress token. compressed := compressToken(token) if len(compressed) <= maxCookieSize { sd.refreshSession.Values["token"] = compressed sd.refreshSession.Values["compressed"] = true } else { - // Split compressed token into chunks + // Split compressed token into chunks. sd.refreshSession.Values["token"] = "" sd.refreshSession.Values["compressed"] = true chunks := splitIntoChunks(compressed, maxCookieSize) @@ -554,13 +500,43 @@ func (sd *SessionData) SetRefreshToken(token string) { } } +// expireAccessTokenChunks expires any existing access token chunk cookies. +func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) { + for i := 0; ; i++ { + sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i) + session, err := sd.manager.store.Get(sd.request, sessionName) + if err != nil || session.IsNew { + break + } + session.Options.MaxAge = -1 + session.Values = make(map[interface{}]interface{}) + if w != nil { + if err := session.Save(sd.request, w); err != nil { + sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err) + } + } + } +} + +// expireRefreshTokenChunks expires any existing refresh token chunk cookies. +func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) { + for i := 0; ; i++ { + sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i) + session, err := sd.manager.store.Get(sd.request, sessionName) + if err != nil || session.IsNew { + break + } + session.Options.MaxAge = -1 + session.Values = make(map[interface{}]interface{}) + if w != nil { + if err := session.Save(sd.request, w); err != nil { + sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err) + } + } + } +} + // splitIntoChunks splits a string into chunks of specified size. -// This is used internally to handle large tokens that exceed cookie size limits. -// Parameters: -// - s: The string to split -// - chunkSize: Maximum size of each chunk -// -// Returns an array of string chunks, each no larger than chunkSize. func splitIntoChunks(s string, chunkSize int) []string { var chunks []string for len(s) > 0 { @@ -576,64 +552,45 @@ func splitIntoChunks(s string, chunkSize int) []string { } // GetCSRF retrieves the CSRF token from the session. -// This token is used to prevent cross-site request forgery attacks -// by ensuring requests originate from the authenticated user. -// Returns an empty string if no CSRF token is found. func (sd *SessionData) GetCSRF() string { csrf, _ := sd.mainSession.Values["csrf"].(string) return csrf } // SetCSRF stores a new CSRF token in the session. -// This should be called when initiating authentication to generate -// a new token for the authentication flow. func (sd *SessionData) SetCSRF(token string) { sd.mainSession.Values["csrf"] = token } // GetNonce retrieves the nonce value from the session. -// The nonce is used to prevent replay attacks in the OIDC flow -// by ensuring the token received matches the authentication request. -// Returns an empty string if no nonce is found. func (sd *SessionData) GetNonce() string { nonce, _ := sd.mainSession.Values["nonce"].(string) return nonce } // SetNonce stores a new nonce value in the session. -// This should be called when initiating authentication to generate -// a new nonce for the OIDC authentication flow. func (sd *SessionData) SetNonce(nonce string) { sd.mainSession.Values["nonce"] = nonce } // GetEmail retrieves the authenticated user's email address from the session. -// The email is typically extracted from the OIDC ID token claims. -// Returns an empty string if no email is found. func (sd *SessionData) GetEmail() string { email, _ := sd.mainSession.Values["email"].(string) return email } // SetEmail stores the user's email address in the session. -// This should be called after successful authentication when -// processing the OIDC ID token claims. func (sd *SessionData) SetEmail(email string) { sd.mainSession.Values["email"] = email } -// GetIncomingPath retrieves the original request path that triggered -// the authentication flow. This is used to redirect the user back -// to their intended destination after successful authentication. -// Returns an empty string if no path was stored. +// GetIncomingPath retrieves the original request path that triggered the authentication flow. func (sd *SessionData) GetIncomingPath() string { path, _ := sd.mainSession.Values["incoming_path"].(string) return path } -// SetIncomingPath stores the original request path that triggered -// the authentication flow. This should be called before redirecting -// to the OIDC provider to remember where to send the user afterward. +// SetIncomingPath stores the original request path that triggered the authentication flow. func (sd *SessionData) SetIncomingPath(path string) { sd.mainSession.Values["incoming_path"] = path } diff --git a/session_test.go b/session_test.go index 601809b..6ea172a 100644 --- a/session_test.go +++ b/session_test.go @@ -5,14 +5,8 @@ import ( "net/http/httptest" "strings" "testing" - "time" ) -func init() { - // Initialize random seed - rand.Seed(time.Now().UnixNano()) -} - // generateRandomString creates a random string of specified length func generateRandomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" @@ -83,7 +77,7 @@ func TestCookiePrefix(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() - sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug")) + sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug")) session, err := sm.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) @@ -117,7 +111,7 @@ func TestTokenRefreshCleanup(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() - sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug")) + sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug")) session, err := sm.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err)