Files
traefikoidc/session_test.go
T
lukaszraczylo 6efb78b7a8 Smarter approach to the cookies (#103)
* Smarter approach to the cookies

  - Single maxCookieSize = 1400 constant with clear documentation
  - Combined cookie storage for ~40-45% size reduction
  - Backward compatible migration from legacy cookies

* Tuneup the code.
2025-12-12 18:35:06 +00:00

3279 lines
94 KiB
Go

package traefikoidc
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
)
// ============================================================================
// SESSION TEST FRAMEWORK
// ============================================================================
// SessionTestCase represents a comprehensive session test scenario
type SessionTestCase struct {
setup func(*SessionTestFramework)
execute func(*SessionTestFramework) error
validate func(*testing.T, error, *SessionTestFramework)
cleanup func(*SessionTestFramework)
name string
scenario string
sessionType string
skipReason string
iterations int
timeout time.Duration
concurrent bool
}
// SessionTestFramework provides shared test infrastructure for session tests
type SessionTestFramework struct {
t *testing.T
mockProvider *httptest.Server
testTokens map[string]string
metrics *SessionTestMetrics
config *SessionTestConfig
requests []*http.Request
responses []*httptest.ResponseRecorder
sessionIDs []string
cleanupFuncs []func()
mu sync.RWMutex
}
// SessionTestMetrics tracks test performance metrics
type SessionTestMetrics struct {
SessionsCreated int64
SessionsDestroyed int64
TokensGenerated int64
TokensValidated int64
ChunksCreated int64
ChunksRetrieved int64
ErrorCount int64
Duration time.Duration
}
// SessionTestConfig holds test configuration
type SessionTestConfig struct {
CookieDomain string
EncryptionKey string
MaxChunkSize int
MaxSessions int
SessionTimeout time.Duration
EnableHTTPS bool
EnableCompression bool
}
// NewSessionTestFramework creates a new test framework instance
func NewSessionTestFramework(t *testing.T) *SessionTestFramework {
framework := &SessionTestFramework{
t: t,
requests: make([]*http.Request, 0),
responses: make([]*httptest.ResponseRecorder, 0),
testTokens: make(map[string]string),
sessionIDs: make([]string, 0),
metrics: &SessionTestMetrics{},
cleanupFuncs: make([]func(), 0),
config: &SessionTestConfig{
MaxChunkSize: 3900,
MaxSessions: 1000,
EnableHTTPS: false,
CookieDomain: "",
SessionTimeout: time.Hour,
EncryptionKey: generateTestKey(),
EnableCompression: true,
},
}
// Setup mock OIDC provider
framework.setupMockProvider()
return framework
}
// generateTestKey generates a test encryption key
func generateTestKey() string {
// 48 bytes = 384 bits for testing
return "0123456789abcdef0123456789abcdef0123456789abcdef"
}
// setupMockProvider sets up a mock OIDC provider for testing
func (f *SessionTestFramework) setupMockProvider() {
f.mockProvider = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": f.mockProvider.URL,
"authorization_endpoint": f.mockProvider.URL + "/auth",
"token_endpoint": f.mockProvider.URL + "/token",
"userinfo_endpoint": f.mockProvider.URL + "/userinfo",
"jwks_uri": f.mockProvider.URL + "/jwks",
})
case "/token":
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": f.generateTestToken("access", 3600),
"id_token": f.generateTestToken("id", 3600),
"refresh_token": f.generateTestToken("refresh", 86400),
"token_type": "Bearer",
"expires_in": 3600,
})
case "/userinfo":
json.NewEncoder(w).Encode(map[string]interface{}{
"sub": "test-user-id",
"email": "test@example.com",
"name": "Test User",
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
f.cleanupFuncs = append(f.cleanupFuncs, f.mockProvider.Close)
}
// generateTestToken generates a test token
func (f *SessionTestFramework) generateTestToken(tokenType string, expiresIn int) string {
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
// Create a realistic JWT-like token for testing
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
claims := map[string]interface{}{
"iss": f.mockProvider.URL,
"sub": "test-user-id",
"aud": "test-client-id",
"exp": time.Now().Add(time.Duration(expiresIn) * time.Second).Unix(),
"iat": time.Now().Unix(),
"typ": tokenType,
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Generate a fake signature
signature := make([]byte, 64)
rand.Read(signature)
sig := base64.RawURLEncoding.EncodeToString(signature)
token := fmt.Sprintf("%s.%s.%s", header, payload, sig)
// Thread-safe write to map
f.mu.Lock()
f.testTokens[tokenType] = token
f.mu.Unlock()
return token
}
// generateLargeToken generates a token of specified size for testing chunking
func (f *SessionTestFramework) generateLargeToken(size int) string {
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
// Create base JWT structure
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
// Calculate how much padding we need in claims
baseSize := len(header) + 2 // for the dots
signatureSize := 86 // approximate base64 encoded signature size
paddingSize := size - baseSize - signatureSize - 100 // leave room for other claims
if paddingSize < 0 {
paddingSize = 0
}
// Create large padding data
padding := make([]byte, paddingSize)
for i := range padding {
padding[i] = byte('A' + (i % 26))
}
claims := map[string]interface{}{
"iss": f.mockProvider.URL,
"sub": "test-user-id",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"padding": base64.StdEncoding.EncodeToString(padding),
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Generate signature
signature := make([]byte, 64)
rand.Read(signature)
sig := base64.RawURLEncoding.EncodeToString(signature)
return fmt.Sprintf("%s.%s.%s", header, payload, sig)
}
// Cleanup performs framework cleanup
func (f *SessionTestFramework) Cleanup() {
for _, cleanup := range f.cleanupFuncs {
cleanup()
}
}
// ============================================================================
// SESSION CHUNK MANAGER TESTS
// ============================================================================
// Helper function to create a mock HTTP request for session creation
func createMockRequest() *http.Request {
req := httptest.NewRequest("GET", "http://example.com", nil)
return req
}
func TestNewSessionChunkManager(t *testing.T) {
manager := NewSessionChunkManager(10)
if manager == nil {
t.Fatal("Expected non-nil session chunk manager")
}
if manager.maxChunks != 10 {
t.Errorf("Expected maxChunks 10, got %d", manager.maxChunks)
}
}
func TestNewSessionChunkManagerDefaultLimit(t *testing.T) {
// Test with 0 maxChunks (should use default)
manager := NewSessionChunkManager(0)
if manager.maxChunks != 20 {
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
}
}
func TestNewSessionChunkManagerNegativeLimit(t *testing.T) {
// Test with negative maxChunks (should use default)
manager := NewSessionChunkManager(-5)
if manager.maxChunks != 20 {
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
}
}
func TestCleanupChunksWithoutWriter(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add some chunks
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["token_chunk"] = "chunk-data"
chunks[i] = session
}
// Cleanup without writer (should just clear map)
manager.CleanupChunks(chunks, nil)
if len(chunks) != 0 {
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
}
}
func TestCleanupChunksWithWriter(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add some chunks
for i := 0; i < 3; i++ {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["token_chunk"] = "chunk-data"
session.Options = &sessions.Options{MaxAge: 3600}
chunks[i] = session
}
// Create response writer
w := httptest.NewRecorder()
// Note: We can't fully test the Save behavior without a proper HTTP request
// but we can verify the cleanup clears the map
manager.CleanupChunks(chunks, w)
if len(chunks) != 0 {
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
}
}
func TestCleanupChunksNilSession(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
chunks[0] = nil
chunks[1] = nil
w := httptest.NewRecorder()
// Should handle nil sessions gracefully
manager.CleanupChunks(chunks, w)
if len(chunks) != 0 {
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
}
}
func TestCleanupChunksEmptyMap(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
// Should handle empty map gracefully
manager.CleanupChunks(chunks, nil)
if len(chunks) != 0 {
t.Error("Expected chunks map to remain empty")
}
}
func TestValidateAndCleanChunksWithinLimit(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks within limit
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if !result {
t.Error("Expected validation to pass for chunks within limit")
}
if len(chunks) != 5 {
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
}
}
func TestValidateAndCleanChunksExceedLimit(t *testing.T) {
manager := NewSessionChunkManager(5)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add more chunks than limit
for i := 0; i < 10; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if result {
t.Error("Expected validation to fail for chunks exceeding limit")
}
if len(chunks) != 0 {
t.Errorf("Expected chunks to be cleared, got %d", len(chunks))
}
}
func TestValidateAndCleanChunksAtLimit(t *testing.T) {
manager := NewSessionChunkManager(5)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks exactly at limit
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if !result {
t.Error("Expected validation to pass for chunks at limit")
}
if len(chunks) != 5 {
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
}
}
func TestSafeSetChunkValidIndex(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, 5, session)
if !result {
t.Error("Expected SafeSetChunk to succeed for valid index")
}
if chunks[5] != session {
t.Error("Expected session to be set at index 5")
}
}
func TestSafeSetChunkNegativeIndex(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, -1, session)
if result {
t.Error("Expected SafeSetChunk to fail for negative index")
}
if len(chunks) != 0 {
t.Error("Expected chunks map to remain empty")
}
}
func TestSafeSetChunkIndexTooHigh(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, 10, session)
if result {
t.Error("Expected SafeSetChunk to fail for index >= maxChunks")
}
if len(chunks) != 0 {
t.Error("Expected chunks map to remain empty")
}
}
func TestSafeSetChunkExceedingLimit(t *testing.T) {
manager := NewSessionChunkManager(5)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Fill up to limit
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
// Try to add a new chunk at new index (should fail)
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, 2, session)
// This should succeed because index 2 already exists
if !result {
t.Error("Expected SafeSetChunk to succeed for existing index")
}
}
func TestSafeSetChunkReplaceExisting(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session1, _ := store.New(createMockRequest(), "chunk1")
session2, _ := store.New(createMockRequest(), "chunk2")
// Set initial session
manager.SafeSetChunk(chunks, 3, session1)
// Replace with new session
result := manager.SafeSetChunk(chunks, 3, session2)
if !result {
t.Error("Expected SafeSetChunk to succeed for replacing existing chunk")
}
if chunks[3] != session2 {
t.Error("Expected session to be replaced at index 3")
}
}
func TestGetChunkCount(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add some chunks
for i := 0; i < 7; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
count := manager.GetChunkCount(chunks)
if count != 7 {
t.Errorf("Expected chunk count 7, got %d", count)
}
}
func TestGetChunkCountEmpty(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
count := manager.GetChunkCount(chunks)
if count != 0 {
t.Errorf("Expected chunk count 0, got %d", count)
}
}
func TestCompactChunksNoGaps(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add sequential chunks
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["index"] = i
chunks[i] = session
}
compacted := manager.CompactChunks(chunks)
if len(compacted) != 5 {
t.Errorf("Expected 5 compacted chunks, got %d", len(compacted))
}
// Verify order
for i := 0; i < 5; i++ {
if compacted[i] == nil {
t.Errorf("Expected chunk at index %d", i)
}
}
}
func TestCompactChunksWithGaps(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks with gaps
indices := []int{0, 2, 5, 7}
for _, idx := range indices {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["original_index"] = idx
chunks[idx] = session
}
compacted := manager.CompactChunks(chunks)
if len(compacted) != 4 {
t.Errorf("Expected 4 compacted chunks, got %d", len(compacted))
}
// Verify chunks are reindexed sequentially
for i := 0; i < 4; i++ {
if compacted[i] == nil {
t.Errorf("Expected chunk at compacted index %d", i)
}
}
}
func TestCompactChunksWithNilEntries(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks and nil entries
session1, _ := store.New(createMockRequest(), "chunk1")
session2, _ := store.New(createMockRequest(), "chunk2")
session3, _ := store.New(createMockRequest(), "chunk3")
chunks[0] = session1
chunks[1] = nil
chunks[2] = session2
chunks[3] = nil
chunks[4] = session3
compacted := manager.CompactChunks(chunks)
if len(compacted) != 3 {
t.Errorf("Expected 3 compacted chunks (nil entries removed), got %d", len(compacted))
}
// Verify non-nil chunks are compacted
for i := 0; i < 3; i++ {
if compacted[i] == nil {
t.Errorf("Expected non-nil chunk at compacted index %d", i)
}
}
}
func TestCompactChunksEmpty(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
compacted := manager.CompactChunks(chunks)
if len(compacted) != 0 {
t.Errorf("Expected empty compacted map, got %d entries", len(compacted))
}
}
func TestSessionChunkManagerConcurrentOperations(t *testing.T) {
manager := NewSessionChunkManager(50)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
var wg sync.WaitGroup
// Concurrent SafeSetChunk
for i := 0; i < 20; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
session, _ := store.New(createMockRequest(), "chunk")
manager.SafeSetChunk(chunks, index, session)
}(i)
}
// Concurrent GetChunkCount
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = manager.GetChunkCount(chunks)
}()
}
// Concurrent ValidateAndCleanChunks (reads)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = manager.ValidateAndCleanChunks(chunks)
}()
}
wg.Wait()
// Verify manager is still functional
count := manager.GetChunkCount(chunks)
if count < 0 || count > 50 {
t.Errorf("Unexpected chunk count after concurrent operations: %d", count)
}
}
func TestSessionChunkManagerLargeChunkCount(t *testing.T) {
manager := NewSessionChunkManager(1000)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add many chunks
for i := 0; i < 500; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if !result {
t.Error("Expected validation to pass for 500 chunks with limit 1000")
}
count := manager.GetChunkCount(chunks)
if count != 500 {
t.Errorf("Expected 500 chunks, got %d", count)
}
}
func TestSessionChunkManagerBoundaryConditions(t *testing.T) {
tests := []struct {
name string
maxChunks int
addChunks int
shouldPass bool
}{
{"exactly at limit", 10, 10, true},
{"one over limit", 10, 11, false},
{"way over limit", 10, 50, false},
{"zero chunks with limit", 10, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := NewSessionChunkManager(tt.maxChunks)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
for i := 0; i < tt.addChunks; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if result != tt.shouldPass {
t.Errorf("Expected validation result %v, got %v", tt.shouldPass, result)
}
})
}
}
// ============================================================================
// SESSION HELPER TESTS
// ============================================================================
// TestSetCodeVerifier_NoChange tests the branch where the code verifier value doesn't change
func TestSetCodeVerifier_NoChange(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
defer sm.Shutdown()
req := httptest.NewRequest("GET", "http://example.com/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Set initial code verifier
initialVerifier := "test-code-verifier-12345"
session.SetCodeVerifier(initialVerifier)
if !session.IsDirty() {
t.Error("Session should be dirty after first SetCodeVerifier")
}
// Mark clean to test the no-change branch
session.dirty = false
// Set the same code verifier again - this should hit the uncovered branch
session.SetCodeVerifier(initialVerifier)
// Verify that dirty flag remains false (no change occurred)
if session.IsDirty() {
t.Error("Session should not be dirty when setting same code verifier value")
}
// Verify the code verifier value is still correct
if got := session.GetCodeVerifier(); got != initialVerifier {
t.Errorf("Expected code verifier %q, got %q", initialVerifier, got)
}
}
// TestClearTokenChunks_EmptyChunks tests the branch where the chunks map is empty
func TestClearTokenChunks_EmptyChunks(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
defer sm.Shutdown()
req := httptest.NewRequest("GET", "http://example.com/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Test with empty chunks map - this should hit the uncovered branch where the loop body doesn't execute
emptyChunks := make(map[int]*sessions.Session)
// This should not panic and should handle empty map gracefully
session.clearTokenChunks(req, emptyChunks)
// Verify that no errors occurred and the session is still valid
if session == nil {
t.Fatal("Session should still be valid after clearing empty chunks")
}
// Additional test: clear already-empty chunk maps in the session
session.clearTokenChunks(req, session.accessTokenChunks)
session.clearTokenChunks(req, session.refreshTokenChunks)
session.clearTokenChunks(req, session.idTokenChunks)
// Verify session is still valid
if session.GetAuthenticated() {
// This is fine - session can be authenticated even with no chunks
}
}
// TestClearTokenChunks_WithSessions tests the branch where the chunks map contains actual sessions
func TestClearTokenChunks_WithSessions(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
defer sm.Shutdown()
req := httptest.NewRequest("GET", "http://example.com/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Create chunks map with actual sessions
chunksWithSessions := make(map[int]*sessions.Session)
// Create a few test sessions and add them to the chunks map
for i := 0; i < 3; i++ {
chunkSession, err := sm.store.Get(req, fmt.Sprintf("test_chunk_%d", i))
if err != nil {
t.Fatalf("Failed to create test chunk session: %v", err)
}
// Add some test data to the session
chunkSession.Values["test_data"] = fmt.Sprintf("chunk_%d_data", i)
chunkSession.Values["chunk_index"] = i
chunksWithSessions[i] = chunkSession
}
// Verify chunks have data before clearing
if len(chunksWithSessions) != 3 {
t.Errorf("Expected 3 chunks, got %d", len(chunksWithSessions))
}
for i, chunkSession := range chunksWithSessions {
if chunkSession.Values["test_data"] == nil {
t.Errorf("Chunk %d should have test data before clearing", i)
}
}
// Call clearTokenChunks - this should hit the loop body and clear all sessions
session.clearTokenChunks(req, chunksWithSessions)
// Verify that the sessions were cleared
for i, chunkSession := range chunksWithSessions {
if len(chunkSession.Values) != 0 {
t.Errorf("Chunk %d should have no values after clearing, but has %d values", i, len(chunkSession.Values))
}
// Verify MaxAge was set to -1 (expired)
if chunkSession.Options.MaxAge != -1 {
t.Errorf("Chunk %d should have MaxAge=-1 (expired), but has MaxAge=%d", i, chunkSession.Options.MaxAge)
}
}
}
// ============================================================================
// SESSION POOL AND MEMORY TESTS
// ============================================================================
// TestSessionPoolMemoryLeak tests that session objects are properly returned to the pool
func TestSessionPoolMemoryLeak(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeLeakDetection) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
runner.SetTimeout(30 * time.Second)
tests := []TableTestCase{
{
Name: "Successful session creation and return",
Description: "Test that sessions are properly created and returned to pool",
Setup: func(t *testing.T) error {
return nil
},
Teardown: func(t *testing.T) error {
runtime.GC()
time.Sleep(100 * time.Millisecond)
return nil
},
},
{
Name: "Explicit ReturnToPool method",
Description: "Test that explicit pool return works correctly",
Setup: func(t *testing.T) error {
return nil
},
Teardown: func(t *testing.T) error {
runtime.GC()
time.Sleep(100 * time.Millisecond)
return nil
},
},
{
Name: "Error path in GetSession",
Description: "Test pool behavior when GetSession fails",
Setup: func(t *testing.T) error {
return nil
},
Teardown: func(t *testing.T) error {
runtime.GC()
time.Sleep(100 * time.Millisecond)
return nil
},
},
}
// Custom test execution since we need to test memory behavior
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
if test.Setup != nil {
if err := test.Setup(t); err != nil {
t.Fatalf("Setup failed: %v", err)
}
}
if test.Teardown != nil {
defer func() {
if err := test.Teardown(t); err != nil {
t.Errorf("Teardown failed: %v", err)
}
}()
}
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
switch test.Name {
case "Successful session creation and return":
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.Clear(req, nil)
case "Explicit ReturnToPool method":
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.ReturnToPool()
case "Error path in GetSession":
badSM, _ := NewSessionManager("different0123456789abcdef0123456789abcdef0123456789", false, "", "", 0, logger)
_, err = badSM.GetSession(req)
if err == nil {
t.Log("Note: Expected error when using mismatched encryption keys")
}
}
pooledCount := getPooledObjects(sm)
t.Logf("Pooled objects count: %d", pooledCount)
})
}
_ = testTokens
_ = edgeGen
}
// TestSessionErrorHandling tests comprehensive error scenarios using table-driven tests
func TestSessionErrorHandling(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeQuick) {
return
}
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
// Generate edge case strings for cookie values
edgeCases := edgeGen.GenerateStringEdgeCases()
tests := []TableTestCase{
{
Name: "Corrupt cookie value",
Description: "Test handling of corrupted cookie values",
Input: "corrupt-value",
Expected: "failed to get main session:",
},
{
Name: "Invalid base64 cookie",
Description: "Test handling of invalid base64 in cookies",
Input: "!@#$%^&*()",
Expected: "failed to get main session:",
},
{
Name: "Empty cookie value",
Description: "Test handling of empty cookie values",
Input: "",
Expected: "", // Empty should work without error
},
}
// Add edge cases dynamically
for i, edgeCase := range edgeCases {
if len(edgeCase) > 0 && !strings.ContainsAny(edgeCase, "\x00\x01\x02") { // Skip binary data for cookie tests
tests = append(tests, TableTestCase{
Name: fmt.Sprintf("Edge case %d", i),
Description: fmt.Sprintf("Test edge case string: %q", edgeCase[:minInt(20, len(edgeCase))]),
Input: edgeCase,
Expected: "", // Most edge cases should be handled gracefully
})
}
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
if input, ok := test.Input.(string); ok && input != "" {
req.AddCookie(&http.Cookie{
Name: defaultCookiePrefix + mainCookieSuffix,
Value: input,
})
}
_, err = sm.GetSession(req)
if expected, ok := test.Expected.(string); ok && expected != "" {
if err == nil {
t.Error("Expected error, got nil")
} else if !strings.Contains(err.Error(), expected) {
t.Errorf("Unexpected error message: %v", err)
}
} else {
// For empty expected, we allow either success or specific failures
if err != nil {
t.Logf("Got expected error for edge case: %v", err)
}
}
})
}
_ = runner
}
// TestSessionClearAlwaysReturnsToPool tests that sessions are always returned to pool even on errors
func TestSessionClearAlwaysReturnsToPool(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeQuick) {
return
}
runner := NewTestSuiteRunner()
memoryTests := []MemoryLeakTestCase{
{
Name: "Session clear with error returns to pool",
Description: "Verify sessions return to pool even when Clear() errors",
Iterations: 10,
MaxGoroutineGrowth: 2,
MaxMemoryGrowthMB: 5.0,
GCBetweenRuns: true,
Timeout: 30 * time.Second,
Operation: func() error {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
// Ensure proper cleanup by calling Shutdown
defer func() {
if shutdownErr := sm.Shutdown(); shutdownErr != nil {
logger.Errorf("Failed to shutdown SessionManager: %v", shutdownErr)
}
}()
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Set("X-Test-Error", "true")
session, err := sm.GetSession(req)
if err != nil {
return fmt.Errorf("GetSession failed: %w", err)
}
w := httptest.NewRecorder()
clearErr := session.Clear(req, w)
// We expect an error due to the X-Test-Error header, but the session should still be returned
if clearErr == nil {
return fmt.Errorf("expected error from Clear with X-Test-Error header")
}
return nil
},
},
}
runner.RunMemoryLeakTests(t, memoryTests)
// Additional verification test
t.Run("Verify pool still works after errors", func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Ensure proper cleanup
defer func() {
if shutdownErr := sm.Shutdown(); shutdownErr != nil {
t.Errorf("Failed to shutdown SessionManager: %v", shutdownErr)
}
}()
normalReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
session2, err := sm.GetSession(normalReq)
if err != nil {
t.Fatalf("Second GetSession failed: %v", err)
}
session2.Clear(normalReq, nil)
t.Log("Session returned to pool despite errors")
})
}
// TestSessionObjectTracking tests session object tracking and pool behavior
func TestSessionObjectTracking(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeQuick) {
return
}
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Session pool has New function",
Description: "Verify that session pool is properly configured",
Setup: func(t *testing.T) error {
return nil
},
},
{
Name: "Multiple session creation and disposal",
Description: "Test creating and disposing multiple sessions",
Input: 5,
},
{
Name: "Session with nil mainSession",
Description: "Test error handling with corrupted session state",
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
if test.Setup != nil {
if err := test.Setup(t); err != nil {
t.Fatalf("Setup failed: %v", err)
}
}
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
switch test.Name {
case "Session pool has New function":
hasNew := sm.sessionPool.New != nil
if !hasNew {
t.Error("Expected sessionPool.New function to be set")
}
case "Multiple session creation and disposal":
count := test.Input.(int)
for i := 0; i < count; i++ {
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.ReturnToPool()
}
case "Session with nil mainSession":
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.mainSession = nil // Deliberately cause bad state
session.ReturnToPool()
}
runtime.GC()
time.Sleep(100 * time.Millisecond)
t.Log("Session pool handling verified")
})
}
_ = runner
}
// ============================================================================
// TOKEN COMPRESSION AND CHUNKING TESTS
// ============================================================================
// TestTokenCompressionIntegrity tests token compression using comprehensive test cases
func TestTokenCompressionIntegrity(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
// Create comprehensive test cases using edge case generator and test tokens
testCases := []TableTestCase{
{
Name: "Valid JWT Small",
Input: testTokens.GetValidTokenSet().AccessToken,
Expected: true, // Should compress and decompress correctly
},
{
Name: "Valid JWT Large",
Input: testTokens.CreateLargeValidJWT(5000),
Expected: true,
},
{
Name: "Minimal Valid JWT",
Input: MinimalValidJWT,
Expected: true,
},
{
Name: "Invalid JWT Wrong dot count",
Input: InvalidTokenOneDot,
Expected: false, // Should return original for invalid tokens
},
{
Name: "Invalid JWT No dots",
Input: InvalidTokenNoDots,
Expected: false,
},
{
Name: "Invalid JWT Too many dots",
Input: InvalidTokenThreeDots,
Expected: false,
},
{
Name: "Empty token",
Input: "",
Expected: true, // Empty tokens are handled gracefully
},
{
Name: "Oversized token",
Input: testTokens.CreateIncompressibleToken(55000), // >50KB
Expected: false, // Should be rejected
},
}
// Add string edge cases as additional test inputs
stringEdgeCases := edgeGen.GenerateStringEdgeCases()
for i, edgeCase := range stringEdgeCases {
if len(edgeCase) > 0 && len(edgeCase) < 1000 { // Reasonable size for testing
testCases = append(testCases, TableTestCase{
Name: fmt.Sprintf("Edge case string %d", i),
Input: edgeCase,
Expected: true, // Most edge cases should be handled gracefully
})
}
}
for _, test := range testCases {
t.Run(test.Name, func(t *testing.T) {
token := test.Input.(string)
expectValid := test.Expected.(bool)
compressed := compressToken(token)
if !expectValid {
// For invalid tokens, compression should return original
if compressed != token {
t.Errorf("Expected compression to return original for invalid token, got different result")
}
return
}
// For valid tokens, test round-trip integrity
decompressed := decompressToken(compressed)
if decompressed != token {
t.Errorf("Token integrity lost: original=%q, compressed=%q, decompressed=%q",
token, compressed, decompressed)
}
// Test that decompression is idempotent
decompressed2 := decompressToken(decompressed)
if decompressed2 != token {
t.Errorf("Decompression not idempotent: %q != %q", decompressed2, token)
}
})
}
_ = runner
}
// TestTokenCompressionCorruptionDetection tests corruption detection using table-driven approach
func TestTokenCompressionCorruptionDetection(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Invalid base64",
Input: "!@#$%^&*()",
Expected: true, // Should return original
},
{
Name: "Valid base64 but invalid gzip",
Input: base64.StdEncoding.EncodeToString([]byte("not gzip data")),
Expected: true,
},
{
Name: "Truncated gzip data",
Input: "H4sI", // Incomplete gzip header
Expected: true,
},
{
Name: "Empty string",
Input: "",
Expected: true,
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
corruptedInput := test.Input.(string)
expectOriginal := test.Expected.(bool)
result := decompressToken(corruptedInput)
if expectOriginal && result != corruptedInput {
t.Errorf("Expected decompression to return original corrupted input, got: %q", result)
}
})
}
// Test that valid compression still works
t.Run("Valid compression verification", func(t *testing.T) {
validJWT := testTokens.GetValidTokenSet().AccessToken
compressed := compressToken(validJWT)
decompressed := decompressToken(compressed)
if decompressed != validJWT {
t.Errorf("Valid compression/decompression failed: %q != %q", decompressed, validJWT)
}
})
_ = runner
}
// TestTokenChunkingIntegrity tests token chunking using comprehensive test patterns
func TestTokenChunkingIntegrity(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Small token no chunking",
Description: "Small tokens should not be chunked",
Input: struct {
size int
expectChunked bool
}{100, false},
},
{
Name: "Medium token no chunking",
Description: "Medium tokens should not be chunked",
Input: struct {
size int
expectChunked bool
}{800, false},
},
{
Name: "Large token chunking required",
Description: "Large tokens should be chunked",
Input: struct {
size int
expectChunked bool
}{5000, true},
},
{
Name: "Very large token multiple chunks",
Description: "Very large tokens should create multiple chunks",
Input: struct {
size int
expectChunked bool
}{10000, true},
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
params := test.Input.(struct {
size int
expectChunked bool
})
// Create token based on expectation
var token string
if params.expectChunked {
token = testTokens.CreateIncompressibleToken(params.size)
} else {
token = testTokens.CreateLargeValidJWT(params.size)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Store the token
session.SetAccessToken(token)
// Retrieve the token
retrievedToken := session.GetAccessToken()
// Verify integrity
if retrievedToken != token {
t.Errorf("Token integrity lost:\nOriginal: %q\nRetrieved: %q", token, retrievedToken)
}
// Check if chunking occurred as expected
hasChunks := len(session.accessTokenChunks) > 0
if params.expectChunked != hasChunks {
t.Errorf("Chunking expectation mismatch: expected chunked=%v, has chunks=%v",
params.expectChunked, hasChunks)
}
session.ReturnToPool()
})
}
_ = edgeGen
_ = runner
}
// TestTokenChunkingCorruptionResistance tests chunking corruption resistance using table patterns
func TestTokenChunkingCorruptionResistance(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
runner := NewTestSuiteRunner()
// Define corruption scenarios as test cases
corruptionTests := []TableTestCase{
{
Name: "Missing chunk in sequence",
Description: "Test handling when a chunk is missing from sequence",
Input: func(chunks map[int]*sessions.Session) {
if len(chunks) > 1 {
delete(chunks, 1)
}
},
Expected: true, // Expect empty result
},
{
Name: "Empty chunk data",
Description: "Test handling when chunk contains empty data",
Input: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = ""
}
},
Expected: true,
},
{
Name: "Wrong data type in chunk",
Description: "Test handling when chunk contains wrong data type",
Input: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = 123 // Should be string
}
},
Expected: true,
},
{
Name: "Oversized chunk",
Description: "Test handling when chunk exceeds size limits",
Input: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = strings.Repeat("A", maxCookieSize+200)
}
},
Expected: true,
},
}
for _, test := range corruptionTests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create a large token that will be chunked
largeToken := testTokens.CreateIncompressibleToken(8000)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Store the token (this should create chunks)
session.SetAccessToken(largeToken)
if len(session.accessTokenChunks) == 0 {
t.Skip("Token was not chunked, skipping corruption test")
}
// Apply corruption using the test input function
corruptFunc := test.Input.(func(map[int]*sessions.Session))
corruptFunc(session.accessTokenChunks)
// Try to retrieve the token
retrievedToken := session.GetAccessToken()
expectEmpty := test.Expected.(bool)
if expectEmpty {
if retrievedToken != "" {
t.Errorf("Expected empty token due to corruption, got: %q", retrievedToken)
}
} else {
if retrievedToken != largeToken {
t.Errorf("Expected original token despite corruption, got: %q", retrievedToken)
}
}
session.ReturnToPool()
})
}
_ = corruptionTests
_ = runner
}
// TestTokenSizeLimits tests token size limit enforcement using table-driven tests
func TestTokenSizeLimits(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Normal size token",
Input: 1000,
Expected: true,
},
{
Name: "Large but acceptable token",
Input: 20000, // 20KB
Expected: true,
},
{
Name: "Oversized token rejection",
Input: 120000, // 120KB
Expected: false, // Should be rejected
},
}
// Add integer edge cases for token sizes
intEdgeCases := edgeGen.GenerateIntegerEdgeCases()
for _, size := range intEdgeCases {
if size > 0 && size < 100000 {
tests = append(tests, TableTestCase{
Name: fmt.Sprintf("Edge case size %d", size),
Input: size,
Expected: size < 100000, // Reasonable threshold
})
}
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
tokenSize := test.Input.(int)
expectStored := test.Expected.(bool)
var token string
if expectStored {
token = testTokens.CreateLargeValidJWT(tokenSize)
} else {
token = testTokens.CreateIncompressibleToken(tokenSize)
}
// Store the token
session.SetAccessToken(token)
// Try to retrieve it
retrievedToken := session.GetAccessToken()
if expectStored {
if retrievedToken != token {
t.Errorf("Expected token to be stored and retrieved, but got different token")
}
} else {
if retrievedToken == token {
t.Errorf("Expected oversized token to be rejected, but it was stored")
}
}
})
}
_ = runner
}
// TestConcurrentTokenOperations tests thread safety using structured test patterns
func TestConcurrentTokenOperations(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeConcurrencyStress) {
return
}
testTokens := NewTestTokens()
runner := NewTestSuiteRunner()
// Test concurrent operations using memory leak test pattern
memoryTests := []MemoryLeakTestCase{
{
Name: "Concurrent token operations",
Description: "Test thread safety of concurrent token operations",
Iterations: 50,
MaxGoroutineGrowth: 5, // Allow some growth for goroutines
MaxMemoryGrowthMB: 10.0,
GCBetweenRuns: true,
Timeout: 60 * time.Second,
Operation: func() error {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
defer session.ReturnToPool()
const numGoroutines = 10
const numOperations = 100
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
for j := 0; j < numOperations; j++ {
// Create unique tokens for each goroutine/operation
accessToken := testTokens.CreateUniqueValidJWT(fmt.Sprintf("%d_%d", id, j))
refreshToken := fmt.Sprintf("refresh_token_%d_%d", id, j)
// Concurrent operations
session.SetAccessToken(accessToken)
session.SetRefreshToken(refreshToken)
retrievedAccess := session.GetAccessToken()
retrievedRefresh := session.GetRefreshToken()
// Verify tokens are still valid (should be one of the tokens set by any goroutine)
if retrievedAccess != "" && strings.Count(retrievedAccess, ".") != 2 {
// Note: In concurrent access, we can't guarantee exact token match
// but we can verify format is still valid
}
if retrievedRefresh != "" && len(retrievedRefresh) < 10 {
// Verify minimum reasonable length
}
}
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
<-done
}
return nil
},
},
}
runner.RunMemoryLeakTests(t, memoryTests)
_ = testTokens
}
// TestSessionValidationAndCleanup tests session validation using comprehensive patterns
func TestSessionValidationAndCleanup(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Session creation and token storage",
Description: "Test basic session validation and cleanup",
},
{
Name: "Large token chunking validation",
Description: "Test validation with tokens that require chunking",
},
{
Name: "Session cleanup verification",
Description: "Test that sessions are properly cleaned up",
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
switch test.Name {
case "Session creation and token storage":
// Test with normal tokens
tokenSet := testTokens.GetValidTokenSet()
session.SetAccessToken(tokenSet.AccessToken)
session.SetRefreshToken(tokenSet.RefreshToken)
case "Large token chunking validation":
// Set tokens that will create chunks
largeTokenSet := testTokens.GetLargeTokenSet()
session.SetAccessToken(largeTokenSet.AccessToken)
session.SetRefreshToken(largeTokenSet.RefreshToken)
case "Session cleanup verification":
// Set tokens and then clear them
session.SetAccessToken(testTokens.GetValidTokenSet().AccessToken)
session.SetRefreshToken("refresh_token_test")
}
// Save session to create cookies
if err := session.Save(req, rw); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// For cleanup test, verify clearing works
if test.Name == "Session cleanup verification" {
if err := session.Clear(req, rw); err != nil {
t.Logf("Clear returned error (may be expected): %v", err)
}
// Verify tokens are cleared
if token := session.GetAccessToken(); token != "" {
t.Errorf("Access token should be empty after clear, got: %q", token)
}
if token := session.GetRefreshToken(); token != "" {
t.Errorf("Refresh token should be empty after clear, got: %q", token)
}
}
})
}
_ = edgeGen
_ = runner
}
// TestLargeIDTokenChunking tests ID token chunking using structured approach
func TestLargeIDTokenChunking(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Large ID token chunking 20KB",
Description: "Test that large ID tokens are properly chunked",
Input: 20000,
Expected: 2, // Expect at least 2 chunks
},
{
Name: "Very large ID token chunking 50KB",
Description: "Test very large ID token chunking",
Input: 50000,
Expected: 5, // Expect at least 5 chunks
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
tokenSize := test.Input.(int)
minExpectedChunks := test.Expected.(int)
// Create a large ID token
largeIDToken := createLargeIDToken(tokenSize)
t.Logf("Created large ID token with length: %d", len(largeIDToken))
// Create a request and response recorder
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
// Get session and set large ID token
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set the large ID token
session.SetIDToken(largeIDToken)
t.Logf("Set large ID token in session")
// Save the session to trigger chunking
err = session.Save(req, rr)
if err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify token retrieval integrity
retrievedToken := session.GetIDToken()
t.Logf("Retrieved ID token length: %d", len(retrievedToken))
if len(retrievedToken) != len(largeIDToken) {
t.Errorf("Token length mismatch: expected %d, got %d", len(largeIDToken), len(retrievedToken))
}
// Verify that chunked cookies were created
cookies := rr.Result().Cookies()
t.Logf("Total cookies in response: %d", len(cookies))
var chunkCookies []*http.Cookie
idTokenCookieName := defaultCookiePrefix + idTokenSuffix
for _, cookie := range cookies {
if strings.HasPrefix(cookie.Name, idTokenCookieName+"_") {
chunkCookies = append(chunkCookies, cookie)
}
}
// Verify minimum expected chunks
if len(chunkCookies) < minExpectedChunks {
t.Fatalf("Expected at least %d chunk cookies, got %d", minExpectedChunks, len(chunkCookies))
}
// Test token retrieval from chunked cookies
newReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
retrievedSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get session from chunked cookies: %v", err)
}
retrievedToken2 := retrievedSession.GetIDToken()
// Verify the retrieved token matches the original
if retrievedToken2 != largeIDToken {
t.Errorf("Retrieved ID token doesn't match original. Expected length: %d, got: %d",
len(largeIDToken), len(retrievedToken2))
}
// Test clearing the ID token removes all chunks
retrievedSession.SetIDToken("")
clearRR := httptest.NewRecorder()
err = retrievedSession.Save(newReq, clearRR)
if err != nil {
t.Fatalf("Failed to save session after clearing ID token: %v", err)
}
// Verify chunks are expired (MaxAge = -1)
clearCookies := clearRR.Result().Cookies()
idTokenCookieName2 := defaultCookiePrefix + idTokenSuffix
for _, cookie := range clearCookies {
if strings.HasPrefix(cookie.Name, idTokenCookieName2+"_") {
if cookie.MaxAge != -1 {
t.Errorf("Expected chunk cookie %s to be expired (MaxAge=-1), got MaxAge=%d",
cookie.Name, cookie.MaxAge)
}
}
}
})
}
_ = runner
}
// ============================================================================
// CONSOLIDATED SESSION TESTS
// ============================================================================
// TestSessionConsolidated runs all consolidated session tests
func TestSessionConsolidated(t *testing.T) {
testCases := []SessionTestCase{
// Session Creation Tests
{
name: "session_basic_creation",
scenario: "creation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Simulate session creation
req := httptest.NewRequest("GET", "http://example.com/", nil)
f.requests = append(f.requests, req)
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Session creation should succeed")
assert.Greater(t, f.metrics.SessionsCreated, int64(0), "Session should be created")
},
},
{
name: "session_pool_reuse",
scenario: "creation",
sessionType: "user",
iterations: 100,
execute: func(f *SessionTestFramework) error {
for i := 0; i < 100; i++ {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "Sessions should be properly pooled")
},
},
{
name: "session_concurrent_creation",
scenario: "creation",
sessionType: "user",
concurrent: true,
iterations: 50,
execute: func(f *SessionTestFramework) error {
var wg sync.WaitGroup
errs := make(chan error, 50)
for i := 0; i < 50; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Simulate concurrent session creation
req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%d", id), nil)
f.mu.Lock()
f.requests = append(f.requests, req)
f.mu.Unlock()
}(i)
}
wg.Wait()
close(errs)
for err := range errs {
if err != nil {
return err
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, int64(50), f.metrics.SessionsCreated, "All concurrent sessions should be created")
},
},
// Session Validation Tests
{
name: "session_token_validation",
scenario: "validation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
token := f.generateTestToken("access", 3600)
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Validate token format
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Token validation should succeed")
assert.Greater(t, f.metrics.TokensValidated, int64(0))
},
},
{
name: "session_corrupted_token_detection",
scenario: "validation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
token := f.generateTestToken("access", 3600)
// Corrupt the token by modifying the signature
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
// Corrupt the signature part
corrupted := parts[0] + "." + parts[1] + ".corrupted!"
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Validate should detect corruption - corrupted tokens should fail validation
corruptedParts := strings.Split(corrupted, ".")
if len(corruptedParts) == 3 {
// Try to decode the corrupted signature
_, err := base64.RawURLEncoding.DecodeString(corruptedParts[2])
if err == nil {
return fmt.Errorf("corruption not detected")
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Corruption detection should work")
},
},
{
name: "session_expired_token_handling",
scenario: "validation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Generate an expired token
token := f.generateTestToken("access", -3600) // negative expiry
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Parse and check expiry
parts := strings.Split(token, ".")
if len(parts) == 3 {
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
var claims map[string]interface{}
json.Unmarshal(payload, &claims)
if exp, ok := claims["exp"].(float64); ok {
if exp < float64(time.Now().Unix()) {
atomic.AddInt64(&f.metrics.ErrorCount, 1)
return nil // Expected behavior
}
}
}
return fmt.Errorf("expired token not detected")
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Expired token should be detected")
assert.Greater(t, f.metrics.ErrorCount, int64(0))
},
},
// Session Expiration Tests
{
name: "session_ttl_expiration",
scenario: "expiration",
sessionType: "user",
timeout: 3 * time.Second,
execute: func(f *SessionTestFramework) error {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Simulate session with short TTL
time.Sleep(100 * time.Millisecond) // Don't sleep for full timeout
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
},
},
{
name: "session_refresh_token_expiry",
scenario: "expiration",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
refreshToken := f.generateTestToken("refresh", 86400)
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Check refresh token is valid for longer period
parts := strings.Split(refreshToken, ".")
if len(parts) == 3 {
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
var claims map[string]interface{}
json.Unmarshal(payload, &claims)
if exp, ok := claims["exp"].(float64); ok {
timeUntilExpiry := time.Until(time.Unix(int64(exp), 0))
if timeUntilExpiry < 23*time.Hour {
return fmt.Errorf("refresh token expiry too short: %v", timeUntilExpiry)
}
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Refresh token should have correct expiry")
},
},
// Session Persistence Tests
{
name: "session_cookie_persistence",
scenario: "persistence",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
req := httptest.NewRequest("GET", "http://example.com/", nil)
w := httptest.NewRecorder()
// Set session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "test-session-123",
Path: "/",
HttpOnly: true,
Secure: f.config.EnableHTTPS,
SameSite: http.SameSiteLaxMode,
})
f.requests = append(f.requests, req)
f.responses = append(f.responses, w)
// Verify cookie was set
cookies := w.Result().Cookies()
if len(cookies) == 0 {
return fmt.Errorf("no cookies set")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.NotEmpty(t, f.responses, "Response should be recorded")
},
},
{
name: "session_state_preservation",
scenario: "persistence",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Store state
state := map[string]interface{}{
"user_id": "test-user",
"email": "test@example.com",
"roles": []string{"user", "admin"},
}
// Serialize and deserialize to test persistence
data, err := json.Marshal(state)
if err != nil {
return err
}
var restored map[string]interface{}
if err := json.Unmarshal(data, &restored); err != nil {
return err
}
// Verify state preserved
if restored["user_id"] != state["user_id"] {
return fmt.Errorf("state not preserved")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Session state should be preserved")
},
},
// Session Cleanup Tests
{
name: "session_proper_cleanup",
scenario: "cleanup",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Create and destroy sessions
for i := 0; i < 10; i++ {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
sessionID := fmt.Sprintf("session-%d", i)
f.sessionIDs = append(f.sessionIDs, sessionID)
}
// Cleanup all sessions
for range f.sessionIDs {
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}
f.sessionIDs = nil
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
assert.Empty(t, f.sessionIDs, "All sessions should be cleaned up")
},
},
{
name: "session_goroutine_leak_prevention",
scenario: "cleanup",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
initialGoroutines := runtime.NumGoroutine()
// Create sessions that might spawn goroutines
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
time.Sleep(10 * time.Millisecond)
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}(i)
}
wg.Wait()
runtime.GC()
time.Sleep(100 * time.Millisecond)
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > initialGoroutines+2 { // Allow small variance
return fmt.Errorf("goroutine leak detected: %d -> %d", initialGoroutines, finalGoroutines)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "No goroutine leaks should occur")
},
},
// Session Chunking Tests
{
name: "session_large_token_chunking",
scenario: "chunking",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Generate a large token that requires chunking
largeToken := f.generateLargeToken(10000) // 10KB token
// Calculate expected chunks
chunkSize := f.config.MaxChunkSize
expectedChunks := (len(largeToken) + chunkSize - 1) / chunkSize
// Simulate chunking
chunks := make([]string, 0)
for i := 0; i < len(largeToken); i += chunkSize {
end := i + chunkSize
if end > len(largeToken) {
end = len(largeToken)
}
chunks = append(chunks, largeToken[i:end])
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
}
if len(chunks) != expectedChunks {
return fmt.Errorf("expected %d chunks, got %d", expectedChunks, len(chunks))
}
// Simulate reconstruction
reconstructed := strings.Join(chunks, "")
if reconstructed != largeToken {
return fmt.Errorf("token reconstruction failed")
}
atomic.AddInt64(&f.metrics.ChunksRetrieved, int64(len(chunks)))
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Token chunking should work correctly")
assert.Greater(t, f.metrics.ChunksCreated, int64(0))
assert.Equal(t, f.metrics.ChunksCreated, f.metrics.ChunksRetrieved)
},
},
{
name: "session_chunk_boundary_validation",
scenario: "chunking",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Test exact boundary conditions
testSizes := []int{
f.config.MaxChunkSize - 1,
f.config.MaxChunkSize,
f.config.MaxChunkSize + 1,
f.config.MaxChunkSize * 2,
f.config.MaxChunkSize*2 - 1,
f.config.MaxChunkSize*2 + 1,
}
for _, size := range testSizes {
token := f.generateLargeToken(size)
actualSize := len(token)
expectedChunks := (actualSize + f.config.MaxChunkSize - 1) / f.config.MaxChunkSize
actualChunks := 0
for i := 0; i < len(token); i += f.config.MaxChunkSize {
actualChunks++
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
}
if actualChunks != expectedChunks {
return fmt.Errorf("size %d (actual token size %d): expected %d chunks, got %d", size, actualSize, expectedChunks, actualChunks)
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Chunk boundaries should be handled correctly")
},
},
// Session Security Tests
{
name: "session_csrf_token_management",
scenario: "security",
sessionType: "csrf",
execute: func(f *SessionTestFramework) error {
// Generate CSRF token
csrfToken := make([]byte, 32)
if _, err := rand.Read(csrfToken); err != nil {
return err
}
csrfString := base64.RawURLEncoding.EncodeToString(csrfToken)
// Store in session
f.testTokens["csrf"] = csrfString
// Validate CSRF token
if len(csrfString) < 40 {
return fmt.Errorf("CSRF token too short")
}
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
atomic.AddInt64(&f.metrics.TokensValidated, 1)
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "CSRF token should be properly managed")
assert.NotEmpty(t, f.testTokens["csrf"])
},
},
{
name: "session_injection_prevention",
scenario: "security",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Test various injection attempts
maliciousInputs := []string{
`{"admin": true}`,
`<script>alert('xss')</script>`,
`'; DROP TABLE sessions; --`,
`../../../etc/passwd`,
string([]byte{0x00, 0x01, 0x02}), // null bytes
}
for _, input := range maliciousInputs {
// Validate that input is properly sanitized
sanitized := base64.StdEncoding.EncodeToString([]byte(input))
decoded, err := base64.StdEncoding.DecodeString(sanitized)
if err != nil {
return err
}
if string(decoded) != input {
return fmt.Errorf("sanitization changed input unexpectedly")
}
atomic.AddInt64(&f.metrics.TokensValidated, 1)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Injection attempts should be handled safely")
},
},
{
name: "session_secure_cookie_settings",
scenario: "security",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
w := httptest.NewRecorder()
// Test secure cookie settings
cookie := &http.Cookie{
Name: "session",
Value: "test-session",
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
}
http.SetCookie(w, cookie)
// Verify cookie attributes
cookies := w.Result().Cookies()
if len(cookies) == 0 {
return fmt.Errorf("no cookie set")
}
c := cookies[0]
if !c.HttpOnly {
return fmt.Errorf("cookie not HttpOnly")
}
if c.SameSite != http.SameSiteStrictMode {
return fmt.Errorf("incorrect SameSite setting")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Secure cookie settings should be enforced")
},
},
// Session Stress Tests
{
name: "session_high_concurrency_stress",
scenario: "creation",
sessionType: "user",
concurrent: true,
iterations: 1000,
timeout: 30 * time.Second,
execute: func(f *SessionTestFramework) error {
var wg sync.WaitGroup
errors := make([]error, 0)
// Run high concurrency test
concurrency := 100
iterations := 10
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Create session
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Generate tokens
f.generateTestToken("access", 3600)
f.generateTestToken("refresh", 86400)
// Validate tokens
atomic.AddInt64(&f.metrics.TokensValidated, 2)
// Cleanup session
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
// Small delay to simulate real usage
time.Sleep(time.Millisecond)
}
}(i)
}
wg.Wait()
if len(errors) > 0 {
return errors[0]
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "High concurrency stress test should pass")
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "All sessions should be cleaned up")
},
},
{
name: "session_memory_bounds_enforcement",
scenario: "cleanup",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
maxSessions := f.config.MaxSessions
// Try to create more sessions than allowed
for i := 0; i < maxSessions+100; i++ {
sessionID := fmt.Sprintf("session-%d", i)
f.sessionIDs = append(f.sessionIDs, sessionID)
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Enforce max sessions
if len(f.sessionIDs) > maxSessions {
// Remove oldest session
f.sessionIDs = f.sessionIDs[1:]
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}
}
if len(f.sessionIDs) > maxSessions {
return fmt.Errorf("max sessions exceeded: %d > %d", len(f.sessionIDs), maxSessions)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Memory bounds should be enforced")
assert.LessOrEqual(t, len(f.sessionIDs), f.config.MaxSessions)
},
},
}
// Run all test cases
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.skipReason != "" {
t.Skip(tc.skipReason)
}
framework := NewSessionTestFramework(t)
defer framework.Cleanup()
// Setup
if tc.setup != nil {
tc.setup(framework)
}
// Cleanup
if tc.cleanup != nil {
defer tc.cleanup(framework)
}
// Set timeout if specified
if tc.timeout > 0 {
timer := time.NewTimer(tc.timeout)
done := make(chan bool)
go func() {
err := tc.execute(framework)
tc.validate(t, err, framework)
done <- true
}()
select {
case <-done:
timer.Stop()
case <-timer.C:
t.Fatal("Test timeout exceeded")
}
} else {
// Execute test
err := tc.execute(framework)
// Validate results
tc.validate(t, err, framework)
}
})
}
}
// ============================================================================
// SESSION STATE PRESERVATION TESTS (6-HOUR TOKEN EXPIRY SCENARIOS)
// ============================================================================
// TestSessionStatePreservationWithExpiredTokens tests that session state is preserved
// during token expiry scenarios
func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
t.Log("Testing session state preservation with expired tokens")
logger := NewLogger("debug")
sm, err := NewSessionManager("test-session-key-32-bytes-long-12345", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Simulate real-world session data that should be preserved
originalUserData := map[string]interface{}{
"user_id": "user-12345",
"email": "test.user@company.com",
"name": "Test User",
"roles": []string{"admin", "user"},
"pref_theme": "dark",
"pref_lang": "en",
"last_active": "2023-01-01T10:00:00Z",
}
// Create initial session with valid tokens
req1 := httptest.NewRequest("GET", "/initial", nil)
rr1 := httptest.NewRecorder()
session1, err := sm.GetSession(req1)
if err != nil {
t.Fatalf("Failed to get initial session: %v", err)
}
// Set up initial session state (what user has when first logging in)
session1.SetAuthenticated(true)
session1.SetEmail(originalUserData["email"].(string))
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
// Store additional user data in session - store individual values instead of map
for k, v := range originalUserData {
session1.mainSession.Values["user_data_"+k] = v
}
session1.mainSession.Values["session_created"] = time.Now().Unix() // Store as int64 for gob
session1.mainSession.Values["custom_flag"] = true
if err := session1.Save(req1, rr1); err != nil {
t.Fatalf("Failed to save initial session: %v", err)
}
initialCookies := rr1.Result().Cookies()
session1.ReturnToPool()
t.Log("Initial session created with user data")
// Fast-forward 6 hours - tokens expire due to browser inactivity
time.Sleep(10 * time.Millisecond) // Simulate time passage in test
// Create expired tokens (simulating what happens after 6 hours)
expiredTime := time.Now().Add(-6 * time.Hour)
expiredAccessToken := createExpiredJWTToken("user-12345", "test.user@company.com", expiredTime)
expiredIDToken := createExpiredJWTToken("user-12345", "test.user@company.com", expiredTime)
// User returns after inactivity and makes a request
req2 := httptest.NewRequest("GET", "/protected-resource", nil)
for _, cookie := range initialCookies {
req2.AddCookie(cookie)
}
session2, err := sm.GetSession(req2)
if err != nil {
t.Fatalf("Failed to get session after 6 hours: %v", err)
}
defer session2.ReturnToPool()
// Simulate what happens when middleware detects expired tokens
// It should preserve session state while attempting token refresh
originalAuth := session2.GetAuthenticated()
originalEmail := session2.GetEmail()
// Reconstruct user data from individual stored keys
originalUserDataStored := make(map[string]interface{})
for k := range originalUserData {
if storedValue, exists := session2.mainSession.Values["user_data_"+k]; exists {
originalUserDataStored[k] = storedValue
}
}
// Update session with expired tokens (what middleware does when tokens expire)
session2.SetAccessToken(expiredAccessToken)
session2.SetIDToken(expiredIDToken)
// Refresh token should still be valid
t.Log("Session loaded after 6-hour expiry, checking state preservation")
// Verify authentication state is preserved
if !originalAuth {
t.Error("Authentication state lost during session reload")
}
// Verify email is preserved
if originalEmail != originalUserData["email"].(string) {
t.Errorf("User email lost during session reload - Expected: %s, Got: %s",
originalUserData["email"], originalEmail)
}
// Verify custom user data is preserved
if len(originalUserDataStored) == 0 {
t.Error("All custom user data lost during session reload")
} else {
if originalUserDataStored["user_id"] != originalUserData["user_id"] {
t.Error("User ID lost from session data")
}
if originalUserDataStored["name"] != originalUserData["name"] {
t.Error("User name lost from session data")
}
if originalUserDataStored["pref_theme"] != originalUserData["pref_theme"] {
t.Error("User theme preference lost from session data")
}
if originalUserDataStored["pref_lang"] != originalUserData["pref_lang"] {
t.Error("User language preference lost from session data")
}
}
// Note: System may reject invalid/expired tokens during storage, which is acceptable behavior
currentAccessToken := session2.GetAccessToken()
if currentAccessToken != expiredAccessToken {
t.Logf("INFO: Access token was not stored (possibly rejected due to expiry) - Expected: %s, Got: %s",
expiredAccessToken, currentAccessToken)
}
// Verify that session can be saved again after token expiry without losing data
rr2 := httptest.NewRecorder()
if err := session2.Save(req2, rr2); err != nil {
t.Errorf("Cannot save session after token expiry: %v", err)
} else {
t.Log("Session successfully saved after token expiry")
// Verify cookies are still set
newCookies := rr2.Result().Cookies()
if len(newCookies) == 0 {
t.Error("No session cookies set after saving expired token session")
}
}
// Test session recovery after token refresh simulation
newAccessToken := "refreshed-access-token-longer-than-20-chars"
newIDToken := "refreshed-id-token-longer-than-20-chars"
newRefreshToken := "new-refresh-token-after-successful-renewal"
session2.SetAccessToken(newAccessToken)
session2.SetIDToken(newIDToken)
session2.SetRefreshToken(newRefreshToken)
// Verify all session data is still intact after token refresh
postRefreshAuth := session2.GetAuthenticated()
postRefreshEmail := session2.GetEmail()
userDataPresent := true
for k := range originalUserData {
if session2.mainSession.Values["user_data_"+k] == nil {
userDataPresent = false
break
}
}
if !postRefreshAuth {
t.Error("Authentication state lost after token refresh")
}
if postRefreshEmail != originalUserData["email"].(string) {
t.Error("User email lost after token refresh")
}
if !userDataPresent {
t.Error("User data lost after token refresh")
}
t.Log("Session state preservation test completed")
}
// TestSessionExpiryVsTokenExpiry tests the distinction between session expiry and token expiry
func TestSessionExpiryVsTokenExpiry(t *testing.T) {
t.Log("Testing session expiry vs token expiry distinction")
logger := NewLogger("debug")
sm, err := NewSessionManager("session-vs-token-test-key-32-bytes", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
scenarios := []struct {
name string
expectedBehavior string
sessionAge time.Duration
tokenExpiry time.Duration
sessionShouldExpire bool
tokenShouldRefresh bool
}{
{
name: "New session, expired tokens",
sessionAge: 5 * time.Minute,
tokenExpiry: -6 * time.Hour,
expectedBehavior: "Session valid, tokens should refresh",
sessionShouldExpire: false,
tokenShouldRefresh: true,
},
{
name: "Old session, valid tokens",
sessionAge: 25 * time.Hour,
tokenExpiry: 2 * time.Hour,
expectedBehavior: "Session expired, redirect to login even with valid tokens",
sessionShouldExpire: true,
tokenShouldRefresh: false,
},
{
name: "Both session and tokens expired",
sessionAge: 25 * time.Hour,
tokenExpiry: -6 * time.Hour,
expectedBehavior: "Both expired, clear session and redirect to login",
sessionShouldExpire: true,
tokenShouldRefresh: false,
},
{
name: "Recent session, recently expired tokens",
sessionAge: 30 * time.Minute,
tokenExpiry: -10 * time.Minute,
expectedBehavior: "Session valid, tokens recently expired, should refresh",
sessionShouldExpire: false,
tokenShouldRefresh: true,
},
}
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
t.Logf("Testing: %s", scenario.expectedBehavior)
// Create session at specific "age"
sessionCreatedAt := time.Now().Add(-scenario.sessionAge)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Set up session with specific creation time
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
// Create tokens with specific expiry
tokenExpiredAt := time.Now().Add(scenario.tokenExpiry)
accessToken := createExpiredJWTToken("test-user", "test@example.com", tokenExpiredAt)
session.SetAccessToken(accessToken)
session.SetRefreshToken("test-refresh-token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Test session validity check
isSessionExpired := scenario.sessionAge > absoluteSessionTimeout
isTokenExpired := scenario.tokenExpiry < 0
t.Logf("Session age: %v (expired: %t)", scenario.sessionAge, isSessionExpired)
t.Logf("Token expiry: %v ago (expired: %t)", -scenario.tokenExpiry, isTokenExpired)
if scenario.sessionShouldExpire {
if isSessionExpired && session.GetAuthenticated() {
t.Errorf("Session should be expired after %v but is still authenticated", scenario.sessionAge)
}
} else {
if !isSessionExpired && !session.GetAuthenticated() {
t.Errorf("Session should be valid (age: %v) but shows as not authenticated", scenario.sessionAge)
}
}
if scenario.tokenShouldRefresh {
if !isTokenExpired {
t.Errorf("Test setup error - tokens should be expired but expiry is: %v", scenario.tokenExpiry)
}
t.Logf("Should attempt token refresh for scenario: %s", scenario.name)
} else {
if isSessionExpired {
t.Logf("Correctly identified that session is expired - no need to refresh tokens")
}
}
// Check for critical scenario: confusing session expiry with token expiry
if !isSessionExpired && isTokenExpired {
t.Logf("CRITICAL SCENARIO: Valid session (%v old) but expired tokens (%v ago)",
scenario.sessionAge, -scenario.tokenExpiry)
t.Logf("Expected: System should refresh tokens and continue session")
if scenario.name == "New session, expired tokens" && scenario.tokenExpiry == -6*time.Hour {
t.Logf("This represents the 6-hour browser inactivity scenario")
}
}
})
}
}
// TestSessionCleanupOnTokenExpiry tests that session cleanup happens correctly
func TestSessionCleanupOnTokenExpiry(t *testing.T) {
t.Log("Testing session cleanup on token expiry")
logger := NewLogger("debug")
sm, err := NewSessionManager("cleanup-test-key-32-bytes-long-123", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
scenarios := []struct {
name string
shouldPreserve []string
shouldRemove []string
tokenExpiry time.Duration
shouldCleanup bool
}{
{
name: "Recently expired tokens - preserve session",
tokenExpiry: -30 * time.Minute,
shouldCleanup: false,
shouldPreserve: []string{"user_data", "preferences", "authentication"},
shouldRemove: []string{},
},
{
name: "Long expired tokens - cleanup selectively",
tokenExpiry: -25 * time.Hour,
shouldCleanup: true,
shouldPreserve: []string{},
shouldRemove: []string{"user_data", "preferences", "authentication"},
},
{
name: "6-hour expired tokens - preserve for refresh",
tokenExpiry: -6 * time.Hour,
shouldCleanup: false,
shouldPreserve: []string{"user_data", "preferences", "authentication"},
shouldRemove: []string{},
},
}
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
t.Logf("Testing cleanup behavior: %s", scenario.name)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Set up session with data that should be preserved or removed
session.SetAuthenticated(true)
session.SetEmail("cleanup@example.com")
session.mainSession.Values["user_data"] = "Test User|user-123"
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
session.mainSession.Values["authentication"] = true
session.mainSession.Values["temp_data"] = "should-be-cleaned"
// Set expired tokens
expiredTime := time.Now().Add(scenario.tokenExpiry)
expiredToken := createExpiredJWTToken("user-123", "cleanup@example.com", expiredTime)
session.SetAccessToken(expiredToken)
session.SetRefreshToken("test-refresh-token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Simulate token expiry detection and cleanup logic
tokenExpired := scenario.tokenExpiry < 0
sessionTooOld := scenario.tokenExpiry < -absoluteSessionTimeout
t.Logf("Token expired: %t, Session too old: %t", tokenExpired, sessionTooOld)
// Check current session state before cleanup
preCleanupAuth := session.GetAuthenticated()
preCleanupData := session.mainSession.Values["user_data"]
preCleanupPrefs := session.mainSession.Values["preferences"]
if scenario.shouldCleanup {
if sessionTooOld {
session.SetAuthenticated(false)
session.SetEmail("")
session.SetAccessToken("")
session.SetRefreshToken("")
for key := range session.mainSession.Values {
delete(session.mainSession.Values, key)
}
t.Log("Applied full cleanup for expired session")
}
} else {
t.Log("Preserving session for token refresh")
}
// Check post-cleanup state
postCleanupAuth := session.GetAuthenticated()
postCleanupData := session.mainSession.Values["user_data"]
postCleanupPrefs := session.mainSession.Values["preferences"]
// Verify preservation expectations
for _, item := range scenario.shouldPreserve {
switch item {
case "authentication":
if !postCleanupAuth && preCleanupAuth {
t.Errorf("Authentication state was cleaned up but should be preserved")
}
case "user_data":
if postCleanupData == nil && preCleanupData != nil {
t.Errorf("User data was cleaned up but should be preserved")
}
case "preferences":
if postCleanupPrefs == nil && preCleanupPrefs != nil {
t.Errorf("User preferences were cleaned up but should be preserved")
}
}
}
// Verify removal expectations
for _, item := range scenario.shouldRemove {
switch item {
case "authentication":
if postCleanupAuth && scenario.shouldCleanup {
t.Errorf("Authentication state not cleaned up when it should be")
}
case "user_data":
if postCleanupData != nil && scenario.shouldCleanup {
t.Errorf("User data not cleaned up when session is expired")
}
}
}
// Check the critical 6-hour scenario
if scenario.tokenExpiry == -6*time.Hour {
if !postCleanupAuth {
t.Error("6-hour token expiry caused session cleanup - session should be preserved for token refresh")
}
if postCleanupData == nil {
t.Error("6-hour token expiry caused user data loss - user data should be preserved during token refresh")
}
}
})
}
}
// ============================================================================
// HELPER FUNCTIONS
// ============================================================================
// Helper function to count objects in the session pool for a given manager
func getPooledObjects(sm *SessionManager) int {
var objects []*SessionData
maxAttempts := 100
for i := 0; i < maxAttempts; i++ {
obj := sm.sessionPool.Get()
if obj == nil {
break
}
sessionData, ok := obj.(*SessionData)
if !ok {
sm.sessionPool.Put(obj)
break
}
objects = append(objects, sessionData)
}
count := len(objects)
for _, obj := range objects {
sm.sessionPool.Put(obj)
}
return count
}
// createLargeIDToken creates a JWT-like token of specified size for testing
func createLargeIDToken(size int) string {
randomBytes := make([]byte, size*3/4)
_, err := rand.Read(randomBytes)
if err != nil {
for i := range randomBytes {
randomBytes[i] = byte(i % 256)
}
}
encoded := base64.RawURLEncoding.EncodeToString(randomBytes)
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
if len(encoded) > size-len(header)-100 {
encoded = encoded[:size-len(header)-100]
}
signature := "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
return header + "." + encoded + "." + signature
}
// minInt returns the minimum of two integers
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// Helper function to create expired JWT tokens for testing
func createExpiredJWTToken(userID, email string, expiredTime time.Time) string {
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
claims := map[string]interface{}{
"sub": userID,
"email": email,
"exp": expiredTime.Unix(),
"iat": expiredTime.Add(-1 * time.Hour).Unix(),
"iss": "https://test-provider.com",
"aud": "test-client-id",
}
claimsJSON, _ := json.Marshal(claims)
claimsEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON)
signature := "fake-signature-for-testing"
signatureEncoded := base64.RawURLEncoding.EncodeToString([]byte(signature))
return header + "." + claimsEncoded + "." + signatureEncoded
}
// TestCookiePrefixIsolation tests that different cookie prefixes create isolated sessions
// This addresses GitHub issue #87 where multiple middleware instances should not share sessions
func TestCookiePrefixIsolation(t *testing.T) {
logger := NewLogger("info")
encryptionKey := strings.Repeat("a", 32)
// Create two session managers with different cookie prefixes
sm1, err := NewSessionManager(encryptionKey, false, "", "_oidc_userauth_", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager 1: %v", err)
}
sm2, err := NewSessionManager(encryptionKey, false, "", "_oidc_adminauth_", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager 2: %v", err)
}
// Verify cookie names are different
if sm1.mainCookieName() == sm2.mainCookieName() {
t.Errorf("Expected different main cookie names, got same: %s", sm1.mainCookieName())
}
if sm1.accessTokenCookieName() == sm2.accessTokenCookieName() {
t.Errorf("Expected different access token cookie names, got same: %s", sm1.accessTokenCookieName())
}
// Verify cookie names have the correct prefix
expectedPrefix1 := "_oidc_userauth_"
expectedPrefix2 := "_oidc_adminauth_"
if !strings.HasPrefix(sm1.mainCookieName(), expectedPrefix1) {
t.Errorf("Expected main cookie name to start with %s, got %s", expectedPrefix1, sm1.mainCookieName())
}
if !strings.HasPrefix(sm2.mainCookieName(), expectedPrefix2) {
t.Errorf("Expected main cookie name to start with %s, got %s", expectedPrefix2, sm2.mainCookieName())
}
t.Logf("Session Manager 1 cookies: main=%s, access=%s, refresh=%s, id=%s",
sm1.mainCookieName(), sm1.accessTokenCookieName(), sm1.refreshTokenCookieName(), sm1.idTokenCookieName())
t.Logf("Session Manager 2 cookies: main=%s, access=%s, refresh=%s, id=%s",
sm2.mainCookieName(), sm2.accessTokenCookieName(), sm2.refreshTokenCookieName(), sm2.idTokenCookieName())
}
// TestCookiePrefixDefault tests that the default cookie prefix is applied when none is provided
func TestCookiePrefixDefault(t *testing.T) {
logger := NewLogger("info")
encryptionKey := strings.Repeat("a", 32)
// Create session manager without cookie prefix (should use default)
sm, err := NewSessionManager(encryptionKey, false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Verify default prefix is used
expectedPrefix := defaultCookiePrefix
if !strings.HasPrefix(sm.mainCookieName(), expectedPrefix) {
t.Errorf("Expected default prefix %s, got cookie name %s", expectedPrefix, sm.mainCookieName())
}
// Verify full cookie names
expectedMain := defaultCookiePrefix + mainCookieSuffix
expectedAccess := defaultCookiePrefix + accessTokenSuffix
expectedRefresh := defaultCookiePrefix + refreshTokenSuffix
expectedID := defaultCookiePrefix + idTokenSuffix
if sm.mainCookieName() != expectedMain {
t.Errorf("Expected main cookie name %s, got %s", expectedMain, sm.mainCookieName())
}
if sm.accessTokenCookieName() != expectedAccess {
t.Errorf("Expected access cookie name %s, got %s", expectedAccess, sm.accessTokenCookieName())
}
if sm.refreshTokenCookieName() != expectedRefresh {
t.Errorf("Expected refresh cookie name %s, got %s", expectedRefresh, sm.refreshTokenCookieName())
}
if sm.idTokenCookieName() != expectedID {
t.Errorf("Expected ID cookie name %s, got %s", expectedID, sm.idTokenCookieName())
}
}