Files
traefikoidc/session_test.go
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00

3279 lines
94 KiB
Go

package traefikoidc
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
)
// ============================================================================
// SESSION TEST FRAMEWORK
// ============================================================================
// SessionTestCase represents a comprehensive session test scenario
type SessionTestCase struct {
setup func(*SessionTestFramework)
execute func(*SessionTestFramework) error
validate func(*testing.T, error, *SessionTestFramework)
cleanup func(*SessionTestFramework)
name string
scenario string
sessionType string
skipReason string
iterations int
timeout time.Duration
concurrent bool
}
// SessionTestFramework provides shared test infrastructure for session tests
type SessionTestFramework struct {
t *testing.T
mockProvider *httptest.Server
testTokens map[string]string
metrics *SessionTestMetrics
config *SessionTestConfig
requests []*http.Request
responses []*httptest.ResponseRecorder
sessionIDs []string
cleanupFuncs []func()
mu sync.RWMutex
}
// SessionTestMetrics tracks test performance metrics
type SessionTestMetrics struct {
SessionsCreated int64
SessionsDestroyed int64
TokensGenerated int64
TokensValidated int64
ChunksCreated int64
ChunksRetrieved int64
ErrorCount int64
Duration time.Duration
}
// SessionTestConfig holds test configuration
type SessionTestConfig struct {
CookieDomain string
EncryptionKey string
MaxChunkSize int
MaxSessions int
SessionTimeout time.Duration
EnableHTTPS bool
EnableCompression bool
}
// NewSessionTestFramework creates a new test framework instance
func NewSessionTestFramework(t *testing.T) *SessionTestFramework {
framework := &SessionTestFramework{
t: t,
requests: make([]*http.Request, 0),
responses: make([]*httptest.ResponseRecorder, 0),
testTokens: make(map[string]string),
sessionIDs: make([]string, 0),
metrics: &SessionTestMetrics{},
cleanupFuncs: make([]func(), 0),
config: &SessionTestConfig{
MaxChunkSize: 3900,
MaxSessions: 1000,
EnableHTTPS: false,
CookieDomain: "",
SessionTimeout: time.Hour,
EncryptionKey: generateTestKey(),
EnableCompression: true,
},
}
// Setup mock OIDC provider
framework.setupMockProvider()
return framework
}
// generateTestKey generates a test encryption key
func generateTestKey() string {
// 48 bytes = 384 bits for testing
return "0123456789abcdef0123456789abcdef0123456789abcdef"
}
// setupMockProvider sets up a mock OIDC provider for testing
func (f *SessionTestFramework) setupMockProvider() {
f.mockProvider = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": f.mockProvider.URL,
"authorization_endpoint": f.mockProvider.URL + "/auth",
"token_endpoint": f.mockProvider.URL + "/token",
"userinfo_endpoint": f.mockProvider.URL + "/userinfo",
"jwks_uri": f.mockProvider.URL + "/jwks",
})
case "/token":
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": f.generateTestToken("access", 3600),
"id_token": f.generateTestToken("id", 3600),
"refresh_token": f.generateTestToken("refresh", 86400),
"token_type": "Bearer",
"expires_in": 3600,
})
case "/userinfo":
json.NewEncoder(w).Encode(map[string]interface{}{
"sub": "test-user-id",
"email": "test@example.com",
"name": "Test User",
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
f.cleanupFuncs = append(f.cleanupFuncs, f.mockProvider.Close)
}
// generateTestToken generates a test token
func (f *SessionTestFramework) generateTestToken(tokenType string, expiresIn int) string {
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
// Create a realistic JWT-like token for testing
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
claims := map[string]interface{}{
"iss": f.mockProvider.URL,
"sub": "test-user-id",
"aud": "test-client-id",
"exp": time.Now().Add(time.Duration(expiresIn) * time.Second).Unix(),
"iat": time.Now().Unix(),
"typ": tokenType,
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Generate a fake signature
signature := make([]byte, 64)
rand.Read(signature)
sig := base64.RawURLEncoding.EncodeToString(signature)
token := fmt.Sprintf("%s.%s.%s", header, payload, sig)
// Thread-safe write to map
f.mu.Lock()
f.testTokens[tokenType] = token
f.mu.Unlock()
return token
}
// generateLargeToken generates a token of specified size for testing chunking
func (f *SessionTestFramework) generateLargeToken(size int) string {
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
// Create base JWT structure
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
// Calculate how much padding we need in claims
baseSize := len(header) + 2 // for the dots
signatureSize := 86 // approximate base64 encoded signature size
paddingSize := size - baseSize - signatureSize - 100 // leave room for other claims
if paddingSize < 0 {
paddingSize = 0
}
// Create large padding data
padding := make([]byte, paddingSize)
for i := range padding {
padding[i] = byte('A' + (i % 26))
}
claims := map[string]interface{}{
"iss": f.mockProvider.URL,
"sub": "test-user-id",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"padding": base64.StdEncoding.EncodeToString(padding),
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Generate signature
signature := make([]byte, 64)
rand.Read(signature)
sig := base64.RawURLEncoding.EncodeToString(signature)
return fmt.Sprintf("%s.%s.%s", header, payload, sig)
}
// Cleanup performs framework cleanup
func (f *SessionTestFramework) Cleanup() {
for _, cleanup := range f.cleanupFuncs {
cleanup()
}
}
// ============================================================================
// SESSION CHUNK MANAGER TESTS
// ============================================================================
// Helper function to create a mock HTTP request for session creation
func createMockRequest() *http.Request {
req := httptest.NewRequest("GET", "http://example.com", nil)
return req
}
func TestNewSessionChunkManager(t *testing.T) {
manager := NewSessionChunkManager(10)
if manager == nil {
t.Fatal("Expected non-nil session chunk manager")
}
if manager.maxChunks != 10 {
t.Errorf("Expected maxChunks 10, got %d", manager.maxChunks)
}
}
func TestNewSessionChunkManagerDefaultLimit(t *testing.T) {
// Test with 0 maxChunks (should use default)
manager := NewSessionChunkManager(0)
if manager.maxChunks != 20 {
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
}
}
func TestNewSessionChunkManagerNegativeLimit(t *testing.T) {
// Test with negative maxChunks (should use default)
manager := NewSessionChunkManager(-5)
if manager.maxChunks != 20 {
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
}
}
func TestCleanupChunksWithoutWriter(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add some chunks
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["token_chunk"] = "chunk-data"
chunks[i] = session
}
// Cleanup without writer (should just clear map)
manager.CleanupChunks(chunks, nil)
if len(chunks) != 0 {
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
}
}
func TestCleanupChunksWithWriter(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add some chunks
for i := 0; i < 3; i++ {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["token_chunk"] = "chunk-data"
session.Options = &sessions.Options{MaxAge: 3600}
chunks[i] = session
}
// Create response writer
w := httptest.NewRecorder()
// Note: We can't fully test the Save behavior without a proper HTTP request
// but we can verify the cleanup clears the map
manager.CleanupChunks(chunks, w)
if len(chunks) != 0 {
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
}
}
func TestCleanupChunksNilSession(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
chunks[0] = nil
chunks[1] = nil
w := httptest.NewRecorder()
// Should handle nil sessions gracefully
manager.CleanupChunks(chunks, w)
if len(chunks) != 0 {
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
}
}
func TestCleanupChunksEmptyMap(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
// Should handle empty map gracefully
manager.CleanupChunks(chunks, nil)
if len(chunks) != 0 {
t.Error("Expected chunks map to remain empty")
}
}
func TestValidateAndCleanChunksWithinLimit(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks within limit
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if !result {
t.Error("Expected validation to pass for chunks within limit")
}
if len(chunks) != 5 {
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
}
}
func TestValidateAndCleanChunksExceedLimit(t *testing.T) {
manager := NewSessionChunkManager(5)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add more chunks than limit
for i := 0; i < 10; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if result {
t.Error("Expected validation to fail for chunks exceeding limit")
}
if len(chunks) != 0 {
t.Errorf("Expected chunks to be cleared, got %d", len(chunks))
}
}
func TestValidateAndCleanChunksAtLimit(t *testing.T) {
manager := NewSessionChunkManager(5)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks exactly at limit
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if !result {
t.Error("Expected validation to pass for chunks at limit")
}
if len(chunks) != 5 {
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
}
}
func TestSafeSetChunkValidIndex(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, 5, session)
if !result {
t.Error("Expected SafeSetChunk to succeed for valid index")
}
if chunks[5] != session {
t.Error("Expected session to be set at index 5")
}
}
func TestSafeSetChunkNegativeIndex(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, -1, session)
if result {
t.Error("Expected SafeSetChunk to fail for negative index")
}
if len(chunks) != 0 {
t.Error("Expected chunks map to remain empty")
}
}
func TestSafeSetChunkIndexTooHigh(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, 10, session)
if result {
t.Error("Expected SafeSetChunk to fail for index >= maxChunks")
}
if len(chunks) != 0 {
t.Error("Expected chunks map to remain empty")
}
}
func TestSafeSetChunkExceedingLimit(t *testing.T) {
manager := NewSessionChunkManager(5)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Fill up to limit
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
// Try to add a new chunk at new index (should fail)
session, _ := store.New(createMockRequest(), "chunk")
result := manager.SafeSetChunk(chunks, 2, session)
// This should succeed because index 2 already exists
if !result {
t.Error("Expected SafeSetChunk to succeed for existing index")
}
}
func TestSafeSetChunkReplaceExisting(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
session1, _ := store.New(createMockRequest(), "chunk1")
session2, _ := store.New(createMockRequest(), "chunk2")
// Set initial session
manager.SafeSetChunk(chunks, 3, session1)
// Replace with new session
result := manager.SafeSetChunk(chunks, 3, session2)
if !result {
t.Error("Expected SafeSetChunk to succeed for replacing existing chunk")
}
if chunks[3] != session2 {
t.Error("Expected session to be replaced at index 3")
}
}
func TestGetChunkCount(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add some chunks
for i := 0; i < 7; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
count := manager.GetChunkCount(chunks)
if count != 7 {
t.Errorf("Expected chunk count 7, got %d", count)
}
}
func TestGetChunkCountEmpty(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
count := manager.GetChunkCount(chunks)
if count != 0 {
t.Errorf("Expected chunk count 0, got %d", count)
}
}
func TestCompactChunksNoGaps(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add sequential chunks
for i := 0; i < 5; i++ {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["index"] = i
chunks[i] = session
}
compacted := manager.CompactChunks(chunks)
if len(compacted) != 5 {
t.Errorf("Expected 5 compacted chunks, got %d", len(compacted))
}
// Verify order
for i := 0; i < 5; i++ {
if compacted[i] == nil {
t.Errorf("Expected chunk at index %d", i)
}
}
}
func TestCompactChunksWithGaps(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks with gaps
indices := []int{0, 2, 5, 7}
for _, idx := range indices {
session, _ := store.New(createMockRequest(), "chunk")
session.Values["original_index"] = idx
chunks[idx] = session
}
compacted := manager.CompactChunks(chunks)
if len(compacted) != 4 {
t.Errorf("Expected 4 compacted chunks, got %d", len(compacted))
}
// Verify chunks are reindexed sequentially
for i := 0; i < 4; i++ {
if compacted[i] == nil {
t.Errorf("Expected chunk at compacted index %d", i)
}
}
}
func TestCompactChunksWithNilEntries(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add chunks and nil entries
session1, _ := store.New(createMockRequest(), "chunk1")
session2, _ := store.New(createMockRequest(), "chunk2")
session3, _ := store.New(createMockRequest(), "chunk3")
chunks[0] = session1
chunks[1] = nil
chunks[2] = session2
chunks[3] = nil
chunks[4] = session3
compacted := manager.CompactChunks(chunks)
if len(compacted) != 3 {
t.Errorf("Expected 3 compacted chunks (nil entries removed), got %d", len(compacted))
}
// Verify non-nil chunks are compacted
for i := 0; i < 3; i++ {
if compacted[i] == nil {
t.Errorf("Expected non-nil chunk at compacted index %d", i)
}
}
}
func TestCompactChunksEmpty(t *testing.T) {
manager := NewSessionChunkManager(10)
chunks := make(map[int]*sessions.Session)
compacted := manager.CompactChunks(chunks)
if len(compacted) != 0 {
t.Errorf("Expected empty compacted map, got %d entries", len(compacted))
}
}
func TestSessionChunkManagerConcurrentOperations(t *testing.T) {
manager := NewSessionChunkManager(50)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
var wg sync.WaitGroup
// Concurrent SafeSetChunk
for i := 0; i < 20; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
session, _ := store.New(createMockRequest(), "chunk")
manager.SafeSetChunk(chunks, index, session)
}(i)
}
// Concurrent GetChunkCount
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = manager.GetChunkCount(chunks)
}()
}
// Concurrent ValidateAndCleanChunks (reads)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = manager.ValidateAndCleanChunks(chunks)
}()
}
wg.Wait()
// Verify manager is still functional
count := manager.GetChunkCount(chunks)
if count < 0 || count > 50 {
t.Errorf("Unexpected chunk count after concurrent operations: %d", count)
}
}
func TestSessionChunkManagerLargeChunkCount(t *testing.T) {
manager := NewSessionChunkManager(1000)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
// Add many chunks
for i := 0; i < 500; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if !result {
t.Error("Expected validation to pass for 500 chunks with limit 1000")
}
count := manager.GetChunkCount(chunks)
if count != 500 {
t.Errorf("Expected 500 chunks, got %d", count)
}
}
func TestSessionChunkManagerBoundaryConditions(t *testing.T) {
tests := []struct {
name string
maxChunks int
addChunks int
shouldPass bool
}{
{"exactly at limit", 10, 10, true},
{"one over limit", 10, 11, false},
{"way over limit", 10, 50, false},
{"zero chunks with limit", 10, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := NewSessionChunkManager(tt.maxChunks)
chunks := make(map[int]*sessions.Session)
store := sessions.NewCookieStore([]byte("test-secret"))
for i := 0; i < tt.addChunks; i++ {
session, _ := store.New(createMockRequest(), "chunk")
chunks[i] = session
}
result := manager.ValidateAndCleanChunks(chunks)
if result != tt.shouldPass {
t.Errorf("Expected validation result %v, got %v", tt.shouldPass, result)
}
})
}
}
// ============================================================================
// SESSION HELPER TESTS
// ============================================================================
// TestSetCodeVerifier_NoChange tests the branch where the code verifier value doesn't change
func TestSetCodeVerifier_NoChange(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
defer sm.Shutdown()
req := httptest.NewRequest("GET", "http://example.com/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Set initial code verifier
initialVerifier := "test-code-verifier-12345"
session.SetCodeVerifier(initialVerifier)
if !session.IsDirty() {
t.Error("Session should be dirty after first SetCodeVerifier")
}
// Mark clean to test the no-change branch
session.dirty = false
// Set the same code verifier again - this should hit the uncovered branch
session.SetCodeVerifier(initialVerifier)
// Verify that dirty flag remains false (no change occurred)
if session.IsDirty() {
t.Error("Session should not be dirty when setting same code verifier value")
}
// Verify the code verifier value is still correct
if got := session.GetCodeVerifier(); got != initialVerifier {
t.Errorf("Expected code verifier %q, got %q", initialVerifier, got)
}
}
// TestClearTokenChunks_EmptyChunks tests the branch where the chunks map is empty
func TestClearTokenChunks_EmptyChunks(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
defer sm.Shutdown()
req := httptest.NewRequest("GET", "http://example.com/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Test with empty chunks map - this should hit the uncovered branch where the loop body doesn't execute
emptyChunks := make(map[int]*sessions.Session)
// This should not panic and should handle empty map gracefully
session.clearTokenChunks(req, emptyChunks)
// Verify that no errors occurred and the session is still valid
if session == nil {
t.Fatal("Session should still be valid after clearing empty chunks")
}
// Additional test: clear already-empty chunk maps in the session
session.clearTokenChunks(req, session.accessTokenChunks)
session.clearTokenChunks(req, session.refreshTokenChunks)
session.clearTokenChunks(req, session.idTokenChunks)
// Verify session is still valid
if session.GetAuthenticated() {
// This is fine - session can be authenticated even with no chunks
}
}
// TestClearTokenChunks_WithSessions tests the branch where the chunks map contains actual sessions
func TestClearTokenChunks_WithSessions(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
defer sm.Shutdown()
req := httptest.NewRequest("GET", "http://example.com/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Create chunks map with actual sessions
chunksWithSessions := make(map[int]*sessions.Session)
// Create a few test sessions and add them to the chunks map
for i := 0; i < 3; i++ {
chunkSession, err := sm.store.Get(req, fmt.Sprintf("test_chunk_%d", i))
if err != nil {
t.Fatalf("Failed to create test chunk session: %v", err)
}
// Add some test data to the session
chunkSession.Values["test_data"] = fmt.Sprintf("chunk_%d_data", i)
chunkSession.Values["chunk_index"] = i
chunksWithSessions[i] = chunkSession
}
// Verify chunks have data before clearing
if len(chunksWithSessions) != 3 {
t.Errorf("Expected 3 chunks, got %d", len(chunksWithSessions))
}
for i, chunkSession := range chunksWithSessions {
if chunkSession.Values["test_data"] == nil {
t.Errorf("Chunk %d should have test data before clearing", i)
}
}
// Call clearTokenChunks - this should hit the loop body and clear all sessions
session.clearTokenChunks(req, chunksWithSessions)
// Verify that the sessions were cleared
for i, chunkSession := range chunksWithSessions {
if len(chunkSession.Values) != 0 {
t.Errorf("Chunk %d should have no values after clearing, but has %d values", i, len(chunkSession.Values))
}
// Verify MaxAge was set to -1 (expired)
if chunkSession.Options.MaxAge != -1 {
t.Errorf("Chunk %d should have MaxAge=-1 (expired), but has MaxAge=%d", i, chunkSession.Options.MaxAge)
}
}
}
// ============================================================================
// SESSION POOL AND MEMORY TESTS
// ============================================================================
// TestSessionPoolMemoryLeak tests that session objects are properly returned to the pool
func TestSessionPoolMemoryLeak(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeLeakDetection) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
runner.SetTimeout(30 * time.Second)
tests := []TableTestCase{
{
Name: "Successful session creation and return",
Description: "Test that sessions are properly created and returned to pool",
Setup: func(t *testing.T) error {
return nil
},
Teardown: func(t *testing.T) error {
runtime.GC()
time.Sleep(100 * time.Millisecond)
return nil
},
},
{
Name: "Explicit ReturnToPool method",
Description: "Test that explicit pool return works correctly",
Setup: func(t *testing.T) error {
return nil
},
Teardown: func(t *testing.T) error {
runtime.GC()
time.Sleep(100 * time.Millisecond)
return nil
},
},
{
Name: "Error path in GetSession",
Description: "Test pool behavior when GetSession fails",
Setup: func(t *testing.T) error {
return nil
},
Teardown: func(t *testing.T) error {
runtime.GC()
time.Sleep(100 * time.Millisecond)
return nil
},
},
}
// Custom test execution since we need to test memory behavior
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
if test.Setup != nil {
if err := test.Setup(t); err != nil {
t.Fatalf("Setup failed: %v", err)
}
}
if test.Teardown != nil {
defer func() {
if err := test.Teardown(t); err != nil {
t.Errorf("Teardown failed: %v", err)
}
}()
}
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
switch test.Name {
case "Successful session creation and return":
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.Clear(req, nil)
case "Explicit ReturnToPool method":
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.ReturnToPool()
case "Error path in GetSession":
badSM, _ := NewSessionManager("different0123456789abcdef0123456789abcdef0123456789", false, "", "", 0, logger)
_, err = badSM.GetSession(req)
if err == nil {
t.Log("Note: Expected error when using mismatched encryption keys")
}
}
pooledCount := getPooledObjects(sm)
t.Logf("Pooled objects count: %d", pooledCount)
})
}
_ = testTokens
_ = edgeGen
}
// TestSessionErrorHandling tests comprehensive error scenarios using table-driven tests
func TestSessionErrorHandling(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeQuick) {
return
}
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
// Generate edge case strings for cookie values
edgeCases := edgeGen.GenerateStringEdgeCases()
tests := []TableTestCase{
{
Name: "Corrupt cookie value",
Description: "Test handling of corrupted cookie values",
Input: "corrupt-value",
Expected: "failed to get main session:",
},
{
Name: "Invalid base64 cookie",
Description: "Test handling of invalid base64 in cookies",
Input: "!@#$%^&*()",
Expected: "failed to get main session:",
},
{
Name: "Empty cookie value",
Description: "Test handling of empty cookie values",
Input: "",
Expected: "", // Empty should work without error
},
}
// Add edge cases dynamically
for i, edgeCase := range edgeCases {
if len(edgeCase) > 0 && !strings.ContainsAny(edgeCase, "\x00\x01\x02") { // Skip binary data for cookie tests
tests = append(tests, TableTestCase{
Name: fmt.Sprintf("Edge case %d", i),
Description: fmt.Sprintf("Test edge case string: %q", edgeCase[:minInt(20, len(edgeCase))]),
Input: edgeCase,
Expected: "", // Most edge cases should be handled gracefully
})
}
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
if input, ok := test.Input.(string); ok && input != "" {
req.AddCookie(&http.Cookie{
Name: defaultCookiePrefix + mainCookieSuffix,
Value: input,
})
}
_, err = sm.GetSession(req)
if expected, ok := test.Expected.(string); ok && expected != "" {
if err == nil {
t.Error("Expected error, got nil")
} else if !strings.Contains(err.Error(), expected) {
t.Errorf("Unexpected error message: %v", err)
}
} else {
// For empty expected, we allow either success or specific failures
if err != nil {
t.Logf("Got expected error for edge case: %v", err)
}
}
})
}
_ = runner
}
// TestSessionClearAlwaysReturnsToPool tests that sessions are always returned to pool even on errors
func TestSessionClearAlwaysReturnsToPool(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeQuick) {
return
}
runner := NewTestSuiteRunner()
memoryTests := []MemoryLeakTestCase{
{
Name: "Session clear with error returns to pool",
Description: "Verify sessions return to pool even when Clear() errors",
Iterations: 10,
MaxGoroutineGrowth: 2,
MaxMemoryGrowthMB: 5.0,
GCBetweenRuns: true,
Timeout: 30 * time.Second,
Operation: func() error {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
// Ensure proper cleanup by calling Shutdown
defer func() {
if shutdownErr := sm.Shutdown(); shutdownErr != nil {
logger.Errorf("Failed to shutdown SessionManager: %v", shutdownErr)
}
}()
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Set("X-Test-Error", "true")
session, err := sm.GetSession(req)
if err != nil {
return fmt.Errorf("GetSession failed: %w", err)
}
w := httptest.NewRecorder()
clearErr := session.Clear(req, w)
// We expect an error due to the X-Test-Error header, but the session should still be returned
if clearErr == nil {
return fmt.Errorf("expected error from Clear with X-Test-Error header")
}
return nil
},
},
}
runner.RunMemoryLeakTests(t, memoryTests)
// Additional verification test
t.Run("Verify pool still works after errors", func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Ensure proper cleanup
defer func() {
if shutdownErr := sm.Shutdown(); shutdownErr != nil {
t.Errorf("Failed to shutdown SessionManager: %v", shutdownErr)
}
}()
normalReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
session2, err := sm.GetSession(normalReq)
if err != nil {
t.Fatalf("Second GetSession failed: %v", err)
}
session2.Clear(normalReq, nil)
t.Log("Session returned to pool despite errors")
})
}
// TestSessionObjectTracking tests session object tracking and pool behavior
func TestSessionObjectTracking(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeQuick) {
return
}
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Session pool has New function",
Description: "Verify that session pool is properly configured",
Setup: func(t *testing.T) error {
return nil
},
},
{
Name: "Multiple session creation and disposal",
Description: "Test creating and disposing multiple sessions",
Input: 5,
},
{
Name: "Session with nil mainSession",
Description: "Test error handling with corrupted session state",
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
if test.Setup != nil {
if err := test.Setup(t); err != nil {
t.Fatalf("Setup failed: %v", err)
}
}
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
switch test.Name {
case "Session pool has New function":
hasNew := sm.sessionPool.New != nil
if !hasNew {
t.Error("Expected sessionPool.New function to be set")
}
case "Multiple session creation and disposal":
count := test.Input.(int)
for i := 0; i < count; i++ {
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.ReturnToPool()
}
case "Session with nil mainSession":
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("GetSession failed: %v", err)
}
session.mainSession = nil // Deliberately cause bad state
session.ReturnToPool()
}
runtime.GC()
time.Sleep(100 * time.Millisecond)
t.Log("Session pool handling verified")
})
}
_ = runner
}
// ============================================================================
// TOKEN COMPRESSION AND CHUNKING TESTS
// ============================================================================
// TestTokenCompressionIntegrity tests token compression using comprehensive test cases
func TestTokenCompressionIntegrity(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
// Create comprehensive test cases using edge case generator and test tokens
testCases := []TableTestCase{
{
Name: "Valid JWT Small",
Input: testTokens.GetValidTokenSet().AccessToken,
Expected: true, // Should compress and decompress correctly
},
{
Name: "Valid JWT Large",
Input: testTokens.CreateLargeValidJWT(5000),
Expected: true,
},
{
Name: "Minimal Valid JWT",
Input: MinimalValidJWT,
Expected: true,
},
{
Name: "Invalid JWT Wrong dot count",
Input: InvalidTokenOneDot,
Expected: false, // Should return original for invalid tokens
},
{
Name: "Invalid JWT No dots",
Input: InvalidTokenNoDots,
Expected: false,
},
{
Name: "Invalid JWT Too many dots",
Input: InvalidTokenThreeDots,
Expected: false,
},
{
Name: "Empty token",
Input: "",
Expected: true, // Empty tokens are handled gracefully
},
{
Name: "Oversized token",
Input: testTokens.CreateIncompressibleToken(55000), // >50KB
Expected: false, // Should be rejected
},
}
// Add string edge cases as additional test inputs
stringEdgeCases := edgeGen.GenerateStringEdgeCases()
for i, edgeCase := range stringEdgeCases {
if len(edgeCase) > 0 && len(edgeCase) < 1000 { // Reasonable size for testing
testCases = append(testCases, TableTestCase{
Name: fmt.Sprintf("Edge case string %d", i),
Input: edgeCase,
Expected: true, // Most edge cases should be handled gracefully
})
}
}
for _, test := range testCases {
t.Run(test.Name, func(t *testing.T) {
token := test.Input.(string)
expectValid := test.Expected.(bool)
compressed := compressToken(token)
if !expectValid {
// For invalid tokens, compression should return original
if compressed != token {
t.Errorf("Expected compression to return original for invalid token, got different result")
}
return
}
// For valid tokens, test round-trip integrity
decompressed := decompressToken(compressed)
if decompressed != token {
t.Errorf("Token integrity lost: original=%q, compressed=%q, decompressed=%q",
token, compressed, decompressed)
}
// Test that decompression is idempotent
decompressed2 := decompressToken(decompressed)
if decompressed2 != token {
t.Errorf("Decompression not idempotent: %q != %q", decompressed2, token)
}
})
}
_ = runner
}
// TestTokenCompressionCorruptionDetection tests corruption detection using table-driven approach
func TestTokenCompressionCorruptionDetection(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Invalid base64",
Input: "!@#$%^&*()",
Expected: true, // Should return original
},
{
Name: "Valid base64 but invalid gzip",
Input: base64.StdEncoding.EncodeToString([]byte("not gzip data")),
Expected: true,
},
{
Name: "Truncated gzip data",
Input: "H4sI", // Incomplete gzip header
Expected: true,
},
{
Name: "Empty string",
Input: "",
Expected: true,
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
corruptedInput := test.Input.(string)
expectOriginal := test.Expected.(bool)
result := decompressToken(corruptedInput)
if expectOriginal && result != corruptedInput {
t.Errorf("Expected decompression to return original corrupted input, got: %q", result)
}
})
}
// Test that valid compression still works
t.Run("Valid compression verification", func(t *testing.T) {
validJWT := testTokens.GetValidTokenSet().AccessToken
compressed := compressToken(validJWT)
decompressed := decompressToken(compressed)
if decompressed != validJWT {
t.Errorf("Valid compression/decompression failed: %q != %q", decompressed, validJWT)
}
})
_ = runner
}
// TestTokenChunkingIntegrity tests token chunking using comprehensive test patterns
func TestTokenChunkingIntegrity(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Small token no chunking",
Description: "Small tokens should not be chunked",
Input: struct {
size int
expectChunked bool
}{100, false},
},
{
Name: "Medium token no chunking",
Description: "Medium tokens should not be chunked",
Input: struct {
size int
expectChunked bool
}{800, false},
},
{
Name: "Large token chunking required",
Description: "Large tokens should be chunked",
Input: struct {
size int
expectChunked bool
}{5000, true},
},
{
Name: "Very large token multiple chunks",
Description: "Very large tokens should create multiple chunks",
Input: struct {
size int
expectChunked bool
}{10000, true},
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
params := test.Input.(struct {
size int
expectChunked bool
})
// Create token based on expectation
var token string
if params.expectChunked {
token = testTokens.CreateIncompressibleToken(params.size)
} else {
token = testTokens.CreateLargeValidJWT(params.size)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Store the token
session.SetAccessToken(token)
// Retrieve the token
retrievedToken := session.GetAccessToken()
// Verify integrity
if retrievedToken != token {
t.Errorf("Token integrity lost:\nOriginal: %q\nRetrieved: %q", token, retrievedToken)
}
// Check if chunking occurred as expected
hasChunks := len(session.accessTokenChunks) > 0
if params.expectChunked != hasChunks {
t.Errorf("Chunking expectation mismatch: expected chunked=%v, has chunks=%v",
params.expectChunked, hasChunks)
}
session.ReturnToPool()
})
}
_ = edgeGen
_ = runner
}
// TestTokenChunkingCorruptionResistance tests chunking corruption resistance using table patterns
func TestTokenChunkingCorruptionResistance(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
runner := NewTestSuiteRunner()
// Define corruption scenarios as test cases
corruptionTests := []TableTestCase{
{
Name: "Missing chunk in sequence",
Description: "Test handling when a chunk is missing from sequence",
Input: func(chunks map[int]*sessions.Session) {
if len(chunks) > 1 {
delete(chunks, 1)
}
},
Expected: true, // Expect empty result
},
{
Name: "Empty chunk data",
Description: "Test handling when chunk contains empty data",
Input: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = ""
}
},
Expected: true,
},
{
Name: "Wrong data type in chunk",
Description: "Test handling when chunk contains wrong data type",
Input: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = 123 // Should be string
}
},
Expected: true,
},
{
Name: "Oversized chunk",
Description: "Test handling when chunk exceeds size limits",
Input: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = strings.Repeat("A", maxCookieSize+200)
}
},
Expected: true,
},
}
for _, test := range corruptionTests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create a large token that will be chunked
largeToken := testTokens.CreateIncompressibleToken(8000)
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Store the token (this should create chunks)
session.SetAccessToken(largeToken)
if len(session.accessTokenChunks) == 0 {
t.Skip("Token was not chunked, skipping corruption test")
}
// Apply corruption using the test input function
corruptFunc := test.Input.(func(map[int]*sessions.Session))
corruptFunc(session.accessTokenChunks)
// Try to retrieve the token
retrievedToken := session.GetAccessToken()
expectEmpty := test.Expected.(bool)
if expectEmpty {
if retrievedToken != "" {
t.Errorf("Expected empty token due to corruption, got: %q", retrievedToken)
}
} else {
if retrievedToken != largeToken {
t.Errorf("Expected original token despite corruption, got: %q", retrievedToken)
}
}
session.ReturnToPool()
})
}
_ = corruptionTests
_ = runner
}
// TestTokenSizeLimits tests token size limit enforcement using table-driven tests
func TestTokenSizeLimits(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Normal size token",
Input: 1000,
Expected: true,
},
{
Name: "Large but acceptable token",
Input: 20000, // 20KB
Expected: true,
},
{
Name: "Oversized token rejection",
Input: 120000, // 120KB
Expected: false, // Should be rejected
},
}
// Add integer edge cases for token sizes
intEdgeCases := edgeGen.GenerateIntegerEdgeCases()
for _, size := range intEdgeCases {
if size > 0 && size < 100000 {
tests = append(tests, TableTestCase{
Name: fmt.Sprintf("Edge case size %d", size),
Input: size,
Expected: size < 100000, // Reasonable threshold
})
}
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
tokenSize := test.Input.(int)
expectStored := test.Expected.(bool)
var token string
if expectStored {
token = testTokens.CreateLargeValidJWT(tokenSize)
} else {
token = testTokens.CreateIncompressibleToken(tokenSize)
}
// Store the token
session.SetAccessToken(token)
// Try to retrieve it
retrievedToken := session.GetAccessToken()
if expectStored {
if retrievedToken != token {
t.Errorf("Expected token to be stored and retrieved, but got different token")
}
} else {
if retrievedToken == token {
t.Errorf("Expected oversized token to be rejected, but it was stored")
}
}
})
}
_ = runner
}
// TestConcurrentTokenOperations tests thread safety using structured test patterns
func TestConcurrentTokenOperations(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeConcurrencyStress) {
return
}
testTokens := NewTestTokens()
runner := NewTestSuiteRunner()
// Test concurrent operations using memory leak test pattern
memoryTests := []MemoryLeakTestCase{
{
Name: "Concurrent token operations",
Description: "Test thread safety of concurrent token operations",
Iterations: 50,
MaxGoroutineGrowth: 5, // Allow some growth for goroutines
MaxMemoryGrowthMB: 10.0,
GCBetweenRuns: true,
Timeout: 60 * time.Second,
Operation: func() error {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
return fmt.Errorf("failed to create session manager: %w", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
defer session.ReturnToPool()
const numGoroutines = 10
const numOperations = 100
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
for j := 0; j < numOperations; j++ {
// Create unique tokens for each goroutine/operation
accessToken := testTokens.CreateUniqueValidJWT(fmt.Sprintf("%d_%d", id, j))
refreshToken := fmt.Sprintf("refresh_token_%d_%d", id, j)
// Concurrent operations
session.SetAccessToken(accessToken)
session.SetRefreshToken(refreshToken)
retrievedAccess := session.GetAccessToken()
retrievedRefresh := session.GetRefreshToken()
// Verify tokens are still valid (should be one of the tokens set by any goroutine)
if retrievedAccess != "" && strings.Count(retrievedAccess, ".") != 2 {
// Note: In concurrent access, we can't guarantee exact token match
// but we can verify format is still valid
}
if retrievedRefresh != "" && len(retrievedRefresh) < 10 {
// Verify minimum reasonable length
}
}
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
<-done
}
return nil
},
},
}
runner.RunMemoryLeakTests(t, memoryTests)
_ = testTokens
}
// TestSessionValidationAndCleanup tests session validation using comprehensive patterns
func TestSessionValidationAndCleanup(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
testTokens := NewTestTokens()
edgeGen := NewEdgeCaseGenerator()
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Session creation and token storage",
Description: "Test basic session validation and cleanup",
},
{
Name: "Large token chunking validation",
Description: "Test validation with tokens that require chunking",
},
{
Name: "Session cleanup verification",
Description: "Test that sessions are properly cleaned up",
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
switch test.Name {
case "Session creation and token storage":
// Test with normal tokens
tokenSet := testTokens.GetValidTokenSet()
session.SetAccessToken(tokenSet.AccessToken)
session.SetRefreshToken(tokenSet.RefreshToken)
case "Large token chunking validation":
// Set tokens that will create chunks
largeTokenSet := testTokens.GetLargeTokenSet()
session.SetAccessToken(largeTokenSet.AccessToken)
session.SetRefreshToken(largeTokenSet.RefreshToken)
case "Session cleanup verification":
// Set tokens and then clear them
session.SetAccessToken(testTokens.GetValidTokenSet().AccessToken)
session.SetRefreshToken("refresh_token_test")
}
// Save session to create cookies
if err := session.Save(req, rw); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// For cleanup test, verify clearing works
if test.Name == "Session cleanup verification" {
if err := session.Clear(req, rw); err != nil {
t.Logf("Clear returned error (may be expected): %v", err)
}
// Verify tokens are cleared
if token := session.GetAccessToken(); token != "" {
t.Errorf("Access token should be empty after clear, got: %q", token)
}
if token := session.GetRefreshToken(); token != "" {
t.Errorf("Refresh token should be empty after clear, got: %q", token)
}
}
})
}
_ = edgeGen
_ = runner
}
// TestLargeIDTokenChunking tests ID token chunking using structured approach
func TestLargeIDTokenChunking(t *testing.T) {
config := GetTestConfig()
if config.ShouldSkipTest(t, TestTypeExtended) {
return
}
runner := NewTestSuiteRunner()
tests := []TableTestCase{
{
Name: "Large ID token chunking 20KB",
Description: "Test that large ID tokens are properly chunked",
Input: 20000,
Expected: 2, // Expect at least 2 chunks
},
{
Name: "Very large ID token chunking 50KB",
Description: "Test very large ID token chunking",
Input: 50000,
Expected: 5, // Expect at least 5 chunks
},
}
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
tokenSize := test.Input.(int)
minExpectedChunks := test.Expected.(int)
// Create a large ID token
largeIDToken := createLargeIDToken(tokenSize)
t.Logf("Created large ID token with length: %d", len(largeIDToken))
// Create a request and response recorder
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
// Get session and set large ID token
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set the large ID token
session.SetIDToken(largeIDToken)
t.Logf("Set large ID token in session")
// Save the session to trigger chunking
err = session.Save(req, rr)
if err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify token retrieval integrity
retrievedToken := session.GetIDToken()
t.Logf("Retrieved ID token length: %d", len(retrievedToken))
if len(retrievedToken) != len(largeIDToken) {
t.Errorf("Token length mismatch: expected %d, got %d", len(largeIDToken), len(retrievedToken))
}
// Verify that chunked cookies were created
cookies := rr.Result().Cookies()
t.Logf("Total cookies in response: %d", len(cookies))
var chunkCookies []*http.Cookie
idTokenCookieName := defaultCookiePrefix + idTokenSuffix
for _, cookie := range cookies {
if strings.HasPrefix(cookie.Name, idTokenCookieName+"_") {
chunkCookies = append(chunkCookies, cookie)
}
}
// Verify minimum expected chunks
if len(chunkCookies) < minExpectedChunks {
t.Fatalf("Expected at least %d chunk cookies, got %d", minExpectedChunks, len(chunkCookies))
}
// Test token retrieval from chunked cookies
newReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
retrievedSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get session from chunked cookies: %v", err)
}
retrievedToken2 := retrievedSession.GetIDToken()
// Verify the retrieved token matches the original
if retrievedToken2 != largeIDToken {
t.Errorf("Retrieved ID token doesn't match original. Expected length: %d, got: %d",
len(largeIDToken), len(retrievedToken2))
}
// Test clearing the ID token removes all chunks
retrievedSession.SetIDToken("")
clearRR := httptest.NewRecorder()
err = retrievedSession.Save(newReq, clearRR)
if err != nil {
t.Fatalf("Failed to save session after clearing ID token: %v", err)
}
// Verify chunks are expired (MaxAge = -1)
clearCookies := clearRR.Result().Cookies()
idTokenCookieName2 := defaultCookiePrefix + idTokenSuffix
for _, cookie := range clearCookies {
if strings.HasPrefix(cookie.Name, idTokenCookieName2+"_") {
if cookie.MaxAge != -1 {
t.Errorf("Expected chunk cookie %s to be expired (MaxAge=-1), got MaxAge=%d",
cookie.Name, cookie.MaxAge)
}
}
}
})
}
_ = runner
}
// ============================================================================
// CONSOLIDATED SESSION TESTS
// ============================================================================
// TestSessionConsolidated runs all consolidated session tests
func TestSessionConsolidated(t *testing.T) {
testCases := []SessionTestCase{
// Session Creation Tests
{
name: "session_basic_creation",
scenario: "creation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Simulate session creation
req := httptest.NewRequest("GET", "http://example.com/", nil)
f.requests = append(f.requests, req)
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Session creation should succeed")
assert.Greater(t, f.metrics.SessionsCreated, int64(0), "Session should be created")
},
},
{
name: "session_pool_reuse",
scenario: "creation",
sessionType: "user",
iterations: 100,
execute: func(f *SessionTestFramework) error {
for i := 0; i < 100; i++ {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "Sessions should be properly pooled")
},
},
{
name: "session_concurrent_creation",
scenario: "creation",
sessionType: "user",
concurrent: true,
iterations: 50,
execute: func(f *SessionTestFramework) error {
var wg sync.WaitGroup
errs := make(chan error, 50)
for i := 0; i < 50; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Simulate concurrent session creation
req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%d", id), nil)
f.mu.Lock()
f.requests = append(f.requests, req)
f.mu.Unlock()
}(i)
}
wg.Wait()
close(errs)
for err := range errs {
if err != nil {
return err
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, int64(50), f.metrics.SessionsCreated, "All concurrent sessions should be created")
},
},
// Session Validation Tests
{
name: "session_token_validation",
scenario: "validation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
token := f.generateTestToken("access", 3600)
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Validate token format
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Token validation should succeed")
assert.Greater(t, f.metrics.TokensValidated, int64(0))
},
},
{
name: "session_corrupted_token_detection",
scenario: "validation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
token := f.generateTestToken("access", 3600)
// Corrupt the token by modifying the signature
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
// Corrupt the signature part
corrupted := parts[0] + "." + parts[1] + ".corrupted!"
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Validate should detect corruption - corrupted tokens should fail validation
corruptedParts := strings.Split(corrupted, ".")
if len(corruptedParts) == 3 {
// Try to decode the corrupted signature
_, err := base64.RawURLEncoding.DecodeString(corruptedParts[2])
if err == nil {
return fmt.Errorf("corruption not detected")
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Corruption detection should work")
},
},
{
name: "session_expired_token_handling",
scenario: "validation",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Generate an expired token
token := f.generateTestToken("access", -3600) // negative expiry
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Parse and check expiry
parts := strings.Split(token, ".")
if len(parts) == 3 {
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
var claims map[string]interface{}
json.Unmarshal(payload, &claims)
if exp, ok := claims["exp"].(float64); ok {
if exp < float64(time.Now().Unix()) {
atomic.AddInt64(&f.metrics.ErrorCount, 1)
return nil // Expected behavior
}
}
}
return fmt.Errorf("expired token not detected")
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Expired token should be detected")
assert.Greater(t, f.metrics.ErrorCount, int64(0))
},
},
// Session Expiration Tests
{
name: "session_ttl_expiration",
scenario: "expiration",
sessionType: "user",
timeout: 3 * time.Second,
execute: func(f *SessionTestFramework) error {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Simulate session with short TTL
time.Sleep(100 * time.Millisecond) // Don't sleep for full timeout
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
},
},
{
name: "session_refresh_token_expiry",
scenario: "expiration",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
refreshToken := f.generateTestToken("refresh", 86400)
atomic.AddInt64(&f.metrics.TokensValidated, 1)
// Check refresh token is valid for longer period
parts := strings.Split(refreshToken, ".")
if len(parts) == 3 {
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
var claims map[string]interface{}
json.Unmarshal(payload, &claims)
if exp, ok := claims["exp"].(float64); ok {
timeUntilExpiry := time.Until(time.Unix(int64(exp), 0))
if timeUntilExpiry < 23*time.Hour {
return fmt.Errorf("refresh token expiry too short: %v", timeUntilExpiry)
}
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Refresh token should have correct expiry")
},
},
// Session Persistence Tests
{
name: "session_cookie_persistence",
scenario: "persistence",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
req := httptest.NewRequest("GET", "http://example.com/", nil)
w := httptest.NewRecorder()
// Set session cookie
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "test-session-123",
Path: "/",
HttpOnly: true,
Secure: f.config.EnableHTTPS,
SameSite: http.SameSiteLaxMode,
})
f.requests = append(f.requests, req)
f.responses = append(f.responses, w)
// Verify cookie was set
cookies := w.Result().Cookies()
if len(cookies) == 0 {
return fmt.Errorf("no cookies set")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.NotEmpty(t, f.responses, "Response should be recorded")
},
},
{
name: "session_state_preservation",
scenario: "persistence",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Store state
state := map[string]interface{}{
"user_id": "test-user",
"email": "test@example.com",
"roles": []string{"user", "admin"},
}
// Serialize and deserialize to test persistence
data, err := json.Marshal(state)
if err != nil {
return err
}
var restored map[string]interface{}
if err := json.Unmarshal(data, &restored); err != nil {
return err
}
// Verify state preserved
if restored["user_id"] != state["user_id"] {
return fmt.Errorf("state not preserved")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Session state should be preserved")
},
},
// Session Cleanup Tests
{
name: "session_proper_cleanup",
scenario: "cleanup",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Create and destroy sessions
for i := 0; i < 10; i++ {
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
sessionID := fmt.Sprintf("session-%d", i)
f.sessionIDs = append(f.sessionIDs, sessionID)
}
// Cleanup all sessions
for range f.sessionIDs {
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}
f.sessionIDs = nil
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err)
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
assert.Empty(t, f.sessionIDs, "All sessions should be cleaned up")
},
},
{
name: "session_goroutine_leak_prevention",
scenario: "cleanup",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
initialGoroutines := runtime.NumGoroutine()
// Create sessions that might spawn goroutines
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
time.Sleep(10 * time.Millisecond)
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}(i)
}
wg.Wait()
runtime.GC()
time.Sleep(100 * time.Millisecond)
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > initialGoroutines+2 { // Allow small variance
return fmt.Errorf("goroutine leak detected: %d -> %d", initialGoroutines, finalGoroutines)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "No goroutine leaks should occur")
},
},
// Session Chunking Tests
{
name: "session_large_token_chunking",
scenario: "chunking",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Generate a large token that requires chunking
largeToken := f.generateLargeToken(10000) // 10KB token
// Calculate expected chunks
chunkSize := f.config.MaxChunkSize
expectedChunks := (len(largeToken) + chunkSize - 1) / chunkSize
// Simulate chunking
chunks := make([]string, 0)
for i := 0; i < len(largeToken); i += chunkSize {
end := i + chunkSize
if end > len(largeToken) {
end = len(largeToken)
}
chunks = append(chunks, largeToken[i:end])
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
}
if len(chunks) != expectedChunks {
return fmt.Errorf("expected %d chunks, got %d", expectedChunks, len(chunks))
}
// Simulate reconstruction
reconstructed := strings.Join(chunks, "")
if reconstructed != largeToken {
return fmt.Errorf("token reconstruction failed")
}
atomic.AddInt64(&f.metrics.ChunksRetrieved, int64(len(chunks)))
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Token chunking should work correctly")
assert.Greater(t, f.metrics.ChunksCreated, int64(0))
assert.Equal(t, f.metrics.ChunksCreated, f.metrics.ChunksRetrieved)
},
},
{
name: "session_chunk_boundary_validation",
scenario: "chunking",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Test exact boundary conditions
testSizes := []int{
f.config.MaxChunkSize - 1,
f.config.MaxChunkSize,
f.config.MaxChunkSize + 1,
f.config.MaxChunkSize * 2,
f.config.MaxChunkSize*2 - 1,
f.config.MaxChunkSize*2 + 1,
}
for _, size := range testSizes {
token := f.generateLargeToken(size)
actualSize := len(token)
expectedChunks := (actualSize + f.config.MaxChunkSize - 1) / f.config.MaxChunkSize
actualChunks := 0
for i := 0; i < len(token); i += f.config.MaxChunkSize {
actualChunks++
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
}
if actualChunks != expectedChunks {
return fmt.Errorf("size %d (actual token size %d): expected %d chunks, got %d", size, actualSize, expectedChunks, actualChunks)
}
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Chunk boundaries should be handled correctly")
},
},
// Session Security Tests
{
name: "session_csrf_token_management",
scenario: "security",
sessionType: "csrf",
execute: func(f *SessionTestFramework) error {
// Generate CSRF token
csrfToken := make([]byte, 32)
if _, err := rand.Read(csrfToken); err != nil {
return err
}
csrfString := base64.RawURLEncoding.EncodeToString(csrfToken)
// Store in session
f.testTokens["csrf"] = csrfString
// Validate CSRF token
if len(csrfString) < 40 {
return fmt.Errorf("CSRF token too short")
}
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
atomic.AddInt64(&f.metrics.TokensValidated, 1)
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "CSRF token should be properly managed")
assert.NotEmpty(t, f.testTokens["csrf"])
},
},
{
name: "session_injection_prevention",
scenario: "security",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
// Test various injection attempts
maliciousInputs := []string{
`{"admin": true}`,
`<script>alert('xss')</script>`,
`'; DROP TABLE sessions; --`,
`../../../etc/passwd`,
string([]byte{0x00, 0x01, 0x02}), // null bytes
}
for _, input := range maliciousInputs {
// Validate that input is properly sanitized
sanitized := base64.StdEncoding.EncodeToString([]byte(input))
decoded, err := base64.StdEncoding.DecodeString(sanitized)
if err != nil {
return err
}
if string(decoded) != input {
return fmt.Errorf("sanitization changed input unexpectedly")
}
atomic.AddInt64(&f.metrics.TokensValidated, 1)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Injection attempts should be handled safely")
},
},
{
name: "session_secure_cookie_settings",
scenario: "security",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
w := httptest.NewRecorder()
// Test secure cookie settings
cookie := &http.Cookie{
Name: "session",
Value: "test-session",
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
}
http.SetCookie(w, cookie)
// Verify cookie attributes
cookies := w.Result().Cookies()
if len(cookies) == 0 {
return fmt.Errorf("no cookie set")
}
c := cookies[0]
if !c.HttpOnly {
return fmt.Errorf("cookie not HttpOnly")
}
if c.SameSite != http.SameSiteStrictMode {
return fmt.Errorf("incorrect SameSite setting")
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Secure cookie settings should be enforced")
},
},
// Session Stress Tests
{
name: "session_high_concurrency_stress",
scenario: "creation",
sessionType: "user",
concurrent: true,
iterations: 1000,
timeout: 30 * time.Second,
execute: func(f *SessionTestFramework) error {
var wg sync.WaitGroup
errors := make([]error, 0)
// Run high concurrency test
concurrency := 100
iterations := 10
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Create session
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Generate tokens
f.generateTestToken("access", 3600)
f.generateTestToken("refresh", 86400)
// Validate tokens
atomic.AddInt64(&f.metrics.TokensValidated, 2)
// Cleanup session
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
// Small delay to simulate real usage
time.Sleep(time.Millisecond)
}
}(i)
}
wg.Wait()
if len(errors) > 0 {
return errors[0]
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "High concurrency stress test should pass")
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "All sessions should be cleaned up")
},
},
{
name: "session_memory_bounds_enforcement",
scenario: "cleanup",
sessionType: "user",
execute: func(f *SessionTestFramework) error {
maxSessions := f.config.MaxSessions
// Try to create more sessions than allowed
for i := 0; i < maxSessions+100; i++ {
sessionID := fmt.Sprintf("session-%d", i)
f.sessionIDs = append(f.sessionIDs, sessionID)
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
// Enforce max sessions
if len(f.sessionIDs) > maxSessions {
// Remove oldest session
f.sessionIDs = f.sessionIDs[1:]
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
}
}
if len(f.sessionIDs) > maxSessions {
return fmt.Errorf("max sessions exceeded: %d > %d", len(f.sessionIDs), maxSessions)
}
return nil
},
validate: func(t *testing.T, err error, f *SessionTestFramework) {
assert.NoError(t, err, "Memory bounds should be enforced")
assert.LessOrEqual(t, len(f.sessionIDs), f.config.MaxSessions)
},
},
}
// Run all test cases
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.skipReason != "" {
t.Skip(tc.skipReason)
}
framework := NewSessionTestFramework(t)
defer framework.Cleanup()
// Setup
if tc.setup != nil {
tc.setup(framework)
}
// Cleanup
if tc.cleanup != nil {
defer tc.cleanup(framework)
}
// Set timeout if specified
if tc.timeout > 0 {
timer := time.NewTimer(tc.timeout)
done := make(chan bool)
go func() {
err := tc.execute(framework)
tc.validate(t, err, framework)
done <- true
}()
select {
case <-done:
timer.Stop()
case <-timer.C:
t.Fatal("Test timeout exceeded")
}
} else {
// Execute test
err := tc.execute(framework)
// Validate results
tc.validate(t, err, framework)
}
})
}
}
// ============================================================================
// SESSION STATE PRESERVATION TESTS (6-HOUR TOKEN EXPIRY SCENARIOS)
// ============================================================================
// TestSessionStatePreservationWithExpiredTokens tests that session state is preserved
// during token expiry scenarios
func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
t.Log("Testing session state preservation with expired tokens")
logger := NewLogger("debug")
sm, err := NewSessionManager("test-session-key-32-bytes-long-12345", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Simulate real-world session data that should be preserved
originalUserData := map[string]interface{}{
"user_id": "user-12345",
"email": "test.user@company.com",
"name": "Test User",
"roles": []string{"admin", "user"},
"pref_theme": "dark",
"pref_lang": "en",
"last_active": "2023-01-01T10:00:00Z",
}
// Create initial session with valid tokens
req1 := httptest.NewRequest("GET", "/initial", nil)
rr1 := httptest.NewRecorder()
session1, err := sm.GetSession(req1)
if err != nil {
t.Fatalf("Failed to get initial session: %v", err)
}
// Set up initial session state (what user has when first logging in)
session1.SetAuthenticated(true)
session1.SetUserIdentifier(originalUserData["email"].(string))
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
// Store additional user data in session - store individual values instead of map
for k, v := range originalUserData {
session1.mainSession.Values["user_data_"+k] = v
}
session1.mainSession.Values["session_created"] = time.Now().Unix() // Store as int64 for gob
session1.mainSession.Values["custom_flag"] = true
if err := session1.Save(req1, rr1); err != nil {
t.Fatalf("Failed to save initial session: %v", err)
}
initialCookies := rr1.Result().Cookies()
session1.ReturnToPool()
t.Log("Initial session created with user data")
// Fast-forward 6 hours - tokens expire due to browser inactivity
time.Sleep(10 * time.Millisecond) // Simulate time passage in test
// Create expired tokens (simulating what happens after 6 hours)
expiredTime := time.Now().Add(-6 * time.Hour)
expiredAccessToken := createExpiredJWTToken("user-12345", "test.user@company.com", expiredTime)
expiredIDToken := createExpiredJWTToken("user-12345", "test.user@company.com", expiredTime)
// User returns after inactivity and makes a request
req2 := httptest.NewRequest("GET", "/protected-resource", nil)
for _, cookie := range initialCookies {
req2.AddCookie(cookie)
}
session2, err := sm.GetSession(req2)
if err != nil {
t.Fatalf("Failed to get session after 6 hours: %v", err)
}
defer session2.ReturnToPool()
// Simulate what happens when middleware detects expired tokens
// It should preserve session state while attempting token refresh
originalAuth := session2.GetAuthenticated()
originalEmail := session2.GetUserIdentifier()
// Reconstruct user data from individual stored keys
originalUserDataStored := make(map[string]interface{})
for k := range originalUserData {
if storedValue, exists := session2.mainSession.Values["user_data_"+k]; exists {
originalUserDataStored[k] = storedValue
}
}
// Update session with expired tokens (what middleware does when tokens expire)
session2.SetAccessToken(expiredAccessToken)
session2.SetIDToken(expiredIDToken)
// Refresh token should still be valid
t.Log("Session loaded after 6-hour expiry, checking state preservation")
// Verify authentication state is preserved
if !originalAuth {
t.Error("Authentication state lost during session reload")
}
// Verify email is preserved
if originalEmail != originalUserData["email"].(string) {
t.Errorf("User email lost during session reload - Expected: %s, Got: %s",
originalUserData["email"], originalEmail)
}
// Verify custom user data is preserved
if len(originalUserDataStored) == 0 {
t.Error("All custom user data lost during session reload")
} else {
if originalUserDataStored["user_id"] != originalUserData["user_id"] {
t.Error("User ID lost from session data")
}
if originalUserDataStored["name"] != originalUserData["name"] {
t.Error("User name lost from session data")
}
if originalUserDataStored["pref_theme"] != originalUserData["pref_theme"] {
t.Error("User theme preference lost from session data")
}
if originalUserDataStored["pref_lang"] != originalUserData["pref_lang"] {
t.Error("User language preference lost from session data")
}
}
// Note: System may reject invalid/expired tokens during storage, which is acceptable behavior
currentAccessToken := session2.GetAccessToken()
if currentAccessToken != expiredAccessToken {
t.Logf("INFO: Access token was not stored (possibly rejected due to expiry) - Expected: %s, Got: %s",
expiredAccessToken, currentAccessToken)
}
// Verify that session can be saved again after token expiry without losing data
rr2 := httptest.NewRecorder()
if err := session2.Save(req2, rr2); err != nil {
t.Errorf("Cannot save session after token expiry: %v", err)
} else {
t.Log("Session successfully saved after token expiry")
// Verify cookies are still set
newCookies := rr2.Result().Cookies()
if len(newCookies) == 0 {
t.Error("No session cookies set after saving expired token session")
}
}
// Test session recovery after token refresh simulation
newAccessToken := "refreshed-access-token-longer-than-20-chars"
newIDToken := "refreshed-id-token-longer-than-20-chars"
newRefreshToken := "new-refresh-token-after-successful-renewal"
session2.SetAccessToken(newAccessToken)
session2.SetIDToken(newIDToken)
session2.SetRefreshToken(newRefreshToken)
// Verify all session data is still intact after token refresh
postRefreshAuth := session2.GetAuthenticated()
postRefreshEmail := session2.GetUserIdentifier()
userDataPresent := true
for k := range originalUserData {
if session2.mainSession.Values["user_data_"+k] == nil {
userDataPresent = false
break
}
}
if !postRefreshAuth {
t.Error("Authentication state lost after token refresh")
}
if postRefreshEmail != originalUserData["email"].(string) {
t.Error("User email lost after token refresh")
}
if !userDataPresent {
t.Error("User data lost after token refresh")
}
t.Log("Session state preservation test completed")
}
// TestSessionExpiryVsTokenExpiry tests the distinction between session expiry and token expiry
func TestSessionExpiryVsTokenExpiry(t *testing.T) {
t.Log("Testing session expiry vs token expiry distinction")
logger := NewLogger("debug")
sm, err := NewSessionManager("session-vs-token-test-key-32-bytes", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
scenarios := []struct {
name string
expectedBehavior string
sessionAge time.Duration
tokenExpiry time.Duration
sessionShouldExpire bool
tokenShouldRefresh bool
}{
{
name: "New session, expired tokens",
sessionAge: 5 * time.Minute,
tokenExpiry: -6 * time.Hour,
expectedBehavior: "Session valid, tokens should refresh",
sessionShouldExpire: false,
tokenShouldRefresh: true,
},
{
name: "Old session, valid tokens",
sessionAge: 25 * time.Hour,
tokenExpiry: 2 * time.Hour,
expectedBehavior: "Session expired, redirect to login even with valid tokens",
sessionShouldExpire: true,
tokenShouldRefresh: false,
},
{
name: "Both session and tokens expired",
sessionAge: 25 * time.Hour,
tokenExpiry: -6 * time.Hour,
expectedBehavior: "Both expired, clear session and redirect to login",
sessionShouldExpire: true,
tokenShouldRefresh: false,
},
{
name: "Recent session, recently expired tokens",
sessionAge: 30 * time.Minute,
tokenExpiry: -10 * time.Minute,
expectedBehavior: "Session valid, tokens recently expired, should refresh",
sessionShouldExpire: false,
tokenShouldRefresh: true,
},
}
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
t.Logf("Testing: %s", scenario.expectedBehavior)
// Create session at specific "age"
sessionCreatedAt := time.Now().Add(-scenario.sessionAge)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Set up session with specific creation time
session.SetAuthenticated(true)
session.SetUserIdentifier("test@example.com")
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
// Create tokens with specific expiry
tokenExpiredAt := time.Now().Add(scenario.tokenExpiry)
accessToken := createExpiredJWTToken("test-user", "test@example.com", tokenExpiredAt)
session.SetAccessToken(accessToken)
session.SetRefreshToken("test-refresh-token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Test session validity check
isSessionExpired := scenario.sessionAge > absoluteSessionTimeout
isTokenExpired := scenario.tokenExpiry < 0
t.Logf("Session age: %v (expired: %t)", scenario.sessionAge, isSessionExpired)
t.Logf("Token expiry: %v ago (expired: %t)", -scenario.tokenExpiry, isTokenExpired)
if scenario.sessionShouldExpire {
if isSessionExpired && session.GetAuthenticated() {
t.Errorf("Session should be expired after %v but is still authenticated", scenario.sessionAge)
}
} else {
if !isSessionExpired && !session.GetAuthenticated() {
t.Errorf("Session should be valid (age: %v) but shows as not authenticated", scenario.sessionAge)
}
}
if scenario.tokenShouldRefresh {
if !isTokenExpired {
t.Errorf("Test setup error - tokens should be expired but expiry is: %v", scenario.tokenExpiry)
}
t.Logf("Should attempt token refresh for scenario: %s", scenario.name)
} else {
if isSessionExpired {
t.Logf("Correctly identified that session is expired - no need to refresh tokens")
}
}
// Check for critical scenario: confusing session expiry with token expiry
if !isSessionExpired && isTokenExpired {
t.Logf("CRITICAL SCENARIO: Valid session (%v old) but expired tokens (%v ago)",
scenario.sessionAge, -scenario.tokenExpiry)
t.Logf("Expected: System should refresh tokens and continue session")
if scenario.name == "New session, expired tokens" && scenario.tokenExpiry == -6*time.Hour {
t.Logf("This represents the 6-hour browser inactivity scenario")
}
}
})
}
}
// TestSessionCleanupOnTokenExpiry tests that session cleanup happens correctly
func TestSessionCleanupOnTokenExpiry(t *testing.T) {
t.Log("Testing session cleanup on token expiry")
logger := NewLogger("debug")
sm, err := NewSessionManager("cleanup-test-key-32-bytes-long-123", false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
scenarios := []struct {
name string
shouldPreserve []string
shouldRemove []string
tokenExpiry time.Duration
shouldCleanup bool
}{
{
name: "Recently expired tokens - preserve session",
tokenExpiry: -30 * time.Minute,
shouldCleanup: false,
shouldPreserve: []string{"user_data", "preferences", "authentication"},
shouldRemove: []string{},
},
{
name: "Long expired tokens - cleanup selectively",
tokenExpiry: -25 * time.Hour,
shouldCleanup: true,
shouldPreserve: []string{},
shouldRemove: []string{"user_data", "preferences", "authentication"},
},
{
name: "6-hour expired tokens - preserve for refresh",
tokenExpiry: -6 * time.Hour,
shouldCleanup: false,
shouldPreserve: []string{"user_data", "preferences", "authentication"},
shouldRemove: []string{},
},
}
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
t.Logf("Testing cleanup behavior: %s", scenario.name)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Set up session with data that should be preserved or removed
session.SetAuthenticated(true)
session.SetUserIdentifier("cleanup@example.com")
session.mainSession.Values["user_data"] = "Test User|user-123"
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
session.mainSession.Values["authentication"] = true
session.mainSession.Values["temp_data"] = "should-be-cleaned"
// Set expired tokens
expiredTime := time.Now().Add(scenario.tokenExpiry)
expiredToken := createExpiredJWTToken("user-123", "cleanup@example.com", expiredTime)
session.SetAccessToken(expiredToken)
session.SetRefreshToken("test-refresh-token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Simulate token expiry detection and cleanup logic
tokenExpired := scenario.tokenExpiry < 0
sessionTooOld := scenario.tokenExpiry < -absoluteSessionTimeout
t.Logf("Token expired: %t, Session too old: %t", tokenExpired, sessionTooOld)
// Check current session state before cleanup
preCleanupAuth := session.GetAuthenticated()
preCleanupData := session.mainSession.Values["user_data"]
preCleanupPrefs := session.mainSession.Values["preferences"]
if scenario.shouldCleanup {
if sessionTooOld {
session.SetAuthenticated(false)
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
for key := range session.mainSession.Values {
delete(session.mainSession.Values, key)
}
t.Log("Applied full cleanup for expired session")
}
} else {
t.Log("Preserving session for token refresh")
}
// Check post-cleanup state
postCleanupAuth := session.GetAuthenticated()
postCleanupData := session.mainSession.Values["user_data"]
postCleanupPrefs := session.mainSession.Values["preferences"]
// Verify preservation expectations
for _, item := range scenario.shouldPreserve {
switch item {
case "authentication":
if !postCleanupAuth && preCleanupAuth {
t.Errorf("Authentication state was cleaned up but should be preserved")
}
case "user_data":
if postCleanupData == nil && preCleanupData != nil {
t.Errorf("User data was cleaned up but should be preserved")
}
case "preferences":
if postCleanupPrefs == nil && preCleanupPrefs != nil {
t.Errorf("User preferences were cleaned up but should be preserved")
}
}
}
// Verify removal expectations
for _, item := range scenario.shouldRemove {
switch item {
case "authentication":
if postCleanupAuth && scenario.shouldCleanup {
t.Errorf("Authentication state not cleaned up when it should be")
}
case "user_data":
if postCleanupData != nil && scenario.shouldCleanup {
t.Errorf("User data not cleaned up when session is expired")
}
}
}
// Check the critical 6-hour scenario
if scenario.tokenExpiry == -6*time.Hour {
if !postCleanupAuth {
t.Error("6-hour token expiry caused session cleanup - session should be preserved for token refresh")
}
if postCleanupData == nil {
t.Error("6-hour token expiry caused user data loss - user data should be preserved during token refresh")
}
}
})
}
}
// ============================================================================
// HELPER FUNCTIONS
// ============================================================================
// Helper function to count objects in the session pool for a given manager
func getPooledObjects(sm *SessionManager) int {
var objects []*SessionData
maxAttempts := 100
for i := 0; i < maxAttempts; i++ {
obj := sm.sessionPool.Get()
if obj == nil {
break
}
sessionData, ok := obj.(*SessionData)
if !ok {
sm.sessionPool.Put(obj)
break
}
objects = append(objects, sessionData)
}
count := len(objects)
for _, obj := range objects {
sm.sessionPool.Put(obj)
}
return count
}
// createLargeIDToken creates a JWT-like token of specified size for testing
func createLargeIDToken(size int) string {
randomBytes := make([]byte, size*3/4)
_, err := rand.Read(randomBytes)
if err != nil {
for i := range randomBytes {
randomBytes[i] = byte(i % 256)
}
}
encoded := base64.RawURLEncoding.EncodeToString(randomBytes)
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
if len(encoded) > size-len(header)-100 {
encoded = encoded[:size-len(header)-100]
}
signature := "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
return header + "." + encoded + "." + signature
}
// minInt returns the minimum of two integers
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// Helper function to create expired JWT tokens for testing
func createExpiredJWTToken(userID, email string, expiredTime time.Time) string {
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
claims := map[string]interface{}{
"sub": userID,
"email": email,
"exp": expiredTime.Unix(),
"iat": expiredTime.Add(-1 * time.Hour).Unix(),
"iss": "https://test-provider.com",
"aud": "test-client-id",
}
claimsJSON, _ := json.Marshal(claims)
claimsEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON)
signature := "fake-signature-for-testing"
signatureEncoded := base64.RawURLEncoding.EncodeToString([]byte(signature))
return header + "." + claimsEncoded + "." + signatureEncoded
}
// TestCookiePrefixIsolation tests that different cookie prefixes create isolated sessions
// This addresses GitHub issue #87 where multiple middleware instances should not share sessions
func TestCookiePrefixIsolation(t *testing.T) {
logger := NewLogger("info")
encryptionKey := strings.Repeat("a", 32)
// Create two session managers with different cookie prefixes
sm1, err := NewSessionManager(encryptionKey, false, "", "_oidc_userauth_", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager 1: %v", err)
}
sm2, err := NewSessionManager(encryptionKey, false, "", "_oidc_adminauth_", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager 2: %v", err)
}
// Verify cookie names are different
if sm1.mainCookieName() == sm2.mainCookieName() {
t.Errorf("Expected different main cookie names, got same: %s", sm1.mainCookieName())
}
if sm1.accessTokenCookieName() == sm2.accessTokenCookieName() {
t.Errorf("Expected different access token cookie names, got same: %s", sm1.accessTokenCookieName())
}
// Verify cookie names have the correct prefix
expectedPrefix1 := "_oidc_userauth_"
expectedPrefix2 := "_oidc_adminauth_"
if !strings.HasPrefix(sm1.mainCookieName(), expectedPrefix1) {
t.Errorf("Expected main cookie name to start with %s, got %s", expectedPrefix1, sm1.mainCookieName())
}
if !strings.HasPrefix(sm2.mainCookieName(), expectedPrefix2) {
t.Errorf("Expected main cookie name to start with %s, got %s", expectedPrefix2, sm2.mainCookieName())
}
t.Logf("Session Manager 1 cookies: main=%s, access=%s, refresh=%s, id=%s",
sm1.mainCookieName(), sm1.accessTokenCookieName(), sm1.refreshTokenCookieName(), sm1.idTokenCookieName())
t.Logf("Session Manager 2 cookies: main=%s, access=%s, refresh=%s, id=%s",
sm2.mainCookieName(), sm2.accessTokenCookieName(), sm2.refreshTokenCookieName(), sm2.idTokenCookieName())
}
// TestCookiePrefixDefault tests that the default cookie prefix is applied when none is provided
func TestCookiePrefixDefault(t *testing.T) {
logger := NewLogger("info")
encryptionKey := strings.Repeat("a", 32)
// Create session manager without cookie prefix (should use default)
sm, err := NewSessionManager(encryptionKey, false, "", "", 0, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Verify default prefix is used
expectedPrefix := defaultCookiePrefix
if !strings.HasPrefix(sm.mainCookieName(), expectedPrefix) {
t.Errorf("Expected default prefix %s, got cookie name %s", expectedPrefix, sm.mainCookieName())
}
// Verify full cookie names
expectedMain := defaultCookiePrefix + mainCookieSuffix
expectedAccess := defaultCookiePrefix + accessTokenSuffix
expectedRefresh := defaultCookiePrefix + refreshTokenSuffix
expectedID := defaultCookiePrefix + idTokenSuffix
if sm.mainCookieName() != expectedMain {
t.Errorf("Expected main cookie name %s, got %s", expectedMain, sm.mainCookieName())
}
if sm.accessTokenCookieName() != expectedAccess {
t.Errorf("Expected access cookie name %s, got %s", expectedAccess, sm.accessTokenCookieName())
}
if sm.refreshTokenCookieName() != expectedRefresh {
t.Errorf("Expected refresh cookie name %s, got %s", expectedRefresh, sm.refreshTokenCookieName())
}
if sm.idTokenCookieName() != expectedID {
t.Errorf("Expected ID cookie name %s, got %s", expectedID, sm.idTokenCookieName())
}
}