diff --git a/session.go b/session.go index 71a92d8..1388576 100644 --- a/session.go +++ b/session.go @@ -194,6 +194,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (* manager: sm, accessTokenChunks: make(map[int]*sessions.Session), refreshTokenChunks: make(map[int]*sessions.Session), + idTokenChunks: make(map[int]*sessions.Session), refreshMutex: sync.Mutex{}, // Initialize the mutex sessionMutex: sync.RWMutex{}, // Initialize the session mutex dirty: false, // Initialize dirty flag @@ -280,10 +281,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { for k := range sessionData.refreshTokenChunks { delete(sessionData.refreshTokenChunks, k) } + for k := range sessionData.idTokenChunks { + delete(sessionData.idTokenChunks, k) + } // Retrieve chunked token sessions. sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks) sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks) + sm.getTokenChunkSessions(r, mainCookieName, sessionData.idTokenChunks) return sessionData, nil } @@ -335,6 +340,10 @@ type SessionData struct { // when it exceeds the maximum cookie size. refreshTokenChunks map[int]*sessions.Session + // idTokenChunks stores additional chunks of the ID token + // when it exceeds the maximum cookie size. + idTokenChunks map[int]*sessions.Session + // refreshMutex protects refresh token operations within this session instance. refreshMutex sync.Mutex @@ -420,6 +429,12 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i)) } + // Save ID token chunks. + for i, sessionChunk := range sd.idTokenChunks { + sessionChunk.Options = options + saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i)) + } + if firstErr == nil { sd.dirty = false // Reset dirty flag only if all saves were successful } @@ -467,6 +482,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { // Clear chunk sessions. sd.clearTokenChunks(r, sd.accessTokenChunks) sd.clearTokenChunks(r, sd.refreshTokenChunks) + sd.clearTokenChunks(r, sd.idTokenChunks) // Create a guaranteed error when the response writer is set // This is primarily for testing - in production w will often be nil @@ -648,6 +664,9 @@ func (sd *SessionData) Reset() { for k := range sd.refreshTokenChunks { delete(sd.refreshTokenChunks, k) } + for k := range sd.idTokenChunks { + delete(sd.idTokenChunks, k) + } // Reset state flags sd.dirty = false @@ -926,6 +945,30 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) { } } +// expireIDTokenChunks finds all existing ID token chunk cookies (_oidc_raczylo_N) +// associated with the current request, clears their values, and sets their MaxAge to -1. +// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send +// the expiring Set-Cookie headers. This is used internally when setting a new ID token. +// +// Parameters: +// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent. +func (sd *SessionData) expireIDTokenChunks(w http.ResponseWriter) { + for i := 0; ; i++ { + sessionName := fmt.Sprintf("%s_%d", mainCookieName, i) + session, err := sd.manager.store.Get(sd.request, sessionName) + if err != nil || session.IsNew { + break + } + session.Options.MaxAge = -1 + session.Values = make(map[interface{}]interface{}) + if w != nil { + if err := session.Save(sd.request, w); err != nil { + sd.manager.logger.Errorf("failed to save expired ID token cookie: %v", err) + } + } + } +} + // splitIntoChunks divides a string `s` into a slice of strings, where each element // has a maximum length of `chunkSize`. // @@ -1077,6 +1120,14 @@ func (sd *SessionData) SetIncomingPath(path string) { // Returns: // - The complete, decompressed ID token string, or an empty string if not found. func (sd *SessionData) GetIDToken() string { + sd.sessionMutex.RLock() + defer sd.sessionMutex.RUnlock() + + return sd.getIDTokenUnsafe() +} + +// getIDTokenUnsafe is the internal implementation without mutex protection +func (sd *SessionData) getIDTokenUnsafe() string { token, _ := sd.mainSession.Values["id_token"].(string) if token != "" { compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool) @@ -1085,33 +1136,97 @@ func (sd *SessionData) GetIDToken() string { } return token } - return "" + + // Reassemble token from chunks. + if len(sd.idTokenChunks) == 0 { + return "" + } + + var chunks []string + for i := 0; ; i++ { + session, ok := sd.idTokenChunks[i] + if !ok { + break + } + chunk, _ := session.Values["id_token_chunk"].(string) + chunks = append(chunks, chunk) + } + + token = strings.Join(chunks, "") + compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool) + if compressed { + return decompressToken(token) + } + return token } // SetIDToken stores the provided ID token in the session. +// It first expires any existing ID token chunk cookies. +// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize), +// it's stored directly in the primary main session. Otherwise, the compressed token +// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_0, _oidc_raczylo_1, etc.). // // Parameters: // - token: The ID token string to store. func (sd *SessionData) SetIDToken(token string) { - currentIDToken := sd.GetIDToken() // Gets fully reassembled, decompressed token + sd.sessionMutex.Lock() + defer sd.sessionMutex.Unlock() + + currentIDToken := sd.getIDTokenUnsafe() if currentIDToken == token { - // This handles cases where token is "" and currentIDToken is also "", no change. - // Or token is "abc" and currentIDToken is "abc", no change. + // If token is empty, and current is also empty, it's not a change. + // This check handles both empty and non-empty identical cases. + return + } + sd.dirty = true + + // Expire any existing chunk cookies first. + if sd.request != nil { + sd.expireIDTokenChunks(nil) // Will be saved when Save() is called. + } + + // Clear and prepare chunks map for new token. + sd.idTokenChunks = make(map[int]*sessions.Session) + + if token == "" { // Clearing the token + // STABILITY FIX: Add nil checks before accessing session values + if sd.mainSession != nil { + sd.mainSession.Values["id_token"] = "" + sd.mainSession.Values["id_token_compressed"] = false + } + // sd.idTokenChunks is already cleared return } - sd.dirty = true // Mark as dirty because a change is being made - - if token == "" { - sd.mainSession.Values["id_token"] = "" - sd.mainSession.Values["id_token_compressed"] = false - return - } - - // Compress token + // Compress token. compressed := compressToken(token) - sd.mainSession.Values["id_token"] = compressed - sd.mainSession.Values["id_token_compressed"] = true + + if len(compressed) <= maxCookieSize { + // STABILITY FIX: Add nil checks before accessing session values + if sd.mainSession != nil { + sd.mainSession.Values["id_token"] = compressed + sd.mainSession.Values["id_token_compressed"] = true + } + } else { + // Split compressed token into chunks. + if sd.mainSession != nil { + sd.mainSession.Values["id_token"] = "" // Main cookie won't hold the token directly + sd.mainSession.Values["id_token_compressed"] = true // Data in chunks is compressed + } + chunks := splitIntoChunks(compressed, maxCookieSize) + for i, chunkData := range chunks { + sessionName := fmt.Sprintf("%s_%d", mainCookieName, i) + // Ensure sd.request is available, otherwise log warning or handle error + if sd.request == nil { + sd.manager.logger.Infof("SetIDToken: sd.request is nil, cannot get/create chunk session %s", sessionName) + // Potentially skip this chunk or error out, depending on desired robustness + continue + } + session, _ := sd.manager.store.Get(sd.request, sessionName) + session.Values["id_token_chunk"] = chunkData + sd.idTokenChunks[i] = session + } + } } // GetRedirectCount retrieves the current redirect count from the session. diff --git a/session_test.go b/session_test.go index 6d41921..9c3db87 100644 --- a/session_test.go +++ b/session_test.go @@ -1,6 +1,9 @@ package traefikoidc import ( + "crypto/rand" + "encoding/base64" + "fmt" "net/http" "net/http/httptest" "runtime" @@ -218,4 +221,165 @@ func TestSessionObjectTracking(t *testing.T) { t.Log("Session pool handling verified") } +// TestLargeIDTokenChunking tests that large ID tokens are properly chunked across multiple cookies +func TestLargeIDTokenChunking(t *testing.T) { + logger := NewLogger("debug") + sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger) + if err != nil { + t.Fatalf("Failed to create session manager: %v", err) + } + // Create a large ID token (>4KB) to force chunking + largeIDToken := createLargeIDToken(20000) // 20KB token to ensure chunking after compression + t.Logf("Created large ID token with length: %d", len(largeIDToken)) + + // Create a request and response recorder + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + rr := httptest.NewRecorder() + + // Get session and set large ID token + session, err := sm.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Set the large ID token + session.SetIDToken(largeIDToken) + t.Logf("Set large ID token in session") + + // Let's check what the GetIDToken returns to confirm it's set + retrievedToken := session.GetIDToken() + t.Logf("Retrieved ID token length: %d", len(retrievedToken)) + if len(retrievedToken) != len(largeIDToken) { + t.Errorf("Token length mismatch: expected %d, got %d", len(largeIDToken), len(retrievedToken)) + } + + // Let's check what's in the main session directly + if idToken, ok := session.mainSession.Values["id_token"].(string); ok { + t.Logf("Main session id_token length: %d", len(idToken)) + if compressed, ok := session.mainSession.Values["id_token_compressed"].(bool); ok { + t.Logf("Main session id_token_compressed: %v", compressed) + } + } else { + t.Logf("Main session id_token not found or not a string") + } + + // Save the session to trigger chunking + err = session.Save(req, rr) + if err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Verify that chunked cookies were created + cookies := rr.Result().Cookies() + t.Logf("Total cookies in response: %d", len(cookies)) + + for _, cookie := range cookies { + valuePreview := cookie.Value + if len(valuePreview) > 50 { + valuePreview = valuePreview[:50] + "..." + } + t.Logf("Cookie: %s = %s (len=%d)", cookie.Name, valuePreview, len(cookie.Value)) + } + + var mainCookie *http.Cookie + var chunkCookies []*http.Cookie + + for _, cookie := range cookies { + if cookie.Name == mainCookieName { + mainCookie = cookie + } else if strings.HasPrefix(cookie.Name, mainCookieName+"_") { + chunkCookies = append(chunkCookies, cookie) + } + } + + // Verify main cookie exists + if mainCookie == nil { + t.Fatal("Main cookie not found in response") + } + + // Verify chunk cookies exist (should be at least 2 for a 5KB token) + if len(chunkCookies) < 2 { + t.Fatalf("Expected at least 2 chunk cookies, got %d", len(chunkCookies)) + } + + // Verify chunk cookie naming convention + expectedChunkNames := make(map[string]bool) + for i := 0; i < len(chunkCookies); i++ { + expectedChunkNames[mainCookieName+"_"+fmt.Sprintf("%d", i)] = true + } + + for _, cookie := range chunkCookies { + if !expectedChunkNames[cookie.Name] { + t.Errorf("Unexpected chunk cookie name: %s", cookie.Name) + } + } + + // Test token retrieval from chunked cookies + // Create a new request with all the cookies + newReq := httptest.NewRequest("GET", "http://example.com/foo", nil) + for _, cookie := range cookies { + newReq.AddCookie(cookie) + } + + // Get session and retrieve the ID token + retrievedSession, err := sm.GetSession(newReq) + if err != nil { + t.Fatalf("Failed to get session from chunked cookies: %v", err) + } + + retrievedToken2 := retrievedSession.GetIDToken() + + // Verify the retrieved token matches the original + if retrievedToken2 != largeIDToken { + t.Errorf("Retrieved ID token doesn't match original. Expected length: %d, got: %d", len(largeIDToken), len(retrievedToken2)) + } + + // Test clearing the ID token removes all chunks + retrievedSession.SetIDToken("") + + clearRR := httptest.NewRecorder() + err = retrievedSession.Save(newReq, clearRR) + if err != nil { + t.Fatalf("Failed to save session after clearing ID token: %v", err) + } + + // Verify chunks are expired (MaxAge = -1) + clearCookies := clearRR.Result().Cookies() + for _, cookie := range clearCookies { + if strings.HasPrefix(cookie.Name, mainCookieName+"_") { + if cookie.MaxAge != -1 { + t.Errorf("Expected chunk cookie %s to be expired (MaxAge=-1), got MaxAge=%d", cookie.Name, cookie.MaxAge) + } + } + } +} + +// createLargeIDToken creates a JWT-like token of specified size for testing +func createLargeIDToken(size int) string { + // Create truly random data that won't compress well + randomBytes := make([]byte, size*3/4) // base64 encoding increases size by ~4/3 + _, err := rand.Read(randomBytes) + if err != nil { + // Fallback to pseudo-random if crypto/rand fails + for i := range randomBytes { + randomBytes[i] = byte(i % 256) + } + } + + // Base64 encode the random data to make it look like a JWT + encoded := base64.StdEncoding.EncodeToString(randomBytes) + + // Create JWT-like structure with truly random data + header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9" + + // Truncate or pad to desired size + if len(encoded) > size-len(header)-100 { + encoded = encoded[:size-len(header)-100] + } + + signature := "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + + return header + "." + encoded + "." + signature +} + // This is intentionally left empty to remove unused code