mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
1b49e133da
* Fix bug affecting Azure OIDC authentication ( and most likely others ) * Fixes issue #51 * Ensure that appended roles are unique. Update the documentation. * Improvements targetting possible memory usage spikes. * Additional fixes and cleanup * Refactoring code to fix the issues identified by the users. * Modernize run * Fieldalignment * Multiple changes to improve performance and reduce complexity. - Optimise the errors and recovery. - Deduplicate code in metadata cache. - Remove unused performance monitoring code. - Simplify session management and settings handling. * Fix claims issue. * Add ability to overwrite the default scopes in the settings file * Well.. that escalated quickly. Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ). * Bugfix #51: Ensures that user provided scopes overrides work. * fixup! Bugfix #51: Ensures that user provided scopes overrides work. * fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work. * Abstract the provider logic into a separate package. * Additional micro fixes and cleanups. * Simplify all the things. * fixup! Simplify all the things. * fixup! fixup! Simplify all the things. * fixup! fixup! fixup! Simplify all the things. * fixup! fixup! fixup! fixup! Simplify all the things. * ... * Cleanup tests. * fixup! Cleanup tests. * fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! fixup! Cleanup tests. * Issue #53: Fix CSRF token handling in reverse proxy 1. ✅ HTTPS Detection Fixed (session.go:723) - Now uses X-Forwarded-Proto header instead of r.URL.Scheme - Properly detects HTTPS in reverse proxy environments 2. ✅ SameSite Cookie Attribute Fixed - Removed automatic SameSiteStrictMode for HTTPS (would break OAuth) - Keeps SameSiteLaxMode to allow OAuth callbacks from external domains - Only uses Strict for AJAX requests which don't involve OAuth redirects 3. ✅ Cookie Domain Handling Fixed - Now respects X-Forwarded-Host header for cookie domain - Ensures cookies are set for the public domain, not internal proxy domain 4. ✅ EnhanceSessionSecurity Properly Integrated - Function is now actually called during session save - Applies security enhancements without breaking OAuth flow Why Issue #53 Failed Before: 1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back) 2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail 3. Cookie domain might have been wrong (internal vs public domain) Why It Works Now: 1. Cookies are properly marked Secure for HTTPS 2. Uses SameSite=Lax to allow OAuth provider callbacks 3. Cookie domain uses public domain from X-Forwarded-Host 4. CSRF token persists through the entire OAuth flow * Next set of enhancements together with memory usage improvements. * Memory leak fixes and optimisations. * CSRF and Cookie Domain fixes * fixup! CSRF and Cookie Domain fixes * Metadata cache leak fix + profiling * fixup! Metadata cache leak fix + profiling * Memory leaks hunting, part 1337. * Further pursue of perfection. * fixup! Further pursue of perfection. * fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * Clear race conditions * fixup! Clear race conditions * Weekend fun with memory leaks * Splitting code into multiple files with reasonable testing coverage. ``` ok github.com/lukaszraczylo/traefikoidc 117.017s coverage: 72.6% of statements ok github.com/lukaszraczylo/traefikoidc/auth 0.505s coverage: 87.1% of statements ok github.com/lukaszraczylo/traefikoidc/circuit_breaker 0.283s coverage: 99.0% of statements github.com/lukaszraczylo/traefikoidc/config coverage: 0.0% of statements ok github.com/lukaszraczylo/traefikoidc/handlers 0.349s coverage: 98.2% of statements ok github.com/lukaszraczylo/traefikoidc/internal/providers (cached) coverage: 94.3% of statements ok github.com/lukaszraczylo/traefikoidc/middleware 0.808s coverage: 78.0% of statements ok github.com/lukaszraczylo/traefikoidc/recovery 0.653s coverage: 100.0% of statements ok github.com/lukaszraczylo/traefikoidc/session/chunking (cached) coverage: 87.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/core (cached) coverage: 85.6% of statements ok github.com/lukaszraczylo/traefikoidc/session/crypto (cached) coverage: 81.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/storage (cached) coverage: 93.5% of statements ok github.com/lukaszraczylo/traefikoidc/session/validators (cached) coverage: 98.8% of statements ```` * fixup! Splitting code into multiple files with reasonable testing coverage. * fixup! fixup! Splitting code into multiple files with reasonable testing coverage. * Weekend fun with further optimisations. * fixup! Weekend fun with further optimisations. * fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * Pre-release cleanup. * Enhance test coverage. * fixup! Enhance test coverage. * fixup! fixup! Enhance test coverage. * fixup! fixup! fixup! Enhance test coverage.
1001 lines
28 KiB
Go
1001 lines
28 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// SessionTestCase represents a comprehensive session test scenario
|
|
type SessionTestCase struct {
|
|
name string
|
|
scenario string // "creation", "validation", "expiration", "persistence", "cleanup", "chunking", "security"
|
|
sessionType string // "user", "admin", "api", "guest", "csrf"
|
|
setup func(*SessionTestFramework)
|
|
execute func(*SessionTestFramework) error
|
|
validate func(*testing.T, error, *SessionTestFramework)
|
|
cleanup func(*SessionTestFramework)
|
|
concurrent bool
|
|
iterations int
|
|
timeout time.Duration
|
|
skipReason string
|
|
}
|
|
|
|
// SessionTestFramework provides shared test infrastructure for session tests
|
|
type SessionTestFramework struct {
|
|
t *testing.T
|
|
mockProvider *httptest.Server
|
|
requests []*http.Request
|
|
responses []*httptest.ResponseRecorder
|
|
testTokens map[string]string
|
|
sessionIDs []string
|
|
mu sync.RWMutex
|
|
metrics *SessionTestMetrics
|
|
cleanupFuncs []func()
|
|
config *SessionTestConfig
|
|
}
|
|
|
|
// SessionTestMetrics tracks test performance metrics
|
|
type SessionTestMetrics struct {
|
|
SessionsCreated int64
|
|
SessionsDestroyed int64
|
|
TokensGenerated int64
|
|
TokensValidated int64
|
|
ChunksCreated int64
|
|
ChunksRetrieved int64
|
|
ErrorCount int64
|
|
Duration time.Duration
|
|
}
|
|
|
|
// SessionTestConfig holds test configuration
|
|
type SessionTestConfig struct {
|
|
MaxChunkSize int
|
|
MaxSessions int
|
|
EnableHTTPS bool
|
|
CookieDomain string
|
|
SessionTimeout time.Duration
|
|
EncryptionKey string
|
|
EnableCompression bool
|
|
}
|
|
|
|
// NewSessionTestFramework creates a new test framework instance
|
|
func NewSessionTestFramework(t *testing.T) *SessionTestFramework {
|
|
framework := &SessionTestFramework{
|
|
t: t,
|
|
requests: make([]*http.Request, 0),
|
|
responses: make([]*httptest.ResponseRecorder, 0),
|
|
testTokens: make(map[string]string),
|
|
sessionIDs: make([]string, 0),
|
|
metrics: &SessionTestMetrics{},
|
|
cleanupFuncs: make([]func(), 0),
|
|
config: &SessionTestConfig{
|
|
MaxChunkSize: 3900,
|
|
MaxSessions: 1000,
|
|
EnableHTTPS: false,
|
|
CookieDomain: "",
|
|
SessionTimeout: time.Hour,
|
|
EncryptionKey: generateTestKey(),
|
|
EnableCompression: true,
|
|
},
|
|
}
|
|
|
|
// Setup mock OIDC provider
|
|
framework.setupMockProvider()
|
|
|
|
return framework
|
|
}
|
|
|
|
// generateTestKey generates a test encryption key
|
|
func generateTestKey() string {
|
|
// 48 bytes = 384 bits for testing
|
|
return "0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
}
|
|
|
|
// setupMockProvider sets up a mock OIDC provider for testing
|
|
func (f *SessionTestFramework) setupMockProvider() {
|
|
f.mockProvider = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/.well-known/openid-configuration":
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"issuer": f.mockProvider.URL,
|
|
"authorization_endpoint": f.mockProvider.URL + "/auth",
|
|
"token_endpoint": f.mockProvider.URL + "/token",
|
|
"userinfo_endpoint": f.mockProvider.URL + "/userinfo",
|
|
"jwks_uri": f.mockProvider.URL + "/jwks",
|
|
})
|
|
case "/token":
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"access_token": f.generateTestToken("access", 3600),
|
|
"id_token": f.generateTestToken("id", 3600),
|
|
"refresh_token": f.generateTestToken("refresh", 86400),
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600,
|
|
})
|
|
case "/userinfo":
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"sub": "test-user-id",
|
|
"email": "test@example.com",
|
|
"name": "Test User",
|
|
})
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}))
|
|
|
|
f.cleanupFuncs = append(f.cleanupFuncs, f.mockProvider.Close)
|
|
}
|
|
|
|
// generateTestToken generates a test token
|
|
func (f *SessionTestFramework) generateTestToken(tokenType string, expiresIn int) string {
|
|
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
|
|
|
|
// Create a realistic JWT-like token for testing
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
|
|
|
claims := map[string]interface{}{
|
|
"iss": f.mockProvider.URL,
|
|
"sub": "test-user-id",
|
|
"aud": "test-client-id",
|
|
"exp": time.Now().Add(time.Duration(expiresIn) * time.Second).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"typ": tokenType,
|
|
}
|
|
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
// Generate a fake signature
|
|
signature := make([]byte, 64)
|
|
rand.Read(signature)
|
|
sig := base64.RawURLEncoding.EncodeToString(signature)
|
|
|
|
token := fmt.Sprintf("%s.%s.%s", header, payload, sig)
|
|
|
|
// Thread-safe write to map
|
|
f.mu.Lock()
|
|
f.testTokens[tokenType] = token
|
|
f.mu.Unlock()
|
|
|
|
return token
|
|
}
|
|
|
|
// generateLargeToken generates a token of specified size for testing chunking
|
|
func (f *SessionTestFramework) generateLargeToken(size int) string {
|
|
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
|
|
|
|
// Create base JWT structure
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
|
|
|
// Calculate how much padding we need in claims
|
|
baseSize := len(header) + 2 // for the dots
|
|
signatureSize := 86 // approximate base64 encoded signature size
|
|
paddingSize := size - baseSize - signatureSize - 100 // leave room for other claims
|
|
|
|
if paddingSize < 0 {
|
|
paddingSize = 0
|
|
}
|
|
|
|
// Create large padding data
|
|
padding := make([]byte, paddingSize)
|
|
for i := range padding {
|
|
padding[i] = byte('A' + (i % 26))
|
|
}
|
|
|
|
claims := map[string]interface{}{
|
|
"iss": f.mockProvider.URL,
|
|
"sub": "test-user-id",
|
|
"aud": "test-client-id",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"padding": base64.StdEncoding.EncodeToString(padding),
|
|
}
|
|
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
// Generate signature
|
|
signature := make([]byte, 64)
|
|
rand.Read(signature)
|
|
sig := base64.RawURLEncoding.EncodeToString(signature)
|
|
|
|
return fmt.Sprintf("%s.%s.%s", header, payload, sig)
|
|
}
|
|
|
|
// Cleanup performs framework cleanup
|
|
func (f *SessionTestFramework) Cleanup() {
|
|
for _, cleanup := range f.cleanupFuncs {
|
|
cleanup()
|
|
}
|
|
}
|
|
|
|
// TestSessionConsolidated runs all consolidated session tests
|
|
func TestSessionConsolidated(t *testing.T) {
|
|
testCases := []SessionTestCase{
|
|
// Session Creation Tests
|
|
{
|
|
name: "session_basic_creation",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
// Simulate session creation
|
|
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
|
f.requests = append(f.requests, req)
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Session creation should succeed")
|
|
assert.Greater(t, f.metrics.SessionsCreated, int64(0), "Session should be created")
|
|
},
|
|
},
|
|
{
|
|
name: "session_pool_reuse",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
iterations: 100,
|
|
execute: func(f *SessionTestFramework) error {
|
|
for i := 0; i < 100; i++ {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "Sessions should be properly pooled")
|
|
},
|
|
},
|
|
{
|
|
name: "session_concurrent_creation",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
concurrent: true,
|
|
iterations: 50,
|
|
execute: func(f *SessionTestFramework) error {
|
|
var wg sync.WaitGroup
|
|
errs := make(chan error, 50)
|
|
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
// Simulate concurrent session creation
|
|
req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%d", id), nil)
|
|
f.mu.Lock()
|
|
f.requests = append(f.requests, req)
|
|
f.mu.Unlock()
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errs)
|
|
|
|
for err := range errs {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(50), f.metrics.SessionsCreated, "All concurrent sessions should be created")
|
|
},
|
|
},
|
|
|
|
// Session Validation Tests
|
|
{
|
|
name: "session_token_validation",
|
|
scenario: "validation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
token := f.generateTestToken("access", 3600)
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Validate token format
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Token validation should succeed")
|
|
assert.Greater(t, f.metrics.TokensValidated, int64(0))
|
|
},
|
|
},
|
|
{
|
|
name: "session_corrupted_token_detection",
|
|
scenario: "validation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
token := f.generateTestToken("access", 3600)
|
|
// Corrupt the token by modifying the signature
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
return fmt.Errorf("invalid token format")
|
|
}
|
|
|
|
// Corrupt the signature part
|
|
corrupted := parts[0] + "." + parts[1] + ".corrupted!"
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Validate should detect corruption - corrupted tokens should fail validation
|
|
corruptedParts := strings.Split(corrupted, ".")
|
|
if len(corruptedParts) == 3 {
|
|
// Try to decode the corrupted signature
|
|
_, err := base64.RawURLEncoding.DecodeString(corruptedParts[2])
|
|
if err == nil {
|
|
return fmt.Errorf("corruption not detected")
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Corruption detection should work")
|
|
},
|
|
},
|
|
{
|
|
name: "session_expired_token_handling",
|
|
scenario: "validation",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Generate an expired token
|
|
token := f.generateTestToken("access", -3600) // negative expiry
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Parse and check expiry
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) == 3 {
|
|
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
|
var claims map[string]interface{}
|
|
json.Unmarshal(payload, &claims)
|
|
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
if exp < float64(time.Now().Unix()) {
|
|
atomic.AddInt64(&f.metrics.ErrorCount, 1)
|
|
return nil // Expected behavior
|
|
}
|
|
}
|
|
}
|
|
return fmt.Errorf("expired token not detected")
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Expired token should be detected")
|
|
assert.Greater(t, f.metrics.ErrorCount, int64(0))
|
|
},
|
|
},
|
|
|
|
// Session Expiration Tests
|
|
{
|
|
name: "session_ttl_expiration",
|
|
scenario: "expiration",
|
|
sessionType: "user",
|
|
timeout: 3 * time.Second,
|
|
execute: func(f *SessionTestFramework) error {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
// Simulate session with short TTL
|
|
time.Sleep(100 * time.Millisecond) // Don't sleep for full timeout
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
|
|
},
|
|
},
|
|
{
|
|
name: "session_refresh_token_expiry",
|
|
scenario: "expiration",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
refreshToken := f.generateTestToken("refresh", 86400)
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
// Check refresh token is valid for longer period
|
|
parts := strings.Split(refreshToken, ".")
|
|
if len(parts) == 3 {
|
|
payload, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
|
var claims map[string]interface{}
|
|
json.Unmarshal(payload, &claims)
|
|
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
timeUntilExpiry := time.Until(time.Unix(int64(exp), 0))
|
|
if timeUntilExpiry < 23*time.Hour {
|
|
return fmt.Errorf("refresh token expiry too short: %v", timeUntilExpiry)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Refresh token should have correct expiry")
|
|
},
|
|
},
|
|
|
|
// Session Persistence Tests
|
|
{
|
|
name: "session_cookie_persistence",
|
|
scenario: "persistence",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
// Set session cookie
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session_id",
|
|
Value: "test-session-123",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: f.config.EnableHTTPS,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
f.requests = append(f.requests, req)
|
|
f.responses = append(f.responses, w)
|
|
|
|
// Verify cookie was set
|
|
cookies := w.Result().Cookies()
|
|
if len(cookies) == 0 {
|
|
return fmt.Errorf("no cookies set")
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, f.responses, "Response should be recorded")
|
|
},
|
|
},
|
|
{
|
|
name: "session_state_preservation",
|
|
scenario: "persistence",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Store state
|
|
state := map[string]interface{}{
|
|
"user_id": "test-user",
|
|
"email": "test@example.com",
|
|
"roles": []string{"user", "admin"},
|
|
}
|
|
|
|
// Serialize and deserialize to test persistence
|
|
data, err := json.Marshal(state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var restored map[string]interface{}
|
|
if err := json.Unmarshal(data, &restored); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify state preserved
|
|
if restored["user_id"] != state["user_id"] {
|
|
return fmt.Errorf("state not preserved")
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Session state should be preserved")
|
|
},
|
|
},
|
|
|
|
// Session Cleanup Tests
|
|
{
|
|
name: "session_proper_cleanup",
|
|
scenario: "cleanup",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Create and destroy sessions
|
|
for i := 0; i < 10; i++ {
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
sessionID := fmt.Sprintf("session-%d", i)
|
|
f.sessionIDs = append(f.sessionIDs, sessionID)
|
|
}
|
|
|
|
// Cleanup all sessions
|
|
for range f.sessionIDs {
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}
|
|
f.sessionIDs = nil
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed)
|
|
assert.Empty(t, f.sessionIDs, "All sessions should be cleaned up")
|
|
},
|
|
},
|
|
{
|
|
name: "session_goroutine_leak_prevention",
|
|
scenario: "cleanup",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
initialGoroutines := runtime.NumGoroutine()
|
|
|
|
// Create sessions that might spawn goroutines
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
time.Sleep(10 * time.Millisecond)
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
runtime.GC()
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
finalGoroutines := runtime.NumGoroutine()
|
|
if finalGoroutines > initialGoroutines+2 { // Allow small variance
|
|
return fmt.Errorf("goroutine leak detected: %d -> %d", initialGoroutines, finalGoroutines)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "No goroutine leaks should occur")
|
|
},
|
|
},
|
|
|
|
// Session Chunking Tests
|
|
{
|
|
name: "session_large_token_chunking",
|
|
scenario: "chunking",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Generate a large token that requires chunking
|
|
largeToken := f.generateLargeToken(10000) // 10KB token
|
|
|
|
// Calculate expected chunks
|
|
chunkSize := f.config.MaxChunkSize
|
|
expectedChunks := (len(largeToken) + chunkSize - 1) / chunkSize
|
|
|
|
// Simulate chunking
|
|
chunks := make([]string, 0)
|
|
for i := 0; i < len(largeToken); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(largeToken) {
|
|
end = len(largeToken)
|
|
}
|
|
chunks = append(chunks, largeToken[i:end])
|
|
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
|
|
}
|
|
|
|
if len(chunks) != expectedChunks {
|
|
return fmt.Errorf("expected %d chunks, got %d", expectedChunks, len(chunks))
|
|
}
|
|
|
|
// Simulate reconstruction
|
|
reconstructed := strings.Join(chunks, "")
|
|
if reconstructed != largeToken {
|
|
return fmt.Errorf("token reconstruction failed")
|
|
}
|
|
atomic.AddInt64(&f.metrics.ChunksRetrieved, int64(len(chunks)))
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Token chunking should work correctly")
|
|
assert.Greater(t, f.metrics.ChunksCreated, int64(0))
|
|
assert.Equal(t, f.metrics.ChunksCreated, f.metrics.ChunksRetrieved)
|
|
},
|
|
},
|
|
{
|
|
name: "session_chunk_boundary_validation",
|
|
scenario: "chunking",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Test exact boundary conditions
|
|
testSizes := []int{
|
|
f.config.MaxChunkSize - 1,
|
|
f.config.MaxChunkSize,
|
|
f.config.MaxChunkSize + 1,
|
|
f.config.MaxChunkSize * 2,
|
|
f.config.MaxChunkSize*2 - 1,
|
|
f.config.MaxChunkSize*2 + 1,
|
|
}
|
|
|
|
for _, size := range testSizes {
|
|
token := f.generateLargeToken(size)
|
|
actualSize := len(token)
|
|
expectedChunks := (actualSize + f.config.MaxChunkSize - 1) / f.config.MaxChunkSize
|
|
|
|
actualChunks := 0
|
|
for i := 0; i < len(token); i += f.config.MaxChunkSize {
|
|
actualChunks++
|
|
atomic.AddInt64(&f.metrics.ChunksCreated, 1)
|
|
}
|
|
|
|
if actualChunks != expectedChunks {
|
|
return fmt.Errorf("size %d (actual token size %d): expected %d chunks, got %d", size, actualSize, expectedChunks, actualChunks)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Chunk boundaries should be handled correctly")
|
|
},
|
|
},
|
|
|
|
// Session Security Tests
|
|
{
|
|
name: "session_csrf_token_management",
|
|
scenario: "security",
|
|
sessionType: "csrf",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Generate CSRF token
|
|
csrfToken := make([]byte, 32)
|
|
if _, err := rand.Read(csrfToken); err != nil {
|
|
return err
|
|
}
|
|
|
|
csrfString := base64.RawURLEncoding.EncodeToString(csrfToken)
|
|
|
|
// Store in session
|
|
f.testTokens["csrf"] = csrfString
|
|
|
|
// Validate CSRF token
|
|
if len(csrfString) < 40 {
|
|
return fmt.Errorf("CSRF token too short")
|
|
}
|
|
|
|
atomic.AddInt64(&f.metrics.TokensGenerated, 1)
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "CSRF token should be properly managed")
|
|
assert.NotEmpty(t, f.testTokens["csrf"])
|
|
},
|
|
},
|
|
{
|
|
name: "session_injection_prevention",
|
|
scenario: "security",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
// Test various injection attempts
|
|
maliciousInputs := []string{
|
|
`{"admin": true}`,
|
|
`<script>alert('xss')</script>`,
|
|
`'; DROP TABLE sessions; --`,
|
|
`../../../etc/passwd`,
|
|
string([]byte{0x00, 0x01, 0x02}), // null bytes
|
|
}
|
|
|
|
for _, input := range maliciousInputs {
|
|
// Validate that input is properly sanitized
|
|
sanitized := base64.StdEncoding.EncodeToString([]byte(input))
|
|
decoded, err := base64.StdEncoding.DecodeString(sanitized)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if string(decoded) != input {
|
|
return fmt.Errorf("sanitization changed input unexpectedly")
|
|
}
|
|
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 1)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Injection attempts should be handled safely")
|
|
},
|
|
},
|
|
{
|
|
name: "session_secure_cookie_settings",
|
|
scenario: "security",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
w := httptest.NewRecorder()
|
|
|
|
// Test secure cookie settings
|
|
cookie := &http.Cookie{
|
|
Name: "session",
|
|
Value: "test-session",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteStrictMode,
|
|
MaxAge: 3600,
|
|
}
|
|
|
|
http.SetCookie(w, cookie)
|
|
|
|
// Verify cookie attributes
|
|
cookies := w.Result().Cookies()
|
|
if len(cookies) == 0 {
|
|
return fmt.Errorf("no cookie set")
|
|
}
|
|
|
|
c := cookies[0]
|
|
if !c.HttpOnly {
|
|
return fmt.Errorf("cookie not HttpOnly")
|
|
}
|
|
if c.SameSite != http.SameSiteStrictMode {
|
|
return fmt.Errorf("incorrect SameSite setting")
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Secure cookie settings should be enforced")
|
|
},
|
|
},
|
|
|
|
// Session Stress Tests
|
|
{
|
|
name: "session_high_concurrency_stress",
|
|
scenario: "creation",
|
|
sessionType: "user",
|
|
concurrent: true,
|
|
iterations: 1000,
|
|
timeout: 30 * time.Second,
|
|
execute: func(f *SessionTestFramework) error {
|
|
var wg sync.WaitGroup
|
|
errors := make([]error, 0)
|
|
|
|
// Run high concurrency test
|
|
concurrency := 100
|
|
iterations := 10
|
|
|
|
for i := 0; i < concurrency; i++ {
|
|
wg.Add(1)
|
|
go func(workerID int) {
|
|
defer wg.Done()
|
|
|
|
for j := 0; j < iterations; j++ {
|
|
// Create session
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
|
|
// Generate tokens
|
|
f.generateTestToken("access", 3600)
|
|
f.generateTestToken("refresh", 86400)
|
|
|
|
// Validate tokens
|
|
atomic.AddInt64(&f.metrics.TokensValidated, 2)
|
|
|
|
// Cleanup session
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
|
|
// Small delay to simulate real usage
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if len(errors) > 0 {
|
|
return errors[0]
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "High concurrency stress test should pass")
|
|
assert.Equal(t, f.metrics.SessionsCreated, f.metrics.SessionsDestroyed, "All sessions should be cleaned up")
|
|
},
|
|
},
|
|
{
|
|
name: "session_memory_bounds_enforcement",
|
|
scenario: "cleanup",
|
|
sessionType: "user",
|
|
execute: func(f *SessionTestFramework) error {
|
|
maxSessions := f.config.MaxSessions
|
|
|
|
// Try to create more sessions than allowed
|
|
for i := 0; i < maxSessions+100; i++ {
|
|
sessionID := fmt.Sprintf("session-%d", i)
|
|
f.sessionIDs = append(f.sessionIDs, sessionID)
|
|
atomic.AddInt64(&f.metrics.SessionsCreated, 1)
|
|
|
|
// Enforce max sessions
|
|
if len(f.sessionIDs) > maxSessions {
|
|
// Remove oldest session
|
|
f.sessionIDs = f.sessionIDs[1:]
|
|
atomic.AddInt64(&f.metrics.SessionsDestroyed, 1)
|
|
}
|
|
}
|
|
|
|
if len(f.sessionIDs) > maxSessions {
|
|
return fmt.Errorf("max sessions exceeded: %d > %d", len(f.sessionIDs), maxSessions)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
validate: func(t *testing.T, err error, f *SessionTestFramework) {
|
|
assert.NoError(t, err, "Memory bounds should be enforced")
|
|
assert.LessOrEqual(t, len(f.sessionIDs), f.config.MaxSessions)
|
|
},
|
|
},
|
|
}
|
|
|
|
// Run all test cases
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.skipReason != "" {
|
|
t.Skip(tc.skipReason)
|
|
}
|
|
|
|
framework := NewSessionTestFramework(t)
|
|
defer framework.Cleanup()
|
|
|
|
// Setup
|
|
if tc.setup != nil {
|
|
tc.setup(framework)
|
|
}
|
|
|
|
// Cleanup
|
|
if tc.cleanup != nil {
|
|
defer tc.cleanup(framework)
|
|
}
|
|
|
|
// Set timeout if specified
|
|
if tc.timeout > 0 {
|
|
timer := time.NewTimer(tc.timeout)
|
|
done := make(chan bool)
|
|
|
|
go func() {
|
|
err := tc.execute(framework)
|
|
tc.validate(t, err, framework)
|
|
done <- true
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
timer.Stop()
|
|
case <-timer.C:
|
|
t.Fatal("Test timeout exceeded")
|
|
}
|
|
} else {
|
|
// Execute test
|
|
err := tc.execute(framework)
|
|
|
|
// Validate results
|
|
tc.validate(t, err, framework)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Benchmark tests
|
|
func BenchmarkSessionCreation(b *testing.B) {
|
|
framework := &SessionTestFramework{
|
|
metrics: &SessionTestMetrics{},
|
|
testTokens: make(map[string]string),
|
|
config: &SessionTestConfig{
|
|
MaxChunkSize: 3900,
|
|
MaxSessions: 1000,
|
|
},
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
atomic.AddInt64(&framework.metrics.SessionsCreated, 1)
|
|
atomic.AddInt64(&framework.metrics.SessionsDestroyed, 1)
|
|
}
|
|
|
|
b.ReportMetric(float64(framework.metrics.SessionsCreated)/float64(b.N), "sessions/op")
|
|
}
|
|
|
|
func BenchmarkTokenGeneration(b *testing.B) {
|
|
framework := NewSessionTestFramework(&testing.T{})
|
|
defer framework.Cleanup()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
framework.generateTestToken("access", 3600)
|
|
}
|
|
|
|
b.ReportMetric(float64(framework.metrics.TokensGenerated)/float64(b.N), "tokens/op")
|
|
}
|
|
|
|
func BenchmarkTokenValidation(b *testing.B) {
|
|
framework := NewSessionTestFramework(&testing.T{})
|
|
defer framework.Cleanup()
|
|
|
|
token := framework.generateTestToken("access", 3600)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) == 3 {
|
|
atomic.AddInt64(&framework.metrics.TokensValidated, 1)
|
|
}
|
|
}
|
|
|
|
b.ReportMetric(float64(framework.metrics.TokensValidated)/float64(b.N), "validations/op")
|
|
}
|
|
|
|
func BenchmarkLargeTokenChunking(b *testing.B) {
|
|
framework := &SessionTestFramework{
|
|
metrics: &SessionTestMetrics{},
|
|
testTokens: make(map[string]string),
|
|
config: &SessionTestConfig{
|
|
MaxChunkSize: 3900,
|
|
},
|
|
}
|
|
|
|
// Generate test token once
|
|
largeToken := strings.Repeat("A", 10000)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
chunks := make([]string, 0)
|
|
for j := 0; j < len(largeToken); j += framework.config.MaxChunkSize {
|
|
end := j + framework.config.MaxChunkSize
|
|
if end > len(largeToken) {
|
|
end = len(largeToken)
|
|
}
|
|
chunks = append(chunks, largeToken[j:end])
|
|
atomic.AddInt64(&framework.metrics.ChunksCreated, 1)
|
|
}
|
|
|
|
// Reconstruct
|
|
_ = strings.Join(chunks, "")
|
|
atomic.AddInt64(&framework.metrics.ChunksRetrieved, int64(len(chunks)))
|
|
}
|
|
|
|
b.ReportMetric(float64(framework.metrics.ChunksCreated)/float64(b.N), "chunks_created/op")
|
|
b.ReportMetric(float64(framework.metrics.ChunksRetrieved)/float64(b.N), "chunks_retrieved/op")
|
|
}
|
|
|
|
func BenchmarkConcurrentSessionOperations(b *testing.B) {
|
|
framework := &SessionTestFramework{
|
|
metrics: &SessionTestMetrics{},
|
|
testTokens: make(map[string]string),
|
|
sessionIDs: make([]string, 0),
|
|
config: &SessionTestConfig{
|
|
MaxSessions: 10000,
|
|
},
|
|
}
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
for pb.Next() {
|
|
// Create session
|
|
atomic.AddInt64(&framework.metrics.SessionsCreated, 1)
|
|
|
|
// Generate token
|
|
token := make([]byte, 32)
|
|
rand.Read(token)
|
|
tokenStr := base64.RawURLEncoding.EncodeToString(token)
|
|
atomic.AddInt64(&framework.metrics.TokensGenerated, 1)
|
|
|
|
// Validate token
|
|
if len(tokenStr) > 0 {
|
|
atomic.AddInt64(&framework.metrics.TokensValidated, 1)
|
|
}
|
|
|
|
// Destroy session
|
|
atomic.AddInt64(&framework.metrics.SessionsDestroyed, 1)
|
|
}
|
|
})
|
|
|
|
b.ReportMetric(float64(framework.metrics.SessionsCreated)/float64(b.N), "sessions/op")
|
|
b.ReportMetric(float64(framework.metrics.TokensGenerated)/float64(b.N), "tokens/op")
|
|
}
|