Fix remaining issues with session handling and add additional tests.

This commit is contained in:
2025-01-21 00:18:10 +00:00
parent 5eff0dc866
commit a462e44896
3 changed files with 147 additions and 12 deletions
+7 -6
View File
@@ -2,7 +2,6 @@ package traefikoidc
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"io"
@@ -182,11 +181,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// Generate default session encryption key if not provided
if config.SessionEncryptionKey == "" {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("failed to generate session encryption key: %w", err)
}
config.SessionEncryptionKey = fmt.Sprintf("%x", key) // Convert to hex string
// Generate a fixed key for Traefik Hub testing
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
// Ensure key meets minimum length requirement
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
// Setup HTTP client
+20 -6
View File
@@ -416,19 +416,26 @@ func (sd *SessionData) GetAccessToken() string {
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetAccessToken(token string) {
// Expire any existing chunk cookies
// expireAccessTokenChunks expires any existing access token chunk cookies
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
// Expire the cookie
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
// Save expired cookie
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
}
}
}
// Clear existing chunks from memory
func (sd *SessionData) SetAccessToken(token string) {
// Clear and prepare chunks map for new token
sd.accessTokenChunks = make(map[int]*sessions.Session)
// Compress token
@@ -493,19 +500,26 @@ func (sd *SessionData) GetRefreshToken() string {
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetRefreshToken(token string) {
// Expire any existing chunk cookies
// expireRefreshTokenChunks expires any existing refresh token chunk cookies
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
// Expire the cookie
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
// Save expired cookie
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
}
}
}
// Clear existing chunks from memory
func (sd *SessionData) SetRefreshToken(token string) {
// Clear and prepare chunks map for new token
sd.refreshTokenChunks = make(map[int]*sessions.Session)
// Compress token
+120
View File
@@ -77,6 +77,120 @@ func TestTokenCompression(t *testing.T) {
}
// TestSessionManager tests the SessionManager functionality
func TestCookiePrefix(t *testing.T) {
// Create a session and verify cookie names
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set some data to ensure cookies are created
session.SetAuthenticated(true)
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Set new tokens
session.SetAccessToken("test_token")
session.SetRefreshToken("test_refresh_token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Check cookie prefixes
cookies := rr.Result().Cookies()
for _, cookie := range cookies {
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
}
}
}
func TestTokenRefreshCleanup(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set a large token that will be split into chunks
largeToken := strings.Repeat("x", 5000)
session.SetAccessToken(largeToken)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get initial cookies
initialCookies := rr.Result().Cookies()
// Create a new request with the initial cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range initialCookies {
newReq.AddCookie(cookie)
}
newRr := httptest.NewRecorder()
// Get session with cookies and set a new token
newSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Create a response recorder for expired cookies
expiredRr := httptest.NewRecorder()
// Expire old chunk cookies
newSession.expireAccessTokenChunks(expiredRr)
// Set a smaller token that won't need chunks
newSession.SetAccessToken("small_token")
// Save session with new token
if err := newSession.Save(newReq, newRr); err != nil {
t.Fatalf("Failed to save new session: %v", err)
}
// Check cookies in response where old cookies are expired
intermediateResponse := expiredRr.Result()
intermediateCount := 0
chunkCount := 0
expiredCount := 0
for _, cookie := range intermediateResponse.Cookies() {
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
chunkCount++
if cookie.MaxAge < 0 {
expiredCount++
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
} else if cookie.MaxAge >= 0 {
intermediateCount++
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
}
// All chunk cookies should be expired
if chunkCount > 0 && chunkCount != expiredCount {
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
}
// Should have fewer active cookies after setting smaller token
if intermediateCount >= len(initialCookies) {
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
}
}
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
@@ -151,6 +265,12 @@ func TestSessionManager(t *testing.T) {
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Set new tokens
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)