mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Fix for cookie length (#58)
* Enhance session management by adding support for chunked id token in main session * Add test for large ID token chunking in session management
This commit is contained in:
+130
-15
@@ -194,6 +194,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshMutex: sync.Mutex{}, // Initialize the mutex
|
||||
sessionMutex: sync.RWMutex{}, // Initialize the session mutex
|
||||
dirty: false, // Initialize dirty flag
|
||||
@@ -280,10 +281,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
for k := range sessionData.refreshTokenChunks {
|
||||
delete(sessionData.refreshTokenChunks, k)
|
||||
}
|
||||
for k := range sessionData.idTokenChunks {
|
||||
delete(sessionData.idTokenChunks, k)
|
||||
}
|
||||
|
||||
// Retrieve chunked token sessions.
|
||||
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
|
||||
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
||||
sm.getTokenChunkSessions(r, mainCookieName, sessionData.idTokenChunks)
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
@@ -335,6 +340,10 @@ type SessionData struct {
|
||||
// when it exceeds the maximum cookie size.
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
|
||||
// idTokenChunks stores additional chunks of the ID token
|
||||
// when it exceeds the maximum cookie size.
|
||||
idTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshMutex protects refresh token operations within this session instance.
|
||||
refreshMutex sync.Mutex
|
||||
|
||||
@@ -420,6 +429,12 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
|
||||
}
|
||||
|
||||
// Save ID token chunks.
|
||||
for i, sessionChunk := range sd.idTokenChunks {
|
||||
sessionChunk.Options = options
|
||||
saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i))
|
||||
}
|
||||
|
||||
if firstErr == nil {
|
||||
sd.dirty = false // Reset dirty flag only if all saves were successful
|
||||
}
|
||||
@@ -467,6 +482,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear chunk sessions.
|
||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.idTokenChunks)
|
||||
|
||||
// Create a guaranteed error when the response writer is set
|
||||
// This is primarily for testing - in production w will often be nil
|
||||
@@ -648,6 +664,9 @@ func (sd *SessionData) Reset() {
|
||||
for k := range sd.refreshTokenChunks {
|
||||
delete(sd.refreshTokenChunks, k)
|
||||
}
|
||||
for k := range sd.idTokenChunks {
|
||||
delete(sd.idTokenChunks, k)
|
||||
}
|
||||
|
||||
// Reset state flags
|
||||
sd.dirty = false
|
||||
@@ -926,6 +945,30 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireIDTokenChunks finds all existing ID token chunk cookies (_oidc_raczylo_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new ID token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireIDTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", mainCookieName, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if w != nil {
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("failed to save expired ID token cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks divides a string `s` into a slice of strings, where each element
|
||||
// has a maximum length of `chunkSize`.
|
||||
//
|
||||
@@ -1077,6 +1120,14 @@ func (sd *SessionData) SetIncomingPath(path string) {
|
||||
// Returns:
|
||||
// - The complete, decompressed ID token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetIDToken() string {
|
||||
sd.sessionMutex.RLock()
|
||||
defer sd.sessionMutex.RUnlock()
|
||||
|
||||
return sd.getIDTokenUnsafe()
|
||||
}
|
||||
|
||||
// getIDTokenUnsafe is the internal implementation without mutex protection
|
||||
func (sd *SessionData) getIDTokenUnsafe() string {
|
||||
token, _ := sd.mainSession.Values["id_token"].(string)
|
||||
if token != "" {
|
||||
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
|
||||
@@ -1085,33 +1136,97 @@ func (sd *SessionData) GetIDToken() string {
|
||||
}
|
||||
return token
|
||||
}
|
||||
return ""
|
||||
|
||||
// Reassemble token from chunks.
|
||||
if len(sd.idTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.idTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["id_token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
token = strings.Join(chunks, "")
|
||||
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// SetIDToken stores the provided ID token in the session.
|
||||
// It first expires any existing ID token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary main session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_0, _oidc_raczylo_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The ID token string to store.
|
||||
func (sd *SessionData) SetIDToken(token string) {
|
||||
currentIDToken := sd.GetIDToken() // Gets fully reassembled, decompressed token
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
currentIDToken := sd.getIDTokenUnsafe()
|
||||
if currentIDToken == token {
|
||||
// This handles cases where token is "" and currentIDToken is also "", no change.
|
||||
// Or token is "abc" and currentIDToken is "abc", no change.
|
||||
// If token is empty, and current is also empty, it's not a change.
|
||||
// This check handles both empty and non-empty identical cases.
|
||||
return
|
||||
}
|
||||
sd.dirty = true
|
||||
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireIDTokenChunks(nil) // Will be saved when Save() is called.
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.idTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if token == "" { // Clearing the token
|
||||
// STABILITY FIX: Add nil checks before accessing session values
|
||||
if sd.mainSession != nil {
|
||||
sd.mainSession.Values["id_token"] = ""
|
||||
sd.mainSession.Values["id_token_compressed"] = false
|
||||
}
|
||||
// sd.idTokenChunks is already cleared
|
||||
return
|
||||
}
|
||||
|
||||
sd.dirty = true // Mark as dirty because a change is being made
|
||||
|
||||
if token == "" {
|
||||
sd.mainSession.Values["id_token"] = ""
|
||||
sd.mainSession.Values["id_token_compressed"] = false
|
||||
return
|
||||
}
|
||||
|
||||
// Compress token
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
sd.mainSession.Values["id_token"] = compressed
|
||||
sd.mainSession.Values["id_token_compressed"] = true
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
// STABILITY FIX: Add nil checks before accessing session values
|
||||
if sd.mainSession != nil {
|
||||
sd.mainSession.Values["id_token"] = compressed
|
||||
sd.mainSession.Values["id_token_compressed"] = true
|
||||
}
|
||||
} else {
|
||||
// Split compressed token into chunks.
|
||||
if sd.mainSession != nil {
|
||||
sd.mainSession.Values["id_token"] = "" // Main cookie won't hold the token directly
|
||||
sd.mainSession.Values["id_token_compressed"] = true // Data in chunks is compressed
|
||||
}
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunkData := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", mainCookieName, i)
|
||||
// Ensure sd.request is available, otherwise log warning or handle error
|
||||
if sd.request == nil {
|
||||
sd.manager.logger.Infof("SetIDToken: sd.request is nil, cannot get/create chunk session %s", sessionName)
|
||||
// Potentially skip this chunk or error out, depending on desired robustness
|
||||
continue
|
||||
}
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["id_token_chunk"] = chunkData
|
||||
sd.idTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRedirectCount retrieves the current redirect count from the session.
|
||||
|
||||
+164
@@ -1,6 +1,9 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
@@ -218,4 +221,165 @@ func TestSessionObjectTracking(t *testing.T) {
|
||||
t.Log("Session pool handling verified")
|
||||
}
|
||||
|
||||
// TestLargeIDTokenChunking tests that large ID tokens are properly chunked across multiple cookies
|
||||
func TestLargeIDTokenChunking(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
// Create a large ID token (>4KB) to force chunking
|
||||
largeIDToken := createLargeIDToken(20000) // 20KB token to ensure chunking after compression
|
||||
t.Logf("Created large ID token with length: %d", len(largeIDToken))
|
||||
|
||||
// Create a request and response recorder
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get session and set large ID token
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set the large ID token
|
||||
session.SetIDToken(largeIDToken)
|
||||
t.Logf("Set large ID token in session")
|
||||
|
||||
// Let's check what the GetIDToken returns to confirm it's set
|
||||
retrievedToken := session.GetIDToken()
|
||||
t.Logf("Retrieved ID token length: %d", len(retrievedToken))
|
||||
if len(retrievedToken) != len(largeIDToken) {
|
||||
t.Errorf("Token length mismatch: expected %d, got %d", len(largeIDToken), len(retrievedToken))
|
||||
}
|
||||
|
||||
// Let's check what's in the main session directly
|
||||
if idToken, ok := session.mainSession.Values["id_token"].(string); ok {
|
||||
t.Logf("Main session id_token length: %d", len(idToken))
|
||||
if compressed, ok := session.mainSession.Values["id_token_compressed"].(bool); ok {
|
||||
t.Logf("Main session id_token_compressed: %v", compressed)
|
||||
}
|
||||
} else {
|
||||
t.Logf("Main session id_token not found or not a string")
|
||||
}
|
||||
|
||||
// Save the session to trigger chunking
|
||||
err = session.Save(req, rr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify that chunked cookies were created
|
||||
cookies := rr.Result().Cookies()
|
||||
t.Logf("Total cookies in response: %d", len(cookies))
|
||||
|
||||
for _, cookie := range cookies {
|
||||
valuePreview := cookie.Value
|
||||
if len(valuePreview) > 50 {
|
||||
valuePreview = valuePreview[:50] + "..."
|
||||
}
|
||||
t.Logf("Cookie: %s = %s (len=%d)", cookie.Name, valuePreview, len(cookie.Value))
|
||||
}
|
||||
|
||||
var mainCookie *http.Cookie
|
||||
var chunkCookies []*http.Cookie
|
||||
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == mainCookieName {
|
||||
mainCookie = cookie
|
||||
} else if strings.HasPrefix(cookie.Name, mainCookieName+"_") {
|
||||
chunkCookies = append(chunkCookies, cookie)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify main cookie exists
|
||||
if mainCookie == nil {
|
||||
t.Fatal("Main cookie not found in response")
|
||||
}
|
||||
|
||||
// Verify chunk cookies exist (should be at least 2 for a 5KB token)
|
||||
if len(chunkCookies) < 2 {
|
||||
t.Fatalf("Expected at least 2 chunk cookies, got %d", len(chunkCookies))
|
||||
}
|
||||
|
||||
// Verify chunk cookie naming convention
|
||||
expectedChunkNames := make(map[string]bool)
|
||||
for i := 0; i < len(chunkCookies); i++ {
|
||||
expectedChunkNames[mainCookieName+"_"+fmt.Sprintf("%d", i)] = true
|
||||
}
|
||||
|
||||
for _, cookie := range chunkCookies {
|
||||
if !expectedChunkNames[cookie.Name] {
|
||||
t.Errorf("Unexpected chunk cookie name: %s", cookie.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Test token retrieval from chunked cookies
|
||||
// Create a new request with all the cookies
|
||||
newReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get session and retrieve the ID token
|
||||
retrievedSession, err := sm.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session from chunked cookies: %v", err)
|
||||
}
|
||||
|
||||
retrievedToken2 := retrievedSession.GetIDToken()
|
||||
|
||||
// Verify the retrieved token matches the original
|
||||
if retrievedToken2 != largeIDToken {
|
||||
t.Errorf("Retrieved ID token doesn't match original. Expected length: %d, got: %d", len(largeIDToken), len(retrievedToken2))
|
||||
}
|
||||
|
||||
// Test clearing the ID token removes all chunks
|
||||
retrievedSession.SetIDToken("")
|
||||
|
||||
clearRR := httptest.NewRecorder()
|
||||
err = retrievedSession.Save(newReq, clearRR)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save session after clearing ID token: %v", err)
|
||||
}
|
||||
|
||||
// Verify chunks are expired (MaxAge = -1)
|
||||
clearCookies := clearRR.Result().Cookies()
|
||||
for _, cookie := range clearCookies {
|
||||
if strings.HasPrefix(cookie.Name, mainCookieName+"_") {
|
||||
if cookie.MaxAge != -1 {
|
||||
t.Errorf("Expected chunk cookie %s to be expired (MaxAge=-1), got MaxAge=%d", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createLargeIDToken creates a JWT-like token of specified size for testing
|
||||
func createLargeIDToken(size int) string {
|
||||
// Create truly random data that won't compress well
|
||||
randomBytes := make([]byte, size*3/4) // base64 encoding increases size by ~4/3
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
// Fallback to pseudo-random if crypto/rand fails
|
||||
for i := range randomBytes {
|
||||
randomBytes[i] = byte(i % 256)
|
||||
}
|
||||
}
|
||||
|
||||
// Base64 encode the random data to make it look like a JWT
|
||||
encoded := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
|
||||
// Create JWT-like structure with truly random data
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
// Truncate or pad to desired size
|
||||
if len(encoded) > size-len(header)-100 {
|
||||
encoded = encoded[:size-len(header)-100]
|
||||
}
|
||||
|
||||
signature := "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
||||
|
||||
return header + "." + encoded + "." + signature
|
||||
}
|
||||
|
||||
// This is intentionally left empty to remove unused code
|
||||
|
||||
Reference in New Issue
Block a user