Well, release it finally.

This commit is contained in:
2025-01-21 19:31:51 +00:00
parent dfb9c0771e
commit 025107fe3e
11 changed files with 231 additions and 198 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ testData:
- raczylo.com - raczylo.com
allowedRolesAndGroups: allowedRolesAndGroups:
- guest-endpoints - guest-endpoints
sessionEncryptionKey: potato-secret sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
forceHTTPS: false forceHTTPS: false
logLevel: debug # debug, info, warn, error logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks rateLimit: 100 # Simple rate limiter to prevent brute force attacks
+2
View File
@@ -19,6 +19,8 @@ Middleware currently supports following scenarios:
#### How to configure... #### How to configure...
* `sessionEncryptionKey` should be at least 32 bytes long.
##### Keeping secrets secret ##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys. This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
+3
View File
@@ -50,6 +50,7 @@ func NewCache() *Cache {
// - key: Unique identifier for the cached item // - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type) // - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache // - expiration: How long the item should remain in the cache
//
// Thread-safe: Uses write locking to ensure safe concurrent access. // Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock() c.mutex.Lock()
@@ -80,9 +81,11 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
// Get retrieves an item from the cache if it exists and hasn't expired. // Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters: // Parameters:
// - key: The identifier of the item to retrieve // - key: The identifier of the item to retrieve
//
// Returns: // Returns:
// - value: The cached data (nil if not found or expired) // - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise // - found: true if the item was found and is valid, false otherwise
//
// Thread-safe: Uses read locking to ensure safe concurrent access. // Thread-safe: Uses read locking to ensure safe concurrent access.
func (c *Cache) Get(key string) (interface{}, bool) { func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock() c.mutex.RLock()
+1
View File
@@ -19,6 +19,7 @@ import (
// newSessionOptions creates secure session cookie options. // newSessionOptions creates secure session cookie options.
// Parameters: // Parameters:
// - isSecure: Whether to set the Secure flag on cookies // - isSecure: Whether to set the Secure flag on cookies
//
// Returns session options configured for security with: // Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access // - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection // - SameSite=Lax for CSRF protection
+2
View File
@@ -81,6 +81,7 @@ type JWKCacheInterface interface {
// Parameters: // Parameters:
// - jwksURL: The URL of the JWKS endpoint // - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys // - httpClient: The HTTP client to use for fetching keys
//
// Returns: // Returns:
// - The JSON Web Key Set // - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed // - An error if the keys cannot be retrieved or parsed
@@ -115,6 +116,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
// Parameters: // Parameters:
// - jwksURL: The URL of the JWKS endpoint // - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request // - httpClient: The HTTP client to use for the request
//
// Returns: // Returns:
// - The parsed JSON Web Key Set // - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid // - An error if the request fails or the response is invalid
+9
View File
@@ -38,6 +38,7 @@ type JWT struct {
// (header, claims, signature) using base64url decoding. // (header, claims, signature) using base64url decoding.
// Parameters: // Parameters:
// - tokenString: The raw JWT token string // - tokenString: The raw JWT token string
//
// Returns: // Returns:
// - A parsed JWT struct // - A parsed JWT struct
// - An error if the token format is invalid or parsing fails // - An error if the token format is invalid or parsing fails
@@ -88,6 +89,7 @@ func parseJWT(tokenString string) (*JWT, error) {
// - not before time (nbf) is in the past (with clock skew tolerance) // - not before time (nbf) is in the past (with clock skew tolerance)
// - subject (sub) is present and not empty // - subject (sub) is present and not empty
// - algorithm matches expected value to prevent algorithm switching attacks // - algorithm matches expected value to prevent algorithm switching attacks
//
// Returns an error if any validation fails. // Returns an error if any validation fails.
func (j *JWT) Verify(issuerURL, clientID string) error { func (j *JWT) Verify(issuerURL, clientID string) error {
// Debug logging of validation parameters // Debug logging of validation parameters
@@ -174,6 +176,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
// Parameters: // Parameters:
// - tokenAudience: The audience claim from the token // - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value // - expectedAudience: The expected audience value
//
// Returns an error if validation fails. // Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error { func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Debug logging // Debug logging
@@ -207,6 +210,7 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Parameters: // Parameters:
// - tokenIssuer: The issuer claim from the token // - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL // - expectedIssuer: The expected issuer URL
//
// Returns an error if validation fails. // Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error { func verifyIssuer(tokenIssuer, expectedIssuer string) error {
// Debug logging // Debug logging
@@ -227,6 +231,7 @@ const clockSkewTolerance = 2 * time.Minute
// The expiration time is compared against the current time with clock skew tolerance. // The expiration time is compared against the current time with clock skew tolerance.
// Parameters: // Parameters:
// - expiration: The expiration timestamp from the token // - expiration: The expiration timestamp from the token
//
// Returns an error if the token has expired. // Returns an error if the token has expired.
func verifyExpiration(expiration float64) error { func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0) expirationTime := time.Unix(int64(expiration), 0)
@@ -257,6 +262,7 @@ func verifyExpiration(expiration float64) error {
// Ensures the token wasn't issued in the future, accounting for clock skew. // Ensures the token wasn't issued in the future, accounting for clock skew.
// Parameters: // Parameters:
// - issuedAt: The issued-at timestamp from the token // - issuedAt: The issued-at timestamp from the token
//
// Returns an error if the token was issued in the future. // Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error { func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0) issuedAtTime := time.Unix(int64(issuedAt), 0)
@@ -287,6 +293,7 @@ func verifyIssuedAt(issuedAt float64) error {
// Ensures the token is not used before its valid time period, accounting for clock skew. // Ensures the token is not used before its valid time period, accounting for clock skew.
// Parameters: // Parameters:
// - notBefore: The not-before timestamp from the token // - notBefore: The not-before timestamp from the token
//
// Returns an error if the token is not yet valid. // Returns an error if the token is not yet valid.
func verifyNotBefore(notBefore float64) error { func verifyNotBefore(notBefore float64) error {
notBeforeTime := time.Unix(int64(notBefore), 0) notBeforeTime := time.Unix(int64(notBefore), 0)
@@ -318,10 +325,12 @@ func verifyNotBefore(notBefore float64) error {
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5) // - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512 // - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512 // - ECDSA: ES256, ES384, ES512
//
// Parameters: // Parameters:
// - tokenString: The complete JWT token string // - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification // - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier // - alg: The signature algorithm identifier
//
// Returns an error if signature verification fails. // Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Debug logging // Debug logging
+37 -25
View File
@@ -12,6 +12,8 @@ import (
"strings" "strings"
"time" "time"
"runtime"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
@@ -185,9 +187,18 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
} }
// Initialize logger
logger := NewLogger(config.LogLevel)
// Ensure key meets minimum length requirement // Ensure key meets minimum length requirement
if len(config.SessionEncryptionKey) < minEncryptionKeyLength { if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) if runtime.Compiler == "yaegi" {
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
} }
// Setup HTTP client // Setup HTTP client
@@ -195,19 +206,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive KeepAlive: 15 * time.Second, // Reduced keepalive
} }
return dialer.DialContext(ctx, network, addr) return dialer.DialContext(ctx, network, addr)
}, },
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0, ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100 MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100 MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections MaxConnsPerHost: 50, // Limit max connections
} }
var httpClient *http.Client var httpClient *http.Client
@@ -245,12 +256,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(), tokenCache: NewTokenCache(),
httpClient: httpClient, httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs), excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains), allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups), allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}), initComplete: make(chan struct{}),
} }
// Assign the initialized logger
t.logger = logger
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims t.extractClaimsFunc = extractClaims
@@ -400,24 +412,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return return
} }
// Get session // Get session
session, err := t.sessionManager.GetSession(req) session, err := t.sessionManager.GetSession(req)
if err != nil { if err != nil {
t.logger.Errorf("Error getting session: %v", err) t.logger.Errorf("Error getting session: %v", err)
// Obtain a new session and clear any residual session cookies // Obtain a new session and clear any residual session cookies
session, _ = t.sessionManager.GetSession(req) session, _ = t.sessionManager.GetSession(req)
session.Clear(req, rw) session.Clear(req, rw)
// Build redirect URL // Build redirect URL
scheme := t.determineScheme(req) scheme := t.determineScheme(req)
host := t.determineHost(req) host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath) redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Initiate authentication // Initiate authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL) t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return return
} }
// Build redirect URL // Build redirect URL
scheme := t.determineScheme(req) scheme := t.determineScheme(req)
+4 -4
View File
@@ -1370,10 +1370,10 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
// Create base config // Create base config
config := &Config{ config := &Config{
ProviderURL: mockServer.URL, ProviderURL: mockServer.URL,
ClientID: "test-client", ClientID: "test-client",
ClientSecret: "test-secret", ClientSecret: "test-secret",
CallbackURL: "/callback", CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough", SessionEncryptionKey: "test-encryption-key-thats-long-enough",
} }
+12 -9
View File
@@ -28,8 +28,8 @@ func generateSecureRandomString(length int) string {
// Cookie names and configuration constants used for session management // Cookie names and configuration constants used for session management
const ( const (
// Using fixed prefixes for consistent cookie naming across restarts // Using fixed prefixes for consistent cookie naming across restarts
mainCookieName = "_oidc_raczylo_m" mainCookieName = "_oidc_raczylo_m"
accessTokenCookie = "_oidc_raczylo_a" accessTokenCookie = "_oidc_raczylo_a"
refreshTokenCookie = "_oidc_raczylo_r" refreshTokenCookie = "_oidc_raczylo_r"
) )
@@ -111,6 +111,7 @@ type SessionManager struct {
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes) // - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme // - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events // - logger: Logger instance for recording session-related events
//
// The manager handles session creation, storage, and cookie security settings. // The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager { func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
// Validate encryption key length // Validate encryption key length
@@ -127,8 +128,8 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
// Initialize session pool // Initialize session pool
sm.sessionPool.New = func() interface{} { sm.sessionPool.New = func() interface{} {
return &SessionData{ return &SessionData{
manager: sm, manager: sm,
accessTokenChunks: make(map[int]*sessions.Session), accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session), refreshTokenChunks: make(map[int]*sessions.Session),
} }
} }
@@ -139,6 +140,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
// getSessionOptions returns secure session options configured for the current request. // getSessionOptions returns secure session options configured for the current request.
// Parameters: // Parameters:
// - isSecure: Whether the current request is using HTTPS // - isSecure: Whether the current request is using HTTPS
//
// The options ensure cookies are: // The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript) // - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled // - Secure when using HTTPS or when forceHTTPS is enabled
@@ -172,11 +174,11 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
// Check for absolute session timeout // Check for absolute session timeout
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok { if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout { if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
// Session has expired // Session has expired
sm.sessionPool.Put(sessionData) sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("session expired") return nil, fmt.Errorf("session expired")
} }
} }
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie) sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
@@ -557,6 +559,7 @@ func (sd *SessionData) SetRefreshToken(token string) {
// Parameters: // Parameters:
// - s: The string to split // - s: The string to split
// - chunkSize: Maximum size of each chunk // - chunkSize: Maximum size of each chunk
//
// Returns an array of string chunks, each no larger than chunkSize. // Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string { func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string var chunks []string
+114 -114
View File
@@ -198,160 +198,160 @@ func TestSessionManager(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
authenticated bool authenticated bool
email string email string
accessToken string accessToken string
refreshToken string refreshToken string
expectedCookieCount int expectedCookieCount int
wantCompressed bool // Whether tokens should be compressed wantCompressed bool // Whether tokens should be compressed
}{ }{
{ {
name: "Short tokens", name: "Short tokens",
authenticated: true, authenticated: true,
email: "test@example.com", email: "test@example.com",
accessToken: "shortaccesstoken", accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken", refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh expectedCookieCount: 3, // main, access, refresh
wantCompressed: true, wantCompressed: true,
}, },
{ {
name: "Long tokens exceeding 4096 bytes", name: "Long tokens exceeding 4096 bytes",
authenticated: true, authenticated: true,
email: "test@example.com", email: "test@example.com",
accessToken: strings.Repeat("x", 5000), accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000), refreshToken: strings.Repeat("y", 6000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)), expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
wantCompressed: true, wantCompressed: true,
}, },
{ {
name: "REALLY long tokens, exceeding 25000 bytes", name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true, authenticated: true,
email: "test@example.com", email: "test@example.com",
accessToken: strings.Repeat("x", 25000), accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000), refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)), expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
wantCompressed: true, wantCompressed: true,
}, },
{ {
name: "Unauthenticated session", name: "Unauthenticated session",
authenticated: false, authenticated: false,
email: "", email: "",
accessToken: "", accessToken: "",
refreshToken: "", refreshToken: "",
expectedCookieCount: 3, // main, access, refresh expectedCookieCount: 3, // main, access, refresh
wantCompressed: false, wantCompressed: false,
}, },
{ {
name: "Random content tokens", name: "Random content tokens",
authenticated: true, authenticated: true,
email: "test@example.com", email: "test@example.com",
accessToken: generateRandomString(5000), accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000), refreshToken: generateRandomString(5000),
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)), expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
wantCompressed: true, wantCompressed: true,
}, },
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc // Capture range variable tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil) req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req) session, err := ts.sessionManager.GetSession(req)
if err != nil { if err != nil {
t.Fatalf("Failed to get session: %v", err) t.Fatalf("Failed to get session: %v", err)
} }
// Set session values // Set session values
session.SetAuthenticated(tc.authenticated) session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email) session.SetEmail(tc.email)
// Expire any existing cookies // Expire any existing cookies
session.expireAccessTokenChunks(rr) session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr) session.expireRefreshTokenChunks(rr)
// Set new tokens // Set new tokens
session.SetAccessToken(tc.accessToken) session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken) session.SetRefreshToken(tc.refreshToken)
// Save session // Save session
if err := session.Save(req, rr); err != nil { if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err) t.Fatalf("Failed to save session: %v", err)
} }
// Verify cookies are set and compression is used when appropriate // Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies() cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount { if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies)) t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
} }
// Verify compression is working by checking token sizes // Verify compression is working by checking token sizes
for _, cookie := range cookies { for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) { if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes // Get original and stored sizes
originalSize := len(tc.accessToken) originalSize := len(tc.accessToken)
storedSize := len(cookie.Value) storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed { if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred // For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize) compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)", t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize) compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)", t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio) cookie.Name, compressionRatio)
}
}
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
} }
} }
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
// Create a new request with the cookies if originalSize > 100 && tc.wantCompressed {
newReq := httptest.NewRequest("GET", "/test", nil) compressionRatio := float64(storedSize) / float64(originalSize)
for _, cookie := range cookies { t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
newReq.AddCookie(cookie) compressionRatio, originalSize, storedSize)
}
// Get the session again and verify values if compressionRatio > 0.9 {
newSession, err := ts.sessionManager.GetSession(newReq) t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
if err != nil { cookie.Name, compressionRatio)
t.Fatalf("Failed to get new session: %v", err) }
} }
}
}
// Verify session values // Create a new request with the cookies
if newSession.GetAuthenticated() != tc.authenticated { newReq := httptest.NewRequest("GET", "/test", nil)
t.Errorf("Authentication status not preserved") for _, cookie := range cookies {
} newReq.AddCookie(cookie)
if email := newSession.GetEmail(); email != tc.email { }
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Verify session pooling by checking if the session is reused // Get the session again and verify values
session2, _ := ts.sessionManager.GetSession(newReq) newSession, err := ts.sessionManager.GetSession(newReq)
if session2 == newSession { if err != nil {
t.Error("Session not properly pooled") t.Fatalf("Failed to get new session: %v", err)
} }
})
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
} }
} }
+1
View File
@@ -221,6 +221,7 @@ type Logger struct {
// - "debug": Outputs all messages (debug, info, error) // - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages // - "info": Outputs info and error messages
// - "error": Outputs only error messages // - "error": Outputs only error messages
//
// Error messages are always written to stderr, while info and debug // Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled. // messages are written to stdout when enabled.
func NewLogger(logLevel string) *Logger { func NewLogger(logLevel string) *Logger {