mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
3 Commits
v0.6.0
...
v0.7.0-beta4
| Author | SHA1 | Date | |
|---|---|---|---|
| 784b161732 | |||
| efa0cd708b | |||
| 99881f5837 |
+18
-13
@@ -123,19 +123,24 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
data.Set("refresh_token", codeOrToken)
|
||||
}
|
||||
|
||||
// Create a cookie jar for this request to handle redirects with cookies
|
||||
jar, _ := cookiejar.New(nil)
|
||||
client := &http.Client{
|
||||
Transport: t.httpClient.Transport,
|
||||
Timeout: t.httpClient.Timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
// Use the reusable token HTTP client, fallback to creating one if not initialized
|
||||
client := t.tokenHTTPClient
|
||||
if client == nil {
|
||||
// Fallback for tests or incomplete initialization - create a temporary client
|
||||
// with the same behavior as the original implementation
|
||||
jar, _ := cookiejar.New(nil)
|
||||
client = &http.Client{
|
||||
Transport: t.httpClient.Transport,
|
||||
Timeout: t.httpClient.Timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
|
||||
@@ -17,26 +17,14 @@ import (
|
||||
|
||||
var (
|
||||
replayCacheMu sync.Mutex
|
||||
replayCache = make(map[string]time.Time)
|
||||
replayCache *Cache // Replace unbounded map with bounded Cache
|
||||
)
|
||||
|
||||
// cleanupReplayCache iterates through the replay cache and removes entries
|
||||
// whose expiration time is before the current time. This function should be
|
||||
// called periodically to prevent the cache from growing indefinitely.
|
||||
// It acquires a mutex to ensure thread safety during cleanup.
|
||||
// SECURITY FIX: Add proper locking protection for cleanupReplayCache
|
||||
func cleanupReplayCache() {
|
||||
now := time.Now()
|
||||
// SECURITY FIX: Use safe iteration with proper locking
|
||||
toDelete := make([]string, 0)
|
||||
for token, expiry := range replayCache {
|
||||
if expiry.Before(now) {
|
||||
toDelete = append(toDelete, token)
|
||||
}
|
||||
}
|
||||
// Delete expired entries
|
||||
for _, token := range toDelete {
|
||||
delete(replayCache, token)
|
||||
// initReplayCache initializes the global replay cache with size limit
|
||||
func initReplayCache() {
|
||||
if replayCache == nil {
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000) // Set size limit to 10,000 entries
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,11 +123,12 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
// Parameters:
|
||||
// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com").
|
||||
// - clientID: The expected audience value (the client ID of this application).
|
||||
// - skipReplayCheck: If true, skips JTI replay detection (used for revalidation of cached tokens).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if all standard claims are valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error {
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
if !ok {
|
||||
@@ -195,7 +184,10 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
}
|
||||
|
||||
// Implement replay protection by checking the jti (JWT ID)
|
||||
if jti, ok := claims["jti"].(string); ok {
|
||||
// Skip replay check if explicitly requested (for revalidation scenarios)
|
||||
shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0]
|
||||
|
||||
if jti, ok := claims["jti"].(string); ok && !shouldSkipReplay {
|
||||
// Skip replay detection for tokens that are being verified from the cache
|
||||
if j.Token == "" {
|
||||
// This is a parsed JWT without the original token string,
|
||||
@@ -203,15 +195,15 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SECURITY FIX: Implement thread-safe replay cache operations with proper locking
|
||||
// SECURITY FIX: Use bounded Cache with thread-safe operations
|
||||
replayCacheMu.Lock()
|
||||
defer replayCacheMu.Unlock() // Ensure unlock happens even if panic occurs
|
||||
defer replayCacheMu.Unlock()
|
||||
|
||||
// SECURITY FIX: Clean up expired entries safely
|
||||
cleanupReplayCache()
|
||||
// Initialize cache if not already done
|
||||
initReplayCache()
|
||||
|
||||
// SECURITY FIX: Check for replay attack with atomic operation
|
||||
if _, exists := replayCache[jti]; exists {
|
||||
// SECURITY FIX: Check for replay attack using Cache API
|
||||
if _, exists := replayCache.Get(jti); exists {
|
||||
return fmt.Errorf("token replay detected")
|
||||
}
|
||||
|
||||
@@ -224,8 +216,11 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
expTime = time.Now().Add(10 * time.Minute)
|
||||
}
|
||||
|
||||
// SECURITY FIX: Add to replay cache atomically
|
||||
replayCache[jti] = expTime
|
||||
// SECURITY FIX: Add to replay cache with expiration using Cache API
|
||||
duration := time.Until(expTime)
|
||||
if duration > 0 {
|
||||
replayCache.Set(jti, true, duration)
|
||||
}
|
||||
}
|
||||
|
||||
sub, ok := claims["sub"].(string)
|
||||
|
||||
@@ -9,9 +9,11 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
@@ -59,6 +61,33 @@ func createDefaultHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// createTokenHTTPClient creates a specialized HTTP client for token operations.
|
||||
// It reuses the transport from the main HTTP client but adds cookie jar support
|
||||
// and optimized redirect handling for OIDC token endpoints.
|
||||
//
|
||||
// Parameters:
|
||||
// - baseClient: The base HTTP client to derive transport settings from.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the configured http.Client optimized for token operations.
|
||||
func createTokenHTTPClient(baseClient *http.Client) *http.Client {
|
||||
// Create a cookie jar for handling redirects with cookies
|
||||
jar, _ := cookiejar.New(nil)
|
||||
|
||||
return &http.Client{
|
||||
Transport: baseClient.Transport, // Reuse the transport from base client
|
||||
Timeout: baseClient.Timeout, // Reuse the timeout from base client
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar, // Add cookie jar for redirect handling
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ConstSessionTimeout = 86400 // Session timeout in seconds
|
||||
defaultBlacklistDuration = 24 * time.Hour // Default duration to blacklist a JTI
|
||||
@@ -105,6 +134,7 @@ type TraefikOidc struct {
|
||||
scheme string
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
tokenHTTPClient *http.Client // Reusable HTTP client for token operations
|
||||
logger *Logger
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
@@ -124,6 +154,7 @@ type TraefikOidc struct {
|
||||
headerTemplates map[string]*template.Template // Parsed templates for custom headers
|
||||
tokenCleanupStopChan chan struct{} // Channel to stop token cleanup goroutine
|
||||
metadataRefreshStopChan chan struct{} // Channel to stop metadata refresh goroutine
|
||||
goroutineWG sync.WaitGroup // WaitGroup to track background goroutines
|
||||
}
|
||||
|
||||
// ProviderMetadata holds OIDC provider metadata
|
||||
@@ -243,7 +274,14 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
|
||||
// Also update the global replayCache for backwards compatibility
|
||||
replayCacheMu.Lock()
|
||||
replayCache[jti] = expiry
|
||||
// Initialize cache if not already done
|
||||
if replayCache == nil {
|
||||
initReplayCache()
|
||||
}
|
||||
duration := time.Until(expiry)
|
||||
if duration > 0 {
|
||||
replayCache.Set(jti, true, duration)
|
||||
}
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -325,8 +363,8 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Verify standard claims
|
||||
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
|
||||
// Verify standard claims - skip replay check since it's already handled in VerifyToken
|
||||
if err := jwt.Verify(t.issuerURL, t.clientID, true); err != nil {
|
||||
return fmt.Errorf("standard claim verification failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -416,6 +454,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
httpClient: httpClient,
|
||||
tokenHTTPClient: createTokenHTTPClient(httpClient),
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
|
||||
@@ -526,30 +565,34 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
// - providerURL: The base URL of bogged OIDC provider, used for subsequent refresh attempts.
|
||||
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
// No defer ticker.Stop() here, it's stopped in the select case
|
||||
t.goroutineWG.Add(1) // Track this goroutine
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.logger.Debug("Refreshing OIDC metadata")
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to refresh metadata: %v", err)
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
defer t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
defer ticker.Stop() // Ensure ticker is always stopped
|
||||
|
||||
if metadata != nil {
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
t.logger.Debug("Successfully refreshed metadata")
|
||||
} else {
|
||||
t.logger.Error("Received nil metadata during refresh")
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.logger.Debug("Refreshing OIDC metadata")
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to refresh metadata: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
t.logger.Debug("Successfully refreshed metadata")
|
||||
} else {
|
||||
t.logger.Error("Received nil metadata during refresh")
|
||||
}
|
||||
case <-t.metadataRefreshStopChan:
|
||||
t.logger.Debug("Metadata refresh goroutine stopped.")
|
||||
return
|
||||
}
|
||||
case <-t.metadataRefreshStopChan:
|
||||
ticker.Stop()
|
||||
t.logger.Debug("Metadata refresh goroutine stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// discoverProviderMetadata attempts to fetch the OIDC provider's configuration from its
|
||||
@@ -1720,8 +1763,11 @@ func (t *TraefikOidc) validateHost(host string) error {
|
||||
// the token cache, token blacklist cache, and JWK cache.
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
|
||||
t.goroutineWG.Add(1) // Track this goroutine
|
||||
go func() {
|
||||
// No defer ticker.Stop() here, it's stopped in the select case
|
||||
defer t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
defer ticker.Stop() // Ensure ticker is always stopped
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
@@ -1741,7 +1787,6 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
t.jwkCache.Cleanup()
|
||||
}
|
||||
case <-t.tokenCleanupStopChan:
|
||||
ticker.Stop()
|
||||
t.logger.Debug("Token cleanup goroutine stopped.")
|
||||
return
|
||||
}
|
||||
@@ -2231,20 +2276,36 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
|
||||
_, _ = rw.Write([]byte(htmlBody)) // Ignore write error as header is already sent
|
||||
}
|
||||
|
||||
// Close stops all background goroutines and closes resources.
|
||||
// Close stops all background goroutines and closes resources with proper timeout.
|
||||
func (t *TraefikOidc) Close() error {
|
||||
t.logger.Debug("Closing TraefikOidc plugin instance")
|
||||
// Signal and close tokenCleanupStopChan
|
||||
|
||||
// Signal all goroutines to stop
|
||||
if t.tokenCleanupStopChan != nil {
|
||||
close(t.tokenCleanupStopChan)
|
||||
t.logger.Debug("tokenCleanupStopChan closed")
|
||||
}
|
||||
// Signal and close metadataRefreshStopChan
|
||||
if t.metadataRefreshStopChan != nil {
|
||||
close(t.metadataRefreshStopChan)
|
||||
t.logger.Debug("metadataRefreshStopChan closed")
|
||||
}
|
||||
|
||||
// Wait for all goroutines to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.goroutineWG.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Wait for goroutines to finish or timeout after 10 seconds
|
||||
select {
|
||||
case <-done:
|
||||
t.logger.Debug("All background goroutines stopped gracefully")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.logger.Errorf("Timeout waiting for background goroutines to stop")
|
||||
// Continue with cleanup even if goroutines didn't stop gracefully
|
||||
}
|
||||
|
||||
// Close caches
|
||||
// These Close methods should stop their respective autoCleanupRoutine goroutines
|
||||
if t.tokenBlacklist != nil {
|
||||
|
||||
+764
-3
@@ -722,7 +722,8 @@ func TestServeHTTP(t *testing.T) {
|
||||
|
||||
// Reset the global replayCache to prevent "token replay detected" errors
|
||||
replayCacheMu.Lock()
|
||||
replayCache = make(map[string]time.Time) // Reset the global cache
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Store original tokenVerifier to restore later
|
||||
@@ -734,7 +735,8 @@ func TestServeHTTP(t *testing.T) {
|
||||
VerifyFunc: func(token string) error {
|
||||
// Clear replay cache before token verification
|
||||
replayCacheMu.Lock()
|
||||
replayCache = make(map[string]time.Time)
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Call the original verifier's VerifyToken method
|
||||
@@ -1143,7 +1145,8 @@ func TestHandleCallback(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Clear the global replay cache before each test run
|
||||
replayCacheMu.Lock()
|
||||
replayCache = make(map[string]time.Time) // Reset the global cache
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Explicitly clear the shared blacklist at the start of each sub-test
|
||||
@@ -2803,3 +2806,761 @@ func TestVerifyTimeConstraint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
} // Add missing closing brace for TestVerifyTimeConstraint
|
||||
|
||||
// ===== JWT REPLAY DETECTION TESTS =====
|
||||
// These tests ensure the replay detection fix works correctly and prevents regressions
|
||||
|
||||
// TestJWTVerifyWithSkipReplayCheck tests the new skipReplayCheck parameter functionality
|
||||
func TestJWTVerifyWithSkipReplayCheck(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache before test
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create a test JWT with unique JTI
|
||||
jti := generateRandomString(16)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
skipReplayCheck bool
|
||||
firstCall bool
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "First verification with skipReplayCheck=false should succeed",
|
||||
skipReplayCheck: false,
|
||||
firstCall: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Second verification with skipReplayCheck=false should fail (replay detected)",
|
||||
skipReplayCheck: false,
|
||||
firstCall: false,
|
||||
expectError: true,
|
||||
errorContains: "token replay detected",
|
||||
},
|
||||
{
|
||||
name: "Verification with skipReplayCheck=true should always succeed",
|
||||
skipReplayCheck: true,
|
||||
firstCall: false, // Even on subsequent calls
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.firstCall {
|
||||
// Clear replay cache for first call tests
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
|
||||
err := jwt.Verify("https://test-issuer.com", "test-client-id", tc.skipReplayCheck)
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', but got nil", tc.errorContains)
|
||||
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
||||
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTVerifyBackwardCompatibility tests that calls without the skipReplayCheck parameter default to replay checking
|
||||
func TestJWTVerifyBackwardCompatibility(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create a test JWT with unique JTI
|
||||
jti := generateRandomString(16)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT: %v", err)
|
||||
}
|
||||
|
||||
// First call with old signature (no skipReplayCheck parameter) should succeed
|
||||
err = jwt.Verify("https://test-issuer.com", "test-client-id")
|
||||
if err != nil {
|
||||
t.Errorf("First verification should succeed, got: %v", err)
|
||||
}
|
||||
|
||||
// Second call with old signature should fail due to replay detection
|
||||
err = jwt.Verify("https://test-issuer.com", "test-client-id")
|
||||
if err == nil {
|
||||
t.Error("Second verification should fail due to replay detection")
|
||||
} else if !strings.Contains(err.Error(), "token replay detected") {
|
||||
t.Errorf("Expected 'token replay detected' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenReplayDetectionFalsePositiveFix tests the specific scenario that was causing false positives
|
||||
func TestTokenReplayDetectionFalsePositiveFix(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create a test JWT with unique JTI
|
||||
jti := generateRandomString(16)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Simulate the authentication flow that was causing false positives:
|
||||
// 1. Initial authentication adds JTI to cache
|
||||
// 2. Subsequent request validation should not trigger false positive
|
||||
|
||||
// Step 1: Initial authentication (this would add JTI to cache)
|
||||
jwt1, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT for initial auth: %v", err)
|
||||
}
|
||||
|
||||
err = jwt1.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
|
||||
if err != nil {
|
||||
t.Fatalf("Initial authentication should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Step 2: Subsequent request validation (this should skip replay check to avoid false positive)
|
||||
jwt2, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT for subsequent request: %v", err)
|
||||
}
|
||||
|
||||
err = jwt2.Verify("https://test-issuer.com", "test-client-id", true) // Skip replay check
|
||||
if err != nil {
|
||||
t.Errorf("Subsequent request validation should succeed with skipReplayCheck=true: %v", err)
|
||||
}
|
||||
|
||||
// Step 3: Verify that actual replay attacks are still detected
|
||||
jwt3, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
|
||||
}
|
||||
|
||||
err = jwt3.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
|
||||
if err == nil {
|
||||
t.Error("Actual replay attack should be detected when skipReplayCheck=false")
|
||||
} else if !strings.Contains(err.Error(), "token replay detected") {
|
||||
t.Errorf("Expected 'token replay detected' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthenticationFlowReplayDetection tests the complete authentication flow
|
||||
func TestAuthenticationFlowReplayDetection(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create a test JWT with unique JTI
|
||||
jti := generateRandomString(16)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Test the complete flow:
|
||||
// 1. Initial authentication (should add JTI to cache)
|
||||
// 2. Multiple subsequent requests (should not trigger false positives)
|
||||
// 3. Actual replay attack from different source (should be detected)
|
||||
|
||||
// Step 1: Initial authentication
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Initial authentication should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Verify JTI is in cache
|
||||
replayCacheMu.Lock()
|
||||
_, exists := replayCache.Get(jti)
|
||||
replayCacheMu.Unlock()
|
||||
if !exists {
|
||||
t.Error("JTI should be added to replay cache during initial authentication")
|
||||
}
|
||||
|
||||
// Step 2: Subsequent requests (simulate normal request processing)
|
||||
// These should use the token cache and skip replay detection
|
||||
for i := 0; i < 3; i++ {
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Errorf("Subsequent request %d should succeed: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Simulate actual replay attack by directly calling JWT.Verify with replay check
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
|
||||
}
|
||||
|
||||
err = jwt.Verify("https://test-issuer.com", "test-client-id", false) // Force replay check
|
||||
if err == nil {
|
||||
t.Error("Actual replay attack should be detected")
|
||||
} else if !strings.Contains(err.Error(), "token replay detected") {
|
||||
t.Errorf("Expected 'token replay detected' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestActualReplayAttackDetection ensures real replay attacks are still properly detected
|
||||
func TestActualReplayAttackDetection(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create a test JWT with unique JTI
|
||||
jti := generateRandomString(16)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT: %v", err)
|
||||
}
|
||||
|
||||
// First verification should succeed
|
||||
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
|
||||
if err != nil {
|
||||
t.Fatalf("First verification should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Simulate different types of replay attacks
|
||||
replayTests := []struct {
|
||||
name string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Direct replay attack",
|
||||
description: "Same token used again with replay checking enabled",
|
||||
},
|
||||
{
|
||||
name: "Replay from different source",
|
||||
description: "Token intercepted and replayed by attacker",
|
||||
},
|
||||
}
|
||||
|
||||
for _, rt := range replayTests {
|
||||
t.Run(rt.name, func(t *testing.T) {
|
||||
// Parse token again (simulating replay)
|
||||
replayJWT, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT for replay test: %v", err)
|
||||
}
|
||||
|
||||
// Attempt replay with normal replay checking
|
||||
err = replayJWT.Verify("https://test-issuer.com", "test-client-id", false)
|
||||
if err == nil {
|
||||
t.Errorf("Replay attack should be detected for: %s", rt.description)
|
||||
} else if !strings.Contains(err.Error(), "token replay detected") {
|
||||
t.Errorf("Expected 'token replay detected' error for %s, got: %v", rt.description, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentTokenValidation tests thread safety of replay detection
|
||||
func TestConcurrentTokenValidation(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Configure rate limiter to allow more requests for concurrent testing
|
||||
ts.tOidc.limiter = rate.NewLimiter(rate.Limit(1000), 1000) // Allow 1000 requests per second with burst of 1000
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create multiple tokens with unique JTIs
|
||||
var tokens []string
|
||||
var jtis []string
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
jti := generateRandomString(16)
|
||||
jtis = append(jtis, jti)
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT %d: %v", i, err)
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
// Test concurrent validation
|
||||
const numGoroutines = 20
|
||||
const numIterations = 5
|
||||
|
||||
results := make(chan error, numGoroutines*numIterations)
|
||||
|
||||
for g := 0; g < numGoroutines; g++ {
|
||||
go func(goroutineID int) {
|
||||
for i := 0; i < numIterations; i++ {
|
||||
tokenIndex := (goroutineID + i) % len(tokens)
|
||||
token := tokens[tokenIndex]
|
||||
|
||||
// First validation should succeed
|
||||
err := ts.tOidc.VerifyToken(token)
|
||||
results <- err
|
||||
|
||||
// Subsequent validation with same token should also succeed (uses cache)
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
results <- err
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
|
||||
// Collect results
|
||||
var errors []error
|
||||
for i := 0; i < numGoroutines*numIterations*2; i++ {
|
||||
if err := <-results; err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
// All validations should succeed (no race conditions)
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("Expected no errors in concurrent validation, got %d errors: %v", len(errors), errors)
|
||||
}
|
||||
|
||||
// Verify all JTIs are in cache
|
||||
replayCacheMu.Lock()
|
||||
for i, jti := range jtis {
|
||||
if _, exists := replayCache.Get(jti); !exists {
|
||||
t.Errorf("JTI %d (%s) should be in replay cache", i, jti)
|
||||
}
|
||||
}
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// TestJTIBlacklistBehavior tests the JTI blacklist cache management
|
||||
func TestJTIBlacklistBehavior(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create a test JWT with unique JTI
|
||||
jti := generateRandomString(16)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Test JTI blacklist behavior
|
||||
tests := []struct {
|
||||
name string
|
||||
action func() error
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Initial verification adds JTI to blacklist",
|
||||
action: func() error {
|
||||
return ts.tOidc.VerifyToken(token)
|
||||
},
|
||||
expectError: false,
|
||||
description: "First verification should succeed and add JTI to blacklist",
|
||||
},
|
||||
{
|
||||
name: "JTI exists in blacklist after verification",
|
||||
action: func() error {
|
||||
replayCacheMu.Lock()
|
||||
defer replayCacheMu.Unlock()
|
||||
if _, exists := replayCache.Get(jti); !exists {
|
||||
return fmt.Errorf("JTI not found in blacklist cache")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
expectError: false,
|
||||
description: "JTI should be present in blacklist cache",
|
||||
},
|
||||
{
|
||||
name: "Subsequent verification uses cache (no replay check)",
|
||||
action: func() error {
|
||||
return ts.tOidc.VerifyToken(token)
|
||||
},
|
||||
expectError: false,
|
||||
description: "Subsequent verification should succeed using token cache",
|
||||
},
|
||||
{
|
||||
name: "Direct JWT verification detects replay",
|
||||
action: func() error {
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return jwt.Verify("https://test-issuer.com", "test-client-id", false)
|
||||
},
|
||||
expectError: true,
|
||||
description: "Direct JWT verification should detect replay",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.action()
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for %s, but got nil", tc.description)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for %s, but got: %v", tc.description, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionBasedTokenRevalidation tests token revalidation in session-based scenarios
|
||||
func TestSessionBasedTokenRevalidation(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Create a test JWT with unique JTI
|
||||
jti := generateRandomString(16)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Simulate session-based token revalidation scenario
|
||||
// This tests the specific case that was causing false positives
|
||||
|
||||
// Step 1: Initial authentication (callback processing)
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Initial authentication should succeed: %v", err)
|
||||
}
|
||||
|
||||
// Step 2: Multiple session-based requests (normal request processing)
|
||||
// These should not trigger replay detection false positives
|
||||
for i := 0; i < 5; i++ {
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Errorf("Session request %d should succeed: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Verify token is in both caches appropriately
|
||||
// Check token cache
|
||||
if _, exists := ts.tOidc.tokenCache.Get(token); !exists {
|
||||
t.Error("Token should be in token cache")
|
||||
}
|
||||
|
||||
// Check replay cache
|
||||
replayCacheMu.Lock()
|
||||
_, inReplayCache := replayCache.Get(jti)
|
||||
replayCacheMu.Unlock()
|
||||
if !inReplayCache {
|
||||
t.Error("JTI should be in replay cache")
|
||||
}
|
||||
|
||||
// Step 4: Verify that clearing token cache still allows validation
|
||||
ts.tOidc.tokenCache = NewTokenCache() // Clear token cache
|
||||
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Errorf("Token validation should succeed even after cache clear: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEdgeCasesWithDifferentTokenTypes tests replay detection with different token types
|
||||
func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Clear the global replay cache
|
||||
replayCacheMu.Lock()
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenType string
|
||||
claims map[string]interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "ID Token with JTI",
|
||||
tokenType: "id_token",
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateRandomString(16),
|
||||
"token_type": "id_token",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Access Token with JTI",
|
||||
tokenType: "access_token",
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"scope": "openid profile email",
|
||||
"jti": generateRandomString(16),
|
||||
"token_type": "access_token",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Token without JTI",
|
||||
tokenType: "no_jti",
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
// No JTI claim
|
||||
},
|
||||
expectError: false, // Should still work, just no replay protection
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with specific claims
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// First verification should succeed
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for token type %s, but got nil", tc.tokenType)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for token type %s, but got: %v", tc.tokenType, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Second verification should also succeed (uses cache)
|
||||
if !tc.expectError {
|
||||
err = ts.tOidc.VerifyToken(token)
|
||||
if err != nil {
|
||||
t.Errorf("Second verification should succeed for token type %s: %v", tc.tokenType, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test direct JWT verification for replay detection
|
||||
if !tc.expectError && tc.claims["jti"] != nil {
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse JWT: %v", err)
|
||||
}
|
||||
|
||||
// This should detect replay for tokens with JTI
|
||||
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
|
||||
if err == nil {
|
||||
t.Errorf("Expected replay detection for token type %s with JTI", tc.tokenType)
|
||||
} else if !strings.Contains(err.Error(), "token replay detected") {
|
||||
t.Errorf("Expected 'token replay detected' error for token type %s, got: %v", tc.tokenType, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,10 @@ type PerformanceMetrics struct {
|
||||
// Resource metrics
|
||||
memoryUsage int64
|
||||
goroutineCount int64
|
||||
memoryPressure int64 // Memory pressure level (0-100)
|
||||
gcPauseTime int64 // Last GC pause time in nanoseconds
|
||||
heapSize int64 // Current heap size
|
||||
heapInUse int64 // Heap memory in use
|
||||
|
||||
// Error metrics (kept for backward compatibility)
|
||||
verificationErrors int64
|
||||
@@ -251,9 +255,36 @@ func (pm *PerformanceMetrics) collectSystemMetrics() {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
atomic.StoreInt64(&pm.memoryUsage, int64(m.Alloc))
|
||||
atomic.StoreInt64(&pm.heapSize, int64(m.HeapSys))
|
||||
atomic.StoreInt64(&pm.heapInUse, int64(m.HeapInuse))
|
||||
atomic.StoreInt64(&pm.gcPauseTime, int64(m.PauseNs[(m.NumGC+255)%256]))
|
||||
|
||||
// Calculate memory pressure (0-100 scale)
|
||||
// Based on heap utilization and GC frequency
|
||||
heapUtilization := float64(m.HeapInuse) / float64(m.HeapSys)
|
||||
gcFrequency := float64(m.NumGC) / time.Since(pm.startTime).Minutes()
|
||||
|
||||
// Memory pressure calculation
|
||||
pressure := int64(heapUtilization * 50) // 0-50 based on heap utilization
|
||||
if gcFrequency > 10 { // High GC frequency indicates pressure
|
||||
pressure += int64((gcFrequency - 10) * 2) // Add up to 50 more
|
||||
}
|
||||
if pressure > 100 {
|
||||
pressure = 100
|
||||
}
|
||||
atomic.StoreInt64(&pm.memoryPressure, pressure)
|
||||
|
||||
// Goroutine count
|
||||
atomic.StoreInt64(&pm.goroutineCount, int64(runtime.NumGoroutine()))
|
||||
|
||||
// Log memory pressure warnings
|
||||
if pressure > 80 {
|
||||
pm.logger.Errorf("High memory pressure detected: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
|
||||
pressure, heapUtilization*100, gcFrequency)
|
||||
} else if pressure > 60 {
|
||||
pm.logger.Infof("Moderate memory pressure: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
|
||||
pressure, heapUtilization*100, gcFrequency)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns all current performance metrics
|
||||
@@ -317,6 +348,10 @@ func (pm *PerformanceMetrics) GetMetrics() map[string]interface{} {
|
||||
|
||||
// Resource metrics
|
||||
"memory_usage_bytes": atomic.LoadInt64(&pm.memoryUsage),
|
||||
"memory_pressure": atomic.LoadInt64(&pm.memoryPressure),
|
||||
"heap_size_bytes": atomic.LoadInt64(&pm.heapSize),
|
||||
"heap_inuse_bytes": atomic.LoadInt64(&pm.heapInUse),
|
||||
"gc_pause_time_ns": atomic.LoadInt64(&pm.gcPauseTime),
|
||||
"goroutine_count": atomic.LoadInt64(&pm.goroutineCount),
|
||||
|
||||
// Rate limiting metrics
|
||||
@@ -414,6 +449,10 @@ type ResourceMonitor struct {
|
||||
// Session limits
|
||||
maxSessions int64
|
||||
|
||||
// Cache size tracking
|
||||
cacheSizes map[string]int64
|
||||
cacheMutex sync.RWMutex
|
||||
|
||||
// Monitoring state
|
||||
alertThresholds map[string]float64
|
||||
alerts []ResourceAlert
|
||||
@@ -441,11 +480,13 @@ func NewResourceMonitor(perfMetrics *PerformanceMetrics, logger *Logger) *Resour
|
||||
maxMemoryBytes: 100 * 1024 * 1024, // 100MB default
|
||||
maxCacheSize: 10000, // 10k items default
|
||||
maxSessions: 1000, // 1k sessions default
|
||||
cacheSizes: make(map[string]int64),
|
||||
alertThresholds: map[string]float64{
|
||||
"memory_usage": 0.8, // 80%
|
||||
"cache_usage": 0.9, // 90%
|
||||
"session_usage": 0.85, // 85%
|
||||
"error_rate": 0.1, // 10%
|
||||
"memory_usage": 0.8, // 80%
|
||||
"memory_pressure": 0.7, // 70%
|
||||
"cache_usage": 0.9, // 90%
|
||||
"session_usage": 0.85, // 85%
|
||||
"error_rate": 0.1, // 10%
|
||||
},
|
||||
alerts: make([]ResourceAlert, 0),
|
||||
perfMetrics: perfMetrics,
|
||||
@@ -473,6 +514,25 @@ func (rm *ResourceMonitor) SetSessionLimit(count int64) {
|
||||
rm.maxSessions = count
|
||||
}
|
||||
|
||||
// UpdateCacheSize updates the size of a specific cache
|
||||
func (rm *ResourceMonitor) UpdateCacheSize(cacheName string, size int64) {
|
||||
rm.cacheMutex.Lock()
|
||||
defer rm.cacheMutex.Unlock()
|
||||
rm.cacheSizes[cacheName] = size
|
||||
}
|
||||
|
||||
// GetCacheSizes returns current cache sizes
|
||||
func (rm *ResourceMonitor) GetCacheSizes() map[string]int64 {
|
||||
rm.cacheMutex.RLock()
|
||||
defer rm.cacheMutex.RUnlock()
|
||||
|
||||
sizes := make(map[string]int64)
|
||||
for name, size := range rm.cacheSizes {
|
||||
sizes[name] = size
|
||||
}
|
||||
return sizes
|
||||
}
|
||||
|
||||
// startMonitoring starts the background monitoring routine
|
||||
func (rm *ResourceMonitor) startMonitoring() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
@@ -502,6 +562,21 @@ func (rm *ResourceMonitor) checkResourceUsage() {
|
||||
}
|
||||
}
|
||||
|
||||
// Check memory pressure
|
||||
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
|
||||
pressureRatio := float64(memPressure) / 100.0 // Convert to 0-1 scale
|
||||
if pressureRatio > rm.alertThresholds["memory_pressure"] {
|
||||
rm.addAlert(ResourceAlert{
|
||||
Type: "memory_pressure",
|
||||
Message: "Memory pressure exceeds threshold",
|
||||
Threshold: rm.alertThresholds["memory_pressure"],
|
||||
CurrentValue: pressureRatio,
|
||||
Timestamp: time.Now(),
|
||||
Severity: rm.getSeverity(pressureRatio, rm.alertThresholds["memory_pressure"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache usage
|
||||
if cacheSize, ok := metrics["cache_size"].(int64); ok {
|
||||
cacheUsageRatio := float64(cacheSize) / float64(rm.maxCacheSize)
|
||||
@@ -592,6 +667,7 @@ func (rm *ResourceMonitor) GetAlerts() []ResourceAlert {
|
||||
// GetResourceStatus returns current resource status
|
||||
func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
|
||||
metrics := rm.perfMetrics.GetMetrics()
|
||||
cacheSizes := rm.GetCacheSizes()
|
||||
|
||||
status := map[string]interface{}{
|
||||
"limits": map[string]interface{}{
|
||||
@@ -599,8 +675,9 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
|
||||
"max_cache_size": rm.maxCacheSize,
|
||||
"max_sessions": rm.maxSessions,
|
||||
},
|
||||
"thresholds": rm.alertThresholds,
|
||||
"current": metrics,
|
||||
"thresholds": rm.alertThresholds,
|
||||
"current": metrics,
|
||||
"cache_sizes": cacheSizes,
|
||||
// Add expected keys for tests
|
||||
"memory_limit": uint64(rm.maxMemoryBytes),
|
||||
"cache_limit": int(rm.maxCacheSize),
|
||||
@@ -611,6 +688,9 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
|
||||
if memUsage, ok := metrics["memory_usage_bytes"].(int64); ok {
|
||||
status["memory_usage_ratio"] = float64(memUsage) / float64(rm.maxMemoryBytes)
|
||||
}
|
||||
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
|
||||
status["memory_pressure_ratio"] = float64(memPressure) / 100.0
|
||||
}
|
||||
if cacheSize, ok := metrics["cache_size"].(int64); ok {
|
||||
status["cache_usage_ratio"] = float64(cacheSize) / float64(rm.maxCacheSize)
|
||||
}
|
||||
@@ -618,5 +698,12 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
|
||||
status["session_usage_ratio"] = float64(activeSessions) / float64(rm.maxSessions)
|
||||
}
|
||||
|
||||
// Calculate total cache size across all caches
|
||||
var totalCacheSize int64
|
||||
for _, size := range cacheSizes {
|
||||
totalCacheSize += size
|
||||
}
|
||||
status["total_cache_size"] = totalCacheSize
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
+188
-20
@@ -190,13 +190,19 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
|
||||
// Initialize session pool.
|
||||
sm.sessionPool.New = func() interface{} {
|
||||
// Initialize SessionData with necessary fields and the mutex.
|
||||
return &SessionData{
|
||||
sd := &SessionData{
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshMutex: sync.Mutex{}, // Initialize the mutex
|
||||
dirty: false, // Initialize dirty flag
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshMutex: sync.Mutex{}, // Initialize the mutex
|
||||
sessionMutex: sync.RWMutex{}, // Initialize the session mutex
|
||||
dirty: false, // Initialize dirty flag
|
||||
inUse: false, // Initialize in-use flag
|
||||
}
|
||||
// Ensure the object is properly reset when created
|
||||
sd.Reset()
|
||||
return sd
|
||||
}
|
||||
|
||||
return sm, nil
|
||||
@@ -275,10 +281,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
for k := range sessionData.refreshTokenChunks {
|
||||
delete(sessionData.refreshTokenChunks, k)
|
||||
}
|
||||
for k := range sessionData.idTokenChunks {
|
||||
delete(sessionData.idTokenChunks, k)
|
||||
}
|
||||
|
||||
// Retrieve chunked token sessions.
|
||||
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
|
||||
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
||||
sm.getTokenChunkSessions(r, mainCookieName, sessionData.idTokenChunks)
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
@@ -330,6 +340,10 @@ type SessionData struct {
|
||||
// when it exceeds the maximum cookie size.
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
|
||||
// idTokenChunks stores additional chunks of the ID token
|
||||
// when it exceeds the maximum cookie size.
|
||||
idTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshMutex protects refresh token operations within this session instance.
|
||||
refreshMutex sync.Mutex
|
||||
|
||||
@@ -415,6 +429,12 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
saveOrLogError(sessionChunk, fmt.Sprintf("refresh token chunk %d", i))
|
||||
}
|
||||
|
||||
// Save ID token chunks.
|
||||
for i, sessionChunk := range sd.idTokenChunks {
|
||||
sessionChunk.Options = options
|
||||
saveOrLogError(sessionChunk, fmt.Sprintf("ID token chunk %d", i))
|
||||
}
|
||||
|
||||
if firstErr == nil {
|
||||
sd.dirty = false // Reset dirty flag only if all saves were successful
|
||||
}
|
||||
@@ -462,6 +482,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear chunk sessions.
|
||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.idTokenChunks)
|
||||
|
||||
// Create a guaranteed error when the response writer is set
|
||||
// This is primarily for testing - in production w will often be nil
|
||||
@@ -482,6 +503,8 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// STABILITY FIX: Mark as not in use and return session to pool, regardless of error.
|
||||
// This ensures the session is always returned to the pool, preventing memory leaks.
|
||||
sd.inUse = false
|
||||
// Reset the session data before returning to pool to prevent data leakage
|
||||
sd.Reset()
|
||||
sd.manager.sessionPool.Put(sd)
|
||||
|
||||
// Return the error from Save, if any
|
||||
@@ -602,6 +625,55 @@ func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset clears all session data and prepares the SessionData object for reuse.
|
||||
// This method is called when returning objects to the pool to prevent data leakage
|
||||
// between different users/sessions.
|
||||
func (sd *SessionData) Reset() {
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
// Clear all session values if sessions exist
|
||||
if sd.mainSession != nil {
|
||||
for k := range sd.mainSession.Values {
|
||||
delete(sd.mainSession.Values, k)
|
||||
}
|
||||
sd.mainSession.ID = ""
|
||||
sd.mainSession.IsNew = true
|
||||
}
|
||||
|
||||
if sd.accessSession != nil {
|
||||
for k := range sd.accessSession.Values {
|
||||
delete(sd.accessSession.Values, k)
|
||||
}
|
||||
sd.accessSession.ID = ""
|
||||
sd.accessSession.IsNew = true
|
||||
}
|
||||
|
||||
if sd.refreshSession != nil {
|
||||
for k := range sd.refreshSession.Values {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
sd.refreshSession.ID = ""
|
||||
sd.refreshSession.IsNew = true
|
||||
}
|
||||
|
||||
// Clear chunk maps
|
||||
for k := range sd.accessTokenChunks {
|
||||
delete(sd.accessTokenChunks, k)
|
||||
}
|
||||
for k := range sd.refreshTokenChunks {
|
||||
delete(sd.refreshTokenChunks, k)
|
||||
}
|
||||
for k := range sd.idTokenChunks {
|
||||
delete(sd.idTokenChunks, k)
|
||||
}
|
||||
|
||||
// Reset state flags
|
||||
sd.dirty = false
|
||||
sd.inUse = false
|
||||
sd.request = nil
|
||||
}
|
||||
|
||||
// ReturnToPool explicitly returns this SessionData object to the pool.
|
||||
// This should be called when you're done with a SessionData in any error path
|
||||
// where Clear() is not called, to prevent memory leaks.
|
||||
@@ -609,8 +681,8 @@ func (sd *SessionData) ReturnToPool() {
|
||||
if sd != nil && sd.manager != nil {
|
||||
// STABILITY FIX: Only return to pool if not currently in use
|
||||
if !sd.inUse {
|
||||
// Clear request reference to avoid memory leaks
|
||||
sd.request = nil
|
||||
// Reset the session data before returning to pool
|
||||
sd.Reset()
|
||||
sd.manager.sessionPool.Put(sd)
|
||||
}
|
||||
}
|
||||
@@ -873,6 +945,30 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireIDTokenChunks finds all existing ID token chunk cookies (_oidc_raczylo_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new ID token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireIDTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", mainCookieName, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if w != nil {
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("failed to save expired ID token cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks divides a string `s` into a slice of strings, where each element
|
||||
// has a maximum length of `chunkSize`.
|
||||
//
|
||||
@@ -1024,6 +1120,14 @@ func (sd *SessionData) SetIncomingPath(path string) {
|
||||
// Returns:
|
||||
// - The complete, decompressed ID token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetIDToken() string {
|
||||
sd.sessionMutex.RLock()
|
||||
defer sd.sessionMutex.RUnlock()
|
||||
|
||||
return sd.getIDTokenUnsafe()
|
||||
}
|
||||
|
||||
// getIDTokenUnsafe is the internal implementation without mutex protection
|
||||
func (sd *SessionData) getIDTokenUnsafe() string {
|
||||
token, _ := sd.mainSession.Values["id_token"].(string)
|
||||
if token != "" {
|
||||
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
|
||||
@@ -1032,33 +1136,97 @@ func (sd *SessionData) GetIDToken() string {
|
||||
}
|
||||
return token
|
||||
}
|
||||
return ""
|
||||
|
||||
// Reassemble token from chunks.
|
||||
if len(sd.idTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.idTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["id_token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
token = strings.Join(chunks, "")
|
||||
compressed, _ := sd.mainSession.Values["id_token_compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// SetIDToken stores the provided ID token in the session.
|
||||
// It first expires any existing ID token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary main session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_0, _oidc_raczylo_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The ID token string to store.
|
||||
func (sd *SessionData) SetIDToken(token string) {
|
||||
currentIDToken := sd.GetIDToken() // Gets fully reassembled, decompressed token
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
currentIDToken := sd.getIDTokenUnsafe()
|
||||
if currentIDToken == token {
|
||||
// This handles cases where token is "" and currentIDToken is also "", no change.
|
||||
// Or token is "abc" and currentIDToken is "abc", no change.
|
||||
// If token is empty, and current is also empty, it's not a change.
|
||||
// This check handles both empty and non-empty identical cases.
|
||||
return
|
||||
}
|
||||
sd.dirty = true
|
||||
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireIDTokenChunks(nil) // Will be saved when Save() is called.
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.idTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if token == "" { // Clearing the token
|
||||
// STABILITY FIX: Add nil checks before accessing session values
|
||||
if sd.mainSession != nil {
|
||||
sd.mainSession.Values["id_token"] = ""
|
||||
sd.mainSession.Values["id_token_compressed"] = false
|
||||
}
|
||||
// sd.idTokenChunks is already cleared
|
||||
return
|
||||
}
|
||||
|
||||
sd.dirty = true // Mark as dirty because a change is being made
|
||||
|
||||
if token == "" {
|
||||
sd.mainSession.Values["id_token"] = ""
|
||||
sd.mainSession.Values["id_token_compressed"] = false
|
||||
return
|
||||
}
|
||||
|
||||
// Compress token
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
sd.mainSession.Values["id_token"] = compressed
|
||||
sd.mainSession.Values["id_token_compressed"] = true
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
// STABILITY FIX: Add nil checks before accessing session values
|
||||
if sd.mainSession != nil {
|
||||
sd.mainSession.Values["id_token"] = compressed
|
||||
sd.mainSession.Values["id_token_compressed"] = true
|
||||
}
|
||||
} else {
|
||||
// Split compressed token into chunks.
|
||||
if sd.mainSession != nil {
|
||||
sd.mainSession.Values["id_token"] = "" // Main cookie won't hold the token directly
|
||||
sd.mainSession.Values["id_token_compressed"] = true // Data in chunks is compressed
|
||||
}
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunkData := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", mainCookieName, i)
|
||||
// Ensure sd.request is available, otherwise log warning or handle error
|
||||
if sd.request == nil {
|
||||
sd.manager.logger.Infof("SetIDToken: sd.request is nil, cannot get/create chunk session %s", sessionName)
|
||||
// Potentially skip this chunk or error out, depending on desired robustness
|
||||
continue
|
||||
}
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["id_token_chunk"] = chunkData
|
||||
sd.idTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRedirectCount retrieves the current redirect count from the session.
|
||||
|
||||
+164
@@ -1,6 +1,9 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
@@ -218,4 +221,165 @@ func TestSessionObjectTracking(t *testing.T) {
|
||||
t.Log("Session pool handling verified")
|
||||
}
|
||||
|
||||
// TestLargeIDTokenChunking tests that large ID tokens are properly chunked across multiple cookies
|
||||
func TestLargeIDTokenChunking(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
// Create a large ID token (>4KB) to force chunking
|
||||
largeIDToken := createLargeIDToken(20000) // 20KB token to ensure chunking after compression
|
||||
t.Logf("Created large ID token with length: %d", len(largeIDToken))
|
||||
|
||||
// Create a request and response recorder
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get session and set large ID token
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set the large ID token
|
||||
session.SetIDToken(largeIDToken)
|
||||
t.Logf("Set large ID token in session")
|
||||
|
||||
// Let's check what the GetIDToken returns to confirm it's set
|
||||
retrievedToken := session.GetIDToken()
|
||||
t.Logf("Retrieved ID token length: %d", len(retrievedToken))
|
||||
if len(retrievedToken) != len(largeIDToken) {
|
||||
t.Errorf("Token length mismatch: expected %d, got %d", len(largeIDToken), len(retrievedToken))
|
||||
}
|
||||
|
||||
// Let's check what's in the main session directly
|
||||
if idToken, ok := session.mainSession.Values["id_token"].(string); ok {
|
||||
t.Logf("Main session id_token length: %d", len(idToken))
|
||||
if compressed, ok := session.mainSession.Values["id_token_compressed"].(bool); ok {
|
||||
t.Logf("Main session id_token_compressed: %v", compressed)
|
||||
}
|
||||
} else {
|
||||
t.Logf("Main session id_token not found or not a string")
|
||||
}
|
||||
|
||||
// Save the session to trigger chunking
|
||||
err = session.Save(req, rr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify that chunked cookies were created
|
||||
cookies := rr.Result().Cookies()
|
||||
t.Logf("Total cookies in response: %d", len(cookies))
|
||||
|
||||
for _, cookie := range cookies {
|
||||
valuePreview := cookie.Value
|
||||
if len(valuePreview) > 50 {
|
||||
valuePreview = valuePreview[:50] + "..."
|
||||
}
|
||||
t.Logf("Cookie: %s = %s (len=%d)", cookie.Name, valuePreview, len(cookie.Value))
|
||||
}
|
||||
|
||||
var mainCookie *http.Cookie
|
||||
var chunkCookies []*http.Cookie
|
||||
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == mainCookieName {
|
||||
mainCookie = cookie
|
||||
} else if strings.HasPrefix(cookie.Name, mainCookieName+"_") {
|
||||
chunkCookies = append(chunkCookies, cookie)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify main cookie exists
|
||||
if mainCookie == nil {
|
||||
t.Fatal("Main cookie not found in response")
|
||||
}
|
||||
|
||||
// Verify chunk cookies exist (should be at least 2 for a 5KB token)
|
||||
if len(chunkCookies) < 2 {
|
||||
t.Fatalf("Expected at least 2 chunk cookies, got %d", len(chunkCookies))
|
||||
}
|
||||
|
||||
// Verify chunk cookie naming convention
|
||||
expectedChunkNames := make(map[string]bool)
|
||||
for i := 0; i < len(chunkCookies); i++ {
|
||||
expectedChunkNames[mainCookieName+"_"+fmt.Sprintf("%d", i)] = true
|
||||
}
|
||||
|
||||
for _, cookie := range chunkCookies {
|
||||
if !expectedChunkNames[cookie.Name] {
|
||||
t.Errorf("Unexpected chunk cookie name: %s", cookie.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Test token retrieval from chunked cookies
|
||||
// Create a new request with all the cookies
|
||||
newReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get session and retrieve the ID token
|
||||
retrievedSession, err := sm.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session from chunked cookies: %v", err)
|
||||
}
|
||||
|
||||
retrievedToken2 := retrievedSession.GetIDToken()
|
||||
|
||||
// Verify the retrieved token matches the original
|
||||
if retrievedToken2 != largeIDToken {
|
||||
t.Errorf("Retrieved ID token doesn't match original. Expected length: %d, got: %d", len(largeIDToken), len(retrievedToken2))
|
||||
}
|
||||
|
||||
// Test clearing the ID token removes all chunks
|
||||
retrievedSession.SetIDToken("")
|
||||
|
||||
clearRR := httptest.NewRecorder()
|
||||
err = retrievedSession.Save(newReq, clearRR)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save session after clearing ID token: %v", err)
|
||||
}
|
||||
|
||||
// Verify chunks are expired (MaxAge = -1)
|
||||
clearCookies := clearRR.Result().Cookies()
|
||||
for _, cookie := range clearCookies {
|
||||
if strings.HasPrefix(cookie.Name, mainCookieName+"_") {
|
||||
if cookie.MaxAge != -1 {
|
||||
t.Errorf("Expected chunk cookie %s to be expired (MaxAge=-1), got MaxAge=%d", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createLargeIDToken creates a JWT-like token of specified size for testing
|
||||
func createLargeIDToken(size int) string {
|
||||
// Create truly random data that won't compress well
|
||||
randomBytes := make([]byte, size*3/4) // base64 encoding increases size by ~4/3
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
// Fallback to pseudo-random if crypto/rand fails
|
||||
for i := range randomBytes {
|
||||
randomBytes[i] = byte(i % 256)
|
||||
}
|
||||
}
|
||||
|
||||
// Base64 encode the random data to make it look like a JWT
|
||||
encoded := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
|
||||
// Create JWT-like structure with truly random data
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
// Truncate or pad to desired size
|
||||
if len(encoded) > size-len(header)-100 {
|
||||
encoded = encoded[:size-len(header)-100]
|
||||
}
|
||||
|
||||
signature := "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
||||
|
||||
return header + "." + encoded + "." + signature
|
||||
}
|
||||
|
||||
// This is intentionally left empty to remove unused code
|
||||
|
||||
Reference in New Issue
Block a user