mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Fix remaining issues with session handling and add additional tests.
This commit is contained in:
@@ -2,7 +2,6 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -182,11 +181,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
// Generate default session encryption key if not provided
|
||||
if config.SessionEncryptionKey == "" {
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session encryption key: %w", err)
|
||||
}
|
||||
config.SessionEncryptionKey = fmt.Sprintf("%x", key) // Convert to hex string
|
||||
// Generate a fixed key for Traefik Hub testing
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
// Ensure key meets minimum length requirement
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
|
||||
// Setup HTTP client
|
||||
|
||||
+20
-6
@@ -416,19 +416,26 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Expire any existing chunk cookies
|
||||
// expireAccessTokenChunks expires any existing access token chunk cookies
|
||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
// Expire the cookie
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
// Save expired cookie
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear existing chunks from memory
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Clear and prepare chunks map for new token
|
||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token
|
||||
@@ -493,19 +500,26 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Expire any existing chunk cookies
|
||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies
|
||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
// Expire the cookie
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
// Save expired cookie
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear existing chunks from memory
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Clear and prepare chunks map for new token
|
||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token
|
||||
|
||||
+120
@@ -77,6 +77,120 @@ func TestTokenCompression(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestSessionManager tests the SessionManager functionality
|
||||
|
||||
func TestCookiePrefix(t *testing.T) {
|
||||
// Create a session and verify cookie names
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set some data to ensure cookies are created
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Set new tokens
|
||||
session.SetAccessToken("test_token")
|
||||
session.SetRefreshToken("test_refresh_token")
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Check cookie prefixes
|
||||
cookies := rr.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
|
||||
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRefreshCleanup(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set a large token that will be split into chunks
|
||||
largeToken := strings.Repeat("x", 5000)
|
||||
session.SetAccessToken(largeToken)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get initial cookies
|
||||
initialCookies := rr.Result().Cookies()
|
||||
|
||||
// Create a new request with the initial cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range initialCookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
newRr := httptest.NewRecorder()
|
||||
|
||||
// Get session with cookies and set a new token
|
||||
newSession, err := sm.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
// Create a response recorder for expired cookies
|
||||
expiredRr := httptest.NewRecorder()
|
||||
|
||||
// Expire old chunk cookies
|
||||
newSession.expireAccessTokenChunks(expiredRr)
|
||||
|
||||
// Set a smaller token that won't need chunks
|
||||
newSession.SetAccessToken("small_token")
|
||||
|
||||
// Save session with new token
|
||||
if err := newSession.Save(newReq, newRr); err != nil {
|
||||
t.Fatalf("Failed to save new session: %v", err)
|
||||
}
|
||||
|
||||
// Check cookies in response where old cookies are expired
|
||||
intermediateResponse := expiredRr.Result()
|
||||
intermediateCount := 0
|
||||
chunkCount := 0
|
||||
expiredCount := 0
|
||||
|
||||
for _, cookie := range intermediateResponse.Cookies() {
|
||||
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
|
||||
chunkCount++
|
||||
if cookie.MaxAge < 0 {
|
||||
expiredCount++
|
||||
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
} else if cookie.MaxAge >= 0 {
|
||||
intermediateCount++
|
||||
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
}
|
||||
|
||||
// All chunk cookies should be expired
|
||||
if chunkCount > 0 && chunkCount != expiredCount {
|
||||
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
|
||||
}
|
||||
|
||||
// Should have fewer active cookies after setting smaller token
|
||||
if intermediateCount >= len(initialCookies) {
|
||||
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
@@ -151,6 +265,12 @@ func TestSessionManager(t *testing.T) {
|
||||
// Set session values
|
||||
session.SetAuthenticated(tc.authenticated)
|
||||
session.SetEmail(tc.email)
|
||||
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Set new tokens
|
||||
session.SetAccessToken(tc.accessToken)
|
||||
session.SetRefreshToken(tc.refreshToken)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user