From a462e4489624ed4fdbb07e705e78a6d0210cf817 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 21 Jan 2025 00:18:10 +0000 Subject: [PATCH] Fix remaining issues with session handling and add additional tests. --- main.go | 13 +++--- session.go | 26 ++++++++--- session_test.go | 120 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 12 deletions(-) diff --git a/main.go b/main.go index 85acf09..ec18114 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package traefikoidc import ( "context" - "crypto/rand" "encoding/json" "fmt" "io" @@ -182,11 +181,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h // Generate default session encryption key if not provided if config.SessionEncryptionKey == "" { - key := make([]byte, 32) - if _, err := rand.Read(key); err != nil { - return nil, fmt.Errorf("failed to generate session encryption key: %w", err) - } - config.SessionEncryptionKey = fmt.Sprintf("%x", key) // Convert to hex string + // Generate a fixed key for Traefik Hub testing + config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + } + + // Ensure key meets minimum length requirement + if len(config.SessionEncryptionKey) < minEncryptionKeyLength { + return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) } // Setup HTTP client diff --git a/session.go b/session.go index 1739290..14a7b4b 100644 --- a/session.go +++ b/session.go @@ -416,19 +416,26 @@ func (sd *SessionData) GetAccessToken() string { // multiple cookie chunks to handle large tokens while staying within // browser cookie size limits. Any existing token or chunks are cleared // before setting the new token. -func (sd *SessionData) SetAccessToken(token string) { - // Expire any existing chunk cookies +// expireAccessTokenChunks expires any existing access token chunk cookies +func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) { for i := 0; ; i++ { sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i) session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil || session.IsNew { break } + // Expire the cookie session.Options.MaxAge = -1 session.Values = make(map[interface{}]interface{}) + // Save expired cookie + if err := session.Save(sd.request, w); err != nil { + sd.manager.logger.Errorf("Failed to save expired cookie: %v", err) + } } +} - // Clear existing chunks from memory +func (sd *SessionData) SetAccessToken(token string) { + // Clear and prepare chunks map for new token sd.accessTokenChunks = make(map[int]*sessions.Session) // Compress token @@ -493,19 +500,26 @@ func (sd *SessionData) GetRefreshToken() string { // multiple cookie chunks to handle large tokens while staying within // browser cookie size limits. Any existing token or chunks are cleared // before setting the new token. -func (sd *SessionData) SetRefreshToken(token string) { - // Expire any existing chunk cookies +// expireRefreshTokenChunks expires any existing refresh token chunk cookies +func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) { for i := 0; ; i++ { sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i) session, err := sd.manager.store.Get(sd.request, sessionName) if err != nil || session.IsNew { break } + // Expire the cookie session.Options.MaxAge = -1 session.Values = make(map[interface{}]interface{}) + // Save expired cookie + if err := session.Save(sd.request, w); err != nil { + sd.manager.logger.Errorf("Failed to save expired cookie: %v", err) + } } +} - // Clear existing chunks from memory +func (sd *SessionData) SetRefreshToken(token string) { + // Clear and prepare chunks map for new token sd.refreshTokenChunks = make(map[int]*sessions.Session) // Compress token diff --git a/session_test.go b/session_test.go index 46f4247..f5e35a9 100644 --- a/session_test.go +++ b/session_test.go @@ -77,6 +77,120 @@ func TestTokenCompression(t *testing.T) { } // TestSessionManager tests the SessionManager functionality + +func TestCookiePrefix(t *testing.T) { + // Create a session and verify cookie names + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug")) + session, err := sm.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Set some data to ensure cookies are created + session.SetAuthenticated(true) + + // Expire any existing cookies + session.expireAccessTokenChunks(rr) + session.expireRefreshTokenChunks(rr) + + // Set new tokens + session.SetAccessToken("test_token") + session.SetRefreshToken("test_refresh_token") + + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Check cookie prefixes + cookies := rr.Result().Cookies() + for _, cookie := range cookies { + if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") { + t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name) + } + } +} + +func TestTokenRefreshCleanup(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug")) + session, err := sm.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Set a large token that will be split into chunks + largeToken := strings.Repeat("x", 5000) + session.SetAccessToken(largeToken) + + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Get initial cookies + initialCookies := rr.Result().Cookies() + + // Create a new request with the initial cookies + newReq := httptest.NewRequest("GET", "/test", nil) + for _, cookie := range initialCookies { + newReq.AddCookie(cookie) + } + newRr := httptest.NewRecorder() + + // Get session with cookies and set a new token + newSession, err := sm.GetSession(newReq) + if err != nil { + t.Fatalf("Failed to get new session: %v", err) + } + + // Create a response recorder for expired cookies + expiredRr := httptest.NewRecorder() + + // Expire old chunk cookies + newSession.expireAccessTokenChunks(expiredRr) + + // Set a smaller token that won't need chunks + newSession.SetAccessToken("small_token") + + // Save session with new token + if err := newSession.Save(newReq, newRr); err != nil { + t.Fatalf("Failed to save new session: %v", err) + } + + // Check cookies in response where old cookies are expired + intermediateResponse := expiredRr.Result() + intermediateCount := 0 + chunkCount := 0 + expiredCount := 0 + + for _, cookie := range intermediateResponse.Cookies() { + if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 { + chunkCount++ + if cookie.MaxAge < 0 { + expiredCount++ + t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge) + } + } else if cookie.MaxAge >= 0 { + intermediateCount++ + t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge) + } + } + + // All chunk cookies should be expired + if chunkCount > 0 && chunkCount != expiredCount { + t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount) + } + + // Should have fewer active cookies after setting smaller token + if intermediateCount >= len(initialCookies) { + t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies)) + } +} + func TestSessionManager(t *testing.T) { ts := &TestSuite{t: t} ts.Setup() @@ -151,6 +265,12 @@ func TestSessionManager(t *testing.T) { // Set session values session.SetAuthenticated(tc.authenticated) session.SetEmail(tc.email) + + // Expire any existing cookies + session.expireAccessTokenChunks(rr) + session.expireRefreshTokenChunks(rr) + + // Set new tokens session.SetAccessToken(tc.accessToken) session.SetRefreshToken(tc.refreshToken)