From 81000a824d4095e6e788345bc28ab39adf476fe2 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 7 May 2025 02:33:34 +0100 Subject: [PATCH] Fix dirty session handling. --- main.go | 13 +++- session.go | 221 ++++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 186 insertions(+), 48 deletions(-) diff --git a/main.go b/main.go index 5c0c568..a94a6fa 100644 --- a/main.go +++ b/main.go @@ -919,15 +919,22 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http req.Header.Set(headerName, headerValue) t.logger.Debugf("Set templated header %s = %s", headerName, headerValue) } + // Mark session as dirty after processing templated headers to ensure cookie is re-issued + session.MarkDirty() + t.logger.Debugf("Session marked dirty after templated header processing.") } } // Always save session after processing claims and before proceeding // This is especially important for opaque tokens where we need to ensure // authentication state and user information are preserved - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save session after processing headers: %v", err) - // Continue anyway since we have valid tokens + if session.IsDirty() { + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save session after processing headers: %v", err) + // Continue anyway since we have valid tokens + } + } else { + t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest") } // Set security headers diff --git a/session.go b/session.go index de41b4d..4919c92 100644 --- a/session.go +++ b/session.go @@ -156,6 +156,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (* accessTokenChunks: make(map[int]*sessions.Session), refreshTokenChunks: make(map[int]*sessions.Session), refreshMutex: sync.Mutex{}, // Initialize the mutex + dirty: false, // Initialize dirty flag } } @@ -189,6 +190,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { // Get session from pool. sessionData := sm.sessionPool.Get().(*SessionData) sessionData.request = r + sessionData.dirty = false // Reset dirty flag when getting a session var err error sessionData.mainSession, err = sm.store.Get(r, mainCookieName) @@ -281,6 +283,21 @@ type SessionData struct { // refreshMutex protects refresh token operations within this session instance. refreshMutex sync.Mutex + + // dirty indicates whether the session data has changed and needs to be saved. + dirty bool +} + +// IsDirty returns true if the session data has been modified since it was last loaded or saved. +func (sd *SessionData) IsDirty() bool { + return sd.dirty +} + +// MarkDirty explicitly sets the dirty flag to true. +// This can be used when an operation doesn't change session data +// but should still trigger a session save (e.g., to ensure the cookie is re-issued). +func (sd *SessionData) MarkDirty() { + sd.dirty = true } // Save persists all parts of the session (main, access token, refresh token, and any chunks) @@ -302,38 +319,50 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { sd.accessSession.Options = options sd.refreshSession.Options = options - // Save main session. - if err := sd.mainSession.Save(r, w); err != nil { - return fmt.Errorf("failed to save main session: %w", err) + var firstErr error + // Helper to record first error and log subsequent ones + saveOrLogError := func(s *sessions.Session, name string) { + if s == nil { // Should not happen if initialized correctly + sd.manager.logger.Errorf("Attempted to save nil session: %s", name) + if firstErr == nil { + firstErr = fmt.Errorf("attempted to save nil session: %s", name) + } + return + } + if err := s.Save(r, w); err != nil { + errMsg := fmt.Errorf("failed to save %s session: %w", name, err) + sd.manager.logger.Error(errMsg.Error()) + if firstErr == nil { + firstErr = errMsg + } + } } + // Save main session. + saveOrLogError(sd.mainSession, "main") + // Save access token session. - if err := sd.accessSession.Save(r, w); err != nil { - return fmt.Errorf("failed to save access token session: %w", err) - } + saveOrLogError(sd.accessSession, "access token") // Save refresh token session. - if err := sd.refreshSession.Save(r, w); err != nil { - return fmt.Errorf("failed to save refresh token session: %w", err) - } + saveOrLogError(sd.refreshSession, "refresh token") // Save access token chunks. - for _, session := range sd.accessTokenChunks { - session.Options = options - if err := session.Save(r, w); err != nil { - return fmt.Errorf("failed to save access token chunk session: %w", err) - } + for i, sessionChunk := range sd.accessTokenChunks { + sessionChunk.Options = options + saveOrLogError(sessionChunk, fmt.Sprintf("access token chunk %d", i)) } // Save refresh token chunks. - for _, session := range sd.refreshTokenChunks { - session.Options = options - if err := session.Save(r, w); err != nil { - return fmt.Errorf("failed to save refresh token chunk session: %w", err) - } + for i, sessionChunk := range sd.refreshTokenChunks { + sessionChunk.Options = options + saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i)) } - return nil + if firstErr == nil { + sd.dirty = false // Reset dirty flag only if all saves were successful + } + return firstErr } // Clear removes all session data associated with this SessionData instance. @@ -350,19 +379,26 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { // Returns: // - An error if saving the expired sessions fails (only if w is not nil). func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { - // Clear and expire all sessions. - sd.mainSession.Options.MaxAge = -1 - sd.accessSession.Options.MaxAge = -1 - sd.refreshSession.Options.MaxAge = -1 + sd.dirty = true // Clearing the session means its state is changing and needs to be saved. - for k := range sd.mainSession.Values { - delete(sd.mainSession.Values, k) + // Clear and expire all sessions. + if sd.mainSession != nil { + sd.mainSession.Options.MaxAge = -1 + for k := range sd.mainSession.Values { + delete(sd.mainSession.Values, k) + } } - for k := range sd.accessSession.Values { - delete(sd.accessSession.Values, k) + if sd.accessSession != nil { + sd.accessSession.Options.MaxAge = -1 + for k := range sd.accessSession.Values { + delete(sd.accessSession.Values, k) + } } - for k := range sd.refreshSession.Values { - delete(sd.refreshSession.Values, k) + if sd.refreshSession != nil { + sd.refreshSession.Options.MaxAge = -1 + for k := range sd.refreshSession.Values { + delete(sd.refreshSession.Values, k) + } } // Clear chunk sessions. @@ -428,15 +464,44 @@ func (sd *SessionData) GetAuthenticated() bool { // Returns: // - An error if generating a new session ID fails when setting value to true. func (sd *SessionData) SetAuthenticated(value bool) error { + currentAuth := sd.GetAuthenticated() // This checks flag and expiry + changed := false + + if currentAuth != value { + changed = true + } + if value { + // If we are setting to true, and either it wasn't true before, + // or if the session ID needs regeneration (e.g. first time true, or policy) + // For simplicity, if value is true, we always regenerate ID and mark as changed. + // This ensures session ID regeneration is always saved. id, err := generateSecureRandomString(32) if err != nil { return fmt.Errorf("failed to generate secure session id: %w", err) } + if sd.mainSession.ID != id { // ID actually changed + changed = true + } sd.mainSession.ID = id - sd.mainSession.Values["created_at"] = time.Now().Unix() + newCreationTime := time.Now().Unix() + if oldTime, ok := sd.mainSession.Values["created_at"].(int64); !ok || oldTime != newCreationTime { + changed = true + } + sd.mainSession.Values["created_at"] = newCreationTime + if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value { + changed = true + } + } else { // value is false + if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value { + changed = true + } } + sd.mainSession.Values["authenticated"] = value + if changed { + sd.dirty = true + } return nil } @@ -488,6 +553,14 @@ func (sd *SessionData) GetAccessToken() string { // Parameters: // - token: The access token string to store. func (sd *SessionData) SetAccessToken(token string) { + currentAccessToken := sd.GetAccessToken() + if currentAccessToken == token { + // If token is empty, and current is also empty, it's not a change. + // This check handles both empty and non-empty identical cases. + return + } + sd.dirty = true + // Expire any existing chunk cookies first. if sd.request != nil { sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called. @@ -496,6 +569,13 @@ func (sd *SessionData) SetAccessToken(token string) { // Clear and prepare chunks map for new token. sd.accessTokenChunks = make(map[int]*sessions.Session) + if token == "" { // Clearing the token + sd.accessSession.Values["token"] = "" + sd.accessSession.Values["compressed"] = false + // sd.accessTokenChunks is already cleared + return + } + // Compress token. compressed := compressToken(token) @@ -504,13 +584,19 @@ func (sd *SessionData) SetAccessToken(token string) { sd.accessSession.Values["compressed"] = true } else { // Split compressed token into chunks. - sd.accessSession.Values["token"] = "" - sd.accessSession.Values["compressed"] = true + sd.accessSession.Values["token"] = "" // Main cookie won't hold the token directly + sd.accessSession.Values["compressed"] = true // Data in chunks is compressed chunks := splitIntoChunks(compressed, maxCookieSize) - for i, chunk := range chunks { + for i, chunkData := range chunks { sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i) + // Ensure sd.request is available, otherwise log warning or handle error + if sd.request == nil { + sd.manager.logger.Infof("SetAccessToken: sd.request is nil, cannot get/create chunk session %s", sessionName) + // Potentially skip this chunk or error out, depending on desired robustness + continue + } session, _ := sd.manager.store.Get(sd.request, sessionName) - session.Values["token_chunk"] = chunk + session.Values["token_chunk"] = chunkData sd.accessTokenChunks[i] = session } } @@ -564,6 +650,12 @@ func (sd *SessionData) GetRefreshToken() string { // Parameters: // - token: The refresh token string to store. func (sd *SessionData) SetRefreshToken(token string) { + currentRefreshToken := sd.GetRefreshToken() + if currentRefreshToken == token { + return + } + sd.dirty = true + // Expire any existing chunk cookies first. if sd.request != nil { sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called. @@ -572,6 +664,13 @@ func (sd *SessionData) SetRefreshToken(token string) { // Clear and prepare chunks map for new token. sd.refreshTokenChunks = make(map[int]*sessions.Session) + if token == "" { // Clearing the token + sd.refreshSession.Values["token"] = "" + sd.refreshSession.Values["compressed"] = false + // sd.refreshTokenChunks is already cleared + return + } + // Compress token. compressed := compressToken(token) @@ -580,13 +679,17 @@ func (sd *SessionData) SetRefreshToken(token string) { sd.refreshSession.Values["compressed"] = true } else { // Split compressed token into chunks. - sd.refreshSession.Values["token"] = "" - sd.refreshSession.Values["compressed"] = true + sd.refreshSession.Values["token"] = "" // Main cookie won't hold the token directly + sd.refreshSession.Values["compressed"] = true // Data in chunks is compressed chunks := splitIntoChunks(compressed, maxCookieSize) - for i, chunk := range chunks { + for i, chunkData := range chunks { sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i) + if sd.request == nil { + sd.manager.logger.Infof("SetRefreshToken: sd.request is nil, cannot get/create chunk session %s", sessionName) + continue + } session, _ := sd.manager.store.Get(sd.request, sessionName) - session.Values["token_chunk"] = chunk + session.Values["token_chunk"] = chunkData sd.refreshTokenChunks[i] = session } } @@ -678,7 +781,11 @@ func (sd *SessionData) GetCSRF() string { // Parameters: // - token: The CSRF token to store. func (sd *SessionData) SetCSRF(token string) { - sd.mainSession.Values["csrf"] = token + currentVal, _ := sd.mainSession.Values["csrf"].(string) + if currentVal != token { + sd.mainSession.Values["csrf"] = token + sd.dirty = true + } } // GetNonce retrieves the OIDC nonce value stored in the main session. @@ -697,7 +804,11 @@ func (sd *SessionData) GetNonce() string { // Parameters: // - nonce: The nonce string to store. func (sd *SessionData) SetNonce(nonce string) { - sd.mainSession.Values["nonce"] = nonce + currentVal, _ := sd.mainSession.Values["nonce"].(string) + if currentVal != nonce { + sd.mainSession.Values["nonce"] = nonce + sd.dirty = true + } } // GetCodeVerifier retrieves the PKCE (Proof Key for Code Exchange) code verifier @@ -716,7 +827,11 @@ func (sd *SessionData) GetCodeVerifier() string { // Parameters: // - codeVerifier: The PKCE code verifier string to store. func (sd *SessionData) SetCodeVerifier(codeVerifier string) { - sd.mainSession.Values["code_verifier"] = codeVerifier + currentVal, _ := sd.mainSession.Values["code_verifier"].(string) + if currentVal != codeVerifier { + sd.mainSession.Values["code_verifier"] = codeVerifier + sd.dirty = true + } } // GetEmail retrieves the authenticated user's email address stored in the main session. @@ -735,7 +850,11 @@ func (sd *SessionData) GetEmail() string { // Parameters: // - email: The user's email address to store. func (sd *SessionData) SetEmail(email string) { - sd.mainSession.Values["email"] = email + currentVal, _ := sd.mainSession.Values["email"].(string) + if currentVal != email { + sd.mainSession.Values["email"] = email + sd.dirty = true + } } // GetIncomingPath retrieves the original request URI (including query parameters) @@ -755,7 +874,11 @@ func (sd *SessionData) GetIncomingPath() string { // Parameters: // - path: The original request URI string (e.g., "/protected/resource?id=123"). func (sd *SessionData) SetIncomingPath(path string) { - sd.mainSession.Values["incoming_path"] = path + currentVal, _ := sd.mainSession.Values["incoming_path"].(string) + if currentVal != path { + sd.mainSession.Values["incoming_path"] = path + sd.dirty = true + } } // GetIDToken retrieves the ID token stored in the session. @@ -781,6 +904,15 @@ func (sd *SessionData) GetIDToken() string { // Parameters: // - token: The ID token string to store. func (sd *SessionData) SetIDToken(token string) { + currentIDToken := sd.GetIDToken() // Gets fully reassembled, decompressed token + if currentIDToken == token { + // This handles cases where token is "" and currentIDToken is also "", no change. + // Or token is "abc" and currentIDToken is "abc", no change. + return + } + + sd.dirty = true // Mark as dirty because a change is being made + if token == "" { sd.mainSession.Values["id_token"] = "" sd.mainSession.Values["id_token_compressed"] = false @@ -789,7 +921,6 @@ func (sd *SessionData) SetIDToken(token string) { // Compress token compressed := compressToken(token) - sd.mainSession.Values["id_token"] = compressed sd.mainSession.Values["id_token_compressed"] = true }