Compare commits

...

2 Commits

Author SHA1 Message Date
Arul 784b161732 Fix for cookie length (#58)
* Enhance session management by adding support for chunked id token in main session

* Add test for large ID token chunking in session management
2025-07-22 09:30:04 +01:00
lukaszraczylo efa0cd708b Fixes issue #50 2025-05-26 02:48:20 +01:00
5 changed files with 1060 additions and 19 deletions
+6 -2
View File
@@ -123,11 +123,12 @@ func parseJWT(tokenString string) (*JWT, error) {
// Parameters:
// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com").
// - clientID: The expected audience value (the client ID of this application).
// - skipReplayCheck: If true, skips JTI replay detection (used for revalidation of cached tokens).
//
// Returns:
// - nil if all standard claims are valid.
// - An error describing the first validation failure encountered.
func (j *JWT) Verify(issuerURL, clientID string) error {
func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error {
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
if !ok {
@@ -183,7 +184,10 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
}
// Implement replay protection by checking the jti (JWT ID)
if jti, ok := claims["jti"].(string); ok {
// Skip replay check if explicitly requested (for revalidation scenarios)
shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0]
if jti, ok := claims["jti"].(string); ok && !shouldSkipReplay {
// Skip replay detection for tokens that are being verified from the cache
if j.Token == "" {
// This is a parsed JWT without the original token string,
+2 -2
View File
@@ -363,8 +363,8 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
return fmt.Errorf("signature verification failed: %w", err)
}
// Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
// Verify standard claims - skip replay check since it's already handled in VerifyToken
if err := jwt.Verify(t.issuerURL, t.clientID, true); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
+758
View File
@@ -2806,3 +2806,761 @@ func TestVerifyTimeConstraint(t *testing.T) {
})
}
} // Add missing closing brace for TestVerifyTimeConstraint
// ===== JWT REPLAY DETECTION TESTS =====
// These tests ensure the replay detection fix works correctly and prevents regressions
// TestJWTVerifyWithSkipReplayCheck tests the new skipReplayCheck parameter functionality
func TestJWTVerifyWithSkipReplayCheck(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache before test
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
tests := []struct {
name string
skipReplayCheck bool
firstCall bool
expectError bool
errorContains string
}{
{
name: "First verification with skipReplayCheck=false should succeed",
skipReplayCheck: false,
firstCall: true,
expectError: false,
},
{
name: "Second verification with skipReplayCheck=false should fail (replay detected)",
skipReplayCheck: false,
firstCall: false,
expectError: true,
errorContains: "token replay detected",
},
{
name: "Verification with skipReplayCheck=true should always succeed",
skipReplayCheck: true,
firstCall: false, // Even on subsequent calls
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.firstCall {
// Clear replay cache for first call tests
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
}
err := jwt.Verify("https://test-issuer.com", "test-client-id", tc.skipReplayCheck)
if tc.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', but got nil", tc.errorContains)
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
}
} else {
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}
})
}
}
// TestJWTVerifyBackwardCompatibility tests that calls without the skipReplayCheck parameter default to replay checking
func TestJWTVerifyBackwardCompatibility(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
// First call with old signature (no skipReplayCheck parameter) should succeed
err = jwt.Verify("https://test-issuer.com", "test-client-id")
if err != nil {
t.Errorf("First verification should succeed, got: %v", err)
}
// Second call with old signature should fail due to replay detection
err = jwt.Verify("https://test-issuer.com", "test-client-id")
if err == nil {
t.Error("Second verification should fail due to replay detection")
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error, got: %v", err)
}
}
// TestTokenReplayDetectionFalsePositiveFix tests the specific scenario that was causing false positives
func TestTokenReplayDetectionFalsePositiveFix(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Simulate the authentication flow that was causing false positives:
// 1. Initial authentication adds JTI to cache
// 2. Subsequent request validation should not trigger false positive
// Step 1: Initial authentication (this would add JTI to cache)
jwt1, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for initial auth: %v", err)
}
err = jwt1.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
if err != nil {
t.Fatalf("Initial authentication should succeed: %v", err)
}
// Step 2: Subsequent request validation (this should skip replay check to avoid false positive)
jwt2, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for subsequent request: %v", err)
}
err = jwt2.Verify("https://test-issuer.com", "test-client-id", true) // Skip replay check
if err != nil {
t.Errorf("Subsequent request validation should succeed with skipReplayCheck=true: %v", err)
}
// Step 3: Verify that actual replay attacks are still detected
jwt3, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
}
err = jwt3.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
if err == nil {
t.Error("Actual replay attack should be detected when skipReplayCheck=false")
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error, got: %v", err)
}
}
// TestAuthenticationFlowReplayDetection tests the complete authentication flow
func TestAuthenticationFlowReplayDetection(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Test the complete flow:
// 1. Initial authentication (should add JTI to cache)
// 2. Multiple subsequent requests (should not trigger false positives)
// 3. Actual replay attack from different source (should be detected)
// Step 1: Initial authentication
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Fatalf("Initial authentication should succeed: %v", err)
}
// Verify JTI is in cache
replayCacheMu.Lock()
_, exists := replayCache.Get(jti)
replayCacheMu.Unlock()
if !exists {
t.Error("JTI should be added to replay cache during initial authentication")
}
// Step 2: Subsequent requests (simulate normal request processing)
// These should use the token cache and skip replay detection
for i := 0; i < 3; i++ {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Subsequent request %d should succeed: %v", i+1, err)
}
}
// Step 3: Simulate actual replay attack by directly calling JWT.Verify with replay check
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
}
err = jwt.Verify("https://test-issuer.com", "test-client-id", false) // Force replay check
if err == nil {
t.Error("Actual replay attack should be detected")
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error, got: %v", err)
}
}
// TestActualReplayAttackDetection ensures real replay attacks are still properly detected
func TestActualReplayAttackDetection(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
// First verification should succeed
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
if err != nil {
t.Fatalf("First verification should succeed: %v", err)
}
// Simulate different types of replay attacks
replayTests := []struct {
name string
description string
}{
{
name: "Direct replay attack",
description: "Same token used again with replay checking enabled",
},
{
name: "Replay from different source",
description: "Token intercepted and replayed by attacker",
},
}
for _, rt := range replayTests {
t.Run(rt.name, func(t *testing.T) {
// Parse token again (simulating replay)
replayJWT, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT for replay test: %v", err)
}
// Attempt replay with normal replay checking
err = replayJWT.Verify("https://test-issuer.com", "test-client-id", false)
if err == nil {
t.Errorf("Replay attack should be detected for: %s", rt.description)
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error for %s, got: %v", rt.description, err)
}
})
}
}
// TestConcurrentTokenValidation tests thread safety of replay detection
func TestConcurrentTokenValidation(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Configure rate limiter to allow more requests for concurrent testing
ts.tOidc.limiter = rate.NewLimiter(rate.Limit(1000), 1000) // Allow 1000 requests per second with burst of 1000
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create multiple tokens with unique JTIs
var tokens []string
var jtis []string
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
for i := 0; i < 10; i++ {
jti := generateRandomString(16)
jtis = append(jtis, jti)
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT %d: %v", i, err)
}
tokens = append(tokens, token)
}
// Test concurrent validation
const numGoroutines = 20
const numIterations = 5
results := make(chan error, numGoroutines*numIterations)
for g := 0; g < numGoroutines; g++ {
go func(goroutineID int) {
for i := 0; i < numIterations; i++ {
tokenIndex := (goroutineID + i) % len(tokens)
token := tokens[tokenIndex]
// First validation should succeed
err := ts.tOidc.VerifyToken(token)
results <- err
// Subsequent validation with same token should also succeed (uses cache)
err = ts.tOidc.VerifyToken(token)
results <- err
}
}(g)
}
// Collect results
var errors []error
for i := 0; i < numGoroutines*numIterations*2; i++ {
if err := <-results; err != nil {
errors = append(errors, err)
}
}
// All validations should succeed (no race conditions)
if len(errors) > 0 {
t.Errorf("Expected no errors in concurrent validation, got %d errors: %v", len(errors), errors)
}
// Verify all JTIs are in cache
replayCacheMu.Lock()
for i, jti := range jtis {
if _, exists := replayCache.Get(jti); !exists {
t.Errorf("JTI %d (%s) should be in replay cache", i, jti)
}
}
replayCacheMu.Unlock()
}
// TestJTIBlacklistBehavior tests the JTI blacklist cache management
func TestJTIBlacklistBehavior(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Test JTI blacklist behavior
tests := []struct {
name string
action func() error
expectError bool
description string
}{
{
name: "Initial verification adds JTI to blacklist",
action: func() error {
return ts.tOidc.VerifyToken(token)
},
expectError: false,
description: "First verification should succeed and add JTI to blacklist",
},
{
name: "JTI exists in blacklist after verification",
action: func() error {
replayCacheMu.Lock()
defer replayCacheMu.Unlock()
if _, exists := replayCache.Get(jti); !exists {
return fmt.Errorf("JTI not found in blacklist cache")
}
return nil
},
expectError: false,
description: "JTI should be present in blacklist cache",
},
{
name: "Subsequent verification uses cache (no replay check)",
action: func() error {
return ts.tOidc.VerifyToken(token)
},
expectError: false,
description: "Subsequent verification should succeed using token cache",
},
{
name: "Direct JWT verification detects replay",
action: func() error {
jwt, err := parseJWT(token)
if err != nil {
return err
}
return jwt.Verify("https://test-issuer.com", "test-client-id", false)
},
expectError: true,
description: "Direct JWT verification should detect replay",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.action()
if tc.expectError {
if err == nil {
t.Errorf("Expected error for %s, but got nil", tc.description)
}
} else {
if err != nil {
t.Errorf("Expected no error for %s, but got: %v", tc.description, err)
}
}
})
}
}
// TestSessionBasedTokenRevalidation tests token revalidation in session-based scenarios
func TestSessionBasedTokenRevalidation(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Create a test JWT with unique JTI
jti := generateRandomString(16)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Simulate session-based token revalidation scenario
// This tests the specific case that was causing false positives
// Step 1: Initial authentication (callback processing)
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Fatalf("Initial authentication should succeed: %v", err)
}
// Step 2: Multiple session-based requests (normal request processing)
// These should not trigger replay detection false positives
for i := 0; i < 5; i++ {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Session request %d should succeed: %v", i+1, err)
}
}
// Step 3: Verify token is in both caches appropriately
// Check token cache
if _, exists := ts.tOidc.tokenCache.Get(token); !exists {
t.Error("Token should be in token cache")
}
// Check replay cache
replayCacheMu.Lock()
_, inReplayCache := replayCache.Get(jti)
replayCacheMu.Unlock()
if !inReplayCache {
t.Error("JTI should be in replay cache")
}
// Step 4: Verify that clearing token cache still allows validation
ts.tOidc.tokenCache = NewTokenCache() // Clear token cache
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Token validation should succeed even after cache clear: %v", err)
}
}
// TestEdgeCasesWithDifferentTokenTypes tests replay detection with different token types
func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Clear the global replay cache
replayCacheMu.Lock()
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
tests := []struct {
name string
tokenType string
claims map[string]interface{}
expectError bool
}{
{
name: "ID Token with JTI",
tokenType: "id_token",
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
"token_type": "id_token",
},
expectError: false,
},
{
name: "Access Token with JTI",
tokenType: "access_token",
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"scope": "openid profile email",
"jti": generateRandomString(16),
"token_type": "access_token",
},
expectError: false,
},
{
name: "Token without JTI",
tokenType: "no_jti",
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
// No JTI claim
},
expectError: false, // Should still work, just no replay protection
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create token with specific claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// First verification should succeed
err = ts.tOidc.VerifyToken(token)
if tc.expectError {
if err == nil {
t.Errorf("Expected error for token type %s, but got nil", tc.tokenType)
}
} else {
if err != nil {
t.Errorf("Expected no error for token type %s, but got: %v", tc.tokenType, err)
}
}
// Second verification should also succeed (uses cache)
if !tc.expectError {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Second verification should succeed for token type %s: %v", tc.tokenType, err)
}
}
// Test direct JWT verification for replay detection
if !tc.expectError && tc.claims["jti"] != nil {
jwt, err := parseJWT(token)
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
// This should detect replay for tokens with JTI
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
if err == nil {
t.Errorf("Expected replay detection for token type %s with JTI", tc.tokenType)
} else if !strings.Contains(err.Error(), "token replay detected") {
t.Errorf("Expected 'token replay detected' error for token type %s, got: %v", tc.tokenType, err)
}
}
})
}
}
+130 -15
View File
@@ -194,6 +194,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
idTokenChunks: make(map[int]*sessions.Session),
refreshMutex: sync.Mutex{}, // Initialize the mutex
sessionMutex: sync.RWMutex{}, // Initialize the session mutex
dirty: false, // Initialize dirty flag
@@ -280,10 +281,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
for k := range sessionData.refreshTokenChunks {
delete(sessionData.refreshTokenChunks, k)
}
for k := range sessionData.idTokenChunks {
delete(sessionData.idTokenChunks, k)
}
// Retrieve chunked token sessions.
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
sm.getTokenChunkSessions(r, mainCookieName, sessionData.idTokenChunks)
return sessionData, nil
}
@@ -335,6 +340,10 @@ type SessionData struct {
// when it exceeds the maximum cookie size.
refreshTokenChunks map[int]*sessions.Session
// idTokenChunks stores additional chunks of the ID token
// when it exceeds the maximum cookie size.
idTokenChunks map[int]*sessions.Session
// refreshMutex protects refresh token operations within this session instance.
refreshMutex sync.Mutex
@@ -420,6 +429,12 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
}
// Save ID token chunks.
for i, sessionChunk := range sd.idTokenChunks {
sessionChunk.Options = options
saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i))
}
if firstErr == nil {
sd.dirty = false // Reset dirty flag only if all saves were successful
}
@@ -467,6 +482,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear chunk sessions.
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
sd.clearTokenChunks(r, sd.idTokenChunks)
// Create a guaranteed error when the response writer is set
// This is primarily for testing - in production w will often be nil
@@ -648,6 +664,9 @@ func (sd *SessionData) Reset() {
for k := range sd.refreshTokenChunks {
delete(sd.refreshTokenChunks, k)
}
for k := range sd.idTokenChunks {
delete(sd.idTokenChunks, k)
}
// Reset state flags
sd.dirty = false
@@ -926,6 +945,30 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
}
}
// expireIDTokenChunks finds all existing ID token chunk cookies (_oidc_raczylo_N)
// associated with the current request, clears their values, and sets their MaxAge to -1.
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
// the expiring Set-Cookie headers. This is used internally when setting a new ID token.
//
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireIDTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", mainCookieName, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired ID token cookie: %v", err)
}
}
}
}
// splitIntoChunks divides a string `s` into a slice of strings, where each element
// has a maximum length of `chunkSize`.
//
@@ -1077,6 +1120,14 @@ func (sd *SessionData) SetIncomingPath(path string) {
// Returns:
// - The complete, decompressed ID token string, or an empty string if not found.
func (sd *SessionData) GetIDToken() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
return sd.getIDTokenUnsafe()
}
// getIDTokenUnsafe is the internal implementation without mutex protection
func (sd *SessionData) getIDTokenUnsafe() string {
token, _ := sd.mainSession.Values["id_token"].(string)
if token != "" {
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
@@ -1085,33 +1136,97 @@ func (sd *SessionData) GetIDToken() string {
}
return token
}
return ""
// Reassemble token from chunks.
if len(sd.idTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.idTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["id_token_chunk"].(string)
chunks = append(chunks, chunk)
}
token = strings.Join(chunks, "")
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// SetIDToken stores the provided ID token in the session.
// It first expires any existing ID token chunk cookies.
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
// it's stored directly in the primary main session. Otherwise, the compressed token
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_0, _oidc_raczylo_1, etc.).
//
// Parameters:
// - token: The ID token string to store.
func (sd *SessionData) SetIDToken(token string) {
currentIDToken := sd.GetIDToken() // Gets fully reassembled, decompressed token
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentIDToken := sd.getIDTokenUnsafe()
if currentIDToken == token {
// This handles cases where token is "" and currentIDToken is also "", no change.
// Or token is "abc" and currentIDToken is "abc", no change.
// If token is empty, and current is also empty, it's not a change.
// This check handles both empty and non-empty identical cases.
return
}
sd.dirty = true
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireIDTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token.
sd.idTokenChunks = make(map[int]*sessions.Session)
if token == "" { // Clearing the token
// STABILITY FIX: Add nil checks before accessing session values
if sd.mainSession != nil {
sd.mainSession.Values["id_token"] = ""
sd.mainSession.Values["id_token_compressed"] = false
}
// sd.idTokenChunks is already cleared
return
}
sd.dirty = true // Mark as dirty because a change is being made
if token == "" {
sd.mainSession.Values["id_token"] = ""
sd.mainSession.Values["id_token_compressed"] = false
return
}
// Compress token
// Compress token.
compressed := compressToken(token)
sd.mainSession.Values["id_token"] = compressed
sd.mainSession.Values["id_token_compressed"] = true
if len(compressed) <= maxCookieSize {
// STABILITY FIX: Add nil checks before accessing session values
if sd.mainSession != nil {
sd.mainSession.Values["id_token"] = compressed
sd.mainSession.Values["id_token_compressed"] = true
}
} else {
// Split compressed token into chunks.
if sd.mainSession != nil {
sd.mainSession.Values["id_token"] = "" // Main cookie won't hold the token directly
sd.mainSession.Values["id_token_compressed"] = true // Data in chunks is compressed
}
chunks := splitIntoChunks(compressed, maxCookieSize)
for i, chunkData := range chunks {
sessionName := fmt.Sprintf("%s_%d", mainCookieName, i)
// Ensure sd.request is available, otherwise log warning or handle error
if sd.request == nil {
sd.manager.logger.Infof("SetIDToken: sd.request is nil, cannot get/create chunk session %s", sessionName)
// Potentially skip this chunk or error out, depending on desired robustness
continue
}
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["id_token_chunk"] = chunkData
sd.idTokenChunks[i] = session
}
}
}
// GetRedirectCount retrieves the current redirect count from the session.
+164
View File
@@ -1,6 +1,9 @@
package traefikoidc
import (
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
@@ -218,4 +221,165 @@ func TestSessionObjectTracking(t *testing.T) {
t.Log("Session pool handling verified")
}
// TestLargeIDTokenChunking tests that large ID tokens are properly chunked across multiple cookies
func TestLargeIDTokenChunking(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create a large ID token (>4KB) to force chunking
largeIDToken := createLargeIDToken(20000) // 20KB token to ensure chunking after compression
t.Logf("Created large ID token with length: %d", len(largeIDToken))
// Create a request and response recorder
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rr := httptest.NewRecorder()
// Get session and set large ID token
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set the large ID token
session.SetIDToken(largeIDToken)
t.Logf("Set large ID token in session")
// Let's check what the GetIDToken returns to confirm it's set
retrievedToken := session.GetIDToken()
t.Logf("Retrieved ID token length: %d", len(retrievedToken))
if len(retrievedToken) != len(largeIDToken) {
t.Errorf("Token length mismatch: expected %d, got %d", len(largeIDToken), len(retrievedToken))
}
// Let's check what's in the main session directly
if idToken, ok := session.mainSession.Values["id_token"].(string); ok {
t.Logf("Main session id_token length: %d", len(idToken))
if compressed, ok := session.mainSession.Values["id_token_compressed"].(bool); ok {
t.Logf("Main session id_token_compressed: %v", compressed)
}
} else {
t.Logf("Main session id_token not found or not a string")
}
// Save the session to trigger chunking
err = session.Save(req, rr)
if err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify that chunked cookies were created
cookies := rr.Result().Cookies()
t.Logf("Total cookies in response: %d", len(cookies))
for _, cookie := range cookies {
valuePreview := cookie.Value
if len(valuePreview) > 50 {
valuePreview = valuePreview[:50] + "..."
}
t.Logf("Cookie: %s = %s (len=%d)", cookie.Name, valuePreview, len(cookie.Value))
}
var mainCookie *http.Cookie
var chunkCookies []*http.Cookie
for _, cookie := range cookies {
if cookie.Name == mainCookieName {
mainCookie = cookie
} else if strings.HasPrefix(cookie.Name, mainCookieName+"_") {
chunkCookies = append(chunkCookies, cookie)
}
}
// Verify main cookie exists
if mainCookie == nil {
t.Fatal("Main cookie not found in response")
}
// Verify chunk cookies exist (should be at least 2 for a 5KB token)
if len(chunkCookies) < 2 {
t.Fatalf("Expected at least 2 chunk cookies, got %d", len(chunkCookies))
}
// Verify chunk cookie naming convention
expectedChunkNames := make(map[string]bool)
for i := 0; i < len(chunkCookies); i++ {
expectedChunkNames[mainCookieName+"_"+fmt.Sprintf("%d", i)] = true
}
for _, cookie := range chunkCookies {
if !expectedChunkNames[cookie.Name] {
t.Errorf("Unexpected chunk cookie name: %s", cookie.Name)
}
}
// Test token retrieval from chunked cookies
// Create a new request with all the cookies
newReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get session and retrieve the ID token
retrievedSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get session from chunked cookies: %v", err)
}
retrievedToken2 := retrievedSession.GetIDToken()
// Verify the retrieved token matches the original
if retrievedToken2 != largeIDToken {
t.Errorf("Retrieved ID token doesn't match original. Expected length: %d, got: %d", len(largeIDToken), len(retrievedToken2))
}
// Test clearing the ID token removes all chunks
retrievedSession.SetIDToken("")
clearRR := httptest.NewRecorder()
err = retrievedSession.Save(newReq, clearRR)
if err != nil {
t.Fatalf("Failed to save session after clearing ID token: %v", err)
}
// Verify chunks are expired (MaxAge = -1)
clearCookies := clearRR.Result().Cookies()
for _, cookie := range clearCookies {
if strings.HasPrefix(cookie.Name, mainCookieName+"_") {
if cookie.MaxAge != -1 {
t.Errorf("Expected chunk cookie %s to be expired (MaxAge=-1), got MaxAge=%d", cookie.Name, cookie.MaxAge)
}
}
}
}
// createLargeIDToken creates a JWT-like token of specified size for testing
func createLargeIDToken(size int) string {
// Create truly random data that won't compress well
randomBytes := make([]byte, size*3/4) // base64 encoding increases size by ~4/3
_, err := rand.Read(randomBytes)
if err != nil {
// Fallback to pseudo-random if crypto/rand fails
for i := range randomBytes {
randomBytes[i] = byte(i % 256)
}
}
// Base64 encode the random data to make it look like a JWT
encoded := base64.StdEncoding.EncodeToString(randomBytes)
// Create JWT-like structure with truly random data
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
// Truncate or pad to desired size
if len(encoded) > size-len(header)-100 {
encoded = encoded[:size-len(header)-100]
}
signature := "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
return header + "." + encoded + "." + signature
}
// This is intentionally left empty to remove unused code