mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Fix dirty session handling.
This commit is contained in:
@@ -919,15 +919,22 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
req.Header.Set(headerName, headerValue)
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
}
|
||||
// Mark session as dirty after processing templated headers to ensure cookie is re-issued
|
||||
session.MarkDirty()
|
||||
t.logger.Debugf("Session marked dirty after templated header processing.")
|
||||
}
|
||||
}
|
||||
|
||||
// Always save session after processing claims and before proceeding
|
||||
// This is especially important for opaque tokens where we need to ensure
|
||||
// authentication state and user information are preserved
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after processing headers: %v", err)
|
||||
// Continue anyway since we have valid tokens
|
||||
if session.IsDirty() {
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after processing headers: %v", err)
|
||||
// Continue anyway since we have valid tokens
|
||||
}
|
||||
} else {
|
||||
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
|
||||
}
|
||||
|
||||
// Set security headers
|
||||
|
||||
+176
-45
@@ -156,6 +156,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshMutex: sync.Mutex{}, // Initialize the mutex
|
||||
dirty: false, // Initialize dirty flag
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,6 +190,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
// Get session from pool.
|
||||
sessionData := sm.sessionPool.Get().(*SessionData)
|
||||
sessionData.request = r
|
||||
sessionData.dirty = false // Reset dirty flag when getting a session
|
||||
|
||||
var err error
|
||||
sessionData.mainSession, err = sm.store.Get(r, mainCookieName)
|
||||
@@ -281,6 +283,21 @@ type SessionData struct {
|
||||
|
||||
// refreshMutex protects refresh token operations within this session instance.
|
||||
refreshMutex sync.Mutex
|
||||
|
||||
// dirty indicates whether the session data has changed and needs to be saved.
|
||||
dirty bool
|
||||
}
|
||||
|
||||
// IsDirty returns true if the session data has been modified since it was last loaded or saved.
|
||||
func (sd *SessionData) IsDirty() bool {
|
||||
return sd.dirty
|
||||
}
|
||||
|
||||
// MarkDirty explicitly sets the dirty flag to true.
|
||||
// This can be used when an operation doesn't change session data
|
||||
// but should still trigger a session save (e.g., to ensure the cookie is re-issued).
|
||||
func (sd *SessionData) MarkDirty() {
|
||||
sd.dirty = true
|
||||
}
|
||||
|
||||
// Save persists all parts of the session (main, access token, refresh token, and any chunks)
|
||||
@@ -302,38 +319,50 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
|
||||
// Save main session.
|
||||
if err := sd.mainSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save main session: %w", err)
|
||||
var firstErr error
|
||||
// Helper to record first error and log subsequent ones
|
||||
saveOrLogError := func(s *sessions.Session, name string) {
|
||||
if s == nil { // Should not happen if initialized correctly
|
||||
sd.manager.logger.Errorf("Attempted to save nil session: %s", name)
|
||||
if firstErr == nil {
|
||||
firstErr = fmt.Errorf("attempted to save nil session: %s", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := s.Save(r, w); err != nil {
|
||||
errMsg := fmt.Errorf("failed to save %s session: %w", name, err)
|
||||
sd.manager.logger.Error(errMsg.Error())
|
||||
if firstErr == nil {
|
||||
firstErr = errMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save main session.
|
||||
saveOrLogError(sd.mainSession, "main")
|
||||
|
||||
// Save access token session.
|
||||
if err := sd.accessSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token session: %w", err)
|
||||
}
|
||||
saveOrLogError(sd.accessSession, "access token")
|
||||
|
||||
// Save refresh token session.
|
||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||
}
|
||||
saveOrLogError(sd.refreshSession, "refresh token")
|
||||
|
||||
// Save access token chunks.
|
||||
for _, session := range sd.accessTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token chunk session: %w", err)
|
||||
}
|
||||
for i, sessionChunk := range sd.accessTokenChunks {
|
||||
sessionChunk.Options = options
|
||||
saveOrLogError(sessionChunk, fmt.Sprintf("access token chunk %d", i))
|
||||
}
|
||||
|
||||
// Save refresh token chunks.
|
||||
for _, session := range sd.refreshTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
|
||||
}
|
||||
for i, sessionChunk := range sd.refreshTokenChunks {
|
||||
sessionChunk.Options = options
|
||||
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
|
||||
}
|
||||
|
||||
return nil
|
||||
if firstErr == nil {
|
||||
sd.dirty = false // Reset dirty flag only if all saves were successful
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// Clear removes all session data associated with this SessionData instance.
|
||||
@@ -350,19 +379,26 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
// Returns:
|
||||
// - An error if saving the expired sessions fails (only if w is not nil).
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions.
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
sd.accessSession.Options.MaxAge = -1
|
||||
sd.refreshSession.Options.MaxAge = -1
|
||||
sd.dirty = true // Clearing the session means its state is changing and needs to be saved.
|
||||
|
||||
for k := range sd.mainSession.Values {
|
||||
delete(sd.mainSession.Values, k)
|
||||
// Clear and expire all sessions.
|
||||
if sd.mainSession != nil {
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
for k := range sd.mainSession.Values {
|
||||
delete(sd.mainSession.Values, k)
|
||||
}
|
||||
}
|
||||
for k := range sd.accessSession.Values {
|
||||
delete(sd.accessSession.Values, k)
|
||||
if sd.accessSession != nil {
|
||||
sd.accessSession.Options.MaxAge = -1
|
||||
for k := range sd.accessSession.Values {
|
||||
delete(sd.accessSession.Values, k)
|
||||
}
|
||||
}
|
||||
for k := range sd.refreshSession.Values {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
if sd.refreshSession != nil {
|
||||
sd.refreshSession.Options.MaxAge = -1
|
||||
for k := range sd.refreshSession.Values {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear chunk sessions.
|
||||
@@ -428,15 +464,44 @@ func (sd *SessionData) GetAuthenticated() bool {
|
||||
// Returns:
|
||||
// - An error if generating a new session ID fails when setting value to true.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
currentAuth := sd.GetAuthenticated() // This checks flag and expiry
|
||||
changed := false
|
||||
|
||||
if currentAuth != value {
|
||||
changed = true
|
||||
}
|
||||
|
||||
if value {
|
||||
// If we are setting to true, and either it wasn't true before,
|
||||
// or if the session ID needs regeneration (e.g. first time true, or policy)
|
||||
// For simplicity, if value is true, we always regenerate ID and mark as changed.
|
||||
// This ensures session ID regeneration is always saved.
|
||||
id, err := generateSecureRandomString(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate secure session id: %w", err)
|
||||
}
|
||||
if sd.mainSession.ID != id { // ID actually changed
|
||||
changed = true
|
||||
}
|
||||
sd.mainSession.ID = id
|
||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||
newCreationTime := time.Now().Unix()
|
||||
if oldTime, ok := sd.mainSession.Values["created_at"].(int64); !ok || oldTime != newCreationTime {
|
||||
changed = true
|
||||
}
|
||||
sd.mainSession.Values["created_at"] = newCreationTime
|
||||
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
|
||||
changed = true
|
||||
}
|
||||
} else { // value is false
|
||||
if oldAuth, ok := sd.mainSession.Values["authenticated"].(bool); !ok || oldAuth != value {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
sd.mainSession.Values["authenticated"] = value
|
||||
if changed {
|
||||
sd.dirty = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -488,6 +553,14 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
// Parameters:
|
||||
// - token: The access token string to store.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
currentAccessToken := sd.GetAccessToken()
|
||||
if currentAccessToken == token {
|
||||
// If token is empty, and current is also empty, it's not a change.
|
||||
// This check handles both empty and non-empty identical cases.
|
||||
return
|
||||
}
|
||||
sd.dirty = true
|
||||
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
|
||||
@@ -496,6 +569,13 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if token == "" { // Clearing the token
|
||||
sd.accessSession.Values["token"] = ""
|
||||
sd.accessSession.Values["compressed"] = false
|
||||
// sd.accessTokenChunks is already cleared
|
||||
return
|
||||
}
|
||||
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
|
||||
@@ -504,13 +584,19 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks.
|
||||
sd.accessSession.Values["token"] = ""
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
sd.accessSession.Values["token"] = "" // Main cookie won't hold the token directly
|
||||
sd.accessSession.Values["compressed"] = true // Data in chunks is compressed
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
for i, chunkData := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
// Ensure sd.request is available, otherwise log warning or handle error
|
||||
if sd.request == nil {
|
||||
sd.manager.logger.Infof("SetAccessToken: sd.request is nil, cannot get/create chunk session %s", sessionName)
|
||||
// Potentially skip this chunk or error out, depending on desired robustness
|
||||
continue
|
||||
}
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
session.Values["token_chunk"] = chunkData
|
||||
sd.accessTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
@@ -564,6 +650,12 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
// Parameters:
|
||||
// - token: The refresh token string to store.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
currentRefreshToken := sd.GetRefreshToken()
|
||||
if currentRefreshToken == token {
|
||||
return
|
||||
}
|
||||
sd.dirty = true
|
||||
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
|
||||
@@ -572,6 +664,13 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if token == "" { // Clearing the token
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = false
|
||||
// sd.refreshTokenChunks is already cleared
|
||||
return
|
||||
}
|
||||
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
|
||||
@@ -580,13 +679,17 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks.
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
sd.refreshSession.Values["token"] = "" // Main cookie won't hold the token directly
|
||||
sd.refreshSession.Values["compressed"] = true // Data in chunks is compressed
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
for i, chunkData := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
if sd.request == nil {
|
||||
sd.manager.logger.Infof("SetRefreshToken: sd.request is nil, cannot get/create chunk session %s", sessionName)
|
||||
continue
|
||||
}
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
session.Values["token_chunk"] = chunkData
|
||||
sd.refreshTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
@@ -678,7 +781,11 @@ func (sd *SessionData) GetCSRF() string {
|
||||
// Parameters:
|
||||
// - token: The CSRF token to store.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
currentVal, _ := sd.mainSession.Values["csrf"].(string)
|
||||
if currentVal != token {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// GetNonce retrieves the OIDC nonce value stored in the main session.
|
||||
@@ -697,7 +804,11 @@ func (sd *SessionData) GetNonce() string {
|
||||
// Parameters:
|
||||
// - nonce: The nonce string to store.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
currentVal, _ := sd.mainSession.Values["nonce"].(string)
|
||||
if currentVal != nonce {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// GetCodeVerifier retrieves the PKCE (Proof Key for Code Exchange) code verifier
|
||||
@@ -716,7 +827,11 @@ func (sd *SessionData) GetCodeVerifier() string {
|
||||
// Parameters:
|
||||
// - codeVerifier: The PKCE code verifier string to store.
|
||||
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
sd.mainSession.Values["code_verifier"] = codeVerifier
|
||||
currentVal, _ := sd.mainSession.Values["code_verifier"].(string)
|
||||
if currentVal != codeVerifier {
|
||||
sd.mainSession.Values["code_verifier"] = codeVerifier
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address stored in the main session.
|
||||
@@ -735,7 +850,11 @@ func (sd *SessionData) GetEmail() string {
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
currentVal, _ := sd.mainSession.Values["email"].(string)
|
||||
if currentVal != email {
|
||||
sd.mainSession.Values["email"] = email
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// GetIncomingPath retrieves the original request URI (including query parameters)
|
||||
@@ -755,7 +874,11 @@ func (sd *SessionData) GetIncomingPath() string {
|
||||
// Parameters:
|
||||
// - path: The original request URI string (e.g., "/protected/resource?id=123").
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
currentVal, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
if currentVal != path {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// GetIDToken retrieves the ID token stored in the session.
|
||||
@@ -781,6 +904,15 @@ func (sd *SessionData) GetIDToken() string {
|
||||
// Parameters:
|
||||
// - token: The ID token string to store.
|
||||
func (sd *SessionData) SetIDToken(token string) {
|
||||
currentIDToken := sd.GetIDToken() // Gets fully reassembled, decompressed token
|
||||
if currentIDToken == token {
|
||||
// This handles cases where token is "" and currentIDToken is also "", no change.
|
||||
// Or token is "abc" and currentIDToken is "abc", no change.
|
||||
return
|
||||
}
|
||||
|
||||
sd.dirty = true // Mark as dirty because a change is being made
|
||||
|
||||
if token == "" {
|
||||
sd.mainSession.Values["id_token"] = ""
|
||||
sd.mainSession.Values["id_token_compressed"] = false
|
||||
@@ -789,7 +921,6 @@ func (sd *SessionData) SetIDToken(token string) {
|
||||
|
||||
// Compress token
|
||||
compressed := compressToken(token)
|
||||
|
||||
sd.mainSession.Values["id_token"] = compressed
|
||||
sd.mainSession.Values["id_token_compressed"] = true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user