mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
9cbca4c4fb
patch-release
The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.
The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.
Mirror the callback logic in refreshToken so both paths behave the same.
Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").
Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
3279 lines
94 KiB
Go
3279 lines
94 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// ============================================================================
|
|
// SESSION TEST FRAMEWORK
|
|
// ============================================================================
|
|
|
|
// SessionTestCase represents a comprehensive session test scenario
|
|
type SessionTestCase struct {
|
|
setup func(*SessionTestFramework)
|
|
execute func(*SessionTestFramework) error
|
|
validate func(*testing.T, error, *SessionTestFramework)
|
|
cleanup func(*SessionTestFramework)
|
|
name string
|
|
scenario string
|
|
sessionType string
|
|
skipReason string
|
|
iterations int
|
|
timeout time.Duration
|
|
concurrent bool
|
|
}
|
|
|
|
// SessionTestFramework provides shared test infrastructure for session tests
|
|
type SessionTestFramework struct {
|
|
t *testing.T
|
|
mockProvider *httptest.Server
|
|
testTokens map[string]string
|
|
metrics *SessionTestMetrics
|
|
config *SessionTestConfig
|
|
requests []*http.Request
|
|
responses []*httptest.ResponseRecorder
|
|
sessionIDs []string
|
|
cleanupFuncs []func()
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// SessionTestMetrics tracks test performance metrics
|
|
type SessionTestMetrics struct {
|
|
SessionsCreated int64
|
|
SessionsDestroyed int64
|
|
TokensGenerated int64
|
|
TokensValidated int64
|
|
ChunksCreated int64
|
|
ChunksRetrieved int64
|
|
ErrorCount int64
|
|
Duration time.Duration
|
|
}
|
|
|
|
// SessionTestConfig holds test configuration
|
|
type SessionTestConfig struct {
|
|
CookieDomain string
|
|
EncryptionKey string
|
|
MaxChunkSize int
|
|
MaxSessions int
|
|
SessionTimeout time.Duration
|
|
EnableHTTPS bool
|
|
EnableCompression bool
|
|
}
|
|
|
|
// NewSessionTestFramework creates a new test framework instance
|
|
func NewSessionTestFramework(t *testing.T) *SessionTestFramework {
|
|
framework := &SessionTestFramework{
|
|
t: t,
|
|
requests: make([]*http.Request, 0),
|
|
responses: make([]*httptest.ResponseRecorder, 0),
|
|
testTokens: make(map[string]string),
|
|
sessionIDs: make([]string, 0),
|
|
metrics: &SessionTestMetrics{},
|
|
cleanupFuncs: make([]func(), 0),
|
|
config: &SessionTestConfig{
|
|
MaxChunkSize: 3900,
|
|
MaxSessions: 1000,
|
|
EnableHTTPS: false,
|
|
CookieDomain: "",
|
|
SessionTimeout: time.Hour,
|
|
EncryptionKey: generateTestKey(),
|
|
EnableCompression: true,
|
|
},
|
|
}
|
|
|
|
// Setup mock OIDC provider
|
|
framework.setupMockProvider()
|
|
|
|
return framework
|
|
}
|
|
|
|
// generateTestKey generates a test encryption key
|
|
func generateTestKey() string {
|
|
// 48 bytes = 384 bits for testing
|
|
return "0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
}
|
|
|
|
// setupMockProvider sets up a mock OIDC provider for testing
|
|
func (f *SessionTestFramework) setupMockProvider() {
|
|
f.mockProvider = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/.well-known/openid-configuration":
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"issuer": f.mockProvider.URL,
|
|
"authorization_endpoint": f.mockProvider.URL + "/auth",
|
|
"token_endpoint": f.mockProvider.URL + "/token",
|
|
"userinfo_endpoint": f.mockProvider.URL + "/userinfo",
|
|
"jwks_uri": f.mockProvider.URL + "/jwks",
|
|
})
|
|
case "/token":
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"access_token": f.generateTestToken("access", 3600),
|
|
"id_token": f.generateTestToken("id", 3600),
|
|
"refresh_token": f.generateTestToken("refresh", 86400),
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600,
|
|
})
|
|
case "/userinfo":
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"sub": "test-user-id",
|
|
"email": "test@example.com",
|
|
"name": "Test User",
|
|
})
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
|
|
f.cleanupFuncs = append(f.cleanupFuncs, f.mockProvider.Close)
|
|
}
|
|
|
|
// generateTestToken generates a test token
|
|
func (f *SessionTestFramework) generateTestToken(tokenType string, expiresIn int) string {
|
|
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
|
|
|
|
// Create a realistic JWT-like token for testing
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
|
|
|
claims := map[string]interface{}{
|
|
"iss": f.mockProvider.URL,
|
|
"sub": "test-user-id",
|
|
"aud": "test-client-id",
|
|
"exp": time.Now().Add(time.Duration(expiresIn) * time.Second).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"typ": tokenType,
|
|
}
|
|
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
// Generate a fake signature
|
|
signature := make([]byte, 64)
|
|
rand.Read(signature)
|
|
sig := base64.RawURLEncoding.EncodeToString(signature)
|
|
|
|
token := fmt.Sprintf("%s.%s.%s", header, payload, sig)
|
|
|
|
// Thread-safe write to map
|
|
f.mu.Lock()
|
|
f.testTokens[tokenType] = token
|
|
f.mu.Unlock()
|
|
|
|
return token
|
|
}
|
|
|
|
// generateLargeToken generates a token of specified size for testing chunking
|
|
func (f *SessionTestFramework) generateLargeToken(size int) string {
|
|
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
|
|
|
|
// Create base JWT structure
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
|
|
|
// Calculate how much padding we need in claims
|
|
baseSize := len(header) + 2 // for the dots
|
|
signatureSize := 86 // approximate base64 encoded signature size
|
|
paddingSize := size - baseSize - signatureSize - 100 // leave room for other claims
|
|
|
|
if paddingSize < 0 {
|
|
paddingSize = 0
|
|
}
|
|
|
|
// Create large padding data
|
|
padding := make([]byte, paddingSize)
|
|
for i := range padding {
|
|
padding[i] = byte('A' + (i % 26))
|
|
}
|
|
|
|
claims := map[string]interface{}{
|
|
"iss": f.mockProvider.URL,
|
|
"sub": "test-user-id",
|
|
"aud": "test-client-id",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"padding": base64.StdEncoding.EncodeToString(padding),
|
|
}
|
|
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
// Generate signature
|
|
signature := make([]byte, 64)
|
|
rand.Read(signature)
|
|
sig := base64.RawURLEncoding.EncodeToString(signature)
|
|
|
|
return fmt.Sprintf("%s.%s.%s", header, payload, sig)
|
|
}
|
|
|
|
// Cleanup performs framework cleanup
|
|
func (f *SessionTestFramework) Cleanup() {
|
|
for _, cleanup := range f.cleanupFuncs {
|
|
cleanup()
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// SESSION CHUNK MANAGER TESTS
|
|
// ============================================================================
|
|
|
|
// Helper function to create a mock HTTP request for session creation
|
|
func createMockRequest() *http.Request {
|
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
|
return req
|
|
}
|
|
|
|
func TestNewSessionChunkManager(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
if manager == nil {
|
|
t.Fatal("Expected non-nil session chunk manager")
|
|
}
|
|
|
|
if manager.maxChunks != 10 {
|
|
t.Errorf("Expected maxChunks 10, got %d", manager.maxChunks)
|
|
}
|
|
}
|
|
|
|
func TestNewSessionChunkManagerDefaultLimit(t *testing.T) {
|
|
// Test with 0 maxChunks (should use default)
|
|
manager := NewSessionChunkManager(0)
|
|
|
|
if manager.maxChunks != 20 {
|
|
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
|
|
}
|
|
}
|
|
|
|
func TestNewSessionChunkManagerNegativeLimit(t *testing.T) {
|
|
// Test with negative maxChunks (should use default)
|
|
manager := NewSessionChunkManager(-5)
|
|
|
|
if manager.maxChunks != 20 {
|
|
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
|
|
}
|
|
}
|
|
|
|
func TestCleanupChunksWithoutWriter(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add some chunks
|
|
for i := 0; i < 5; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
session.Values["token_chunk"] = "chunk-data"
|
|
chunks[i] = session
|
|
}
|
|
|
|
// Cleanup without writer (should just clear map)
|
|
manager.CleanupChunks(chunks, nil)
|
|
|
|
if len(chunks) != 0 {
|
|
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
|
|
}
|
|
}
|
|
|
|
func TestCleanupChunksWithWriter(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add some chunks
|
|
for i := 0; i < 3; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
session.Values["token_chunk"] = "chunk-data"
|
|
session.Options = &sessions.Options{MaxAge: 3600}
|
|
chunks[i] = session
|
|
}
|
|
|
|
// Create response writer
|
|
w := httptest.NewRecorder()
|
|
|
|
// Note: We can't fully test the Save behavior without a proper HTTP request
|
|
// but we can verify the cleanup clears the map
|
|
manager.CleanupChunks(chunks, w)
|
|
|
|
if len(chunks) != 0 {
|
|
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
|
|
}
|
|
}
|
|
|
|
func TestCleanupChunksNilSession(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
chunks[0] = nil
|
|
chunks[1] = nil
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
// Should handle nil sessions gracefully
|
|
manager.CleanupChunks(chunks, w)
|
|
|
|
if len(chunks) != 0 {
|
|
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
|
|
}
|
|
}
|
|
|
|
func TestCleanupChunksEmptyMap(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
|
|
// Should handle empty map gracefully
|
|
manager.CleanupChunks(chunks, nil)
|
|
|
|
if len(chunks) != 0 {
|
|
t.Error("Expected chunks map to remain empty")
|
|
}
|
|
}
|
|
|
|
func TestValidateAndCleanChunksWithinLimit(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add chunks within limit
|
|
for i := 0; i < 5; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
chunks[i] = session
|
|
}
|
|
|
|
result := manager.ValidateAndCleanChunks(chunks)
|
|
|
|
if !result {
|
|
t.Error("Expected validation to pass for chunks within limit")
|
|
}
|
|
|
|
if len(chunks) != 5 {
|
|
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
|
|
}
|
|
}
|
|
|
|
func TestValidateAndCleanChunksExceedLimit(t *testing.T) {
|
|
manager := NewSessionChunkManager(5)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add more chunks than limit
|
|
for i := 0; i < 10; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
chunks[i] = session
|
|
}
|
|
|
|
result := manager.ValidateAndCleanChunks(chunks)
|
|
|
|
if result {
|
|
t.Error("Expected validation to fail for chunks exceeding limit")
|
|
}
|
|
|
|
if len(chunks) != 0 {
|
|
t.Errorf("Expected chunks to be cleared, got %d", len(chunks))
|
|
}
|
|
}
|
|
|
|
func TestValidateAndCleanChunksAtLimit(t *testing.T) {
|
|
manager := NewSessionChunkManager(5)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add chunks exactly at limit
|
|
for i := 0; i < 5; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
chunks[i] = session
|
|
}
|
|
|
|
result := manager.ValidateAndCleanChunks(chunks)
|
|
|
|
if !result {
|
|
t.Error("Expected validation to pass for chunks at limit")
|
|
}
|
|
|
|
if len(chunks) != 5 {
|
|
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
|
|
}
|
|
}
|
|
|
|
func TestSafeSetChunkValidIndex(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
|
|
result := manager.SafeSetChunk(chunks, 5, session)
|
|
|
|
if !result {
|
|
t.Error("Expected SafeSetChunk to succeed for valid index")
|
|
}
|
|
|
|
if chunks[5] != session {
|
|
t.Error("Expected session to be set at index 5")
|
|
}
|
|
}
|
|
|
|
func TestSafeSetChunkNegativeIndex(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
|
|
result := manager.SafeSetChunk(chunks, -1, session)
|
|
|
|
if result {
|
|
t.Error("Expected SafeSetChunk to fail for negative index")
|
|
}
|
|
|
|
if len(chunks) != 0 {
|
|
t.Error("Expected chunks map to remain empty")
|
|
}
|
|
}
|
|
|
|
func TestSafeSetChunkIndexTooHigh(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
|
|
result := manager.SafeSetChunk(chunks, 10, session)
|
|
|
|
if result {
|
|
t.Error("Expected SafeSetChunk to fail for index >= maxChunks")
|
|
}
|
|
|
|
if len(chunks) != 0 {
|
|
t.Error("Expected chunks map to remain empty")
|
|
}
|
|
}
|
|
|
|
func TestSafeSetChunkExceedingLimit(t *testing.T) {
|
|
manager := NewSessionChunkManager(5)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Fill up to limit
|
|
for i := 0; i < 5; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
chunks[i] = session
|
|
}
|
|
|
|
// Try to add a new chunk at new index (should fail)
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
result := manager.SafeSetChunk(chunks, 2, session)
|
|
|
|
// This should succeed because index 2 already exists
|
|
if !result {
|
|
t.Error("Expected SafeSetChunk to succeed for existing index")
|
|
}
|
|
}
|
|
|
|
func TestSafeSetChunkReplaceExisting(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
session1, _ := store.New(createMockRequest(), "chunk1")
|
|
session2, _ := store.New(createMockRequest(), "chunk2")
|
|
|
|
// Set initial session
|
|
manager.SafeSetChunk(chunks, 3, session1)
|
|
|
|
// Replace with new session
|
|
result := manager.SafeSetChunk(chunks, 3, session2)
|
|
|
|
if !result {
|
|
t.Error("Expected SafeSetChunk to succeed for replacing existing chunk")
|
|
}
|
|
|
|
if chunks[3] != session2 {
|
|
t.Error("Expected session to be replaced at index 3")
|
|
}
|
|
}
|
|
|
|
func TestGetChunkCount(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add some chunks
|
|
for i := 0; i < 7; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
chunks[i] = session
|
|
}
|
|
|
|
count := manager.GetChunkCount(chunks)
|
|
|
|
if count != 7 {
|
|
t.Errorf("Expected chunk count 7, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestGetChunkCountEmpty(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
|
|
count := manager.GetChunkCount(chunks)
|
|
|
|
if count != 0 {
|
|
t.Errorf("Expected chunk count 0, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestCompactChunksNoGaps(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add sequential chunks
|
|
for i := 0; i < 5; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
session.Values["index"] = i
|
|
chunks[i] = session
|
|
}
|
|
|
|
compacted := manager.CompactChunks(chunks)
|
|
|
|
if len(compacted) != 5 {
|
|
t.Errorf("Expected 5 compacted chunks, got %d", len(compacted))
|
|
}
|
|
|
|
// Verify order
|
|
for i := 0; i < 5; i++ {
|
|
if compacted[i] == nil {
|
|
t.Errorf("Expected chunk at index %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCompactChunksWithGaps(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add chunks with gaps
|
|
indices := []int{0, 2, 5, 7}
|
|
for _, idx := range indices {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
session.Values["original_index"] = idx
|
|
chunks[idx] = session
|
|
}
|
|
|
|
compacted := manager.CompactChunks(chunks)
|
|
|
|
if len(compacted) != 4 {
|
|
t.Errorf("Expected 4 compacted chunks, got %d", len(compacted))
|
|
}
|
|
|
|
// Verify chunks are reindexed sequentially
|
|
for i := 0; i < 4; i++ {
|
|
if compacted[i] == nil {
|
|
t.Errorf("Expected chunk at compacted index %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCompactChunksWithNilEntries(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add chunks and nil entries
|
|
session1, _ := store.New(createMockRequest(), "chunk1")
|
|
session2, _ := store.New(createMockRequest(), "chunk2")
|
|
session3, _ := store.New(createMockRequest(), "chunk3")
|
|
|
|
chunks[0] = session1
|
|
chunks[1] = nil
|
|
chunks[2] = session2
|
|
chunks[3] = nil
|
|
chunks[4] = session3
|
|
|
|
compacted := manager.CompactChunks(chunks)
|
|
|
|
if len(compacted) != 3 {
|
|
t.Errorf("Expected 3 compacted chunks (nil entries removed), got %d", len(compacted))
|
|
}
|
|
|
|
// Verify non-nil chunks are compacted
|
|
for i := 0; i < 3; i++ {
|
|
if compacted[i] == nil {
|
|
t.Errorf("Expected non-nil chunk at compacted index %d", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCompactChunksEmpty(t *testing.T) {
|
|
manager := NewSessionChunkManager(10)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
|
|
compacted := manager.CompactChunks(chunks)
|
|
|
|
if len(compacted) != 0 {
|
|
t.Errorf("Expected empty compacted map, got %d entries", len(compacted))
|
|
}
|
|
}
|
|
|
|
func TestSessionChunkManagerConcurrentOperations(t *testing.T) {
|
|
manager := NewSessionChunkManager(50)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
// Concurrent SafeSetChunk
|
|
for i := 0; i < 20; i++ {
|
|
wg.Add(1)
|
|
go func(index int) {
|
|
defer wg.Done()
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
manager.SafeSetChunk(chunks, index, session)
|
|
}(i)
|
|
}
|
|
|
|
// Concurrent GetChunkCount
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_ = manager.GetChunkCount(chunks)
|
|
}()
|
|
}
|
|
|
|
// Concurrent ValidateAndCleanChunks (reads)
|
|
for i := 0; i < 5; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_ = manager.ValidateAndCleanChunks(chunks)
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Verify manager is still functional
|
|
count := manager.GetChunkCount(chunks)
|
|
if count < 0 || count > 50 {
|
|
t.Errorf("Unexpected chunk count after concurrent operations: %d", count)
|
|
}
|
|
}
|
|
|
|
func TestSessionChunkManagerLargeChunkCount(t *testing.T) {
|
|
manager := NewSessionChunkManager(1000)
|
|
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
// Add many chunks
|
|
for i := 0; i < 500; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
chunks[i] = session
|
|
}
|
|
|
|
result := manager.ValidateAndCleanChunks(chunks)
|
|
|
|
if !result {
|
|
t.Error("Expected validation to pass for 500 chunks with limit 1000")
|
|
}
|
|
|
|
count := manager.GetChunkCount(chunks)
|
|
if count != 500 {
|
|
t.Errorf("Expected 500 chunks, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestSessionChunkManagerBoundaryConditions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
maxChunks int
|
|
addChunks int
|
|
shouldPass bool
|
|
}{
|
|
{"exactly at limit", 10, 10, true},
|
|
{"one over limit", 10, 11, false},
|
|
{"way over limit", 10, 50, false},
|
|
{"zero chunks with limit", 10, 0, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
manager := NewSessionChunkManager(tt.maxChunks)
|
|
chunks := make(map[int]*sessions.Session)
|
|
store := sessions.NewCookieStore([]byte("test-secret"))
|
|
|
|
for i := 0; i < tt.addChunks; i++ {
|
|
session, _ := store.New(createMockRequest(), "chunk")
|
|
chunks[i] = session
|
|
}
|
|
|
|
result := manager.ValidateAndCleanChunks(chunks)
|
|
|
|
if result != tt.shouldPass {
|
|
t.Errorf("Expected validation result %v, got %v", tt.shouldPass, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// SESSION HELPER TESTS
|
|
// ============================================================================
|
|
|
|
// TestSetCodeVerifier_NoChange tests the branch where the code verifier value doesn't change
|
|
func TestSetCodeVerifier_NoChange(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
defer sm.Shutdown()
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
// Set initial code verifier
|
|
initialVerifier := "test-code-verifier-12345"
|
|
session.SetCodeVerifier(initialVerifier)
|
|
|
|
if !session.IsDirty() {
|
|
t.Error("Session should be dirty after first SetCodeVerifier")
|
|
}
|
|
|
|
// Mark clean to test the no-change branch
|
|
session.dirty = false
|
|
|
|
// Set the same code verifier again - this should hit the uncovered branch
|
|
session.SetCodeVerifier(initialVerifier)
|
|
|
|
// Verify that dirty flag remains false (no change occurred)
|
|
if session.IsDirty() {
|
|
t.Error("Session should not be dirty when setting same code verifier value")
|
|
}
|
|
|
|
// Verify the code verifier value is still correct
|
|
if got := session.GetCodeVerifier(); got != initialVerifier {
|
|
t.Errorf("Expected code verifier %q, got %q", initialVerifier, got)
|
|
}
|
|
}
|
|
|
|
// TestClearTokenChunks_EmptyChunks tests the branch where the chunks map is empty
|
|
func TestClearTokenChunks_EmptyChunks(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
defer sm.Shutdown()
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
// Test with empty chunks map - this should hit the uncovered branch where the loop body doesn't execute
|
|
emptyChunks := make(map[int]*sessions.Session)
|
|
|
|
// This should not panic and should handle empty map gracefully
|
|
session.clearTokenChunks(req, emptyChunks)
|
|
|
|
// Verify that no errors occurred and the session is still valid
|
|
if session == nil {
|
|
t.Fatal("Session should still be valid after clearing empty chunks")
|
|
}
|
|
|
|
// Additional test: clear already-empty chunk maps in the session
|
|
session.clearTokenChunks(req, session.accessTokenChunks)
|
|
session.clearTokenChunks(req, session.refreshTokenChunks)
|
|
session.clearTokenChunks(req, session.idTokenChunks)
|
|
|
|
// Verify session is still valid
|
|
if session.GetAuthenticated() {
|
|
// This is fine - session can be authenticated even with no chunks
|
|
}
|
|
}
|
|
|
|
// TestClearTokenChunks_WithSessions tests the branch where the chunks map contains actual sessions
|
|
func TestClearTokenChunks_WithSessions(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
defer sm.Shutdown()
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
// Create chunks map with actual sessions
|
|
chunksWithSessions := make(map[int]*sessions.Session)
|
|
|
|
// Create a few test sessions and add them to the chunks map
|
|
for i := 0; i < 3; i++ {
|
|
chunkSession, err := sm.store.Get(req, fmt.Sprintf("test_chunk_%d", i))
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test chunk session: %v", err)
|
|
}
|
|
// Add some test data to the session
|
|
chunkSession.Values["test_data"] = fmt.Sprintf("chunk_%d_data", i)
|
|
chunkSession.Values["chunk_index"] = i
|
|
chunksWithSessions[i] = chunkSession
|
|
}
|
|
|
|
// Verify chunks have data before clearing
|
|
if len(chunksWithSessions) != 3 {
|
|
t.Errorf("Expected 3 chunks, got %d", len(chunksWithSessions))
|
|
}
|
|
|
|
for i, chunkSession := range chunksWithSessions {
|
|
if chunkSession.Values["test_data"] == nil {
|
|
t.Errorf("Chunk %d should have test data before clearing", i)
|
|
}
|
|
}
|
|
|
|
// Call clearTokenChunks - this should hit the loop body and clear all sessions
|
|
session.clearTokenChunks(req, chunksWithSessions)
|
|
|
|
// Verify that the sessions were cleared
|
|
for i, chunkSession := range chunksWithSessions {
|
|
if len(chunkSession.Values) != 0 {
|
|
t.Errorf("Chunk %d should have no values after clearing, but has %d values", i, len(chunkSession.Values))
|
|
}
|
|
// Verify MaxAge was set to -1 (expired)
|
|
if chunkSession.Options.MaxAge != -1 {
|
|
t.Errorf("Chunk %d should have MaxAge=-1 (expired), but has MaxAge=%d", i, chunkSession.Options.MaxAge)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// SESSION POOL AND MEMORY TESTS
|
|
// ============================================================================
|
|
|
|
// TestSessionPoolMemoryLeak tests that session objects are properly returned to the pool
|
|
func TestSessionPoolMemoryLeak(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeLeakDetection) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
edgeGen := NewEdgeCaseGenerator()
|
|
runner := NewTestSuiteRunner()
|
|
runner.SetTimeout(30 * time.Second)
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Successful session creation and return",
|
|
Description: "Test that sessions are properly created and returned to pool",
|
|
Setup: func(t *testing.T) error {
|
|
return nil
|
|
},
|
|
Teardown: func(t *testing.T) error {
|
|
runtime.GC()
|
|
time.Sleep(100 * time.Millisecond)
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
Name: "Explicit ReturnToPool method",
|
|
Description: "Test that explicit pool return works correctly",
|
|
Setup: func(t *testing.T) error {
|
|
return nil
|
|
},
|
|
Teardown: func(t *testing.T) error {
|
|
runtime.GC()
|
|
time.Sleep(100 * time.Millisecond)
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
Name: "Error path in GetSession",
|
|
Description: "Test pool behavior when GetSession fails",
|
|
Setup: func(t *testing.T) error {
|
|
return nil
|
|
},
|
|
Teardown: func(t *testing.T) error {
|
|
runtime.GC()
|
|
time.Sleep(100 * time.Millisecond)
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
|
|
// Custom test execution since we need to test memory behavior
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
if test.Setup != nil {
|
|
if err := test.Setup(t); err != nil {
|
|
t.Fatalf("Setup failed: %v", err)
|
|
}
|
|
}
|
|
|
|
if test.Teardown != nil {
|
|
defer func() {
|
|
if err := test.Teardown(t); err != nil {
|
|
t.Errorf("Teardown failed: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
|
|
switch test.Name {
|
|
case "Successful session creation and return":
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("GetSession failed: %v", err)
|
|
}
|
|
session.Clear(req, nil)
|
|
|
|
case "Explicit ReturnToPool method":
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("GetSession failed: %v", err)
|
|
}
|
|
session.ReturnToPool()
|
|
|
|
case "Error path in GetSession":
|
|
badSM, _ := NewSessionManager("different0123456789abcdef0123456789abcdef0123456789", false, "", "", 0, logger)
|
|
_, err = badSM.GetSession(req)
|
|
if err == nil {
|
|
t.Log("Note: Expected error when using mismatched encryption keys")
|
|
}
|
|
}
|
|
|
|
pooledCount := getPooledObjects(sm)
|
|
t.Logf("Pooled objects count: %d", pooledCount)
|
|
})
|
|
}
|
|
|
|
_ = testTokens
|
|
_ = edgeGen
|
|
}
|
|
|
|
// TestSessionErrorHandling tests comprehensive error scenarios using table-driven tests
|
|
func TestSessionErrorHandling(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeQuick) {
|
|
return
|
|
}
|
|
|
|
edgeGen := NewEdgeCaseGenerator()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
// Generate edge case strings for cookie values
|
|
edgeCases := edgeGen.GenerateStringEdgeCases()
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Corrupt cookie value",
|
|
Description: "Test handling of corrupted cookie values",
|
|
Input: "corrupt-value",
|
|
Expected: "failed to get main session:",
|
|
},
|
|
{
|
|
Name: "Invalid base64 cookie",
|
|
Description: "Test handling of invalid base64 in cookies",
|
|
Input: "!@#$%^&*()",
|
|
Expected: "failed to get main session:",
|
|
},
|
|
{
|
|
Name: "Empty cookie value",
|
|
Description: "Test handling of empty cookie values",
|
|
Input: "",
|
|
Expected: "", // Empty should work without error
|
|
},
|
|
}
|
|
|
|
// Add edge cases dynamically
|
|
for i, edgeCase := range edgeCases {
|
|
if len(edgeCase) > 0 && !strings.ContainsAny(edgeCase, "\x00\x01\x02") { // Skip binary data for cookie tests
|
|
tests = append(tests, TableTestCase{
|
|
Name: fmt.Sprintf("Edge case %d", i),
|
|
Description: fmt.Sprintf("Test edge case string: %q", edgeCase[:minInt(20, len(edgeCase))]),
|
|
Input: edgeCase,
|
|
Expected: "", // Most edge cases should be handled gracefully
|
|
})
|
|
}
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
|
|
if input, ok := test.Input.(string); ok && input != "" {
|
|
req.AddCookie(&http.Cookie{
|
|
Name: defaultCookiePrefix + mainCookieSuffix,
|
|
Value: input,
|
|
})
|
|
}
|
|
|
|
_, err = sm.GetSession(req)
|
|
|
|
if expected, ok := test.Expected.(string); ok && expected != "" {
|
|
if err == nil {
|
|
t.Error("Expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), expected) {
|
|
t.Errorf("Unexpected error message: %v", err)
|
|
}
|
|
} else {
|
|
// For empty expected, we allow either success or specific failures
|
|
if err != nil {
|
|
t.Logf("Got expected error for edge case: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
_ = runner
|
|
}
|
|
|
|
// TestSessionClearAlwaysReturnsToPool tests that sessions are always returned to pool even on errors
|
|
func TestSessionClearAlwaysReturnsToPool(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeQuick) {
|
|
return
|
|
}
|
|
|
|
runner := NewTestSuiteRunner()
|
|
|
|
memoryTests := []MemoryLeakTestCase{
|
|
{
|
|
Name: "Session clear with error returns to pool",
|
|
Description: "Verify sessions return to pool even when Clear() errors",
|
|
Iterations: 10,
|
|
MaxGoroutineGrowth: 2,
|
|
MaxMemoryGrowthMB: 5.0,
|
|
GCBetweenRuns: true,
|
|
Timeout: 30 * time.Second,
|
|
Operation: func() error {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create session manager: %w", err)
|
|
}
|
|
|
|
// Ensure proper cleanup by calling Shutdown
|
|
defer func() {
|
|
if shutdownErr := sm.Shutdown(); shutdownErr != nil {
|
|
logger.Errorf("Failed to shutdown SessionManager: %v", shutdownErr)
|
|
}
|
|
}()
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
req.Header.Set("X-Test-Error", "true")
|
|
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
return fmt.Errorf("GetSession failed: %w", err)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
clearErr := session.Clear(req, w)
|
|
|
|
// We expect an error due to the X-Test-Error header, but the session should still be returned
|
|
if clearErr == nil {
|
|
return fmt.Errorf("expected error from Clear with X-Test-Error header")
|
|
}
|
|
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
|
|
runner.RunMemoryLeakTests(t, memoryTests)
|
|
|
|
// Additional verification test
|
|
t.Run("Verify pool still works after errors", func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Ensure proper cleanup
|
|
defer func() {
|
|
if shutdownErr := sm.Shutdown(); shutdownErr != nil {
|
|
t.Errorf("Failed to shutdown SessionManager: %v", shutdownErr)
|
|
}
|
|
}()
|
|
|
|
normalReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
session2, err := sm.GetSession(normalReq)
|
|
if err != nil {
|
|
t.Fatalf("Second GetSession failed: %v", err)
|
|
}
|
|
session2.Clear(normalReq, nil)
|
|
|
|
t.Log("Session returned to pool despite errors")
|
|
})
|
|
}
|
|
|
|
// TestSessionObjectTracking tests session object tracking and pool behavior
|
|
func TestSessionObjectTracking(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeQuick) {
|
|
return
|
|
}
|
|
|
|
runner := NewTestSuiteRunner()
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Session pool has New function",
|
|
Description: "Verify that session pool is properly configured",
|
|
Setup: func(t *testing.T) error {
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
Name: "Multiple session creation and disposal",
|
|
Description: "Test creating and disposing multiple sessions",
|
|
Input: 5,
|
|
},
|
|
{
|
|
Name: "Session with nil mainSession",
|
|
Description: "Test error handling with corrupted session state",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
if test.Setup != nil {
|
|
if err := test.Setup(t); err != nil {
|
|
t.Fatalf("Setup failed: %v", err)
|
|
}
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
|
|
switch test.Name {
|
|
case "Session pool has New function":
|
|
hasNew := sm.sessionPool.New != nil
|
|
if !hasNew {
|
|
t.Error("Expected sessionPool.New function to be set")
|
|
}
|
|
|
|
case "Multiple session creation and disposal":
|
|
count := test.Input.(int)
|
|
for i := 0; i < count; i++ {
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("GetSession failed: %v", err)
|
|
}
|
|
session.ReturnToPool()
|
|
}
|
|
|
|
case "Session with nil mainSession":
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("GetSession failed: %v", err)
|
|
}
|
|
|
|
session.mainSession = nil // Deliberately cause bad state
|
|
session.ReturnToPool()
|
|
}
|
|
|
|
runtime.GC()
|
|
time.Sleep(100 * time.Millisecond)
|
|
t.Log("Session pool handling verified")
|
|
})
|
|
}
|
|
|
|
_ = runner
|
|
}
|
|
|
|
// ============================================================================
|
|
// TOKEN COMPRESSION AND CHUNKING TESTS
|
|
// ============================================================================
|
|
|
|
// TestTokenCompressionIntegrity tests token compression using comprehensive test cases
|
|
func TestTokenCompressionIntegrity(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeExtended) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
edgeGen := NewEdgeCaseGenerator()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
// Create comprehensive test cases using edge case generator and test tokens
|
|
testCases := []TableTestCase{
|
|
{
|
|
Name: "Valid JWT Small",
|
|
Input: testTokens.GetValidTokenSet().AccessToken,
|
|
Expected: true, // Should compress and decompress correctly
|
|
},
|
|
{
|
|
Name: "Valid JWT Large",
|
|
Input: testTokens.CreateLargeValidJWT(5000),
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Minimal Valid JWT",
|
|
Input: MinimalValidJWT,
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Invalid JWT Wrong dot count",
|
|
Input: InvalidTokenOneDot,
|
|
Expected: false, // Should return original for invalid tokens
|
|
},
|
|
{
|
|
Name: "Invalid JWT No dots",
|
|
Input: InvalidTokenNoDots,
|
|
Expected: false,
|
|
},
|
|
{
|
|
Name: "Invalid JWT Too many dots",
|
|
Input: InvalidTokenThreeDots,
|
|
Expected: false,
|
|
},
|
|
{
|
|
Name: "Empty token",
|
|
Input: "",
|
|
Expected: true, // Empty tokens are handled gracefully
|
|
},
|
|
{
|
|
Name: "Oversized token",
|
|
Input: testTokens.CreateIncompressibleToken(55000), // >50KB
|
|
Expected: false, // Should be rejected
|
|
},
|
|
}
|
|
|
|
// Add string edge cases as additional test inputs
|
|
stringEdgeCases := edgeGen.GenerateStringEdgeCases()
|
|
for i, edgeCase := range stringEdgeCases {
|
|
if len(edgeCase) > 0 && len(edgeCase) < 1000 { // Reasonable size for testing
|
|
testCases = append(testCases, TableTestCase{
|
|
Name: fmt.Sprintf("Edge case string %d", i),
|
|
Input: edgeCase,
|
|
Expected: true, // Most edge cases should be handled gracefully
|
|
})
|
|
}
|
|
}
|
|
|
|
for _, test := range testCases {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
token := test.Input.(string)
|
|
expectValid := test.Expected.(bool)
|
|
|
|
compressed := compressToken(token)
|
|
|
|
if !expectValid {
|
|
// For invalid tokens, compression should return original
|
|
if compressed != token {
|
|
t.Errorf("Expected compression to return original for invalid token, got different result")
|
|
}
|
|
return
|
|
}
|
|
|
|
// For valid tokens, test round-trip integrity
|
|
decompressed := decompressToken(compressed)
|
|
if decompressed != token {
|
|
t.Errorf("Token integrity lost: original=%q, compressed=%q, decompressed=%q",
|
|
token, compressed, decompressed)
|
|
}
|
|
|
|
// Test that decompression is idempotent
|
|
decompressed2 := decompressToken(decompressed)
|
|
if decompressed2 != token {
|
|
t.Errorf("Decompression not idempotent: %q != %q", decompressed2, token)
|
|
}
|
|
})
|
|
}
|
|
|
|
_ = runner
|
|
}
|
|
|
|
// TestTokenCompressionCorruptionDetection tests corruption detection using table-driven approach
|
|
func TestTokenCompressionCorruptionDetection(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeExtended) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Invalid base64",
|
|
Input: "!@#$%^&*()",
|
|
Expected: true, // Should return original
|
|
},
|
|
{
|
|
Name: "Valid base64 but invalid gzip",
|
|
Input: base64.StdEncoding.EncodeToString([]byte("not gzip data")),
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Truncated gzip data",
|
|
Input: "H4sI", // Incomplete gzip header
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Empty string",
|
|
Input: "",
|
|
Expected: true,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
corruptedInput := test.Input.(string)
|
|
expectOriginal := test.Expected.(bool)
|
|
|
|
result := decompressToken(corruptedInput)
|
|
if expectOriginal && result != corruptedInput {
|
|
t.Errorf("Expected decompression to return original corrupted input, got: %q", result)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test that valid compression still works
|
|
t.Run("Valid compression verification", func(t *testing.T) {
|
|
validJWT := testTokens.GetValidTokenSet().AccessToken
|
|
compressed := compressToken(validJWT)
|
|
decompressed := decompressToken(compressed)
|
|
if decompressed != validJWT {
|
|
t.Errorf("Valid compression/decompression failed: %q != %q", decompressed, validJWT)
|
|
}
|
|
})
|
|
|
|
_ = runner
|
|
}
|
|
|
|
// TestTokenChunkingIntegrity tests token chunking using comprehensive test patterns
|
|
func TestTokenChunkingIntegrity(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeExtended) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
edgeGen := NewEdgeCaseGenerator()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Small token no chunking",
|
|
Description: "Small tokens should not be chunked",
|
|
Input: struct {
|
|
size int
|
|
expectChunked bool
|
|
}{100, false},
|
|
},
|
|
{
|
|
Name: "Medium token no chunking",
|
|
Description: "Medium tokens should not be chunked",
|
|
Input: struct {
|
|
size int
|
|
expectChunked bool
|
|
}{800, false},
|
|
},
|
|
{
|
|
Name: "Large token chunking required",
|
|
Description: "Large tokens should be chunked",
|
|
Input: struct {
|
|
size int
|
|
expectChunked bool
|
|
}{5000, true},
|
|
},
|
|
{
|
|
Name: "Very large token multiple chunks",
|
|
Description: "Very large tokens should create multiple chunks",
|
|
Input: struct {
|
|
size int
|
|
expectChunked bool
|
|
}{10000, true},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
params := test.Input.(struct {
|
|
size int
|
|
expectChunked bool
|
|
})
|
|
|
|
// Create token based on expectation
|
|
var token string
|
|
if params.expectChunked {
|
|
token = testTokens.CreateIncompressibleToken(params.size)
|
|
} else {
|
|
token = testTokens.CreateLargeValidJWT(params.size)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Store the token
|
|
session.SetAccessToken(token)
|
|
|
|
// Retrieve the token
|
|
retrievedToken := session.GetAccessToken()
|
|
|
|
// Verify integrity
|
|
if retrievedToken != token {
|
|
t.Errorf("Token integrity lost:\nOriginal: %q\nRetrieved: %q", token, retrievedToken)
|
|
}
|
|
|
|
// Check if chunking occurred as expected
|
|
hasChunks := len(session.accessTokenChunks) > 0
|
|
if params.expectChunked != hasChunks {
|
|
t.Errorf("Chunking expectation mismatch: expected chunked=%v, has chunks=%v",
|
|
params.expectChunked, hasChunks)
|
|
}
|
|
|
|
session.ReturnToPool()
|
|
})
|
|
}
|
|
|
|
_ = edgeGen
|
|
_ = runner
|
|
}
|
|
|
|
// TestTokenChunkingCorruptionResistance tests chunking corruption resistance using table patterns
|
|
func TestTokenChunkingCorruptionResistance(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeExtended) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
// Define corruption scenarios as test cases
|
|
corruptionTests := []TableTestCase{
|
|
{
|
|
Name: "Missing chunk in sequence",
|
|
Description: "Test handling when a chunk is missing from sequence",
|
|
Input: func(chunks map[int]*sessions.Session) {
|
|
if len(chunks) > 1 {
|
|
delete(chunks, 1)
|
|
}
|
|
},
|
|
Expected: true, // Expect empty result
|
|
},
|
|
{
|
|
Name: "Empty chunk data",
|
|
Description: "Test handling when chunk contains empty data",
|
|
Input: func(chunks map[int]*sessions.Session) {
|
|
if chunk, exists := chunks[0]; exists {
|
|
chunk.Values["token_chunk"] = ""
|
|
}
|
|
},
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Wrong data type in chunk",
|
|
Description: "Test handling when chunk contains wrong data type",
|
|
Input: func(chunks map[int]*sessions.Session) {
|
|
if chunk, exists := chunks[0]; exists {
|
|
chunk.Values["token_chunk"] = 123 // Should be string
|
|
}
|
|
},
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Oversized chunk",
|
|
Description: "Test handling when chunk exceeds size limits",
|
|
Input: func(chunks map[int]*sessions.Session) {
|
|
if chunk, exists := chunks[0]; exists {
|
|
chunk.Values["token_chunk"] = strings.Repeat("A", maxCookieSize+200)
|
|
}
|
|
},
|
|
Expected: true,
|
|
},
|
|
}
|
|
|
|
for _, test := range corruptionTests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Create a large token that will be chunked
|
|
largeToken := testTokens.CreateIncompressibleToken(8000)
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Store the token (this should create chunks)
|
|
session.SetAccessToken(largeToken)
|
|
if len(session.accessTokenChunks) == 0 {
|
|
t.Skip("Token was not chunked, skipping corruption test")
|
|
}
|
|
|
|
// Apply corruption using the test input function
|
|
corruptFunc := test.Input.(func(map[int]*sessions.Session))
|
|
corruptFunc(session.accessTokenChunks)
|
|
|
|
// Try to retrieve the token
|
|
retrievedToken := session.GetAccessToken()
|
|
|
|
expectEmpty := test.Expected.(bool)
|
|
if expectEmpty {
|
|
if retrievedToken != "" {
|
|
t.Errorf("Expected empty token due to corruption, got: %q", retrievedToken)
|
|
}
|
|
} else {
|
|
if retrievedToken != largeToken {
|
|
t.Errorf("Expected original token despite corruption, got: %q", retrievedToken)
|
|
}
|
|
}
|
|
|
|
session.ReturnToPool()
|
|
})
|
|
}
|
|
|
|
_ = corruptionTests
|
|
_ = runner
|
|
}
|
|
|
|
// TestTokenSizeLimits tests token size limit enforcement using table-driven tests
|
|
func TestTokenSizeLimits(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeExtended) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
edgeGen := NewEdgeCaseGenerator()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Normal size token",
|
|
Input: 1000,
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Large but acceptable token",
|
|
Input: 20000, // 20KB
|
|
Expected: true,
|
|
},
|
|
{
|
|
Name: "Oversized token rejection",
|
|
Input: 120000, // 120KB
|
|
Expected: false, // Should be rejected
|
|
},
|
|
}
|
|
|
|
// Add integer edge cases for token sizes
|
|
intEdgeCases := edgeGen.GenerateIntegerEdgeCases()
|
|
for _, size := range intEdgeCases {
|
|
if size > 0 && size < 100000 {
|
|
tests = append(tests, TableTestCase{
|
|
Name: fmt.Sprintf("Edge case size %d", size),
|
|
Input: size,
|
|
Expected: size < 100000, // Reasonable threshold
|
|
})
|
|
}
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
tokenSize := test.Input.(int)
|
|
expectStored := test.Expected.(bool)
|
|
|
|
var token string
|
|
if expectStored {
|
|
token = testTokens.CreateLargeValidJWT(tokenSize)
|
|
} else {
|
|
token = testTokens.CreateIncompressibleToken(tokenSize)
|
|
}
|
|
|
|
// Store the token
|
|
session.SetAccessToken(token)
|
|
|
|
// Try to retrieve it
|
|
retrievedToken := session.GetAccessToken()
|
|
|
|
if expectStored {
|
|
if retrievedToken != token {
|
|
t.Errorf("Expected token to be stored and retrieved, but got different token")
|
|
}
|
|
} else {
|
|
if retrievedToken == token {
|
|
t.Errorf("Expected oversized token to be rejected, but it was stored")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
_ = runner
|
|
}
|
|
|
|
// TestConcurrentTokenOperations tests thread safety using structured test patterns
|
|
func TestConcurrentTokenOperations(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeConcurrencyStress) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
// Test concurrent operations using memory leak test pattern
|
|
memoryTests := []MemoryLeakTestCase{
|
|
{
|
|
Name: "Concurrent token operations",
|
|
Description: "Test thread safety of concurrent token operations",
|
|
Iterations: 50,
|
|
MaxGoroutineGrowth: 5, // Allow some growth for goroutines
|
|
MaxMemoryGrowthMB: 10.0,
|
|
GCBetweenRuns: true,
|
|
Timeout: 60 * time.Second,
|
|
Operation: func() error {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create session manager: %w", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
const numGoroutines = 10
|
|
const numOperations = 100
|
|
done := make(chan bool, numGoroutines)
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
go func(id int) {
|
|
defer func() { done <- true }()
|
|
|
|
for j := 0; j < numOperations; j++ {
|
|
// Create unique tokens for each goroutine/operation
|
|
accessToken := testTokens.CreateUniqueValidJWT(fmt.Sprintf("%d_%d", id, j))
|
|
refreshToken := fmt.Sprintf("refresh_token_%d_%d", id, j)
|
|
|
|
// Concurrent operations
|
|
session.SetAccessToken(accessToken)
|
|
session.SetRefreshToken(refreshToken)
|
|
|
|
retrievedAccess := session.GetAccessToken()
|
|
retrievedRefresh := session.GetRefreshToken()
|
|
|
|
// Verify tokens are still valid (should be one of the tokens set by any goroutine)
|
|
if retrievedAccess != "" && strings.Count(retrievedAccess, ".") != 2 {
|
|
// Note: In concurrent access, we can't guarantee exact token match
|
|
// but we can verify format is still valid
|
|
}
|
|
if retrievedRefresh != "" && len(retrievedRefresh) < 10 {
|
|
// Verify minimum reasonable length
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Wait for all goroutines to complete
|
|
for i := 0; i < numGoroutines; i++ {
|
|
<-done
|
|
}
|
|
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
|
|
runner.RunMemoryLeakTests(t, memoryTests)
|
|
|
|
_ = testTokens
|
|
}
|
|
|
|
// TestSessionValidationAndCleanup tests session validation using comprehensive patterns
|
|
func TestSessionValidationAndCleanup(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeExtended) {
|
|
return
|
|
}
|
|
|
|
testTokens := NewTestTokens()
|
|
edgeGen := NewEdgeCaseGenerator()
|
|
runner := NewTestSuiteRunner()
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Session creation and token storage",
|
|
Description: "Test basic session validation and cleanup",
|
|
},
|
|
{
|
|
Name: "Large token chunking validation",
|
|
Description: "Test validation with tokens that require chunking",
|
|
},
|
|
{
|
|
Name: "Session cleanup verification",
|
|
Description: "Test that sessions are properly cleaned up",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
switch test.Name {
|
|
case "Session creation and token storage":
|
|
// Test with normal tokens
|
|
tokenSet := testTokens.GetValidTokenSet()
|
|
session.SetAccessToken(tokenSet.AccessToken)
|
|
session.SetRefreshToken(tokenSet.RefreshToken)
|
|
|
|
case "Large token chunking validation":
|
|
// Set tokens that will create chunks
|
|
largeTokenSet := testTokens.GetLargeTokenSet()
|
|
session.SetAccessToken(largeTokenSet.AccessToken)
|
|
session.SetRefreshToken(largeTokenSet.RefreshToken)
|
|
|
|
case "Session cleanup verification":
|
|
// Set tokens and then clear them
|
|
session.SetAccessToken(testTokens.GetValidTokenSet().AccessToken)
|
|
session.SetRefreshToken("refresh_token_test")
|
|
}
|
|
|
|
// Save session to create cookies
|
|
if err := session.Save(req, rw); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// For cleanup test, verify clearing works
|
|
if test.Name == "Session cleanup verification" {
|
|
if err := session.Clear(req, rw); err != nil {
|
|
t.Logf("Clear returned error (may be expected): %v", err)
|
|
}
|
|
|
|
// Verify tokens are cleared
|
|
if token := session.GetAccessToken(); token != "" {
|
|
t.Errorf("Access token should be empty after clear, got: %q", token)
|
|
}
|
|
if token := session.GetRefreshToken(); token != "" {
|
|
t.Errorf("Refresh token should be empty after clear, got: %q", token)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
_ = edgeGen
|
|
_ = runner
|
|
}
|
|
|
|
// TestLargeIDTokenChunking tests ID token chunking using structured approach
|
|
func TestLargeIDTokenChunking(t *testing.T) {
|
|
config := GetTestConfig()
|
|
if config.ShouldSkipTest(t, TestTypeExtended) {
|
|
return
|
|
}
|
|
|
|
runner := NewTestSuiteRunner()
|
|
|
|
tests := []TableTestCase{
|
|
{
|
|
Name: "Large ID token chunking 20KB",
|
|
Description: "Test that large ID tokens are properly chunked",
|
|
Input: 20000,
|
|
Expected: 2, // Expect at least 2 chunks
|
|
},
|
|
{
|
|
Name: "Very large ID token chunking 50KB",
|
|
Description: "Test very large ID token chunking",
|
|
Input: 50000,
|
|
Expected: 5, // Expect at least 5 chunks
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.Name, func(t *testing.T) {
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
tokenSize := test.Input.(int)
|
|
minExpectedChunks := test.Expected.(int)
|
|
|
|
// Create a large ID token
|
|
largeIDToken := createLargeIDToken(tokenSize)
|
|
t.Logf("Created large ID token with length: %d", len(largeIDToken))
|
|
|
|
// Create a request and response recorder
|
|
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Get session and set large ID token
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Set the large ID token
|
|
session.SetIDToken(largeIDToken)
|
|
t.Logf("Set large ID token in session")
|
|
|
|
// Save the session to trigger chunking
|
|
err = session.Save(req, rr)
|
|
if err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Verify token retrieval integrity
|
|
retrievedToken := session.GetIDToken()
|
|
t.Logf("Retrieved ID token length: %d", len(retrievedToken))
|
|
if len(retrievedToken) != len(largeIDToken) {
|
|
t.Errorf("Token length mismatch: expected %d, got %d", len(largeIDToken), len(retrievedToken))
|
|
}
|
|
|
|
// Verify that chunked cookies were created
|
|
cookies := rr.Result().Cookies()
|
|
t.Logf("Total cookies in response: %d", len(cookies))
|
|
|
|
var chunkCookies []*http.Cookie
|
|
idTokenCookieName := defaultCookiePrefix + idTokenSuffix
|
|
for _, cookie := range cookies {
|
|
if strings.HasPrefix(cookie.Name, idTokenCookieName+"_") {
|
|
chunkCookies = append(chunkCookies, cookie)
|
|
}
|
|
}
|
|
|
|
// Verify minimum expected chunks
|
|
if len(chunkCookies) < minExpectedChunks {
|
|
t.Fatalf("Expected at least %d chunk cookies, got %d", minExpectedChunks, len(chunkCookies))
|
|
}
|
|
|
|
// Test token retrieval from chunked cookies
|
|
newReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
|
for _, cookie := range cookies {
|
|
newReq.AddCookie(cookie)
|
|
}
|
|
|
|
retrievedSession, err := sm.GetSession(newReq)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session from chunked cookies: %v", err)
|
|
}
|
|
|
|
retrievedToken2 := retrievedSession.GetIDToken()
|
|
|
|
// Verify the retrieved token matches the original
|
|
if retrievedToken2 != largeIDToken {
|
|
t.Errorf("Retrieved ID token doesn't match original. Expected length: %d, got: %d",
|
|
len(largeIDToken), len(retrievedToken2))
|
|
}
|
|
|
|
// Test clearing the ID token removes all chunks
|
|
retrievedSession.SetIDToken("")
|
|
|
|
clearRR := httptest.NewRecorder()
|
|
err = retrievedSession.Save(newReq, clearRR)
|
|
if err != nil {
|
|
t.Fatalf("Failed to save session after clearing ID token: %v", err)
|
|
}
|
|
|
|
// Verify chunks are expired (MaxAge = -1)
|
|
clearCookies := clearRR.Result().Cookies()
|
|
idTokenCookieName2 := defaultCookiePrefix + idTokenSuffix
|
|
for _, cookie := range clearCookies {
|
|
if strings.HasPrefix(cookie.Name, idTokenCookieName2+"_") {
|
|
if cookie.MaxAge != -1 {
|
|
t.Errorf("Expected chunk cookie %s to be expired (MaxAge=-1), got MaxAge=%d",
|
|
cookie.Name, cookie.MaxAge)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
_ = runner
|
|
}
|
|
|
|
// ============================================================================
|
|
// CONSOLIDATED SESSION TESTS
|
|
// ============================================================================
|
|
|
|
// TestSessionConsolidated runs all consolidated session tests
|
|
func TestSessionConsolidated(t *testing.T) {
|
|
testCases := []SessionTestCase{
|
|
// Session Creation Tests
|
|
{
|
|
name: "session_basic_creation",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
// Simulate session creation
|
|
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
|
f.requests = append(f.requests, req)
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Session creation should succeed")
|
|
assert.Greater(t, f.metrics.SessionsCreated, int64(0), "Session should be created")
|
|
},
|
|
},
|
|
{
|
|
name: "session_pool_reuse",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
iterations: 100,
|
|
execute: func(f *SessionTestFramework) error {
|
|
for i := 0; i < 100; i++ {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "Sessions should be properly pooled")
|
|
},
|
|
},
|
|
{
|
|
name: "session_concurrent_creation",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
concurrent: true,
|
|
iterations: 50,
|
|
execute: func(f *SessionTestFramework) error {
|
|
var wg sync.WaitGroup
|
|
errs := make(chan error, 50)
|
|
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
// Simulate concurrent session creation
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%d", id), nil)
|
|
f.mu.Lock()
|
|
f.requests = append(f.requests, req)
|
|
f.mu.Unlock()
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errs)
|
|
|
|
for err := range errs {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(50), f.metrics.SessionsCreated, "All concurrent sessions should be created")
|
|
},
|
|
},
|
|
|
|
// Session Validation Tests
|
|
{
|
|
name: "session_token_validation",
|
|
scenario: "validation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
token := f.generateTestToken("access", 3600)
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Validate token format
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Token validation should succeed")
|
|
assert.Greater(t, f.metrics.TokensValidated, int64(0))
|
|
},
|
|
},
|
|
{
|
|
name: "session_corrupted_token_detection",
|
|
scenario: "validation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
token := f.generateTestToken("access", 3600)
|
|
// Corrupt the token by modifying the signature
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
|
|
// Corrupt the signature part
|
|
corrupted := parts[0] + "." + parts[1] + ".corrupted!"
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Validate should detect corruption - corrupted tokens should fail validation
|
|
corruptedParts := strings.Split(corrupted, ".")
|
|
if len(corruptedParts) == 3 {
|
|
// Try to decode the corrupted signature
|
|
_, err := base64.RawURLEncoding.DecodeString(corruptedParts[2])
|
|
if err == nil {
|
|
return fmt.Errorf("corruption not detected")
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Corruption detection should work")
|
|
},
|
|
},
|
|
{
|
|
name: "session_expired_token_handling",
|
|
scenario: "validation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Generate an expired token
|
|
token := f.generateTestToken("access", -3600) // negative expiry
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Parse and check expiry
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) == 3 {
|
|
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
|
var claims map[string]interface{}
|
|
json.Unmarshal(payload, &claims)
|
|
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
if exp < float64(time.Now().Unix()) {
|
|
atomic.AddInt64(&f.metrics.ErrorCount, 1)
|
|
return nil // Expected behavior
|
|
}
|
|
}
|
|
}
|
|
return fmt.Errorf("expired token not detected")
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Expired token should be detected")
|
|
assert.Greater(t, f.metrics.ErrorCount, int64(0))
|
|
},
|
|
},
|
|
|
|
// Session Expiration Tests
|
|
{
|
|
name: "session_ttl_expiration",
|
|
scenario: "expiration",
|
|
sessionType: "user",
|
|
timeout: 3 * time.Second,
|
|
execute: func(f *SessionTestFramework) error {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
// Simulate session with short TTL
|
|
time.Sleep(100 * time.Millisecond) // Don't sleep for full timeout
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
|
|
},
|
|
},
|
|
{
|
|
name: "session_refresh_token_expiry",
|
|
scenario: "expiration",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
refreshToken := f.generateTestToken("refresh", 86400)
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Check refresh token is valid for longer period
|
|
parts := strings.Split(refreshToken, ".")
|
|
if len(parts) == 3 {
|
|
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
|
var claims map[string]interface{}
|
|
json.Unmarshal(payload, &claims)
|
|
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
timeUntilExpiry := time.Until(time.Unix(int64(exp), 0))
|
|
if timeUntilExpiry < 23*time.Hour {
|
|
return fmt.Errorf("refresh token expiry too short: %v", timeUntilExpiry)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Refresh token should have correct expiry")
|
|
},
|
|
},
|
|
|
|
// Session Persistence Tests
|
|
{
|
|
name: "session_cookie_persistence",
|
|
scenario: "persistence",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
// Set session cookie
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session_id",
|
|
Value: "test-session-123",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: f.config.EnableHTTPS,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
f.requests = append(f.requests, req)
|
|
f.responses = append(f.responses, w)
|
|
|
|
// Verify cookie was set
|
|
cookies := w.Result().Cookies()
|
|
if len(cookies) == 0 {
|
|
return fmt.Errorf("no cookies set")
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, f.responses, "Response should be recorded")
|
|
},
|
|
},
|
|
{
|
|
name: "session_state_preservation",
|
|
scenario: "persistence",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Store state
|
|
state := map[string]interface{}{
|
|
"user_id": "test-user",
|
|
"email": "test@example.com",
|
|
"roles": []string{"user", "admin"},
|
|
}
|
|
|
|
// Serialize and deserialize to test persistence
|
|
data, err := json.Marshal(state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var restored map[string]interface{}
|
|
if err := json.Unmarshal(data, &restored); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify state preserved
|
|
if restored["user_id"] != state["user_id"] {
|
|
return fmt.Errorf("state not preserved")
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Session state should be preserved")
|
|
},
|
|
},
|
|
|
|
// Session Cleanup Tests
|
|
{
|
|
name: "session_proper_cleanup",
|
|
scenario: "cleanup",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Create and destroy sessions
|
|
for i := 0; i < 10; i++ {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
sessionID := fmt.Sprintf("session-%d", i)
|
|
f.sessionIDs = append(f.sessionIDs, sessionID)
|
|
}
|
|
|
|
// Cleanup all sessions
|
|
for range f.sessionIDs {
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}
|
|
f.sessionIDs = nil
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
|
|
assert.Empty(t, f.sessionIDs, "All sessions should be cleaned up")
|
|
},
|
|
},
|
|
{
|
|
name: "session_goroutine_leak_prevention",
|
|
scenario: "cleanup",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
initialGoroutines := runtime.NumGoroutine()
|
|
|
|
// Create sessions that might spawn goroutines
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
time.Sleep(10 * time.Millisecond)
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
runtime.GC()
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
finalGoroutines := runtime.NumGoroutine()
|
|
if finalGoroutines > initialGoroutines+2 { // Allow small variance
|
|
return fmt.Errorf("goroutine leak detected: %d -> %d", initialGoroutines, finalGoroutines)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "No goroutine leaks should occur")
|
|
},
|
|
},
|
|
|
|
// Session Chunking Tests
|
|
{
|
|
name: "session_large_token_chunking",
|
|
scenario: "chunking",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Generate a large token that requires chunking
|
|
largeToken := f.generateLargeToken(10000) // 10KB token
|
|
|
|
// Calculate expected chunks
|
|
chunkSize := f.config.MaxChunkSize
|
|
expectedChunks := (len(largeToken) + chunkSize - 1) / chunkSize
|
|
|
|
// Simulate chunking
|
|
chunks := make([]string, 0)
|
|
for i := 0; i < len(largeToken); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(largeToken) {
|
|
end = len(largeToken)
|
|
}
|
|
chunks = append(chunks, largeToken[i:end])
|
|
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
|
|
}
|
|
|
|
if len(chunks) != expectedChunks {
|
|
return fmt.Errorf("expected %d chunks, got %d", expectedChunks, len(chunks))
|
|
}
|
|
|
|
// Simulate reconstruction
|
|
reconstructed := strings.Join(chunks, "")
|
|
if reconstructed != largeToken {
|
|
return fmt.Errorf("token reconstruction failed")
|
|
}
|
|
atomic.AddInt64(&f.metrics.ChunksRetrieved, int64(len(chunks)))
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Token chunking should work correctly")
|
|
assert.Greater(t, f.metrics.ChunksCreated, int64(0))
|
|
assert.Equal(t, f.metrics.ChunksCreated, f.metrics.ChunksRetrieved)
|
|
},
|
|
},
|
|
{
|
|
name: "session_chunk_boundary_validation",
|
|
scenario: "chunking",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Test exact boundary conditions
|
|
testSizes := []int{
|
|
f.config.MaxChunkSize - 1,
|
|
f.config.MaxChunkSize,
|
|
f.config.MaxChunkSize + 1,
|
|
f.config.MaxChunkSize * 2,
|
|
f.config.MaxChunkSize*2 - 1,
|
|
f.config.MaxChunkSize*2 + 1,
|
|
}
|
|
|
|
for _, size := range testSizes {
|
|
token := f.generateLargeToken(size)
|
|
actualSize := len(token)
|
|
expectedChunks := (actualSize + f.config.MaxChunkSize - 1) / f.config.MaxChunkSize
|
|
|
|
actualChunks := 0
|
|
for i := 0; i < len(token); i += f.config.MaxChunkSize {
|
|
actualChunks++
|
|
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
|
|
}
|
|
|
|
if actualChunks != expectedChunks {
|
|
return fmt.Errorf("size %d (actual token size %d): expected %d chunks, got %d", size, actualSize, expectedChunks, actualChunks)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Chunk boundaries should be handled correctly")
|
|
},
|
|
},
|
|
|
|
// Session Security Tests
|
|
{
|
|
name: "session_csrf_token_management",
|
|
scenario: "security",
|
|
sessionType: "csrf",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Generate CSRF token
|
|
csrfToken := make([]byte, 32)
|
|
if _, err := rand.Read(csrfToken); err != nil {
|
|
return err
|
|
}
|
|
|
|
csrfString := base64.RawURLEncoding.EncodeToString(csrfToken)
|
|
|
|
// Store in session
|
|
f.testTokens["csrf"] = csrfString
|
|
|
|
// Validate CSRF token
|
|
if len(csrfString) < 40 {
|
|
return fmt.Errorf("CSRF token too short")
|
|
}
|
|
|
|
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "CSRF token should be properly managed")
|
|
assert.NotEmpty(t, f.testTokens["csrf"])
|
|
},
|
|
},
|
|
{
|
|
name: "session_injection_prevention",
|
|
scenario: "security",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Test various injection attempts
|
|
maliciousInputs := []string{
|
|
`{"admin": true}`,
|
|
`<script>alert('xss')</script>`,
|
|
`'; DROP TABLE sessions; --`,
|
|
`../../../etc/passwd`,
|
|
string([]byte{0x00, 0x01, 0x02}), // null bytes
|
|
}
|
|
|
|
for _, input := range maliciousInputs {
|
|
// Validate that input is properly sanitized
|
|
sanitized := base64.StdEncoding.EncodeToString([]byte(input))
|
|
decoded, err := base64.StdEncoding.DecodeString(sanitized)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if string(decoded) != input {
|
|
return fmt.Errorf("sanitization changed input unexpectedly")
|
|
}
|
|
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Injection attempts should be handled safely")
|
|
},
|
|
},
|
|
{
|
|
name: "session_secure_cookie_settings",
|
|
scenario: "security",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
w := httptest.NewRecorder()
|
|
|
|
// Test secure cookie settings
|
|
cookie := &http.Cookie{
|
|
Name: "session",
|
|
Value: "test-session",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteStrictMode,
|
|
MaxAge: 3600,
|
|
}
|
|
|
|
http.SetCookie(w, cookie)
|
|
|
|
// Verify cookie attributes
|
|
cookies := w.Result().Cookies()
|
|
if len(cookies) == 0 {
|
|
return fmt.Errorf("no cookie set")
|
|
}
|
|
|
|
c := cookies[0]
|
|
if !c.HttpOnly {
|
|
return fmt.Errorf("cookie not HttpOnly")
|
|
}
|
|
if c.SameSite != http.SameSiteStrictMode {
|
|
return fmt.Errorf("incorrect SameSite setting")
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Secure cookie settings should be enforced")
|
|
},
|
|
},
|
|
|
|
// Session Stress Tests
|
|
{
|
|
name: "session_high_concurrency_stress",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
concurrent: true,
|
|
iterations: 1000,
|
|
timeout: 30 * time.Second,
|
|
execute: func(f *SessionTestFramework) error {
|
|
var wg sync.WaitGroup
|
|
errors := make([]error, 0)
|
|
|
|
// Run high concurrency test
|
|
concurrency := 100
|
|
iterations := 10
|
|
|
|
for i := 0; i < concurrency; i++ {
|
|
wg.Add(1)
|
|
go func(workerID int) {
|
|
defer wg.Done()
|
|
|
|
for j := 0; j < iterations; j++ {
|
|
// Create session
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
|
|
// Generate tokens
|
|
f.generateTestToken("access", 3600)
|
|
f.generateTestToken("refresh", 86400)
|
|
|
|
// Validate tokens
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 2)
|
|
|
|
// Cleanup session
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
|
|
// Small delay to simulate real usage
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if len(errors) > 0 {
|
|
return errors[0]
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "High concurrency stress test should pass")
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "All sessions should be cleaned up")
|
|
},
|
|
},
|
|
{
|
|
name: "session_memory_bounds_enforcement",
|
|
scenario: "cleanup",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
maxSessions := f.config.MaxSessions
|
|
|
|
// Try to create more sessions than allowed
|
|
for i := 0; i < maxSessions+100; i++ {
|
|
sessionID := fmt.Sprintf("session-%d", i)
|
|
f.sessionIDs = append(f.sessionIDs, sessionID)
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
|
|
// Enforce max sessions
|
|
if len(f.sessionIDs) > maxSessions {
|
|
// Remove oldest session
|
|
f.sessionIDs = f.sessionIDs[1:]
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}
|
|
}
|
|
|
|
if len(f.sessionIDs) > maxSessions {
|
|
return fmt.Errorf("max sessions exceeded: %d > %d", len(f.sessionIDs), maxSessions)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Memory bounds should be enforced")
|
|
assert.LessOrEqual(t, len(f.sessionIDs), f.config.MaxSessions)
|
|
},
|
|
},
|
|
}
|
|
|
|
// Run all test cases
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.skipReason != "" {
|
|
t.Skip(tc.skipReason)
|
|
}
|
|
|
|
framework := NewSessionTestFramework(t)
|
|
defer framework.Cleanup()
|
|
|
|
// Setup
|
|
if tc.setup != nil {
|
|
tc.setup(framework)
|
|
}
|
|
|
|
// Cleanup
|
|
if tc.cleanup != nil {
|
|
defer tc.cleanup(framework)
|
|
}
|
|
|
|
// Set timeout if specified
|
|
if tc.timeout > 0 {
|
|
timer := time.NewTimer(tc.timeout)
|
|
done := make(chan bool)
|
|
|
|
go func() {
|
|
err := tc.execute(framework)
|
|
tc.validate(t, err, framework)
|
|
done <- true
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
timer.Stop()
|
|
case <-timer.C:
|
|
t.Fatal("Test timeout exceeded")
|
|
}
|
|
} else {
|
|
// Execute test
|
|
err := tc.execute(framework)
|
|
|
|
// Validate results
|
|
tc.validate(t, err, framework)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// SESSION STATE PRESERVATION TESTS (6-HOUR TOKEN EXPIRY SCENARIOS)
|
|
// ============================================================================
|
|
|
|
// TestSessionStatePreservationWithExpiredTokens tests that session state is preserved
|
|
// during token expiry scenarios
|
|
func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
|
t.Log("Testing session state preservation with expired tokens")
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("test-session-key-32-bytes-long-12345", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Simulate real-world session data that should be preserved
|
|
originalUserData := map[string]interface{}{
|
|
"user_id": "user-12345",
|
|
"email": "test.user@company.com",
|
|
"name": "Test User",
|
|
"roles": []string{"admin", "user"},
|
|
"pref_theme": "dark",
|
|
"pref_lang": "en",
|
|
"last_active": "2023-01-01T10:00:00Z",
|
|
}
|
|
|
|
// Create initial session with valid tokens
|
|
req1 := httptest.NewRequest("GET", "/initial", nil)
|
|
rr1 := httptest.NewRecorder()
|
|
|
|
session1, err := sm.GetSession(req1)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get initial session: %v", err)
|
|
}
|
|
|
|
// Set up initial session state (what user has when first logging in)
|
|
session1.SetAuthenticated(true)
|
|
session1.SetUserIdentifier(originalUserData["email"].(string))
|
|
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
|
|
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
|
|
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
|
|
|
|
// Store additional user data in session - store individual values instead of map
|
|
for k, v := range originalUserData {
|
|
session1.mainSession.Values["user_data_"+k] = v
|
|
}
|
|
session1.mainSession.Values["session_created"] = time.Now().Unix() // Store as int64 for gob
|
|
session1.mainSession.Values["custom_flag"] = true
|
|
|
|
if err := session1.Save(req1, rr1); err != nil {
|
|
t.Fatalf("Failed to save initial session: %v", err)
|
|
}
|
|
|
|
initialCookies := rr1.Result().Cookies()
|
|
session1.ReturnToPool()
|
|
|
|
t.Log("Initial session created with user data")
|
|
|
|
// Fast-forward 6 hours - tokens expire due to browser inactivity
|
|
time.Sleep(10 * time.Millisecond) // Simulate time passage in test
|
|
|
|
// Create expired tokens (simulating what happens after 6 hours)
|
|
expiredTime := time.Now().Add(-6 * time.Hour)
|
|
expiredAccessToken := createExpiredJWTToken("user-12345", "test.user@company.com", expiredTime)
|
|
expiredIDToken := createExpiredJWTToken("user-12345", "test.user@company.com", expiredTime)
|
|
|
|
// User returns after inactivity and makes a request
|
|
req2 := httptest.NewRequest("GET", "/protected-resource", nil)
|
|
for _, cookie := range initialCookies {
|
|
req2.AddCookie(cookie)
|
|
}
|
|
|
|
session2, err := sm.GetSession(req2)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session after 6 hours: %v", err)
|
|
}
|
|
defer session2.ReturnToPool()
|
|
|
|
// Simulate what happens when middleware detects expired tokens
|
|
// It should preserve session state while attempting token refresh
|
|
originalAuth := session2.GetAuthenticated()
|
|
originalEmail := session2.GetUserIdentifier()
|
|
|
|
// Reconstruct user data from individual stored keys
|
|
originalUserDataStored := make(map[string]interface{})
|
|
for k := range originalUserData {
|
|
if storedValue, exists := session2.mainSession.Values["user_data_"+k]; exists {
|
|
originalUserDataStored[k] = storedValue
|
|
}
|
|
}
|
|
|
|
// Update session with expired tokens (what middleware does when tokens expire)
|
|
session2.SetAccessToken(expiredAccessToken)
|
|
session2.SetIDToken(expiredIDToken)
|
|
// Refresh token should still be valid
|
|
|
|
t.Log("Session loaded after 6-hour expiry, checking state preservation")
|
|
|
|
// Verify authentication state is preserved
|
|
if !originalAuth {
|
|
t.Error("Authentication state lost during session reload")
|
|
}
|
|
|
|
// Verify email is preserved
|
|
if originalEmail != originalUserData["email"].(string) {
|
|
t.Errorf("User email lost during session reload - Expected: %s, Got: %s",
|
|
originalUserData["email"], originalEmail)
|
|
}
|
|
|
|
// Verify custom user data is preserved
|
|
if len(originalUserDataStored) == 0 {
|
|
t.Error("All custom user data lost during session reload")
|
|
} else {
|
|
if originalUserDataStored["user_id"] != originalUserData["user_id"] {
|
|
t.Error("User ID lost from session data")
|
|
}
|
|
|
|
if originalUserDataStored["name"] != originalUserData["name"] {
|
|
t.Error("User name lost from session data")
|
|
}
|
|
|
|
if originalUserDataStored["pref_theme"] != originalUserData["pref_theme"] {
|
|
t.Error("User theme preference lost from session data")
|
|
}
|
|
|
|
if originalUserDataStored["pref_lang"] != originalUserData["pref_lang"] {
|
|
t.Error("User language preference lost from session data")
|
|
}
|
|
}
|
|
|
|
// Note: System may reject invalid/expired tokens during storage, which is acceptable behavior
|
|
currentAccessToken := session2.GetAccessToken()
|
|
if currentAccessToken != expiredAccessToken {
|
|
t.Logf("INFO: Access token was not stored (possibly rejected due to expiry) - Expected: %s, Got: %s",
|
|
expiredAccessToken, currentAccessToken)
|
|
}
|
|
|
|
// Verify that session can be saved again after token expiry without losing data
|
|
rr2 := httptest.NewRecorder()
|
|
if err := session2.Save(req2, rr2); err != nil {
|
|
t.Errorf("Cannot save session after token expiry: %v", err)
|
|
} else {
|
|
t.Log("Session successfully saved after token expiry")
|
|
|
|
// Verify cookies are still set
|
|
newCookies := rr2.Result().Cookies()
|
|
if len(newCookies) == 0 {
|
|
t.Error("No session cookies set after saving expired token session")
|
|
}
|
|
}
|
|
|
|
// Test session recovery after token refresh simulation
|
|
newAccessToken := "refreshed-access-token-longer-than-20-chars"
|
|
newIDToken := "refreshed-id-token-longer-than-20-chars"
|
|
newRefreshToken := "new-refresh-token-after-successful-renewal"
|
|
|
|
session2.SetAccessToken(newAccessToken)
|
|
session2.SetIDToken(newIDToken)
|
|
session2.SetRefreshToken(newRefreshToken)
|
|
|
|
// Verify all session data is still intact after token refresh
|
|
postRefreshAuth := session2.GetAuthenticated()
|
|
postRefreshEmail := session2.GetUserIdentifier()
|
|
userDataPresent := true
|
|
for k := range originalUserData {
|
|
if session2.mainSession.Values["user_data_"+k] == nil {
|
|
userDataPresent = false
|
|
break
|
|
}
|
|
}
|
|
|
|
if !postRefreshAuth {
|
|
t.Error("Authentication state lost after token refresh")
|
|
}
|
|
|
|
if postRefreshEmail != originalUserData["email"].(string) {
|
|
t.Error("User email lost after token refresh")
|
|
}
|
|
|
|
if !userDataPresent {
|
|
t.Error("User data lost after token refresh")
|
|
}
|
|
|
|
t.Log("Session state preservation test completed")
|
|
}
|
|
|
|
// TestSessionExpiryVsTokenExpiry tests the distinction between session expiry and token expiry
|
|
func TestSessionExpiryVsTokenExpiry(t *testing.T) {
|
|
t.Log("Testing session expiry vs token expiry distinction")
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("session-vs-token-test-key-32-bytes", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
scenarios := []struct {
|
|
name string
|
|
expectedBehavior string
|
|
sessionAge time.Duration
|
|
tokenExpiry time.Duration
|
|
sessionShouldExpire bool
|
|
tokenShouldRefresh bool
|
|
}{
|
|
{
|
|
name: "New session, expired tokens",
|
|
sessionAge: 5 * time.Minute,
|
|
tokenExpiry: -6 * time.Hour,
|
|
expectedBehavior: "Session valid, tokens should refresh",
|
|
sessionShouldExpire: false,
|
|
tokenShouldRefresh: true,
|
|
},
|
|
{
|
|
name: "Old session, valid tokens",
|
|
sessionAge: 25 * time.Hour,
|
|
tokenExpiry: 2 * time.Hour,
|
|
expectedBehavior: "Session expired, redirect to login even with valid tokens",
|
|
sessionShouldExpire: true,
|
|
tokenShouldRefresh: false,
|
|
},
|
|
{
|
|
name: "Both session and tokens expired",
|
|
sessionAge: 25 * time.Hour,
|
|
tokenExpiry: -6 * time.Hour,
|
|
expectedBehavior: "Both expired, clear session and redirect to login",
|
|
sessionShouldExpire: true,
|
|
tokenShouldRefresh: false,
|
|
},
|
|
{
|
|
name: "Recent session, recently expired tokens",
|
|
sessionAge: 30 * time.Minute,
|
|
tokenExpiry: -10 * time.Minute,
|
|
expectedBehavior: "Session valid, tokens recently expired, should refresh",
|
|
sessionShouldExpire: false,
|
|
tokenShouldRefresh: true,
|
|
},
|
|
}
|
|
|
|
for _, scenario := range scenarios {
|
|
t.Run(scenario.name, func(t *testing.T) {
|
|
t.Logf("Testing: %s", scenario.expectedBehavior)
|
|
|
|
// Create session at specific "age"
|
|
sessionCreatedAt := time.Now().Add(-scenario.sessionAge)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
// Set up session with specific creation time
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("test@example.com")
|
|
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
|
|
|
|
// Create tokens with specific expiry
|
|
tokenExpiredAt := time.Now().Add(scenario.tokenExpiry)
|
|
accessToken := createExpiredJWTToken("test-user", "test@example.com", tokenExpiredAt)
|
|
|
|
session.SetAccessToken(accessToken)
|
|
session.SetRefreshToken("test-refresh-token")
|
|
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Test session validity check
|
|
isSessionExpired := scenario.sessionAge > absoluteSessionTimeout
|
|
isTokenExpired := scenario.tokenExpiry < 0
|
|
|
|
t.Logf("Session age: %v (expired: %t)", scenario.sessionAge, isSessionExpired)
|
|
t.Logf("Token expiry: %v ago (expired: %t)", -scenario.tokenExpiry, isTokenExpired)
|
|
|
|
if scenario.sessionShouldExpire {
|
|
if isSessionExpired && session.GetAuthenticated() {
|
|
t.Errorf("Session should be expired after %v but is still authenticated", scenario.sessionAge)
|
|
}
|
|
} else {
|
|
if !isSessionExpired && !session.GetAuthenticated() {
|
|
t.Errorf("Session should be valid (age: %v) but shows as not authenticated", scenario.sessionAge)
|
|
}
|
|
}
|
|
|
|
if scenario.tokenShouldRefresh {
|
|
if !isTokenExpired {
|
|
t.Errorf("Test setup error - tokens should be expired but expiry is: %v", scenario.tokenExpiry)
|
|
}
|
|
t.Logf("Should attempt token refresh for scenario: %s", scenario.name)
|
|
} else {
|
|
if isSessionExpired {
|
|
t.Logf("Correctly identified that session is expired - no need to refresh tokens")
|
|
}
|
|
}
|
|
|
|
// Check for critical scenario: confusing session expiry with token expiry
|
|
if !isSessionExpired && isTokenExpired {
|
|
t.Logf("CRITICAL SCENARIO: Valid session (%v old) but expired tokens (%v ago)",
|
|
scenario.sessionAge, -scenario.tokenExpiry)
|
|
t.Logf("Expected: System should refresh tokens and continue session")
|
|
|
|
if scenario.name == "New session, expired tokens" && scenario.tokenExpiry == -6*time.Hour {
|
|
t.Logf("This represents the 6-hour browser inactivity scenario")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSessionCleanupOnTokenExpiry tests that session cleanup happens correctly
|
|
func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
|
t.Log("Testing session cleanup on token expiry")
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager("cleanup-test-key-32-bytes-long-123", false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
scenarios := []struct {
|
|
name string
|
|
shouldPreserve []string
|
|
shouldRemove []string
|
|
tokenExpiry time.Duration
|
|
shouldCleanup bool
|
|
}{
|
|
{
|
|
name: "Recently expired tokens - preserve session",
|
|
tokenExpiry: -30 * time.Minute,
|
|
shouldCleanup: false,
|
|
shouldPreserve: []string{"user_data", "preferences", "authentication"},
|
|
shouldRemove: []string{},
|
|
},
|
|
{
|
|
name: "Long expired tokens - cleanup selectively",
|
|
tokenExpiry: -25 * time.Hour,
|
|
shouldCleanup: true,
|
|
shouldPreserve: []string{},
|
|
shouldRemove: []string{"user_data", "preferences", "authentication"},
|
|
},
|
|
{
|
|
name: "6-hour expired tokens - preserve for refresh",
|
|
tokenExpiry: -6 * time.Hour,
|
|
shouldCleanup: false,
|
|
shouldPreserve: []string{"user_data", "preferences", "authentication"},
|
|
shouldRemove: []string{},
|
|
},
|
|
}
|
|
|
|
for _, scenario := range scenarios {
|
|
t.Run(scenario.name, func(t *testing.T) {
|
|
t.Logf("Testing cleanup behavior: %s", scenario.name)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
defer session.ReturnToPool()
|
|
|
|
// Set up session with data that should be preserved or removed
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("cleanup@example.com")
|
|
|
|
session.mainSession.Values["user_data"] = "Test User|user-123"
|
|
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
|
|
session.mainSession.Values["authentication"] = true
|
|
session.mainSession.Values["temp_data"] = "should-be-cleaned"
|
|
|
|
// Set expired tokens
|
|
expiredTime := time.Now().Add(scenario.tokenExpiry)
|
|
expiredToken := createExpiredJWTToken("user-123", "cleanup@example.com", expiredTime)
|
|
session.SetAccessToken(expiredToken)
|
|
session.SetRefreshToken("test-refresh-token")
|
|
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Simulate token expiry detection and cleanup logic
|
|
tokenExpired := scenario.tokenExpiry < 0
|
|
sessionTooOld := scenario.tokenExpiry < -absoluteSessionTimeout
|
|
|
|
t.Logf("Token expired: %t, Session too old: %t", tokenExpired, sessionTooOld)
|
|
|
|
// Check current session state before cleanup
|
|
preCleanupAuth := session.GetAuthenticated()
|
|
preCleanupData := session.mainSession.Values["user_data"]
|
|
preCleanupPrefs := session.mainSession.Values["preferences"]
|
|
|
|
if scenario.shouldCleanup {
|
|
if sessionTooOld {
|
|
session.SetAuthenticated(false)
|
|
session.SetUserIdentifier("")
|
|
session.SetAccessToken("")
|
|
session.SetRefreshToken("")
|
|
for key := range session.mainSession.Values {
|
|
delete(session.mainSession.Values, key)
|
|
}
|
|
t.Log("Applied full cleanup for expired session")
|
|
}
|
|
} else {
|
|
t.Log("Preserving session for token refresh")
|
|
}
|
|
|
|
// Check post-cleanup state
|
|
postCleanupAuth := session.GetAuthenticated()
|
|
postCleanupData := session.mainSession.Values["user_data"]
|
|
postCleanupPrefs := session.mainSession.Values["preferences"]
|
|
|
|
// Verify preservation expectations
|
|
for _, item := range scenario.shouldPreserve {
|
|
switch item {
|
|
case "authentication":
|
|
if !postCleanupAuth && preCleanupAuth {
|
|
t.Errorf("Authentication state was cleaned up but should be preserved")
|
|
}
|
|
case "user_data":
|
|
if postCleanupData == nil && preCleanupData != nil {
|
|
t.Errorf("User data was cleaned up but should be preserved")
|
|
}
|
|
case "preferences":
|
|
if postCleanupPrefs == nil && preCleanupPrefs != nil {
|
|
t.Errorf("User preferences were cleaned up but should be preserved")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Verify removal expectations
|
|
for _, item := range scenario.shouldRemove {
|
|
switch item {
|
|
case "authentication":
|
|
if postCleanupAuth && scenario.shouldCleanup {
|
|
t.Errorf("Authentication state not cleaned up when it should be")
|
|
}
|
|
case "user_data":
|
|
if postCleanupData != nil && scenario.shouldCleanup {
|
|
t.Errorf("User data not cleaned up when session is expired")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check the critical 6-hour scenario
|
|
if scenario.tokenExpiry == -6*time.Hour {
|
|
if !postCleanupAuth {
|
|
t.Error("6-hour token expiry caused session cleanup - session should be preserved for token refresh")
|
|
}
|
|
|
|
if postCleanupData == nil {
|
|
t.Error("6-hour token expiry caused user data loss - user data should be preserved during token refresh")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// HELPER FUNCTIONS
|
|
// ============================================================================
|
|
|
|
// Helper function to count objects in the session pool for a given manager
|
|
func getPooledObjects(sm *SessionManager) int {
|
|
var objects []*SessionData
|
|
maxAttempts := 100
|
|
|
|
for i := 0; i < maxAttempts; i++ {
|
|
obj := sm.sessionPool.Get()
|
|
if obj == nil {
|
|
break
|
|
}
|
|
|
|
sessionData, ok := obj.(*SessionData)
|
|
if !ok {
|
|
sm.sessionPool.Put(obj)
|
|
break
|
|
}
|
|
|
|
objects = append(objects, sessionData)
|
|
}
|
|
|
|
count := len(objects)
|
|
|
|
for _, obj := range objects {
|
|
sm.sessionPool.Put(obj)
|
|
}
|
|
|
|
return count
|
|
}
|
|
|
|
// createLargeIDToken creates a JWT-like token of specified size for testing
|
|
func createLargeIDToken(size int) string {
|
|
randomBytes := make([]byte, size*3/4)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
for i := range randomBytes {
|
|
randomBytes[i] = byte(i % 256)
|
|
}
|
|
}
|
|
|
|
encoded := base64.RawURLEncoding.EncodeToString(randomBytes)
|
|
|
|
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
|
|
|
if len(encoded) > size-len(header)-100 {
|
|
encoded = encoded[:size-len(header)-100]
|
|
}
|
|
|
|
signature := "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
|
|
|
return header + "." + encoded + "." + signature
|
|
}
|
|
|
|
// minInt returns the minimum of two integers
|
|
func minInt(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Helper function to create expired JWT tokens for testing
|
|
func createExpiredJWTToken(userID, email string, expiredTime time.Time) string {
|
|
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
|
|
|
claims := map[string]interface{}{
|
|
"sub": userID,
|
|
"email": email,
|
|
"exp": expiredTime.Unix(),
|
|
"iat": expiredTime.Add(-1 * time.Hour).Unix(),
|
|
"iss": "https://test-provider.com",
|
|
"aud": "test-client-id",
|
|
}
|
|
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
claimsEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
signature := "fake-signature-for-testing"
|
|
signatureEncoded := base64.RawURLEncoding.EncodeToString([]byte(signature))
|
|
|
|
return header + "." + claimsEncoded + "." + signatureEncoded
|
|
}
|
|
|
|
// TestCookiePrefixIsolation tests that different cookie prefixes create isolated sessions
|
|
// This addresses GitHub issue #87 where multiple middleware instances should not share sessions
|
|
func TestCookiePrefixIsolation(t *testing.T) {
|
|
logger := NewLogger("info")
|
|
encryptionKey := strings.Repeat("a", 32)
|
|
|
|
// Create two session managers with different cookie prefixes
|
|
sm1, err := NewSessionManager(encryptionKey, false, "", "_oidc_userauth_", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager 1: %v", err)
|
|
}
|
|
|
|
sm2, err := NewSessionManager(encryptionKey, false, "", "_oidc_adminauth_", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager 2: %v", err)
|
|
}
|
|
|
|
// Verify cookie names are different
|
|
if sm1.mainCookieName() == sm2.mainCookieName() {
|
|
t.Errorf("Expected different main cookie names, got same: %s", sm1.mainCookieName())
|
|
}
|
|
if sm1.accessTokenCookieName() == sm2.accessTokenCookieName() {
|
|
t.Errorf("Expected different access token cookie names, got same: %s", sm1.accessTokenCookieName())
|
|
}
|
|
|
|
// Verify cookie names have the correct prefix
|
|
expectedPrefix1 := "_oidc_userauth_"
|
|
expectedPrefix2 := "_oidc_adminauth_"
|
|
|
|
if !strings.HasPrefix(sm1.mainCookieName(), expectedPrefix1) {
|
|
t.Errorf("Expected main cookie name to start with %s, got %s", expectedPrefix1, sm1.mainCookieName())
|
|
}
|
|
if !strings.HasPrefix(sm2.mainCookieName(), expectedPrefix2) {
|
|
t.Errorf("Expected main cookie name to start with %s, got %s", expectedPrefix2, sm2.mainCookieName())
|
|
}
|
|
|
|
t.Logf("Session Manager 1 cookies: main=%s, access=%s, refresh=%s, id=%s",
|
|
sm1.mainCookieName(), sm1.accessTokenCookieName(), sm1.refreshTokenCookieName(), sm1.idTokenCookieName())
|
|
t.Logf("Session Manager 2 cookies: main=%s, access=%s, refresh=%s, id=%s",
|
|
sm2.mainCookieName(), sm2.accessTokenCookieName(), sm2.refreshTokenCookieName(), sm2.idTokenCookieName())
|
|
}
|
|
|
|
// TestCookiePrefixDefault tests that the default cookie prefix is applied when none is provided
|
|
func TestCookiePrefixDefault(t *testing.T) {
|
|
logger := NewLogger("info")
|
|
encryptionKey := strings.Repeat("a", 32)
|
|
|
|
// Create session manager without cookie prefix (should use default)
|
|
sm, err := NewSessionManager(encryptionKey, false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
// Verify default prefix is used
|
|
expectedPrefix := defaultCookiePrefix
|
|
if !strings.HasPrefix(sm.mainCookieName(), expectedPrefix) {
|
|
t.Errorf("Expected default prefix %s, got cookie name %s", expectedPrefix, sm.mainCookieName())
|
|
}
|
|
|
|
// Verify full cookie names
|
|
expectedMain := defaultCookiePrefix + mainCookieSuffix
|
|
expectedAccess := defaultCookiePrefix + accessTokenSuffix
|
|
expectedRefresh := defaultCookiePrefix + refreshTokenSuffix
|
|
expectedID := defaultCookiePrefix + idTokenSuffix
|
|
|
|
if sm.mainCookieName() != expectedMain {
|
|
t.Errorf("Expected main cookie name %s, got %s", expectedMain, sm.mainCookieName())
|
|
}
|
|
if sm.accessTokenCookieName() != expectedAccess {
|
|
t.Errorf("Expected access cookie name %s, got %s", expectedAccess, sm.accessTokenCookieName())
|
|
}
|
|
if sm.refreshTokenCookieName() != expectedRefresh {
|
|
t.Errorf("Expected refresh cookie name %s, got %s", expectedRefresh, sm.refreshTokenCookieName())
|
|
}
|
|
if sm.idTokenCookieName() != expectedID {
|
|
t.Errorf("Expected ID cookie name %s, got %s", expectedID, sm.idTokenCookieName())
|
|
}
|
|
}
|