diff --git a/session.go b/session.go index d32ec12..99524cc 100644 --- a/session.go +++ b/session.go @@ -1,9 +1,14 @@ package traefikoidc import ( + "bytes" + "compress/gzip" + "encoding/base64" "fmt" + "io" "net/http" "strings" + "sync" "github.com/gorilla/sessions" ) @@ -36,6 +41,40 @@ const ( maxCookieSize = 2000 ) +// compressToken compresses a token using gzip and base64 encodes it +func compressToken(token string) string { + var b bytes.Buffer + gz := gzip.NewWriter(&b) + if _, err := gz.Write([]byte(token)); err != nil { + return token // fallback to uncompressed on error + } + if err := gz.Close(); err != nil { + return token + } + return base64.StdEncoding.EncodeToString(b.Bytes()) +} + +// decompressToken decompresses a base64 encoded gzipped token +func decompressToken(compressed string) string { + data, err := base64.StdEncoding.DecodeString(compressed) + if err != nil { + return compressed // return as-is if not base64 + } + + gz, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return compressed + } + defer gz.Close() + + decompressed, err := io.ReadAll(gz) + if err != nil { + return compressed + } + + return string(decompressed) +} + // SessionManager handles the management of multiple session cookies for OIDC authentication. // It provides functionality for storing and retrieving authentication state, tokens, // and other session-related data across multiple cookies to handle large tokens. @@ -48,6 +87,9 @@ type SessionManager struct { // logger provides structured logging capabilities logger *Logger + + // sessionPool is a sync.Pool for reusing SessionData objects + sessionPool sync.Pool } // NewSessionManager creates a new session manager with the specified configuration. @@ -57,11 +99,22 @@ type SessionManager struct { // - logger: Logger instance for recording session-related events // The manager handles session creation, storage, and cookie security settings. func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager { - return &SessionManager{ + sm := &SessionManager{ store: sessions.NewCookieStore([]byte(encryptionKey)), forceHTTPS: forceHTTPS, logger: logger, } + + // Initialize session pool + sm.sessionPool.New = func() interface{} { + return &SessionData{ + manager: sm, + accessTokenChunks: make(map[int]*sessions.Session), + refreshTokenChunks: make(map[int]*sessions.Session), + } + } + + return sm } // getSessionOptions returns secure session options configured for the current request. @@ -87,33 +140,40 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { // and combines them into a single SessionData structure for easy access. // Returns an error if any session component cannot be loaded. func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { - mainSession, err := sm.store.Get(r, mainCookieName) + // Get session from pool + sessionData := sm.sessionPool.Get().(*SessionData) + sessionData.request = r + + var err error + sessionData.mainSession, err = sm.store.Get(r, mainCookieName) if err != nil { + sm.sessionPool.Put(sessionData) return nil, fmt.Errorf("failed to get main session: %w", err) } - accessSession, err := sm.store.Get(r, accessTokenCookie) + sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie) if err != nil { + sm.sessionPool.Put(sessionData) return nil, fmt.Errorf("failed to get access token session: %w", err) } - refreshSession, err := sm.store.Get(r, refreshTokenCookie) + sessionData.refreshSession, err = sm.store.Get(r, refreshTokenCookie) if err != nil { + sm.sessionPool.Put(sessionData) return nil, fmt.Errorf("failed to get refresh token session: %w", err) } - sessionData := &SessionData{ - manager: sm, - request: r, - mainSession: mainSession, - accessSession: accessSession, - refreshSession: refreshSession, + // Clear and reuse chunk maps + for k := range sessionData.accessTokenChunks { + delete(sessionData.accessTokenChunks, k) + } + for k := range sessionData.refreshTokenChunks { + delete(sessionData.refreshTokenChunks, k) } - // Retrieve chunked access token sessions - sessionData.accessTokenChunks = sm.getTokenChunkSessions(r, accessTokenCookie) - // Retrieve chunked refresh token sessions - sessionData.refreshTokenChunks = sm.getTokenChunkSessions(r, refreshTokenCookie) + // Retrieve chunked token sessions + sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks) + sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks) return sessionData, nil } @@ -122,10 +182,8 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { // Parameters: // - r: The HTTP request // - baseName: The base name for the token's session cookies -// Returns a map of chunk index to session, used for handling large tokens -// that exceed single cookie size limits. -func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session { - chunks := make(map[int]*sessions.Session) +// - chunks: Map to store the chunks in +func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) { for i := 0; ; i++ { sessionName := fmt.Sprintf("%s_%d", baseName, i) session, err := sm.store.Get(r, sessionName) @@ -135,7 +193,6 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string } chunks[i] = session } - return chunks } // SessionData holds all session information for an authenticated user. @@ -237,7 +294,12 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { sd.clearTokenChunks(r, sd.accessTokenChunks) sd.clearTokenChunks(r, sd.refreshTokenChunks) - return sd.Save(r, w) + err := sd.Save(r, w) + + // Return session to pool + sd.manager.sessionPool.Put(sd) + + return err } // clearTokenChunks removes all session chunks for a given token type. @@ -273,6 +335,10 @@ func (sd *SessionData) SetAuthenticated(value bool) { func (sd *SessionData) GetAccessToken() string { token, _ := sd.accessSession.Values["token"].(string) if token != "" { + compressed, _ := sd.accessSession.Values["compressed"].(bool) + if compressed { + return decompressToken(token) + } return token } @@ -291,11 +357,16 @@ func (sd *SessionData) GetAccessToken() string { chunks = append(chunks, chunk) } - return strings.Join(chunks, "") + token = strings.Join(chunks, "") + compressed, _ := sd.accessSession.Values["compressed"].(bool) + if compressed { + return decompressToken(token) + } + return token } // SetAccessToken stores the access token in the session. -// If the token exceeds maxCookieSize, it is automatically split into +// If the token exceeds maxCookieSize, it is automatically compressed and split into // multiple cookie chunks to handle large tokens while staying within // browser cookie size limits. Any existing token or chunks are cleared // before setting the new token. @@ -304,12 +375,17 @@ func (sd *SessionData) SetAccessToken(token string) { sd.clearTokenChunks(sd.request, sd.accessTokenChunks) sd.accessTokenChunks = make(map[int]*sessions.Session) - if len(token) <= maxCookieSize { - sd.accessSession.Values["token"] = token + // Compress token + compressed := compressToken(token) + + if len(compressed) <= maxCookieSize { + sd.accessSession.Values["token"] = compressed + sd.accessSession.Values["compressed"] = true } else { - // Split token into chunks + // Split compressed token into chunks sd.accessSession.Values["token"] = "" - chunks := splitIntoChunks(token, maxCookieSize) + sd.accessSession.Values["compressed"] = true + chunks := splitIntoChunks(compressed, maxCookieSize) for i, chunk := range chunks { sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i) session, _ := sd.manager.store.Get(sd.request, sessionName) @@ -326,6 +402,10 @@ func (sd *SessionData) SetAccessToken(token string) { func (sd *SessionData) GetRefreshToken() string { token, _ := sd.refreshSession.Values["token"].(string) if token != "" { + compressed, _ := sd.refreshSession.Values["compressed"].(bool) + if compressed { + return decompressToken(token) + } return token } @@ -344,11 +424,16 @@ func (sd *SessionData) GetRefreshToken() string { chunks = append(chunks, chunk) } - return strings.Join(chunks, "") + token = strings.Join(chunks, "") + compressed, _ := sd.refreshSession.Values["compressed"].(bool) + if compressed { + return decompressToken(token) + } + return token } // SetRefreshToken stores the refresh token in the session. -// If the token exceeds maxCookieSize, it is automatically split into +// If the token exceeds maxCookieSize, it is automatically compressed and split into // multiple cookie chunks to handle large tokens while staying within // browser cookie size limits. Any existing token or chunks are cleared // before setting the new token. @@ -357,12 +442,17 @@ func (sd *SessionData) SetRefreshToken(token string) { sd.clearTokenChunks(sd.request, sd.refreshTokenChunks) sd.refreshTokenChunks = make(map[int]*sessions.Session) - if len(token) <= maxCookieSize { - sd.refreshSession.Values["token"] = token + // Compress token + compressed := compressToken(token) + + if len(compressed) <= maxCookieSize { + sd.refreshSession.Values["token"] = compressed + sd.refreshSession.Values["compressed"] = true } else { - // Split token into chunks + // Split compressed token into chunks sd.refreshSession.Values["token"] = "" - chunks := splitIntoChunks(token, maxCookieSize) + sd.refreshSession.Values["compressed"] = true + chunks := splitIntoChunks(compressed, maxCookieSize) for i, chunk := range chunks { sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i) session, _ := sd.manager.store.Get(sd.request, sessionName) diff --git a/session_test.go b/session_test.go index 066ce6d..6b795ff 100644 --- a/session_test.go +++ b/session_test.go @@ -113,17 +113,31 @@ func TestSessionManager(t *testing.T) { func calculateExpectedCookieCount(accessToken, refreshToken string) int { count := 3 // main, access, refresh - // Calculate number of chunks for access token - accessChunks := len(splitIntoChunks(accessToken, maxCookieSize)) - if accessChunks > 1 { - count += accessChunks + // Helper to calculate chunks for compressed token + calculateChunks := func(token string) int { + // Compress token (matching the actual implementation) + compressed := compressToken(token) + + // If compressed token fits in one cookie, no additional chunks needed + if len(compressed) <= maxCookieSize { + return 0 + } + + // Calculate chunks needed for compressed token + return len(splitIntoChunks(compressed, maxCookieSize)) } - // Calculate number of chunks for refresh token - refreshChunks := len(splitIntoChunks(refreshToken, maxCookieSize)) - if refreshChunks > 1 { - count += refreshChunks + // Add chunks for access token if needed + accessChunks := calculateChunks(accessToken) + if accessChunks > 0 { + count += accessChunks + } + + // Add chunks for refresh token if needed + refreshChunks := calculateChunks(refreshToken) + if refreshChunks > 0 { + count += refreshChunks } return count -} \ No newline at end of file +}