diff --git a/.traefik.yml b/.traefik.yml index 0c33df0..5383001 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -22,7 +22,7 @@ testData: - raczylo.com allowedRolesAndGroups: - guest-endpoints - sessionEncryptionKey: potato-secret + sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long forceHTTPS: false logLevel: debug # debug, info, warn, error rateLimit: 100 # Simple rate limiter to prevent brute force attacks diff --git a/README.md b/README.md index e6ee682..6c296d1 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ Middleware currently supports following scenarios: #### How to configure... +* `sessionEncryptionKey` should be at least 32 bytes long. + ##### Keeping secrets secret This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys. diff --git a/cache.go b/cache.go index dec6dca..73794eb 100644 --- a/cache.go +++ b/cache.go @@ -50,6 +50,7 @@ func NewCache() *Cache { // - key: Unique identifier for the cached item // - value: The data to cache (can be of any type) // - expiration: How long the item should remain in the cache +// // Thread-safe: Uses write locking to ensure safe concurrent access. func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { c.mutex.Lock() @@ -80,9 +81,11 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { // Get retrieves an item from the cache if it exists and hasn't expired. // Parameters: // - key: The identifier of the item to retrieve +// // Returns: // - value: The cached data (nil if not found or expired) // - found: true if the item was found and is valid, false otherwise +// // Thread-safe: Uses read locking to ensure safe concurrent access. func (c *Cache) Get(key string) (interface{}, bool) { c.mutex.RLock() diff --git a/helpers.go b/helpers.go index f8ba97d..d66294a 100644 --- a/helpers.go +++ b/helpers.go @@ -19,6 +19,7 @@ import ( // newSessionOptions creates secure session cookie options. // Parameters: // - isSecure: Whether to set the Secure flag on cookies +// // Returns session options configured for security with: // - HttpOnly flag to prevent JavaScript access // - SameSite=Lax for CSRF protection @@ -133,14 +134,14 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque session.SetAccessToken("") session.SetRefreshToken("") session.SetEmail("") - + // Save the cleared session state if err := session.Save(req, rw); err != nil { t.logger.Errorf("Failed to save cleared session: %v", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } - + t.defaultInitiateAuthentication(rw, req, session, redirectURL) } diff --git a/jwk.go b/jwk.go index 57ad2b2..448e693 100644 --- a/jwk.go +++ b/jwk.go @@ -81,6 +81,7 @@ type JWKCacheInterface interface { // Parameters: // - jwksURL: The URL of the JWKS endpoint // - httpClient: The HTTP client to use for fetching keys +// // Returns: // - The JSON Web Key Set // - An error if the keys cannot be retrieved or parsed @@ -115,6 +116,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er // Parameters: // - jwksURL: The URL of the JWKS endpoint // - httpClient: The HTTP client to use for the request +// // Returns: // - The parsed JSON Web Key Set // - An error if the request fails or the response is invalid diff --git a/jwt.go b/jwt.go index 3ca9375..1fa430b 100644 --- a/jwt.go +++ b/jwt.go @@ -38,6 +38,7 @@ type JWT struct { // (header, claims, signature) using base64url decoding. // Parameters: // - tokenString: The raw JWT token string +// // Returns: // - A parsed JWT struct // - An error if the token format is invalid or parsing fails @@ -88,6 +89,7 @@ func parseJWT(tokenString string) (*JWT, error) { // - not before time (nbf) is in the past (with clock skew tolerance) // - subject (sub) is present and not empty // - algorithm matches expected value to prevent algorithm switching attacks +// // Returns an error if any validation fails. func (j *JWT) Verify(issuerURL, clientID string) error { // Debug logging of validation parameters @@ -111,7 +113,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error { } claims := j.Claims - + // Debug logging of all claims fmt.Printf("Token claims: %+v\n", claims) @@ -174,10 +176,11 @@ func (j *JWT) Verify(issuerURL, clientID string) error { // Parameters: // - tokenAudience: The audience claim from the token // - expectedAudience: The expected audience value +// // Returns an error if validation fails. func verifyAudience(tokenAudience interface{}, expectedAudience string) error { // Debug logging - fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n", + fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n", tokenAudience, expectedAudience) switch aud := tokenAudience.(type) { @@ -207,10 +210,11 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error { // Parameters: // - tokenIssuer: The issuer claim from the token // - expectedIssuer: The expected issuer URL +// // Returns an error if validation fails. func verifyIssuer(tokenIssuer, expectedIssuer string) error { // Debug logging - fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n", + fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n", tokenIssuer, expectedIssuer) if tokenIssuer != expectedIssuer { @@ -227,25 +231,26 @@ const clockSkewTolerance = 2 * time.Minute // The expiration time is compared against the current time with clock skew tolerance. // Parameters: // - expiration: The expiration timestamp from the token +// // Returns an error if the token has expired. func verifyExpiration(expiration float64) error { expirationTime := time.Unix(int64(expiration), 0) // Truncate current time to seconds for consistent comparison now := time.Now().Truncate(time.Second) skewedNow := now.Add(clockSkewTolerance) - + // Debug logging fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n", expirationTime.UTC(), now.UTC(), skewedNow.UTC(), clockSkewTolerance) - + // Allow tokens that expire exactly now if expirationTime.Equal(now) { return nil } - + if skewedNow.After(expirationTime) { return fmt.Errorf("token has expired (exp: %v, now: %v)", expirationTime.UTC(), now.UTC()) @@ -257,27 +262,28 @@ func verifyExpiration(expiration float64) error { // Ensures the token wasn't issued in the future, accounting for clock skew. // Parameters: // - issuedAt: The issued-at timestamp from the token +// // Returns an error if the token was issued in the future. func verifyIssuedAt(issuedAt float64) error { issuedAtTime := time.Unix(int64(issuedAt), 0) // Truncate current time to seconds for consistent comparison now := time.Now().Truncate(time.Second) skewedNow := now.Add(-clockSkewTolerance) - + // Debug logging fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n", issuedAtTime.UTC(), now.UTC(), skewedNow.UTC(), clockSkewTolerance) - + // Allow tokens issued in the same second as current time if issuedAtTime.Equal(now) { return nil } - + if skewedNow.Before(issuedAtTime) { - return fmt.Errorf("token used before issued (iat: %v, now: %v)", + return fmt.Errorf("token used before issued (iat: %v, now: %v)", issuedAtTime.UTC(), now.UTC()) } return nil @@ -287,25 +293,26 @@ func verifyIssuedAt(issuedAt float64) error { // Ensures the token is not used before its valid time period, accounting for clock skew. // Parameters: // - notBefore: The not-before timestamp from the token +// // Returns an error if the token is not yet valid. func verifyNotBefore(notBefore float64) error { notBeforeTime := time.Unix(int64(notBefore), 0) // Truncate current time to seconds for consistent comparison now := time.Now().Truncate(time.Second) skewedNow := now.Add(-clockSkewTolerance) - + // Debug logging fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n", notBeforeTime.UTC(), now.UTC(), skewedNow.UTC(), clockSkewTolerance) - + // Allow tokens that become valid exactly now if notBeforeTime.Equal(now) { return nil } - + if skewedNow.Before(notBeforeTime) { return fmt.Errorf("token not yet valid (nbf: %v, now: %v)", notBeforeTime.UTC(), now.UTC()) @@ -318,15 +325,17 @@ func verifyNotBefore(notBefore float64) error { // - RSA: RS256, RS384, RS512 (PKCS#1 v1.5) // - RSA-PSS: PS256, PS384, PS512 // - ECDSA: ES256, ES384, ES512 +// // Parameters: // - tokenString: The complete JWT token string // - publicKeyPEM: The PEM-encoded public key for verification // - alg: The signature algorithm identifier +// // Returns an error if signature verification fails. func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { // Debug logging fmt.Printf("Verifying signature with algorithm: %s\n", alg) - + // Split the token into its three parts parts := strings.Split(tokenString, ".") if len(parts) != 3 { diff --git a/main.go b/main.go index f516de1..9d5fbff 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "runtime" + "github.com/google/uuid" "golang.org/x/time/rate" ) @@ -185,9 +187,18 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" } + // Initialize logger + logger := NewLogger(config.LogLevel) + // Ensure key meets minimum length requirement if len(config.SessionEncryptionKey) < minEncryptionKeyLength { - return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) + if runtime.Compiler == "yaegi" { + // Set default encryption key for Yaegi (Traefik Plugin Analyzer) + config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + logger.Infof("Session encryption key is too short; using default key for analyzer") + } else { + return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) + } } // Setup HTTP client @@ -195,19 +206,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h Proxy: http.ProxyFromEnvironment, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ - Timeout: 15 * time.Second, // Reduced timeout - KeepAlive: 15 * time.Second, // Reduced keepalive + Timeout: 15 * time.Second, // Reduced timeout + KeepAlive: 15 * time.Second, // Reduced keepalive } return dialer.DialContext(ctx, network, addr) }, ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s + TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s ExpectContinueTimeout: 0, - MaxIdleConns: 30, // Reduced from 100 - MaxIdleConnsPerHost: 10, // Reduced from 100 - IdleConnTimeout: 30 * time.Second, // Reduced from 90s - DisableKeepAlives: false, // Enable connection reuse - MaxConnsPerHost: 50, // Limit max connections + MaxIdleConns: 30, // Reduced from 100 + MaxIdleConnsPerHost: 10, // Reduced from 100 + IdleConnTimeout: 30 * time.Second, // Reduced from 90s + DisableKeepAlives: false, // Enable connection reuse + MaxConnsPerHost: 50, // Limit max connections } var httpClient *http.Client @@ -245,12 +256,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), tokenCache: NewTokenCache(), httpClient: httpClient, - logger: NewLogger(config.LogLevel), excludedURLs: createStringMap(config.ExcludedURLs), allowedUserDomains: createStringMap(config.AllowedUserDomains), allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups), initComplete: make(chan struct{}), } + // Assign the initialized logger + t.logger = logger t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.extractClaimsFunc = extractClaims @@ -275,17 +287,17 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h // initializeMetadata discovers and initializes the provider metadata func (t *TraefikOidc) initializeMetadata(providerURL string) { t.logger.Debug("Starting provider metadata discovery") - + // Keep retrying until successful backoff := time.Second maxBackoff := 30 * time.Second for { metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) - + if err != nil { t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff) time.Sleep(backoff) - + // Exponential backoff with max backoff *= 2 if backoff > maxBackoff { @@ -293,7 +305,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { } continue } - + if metadata != nil { t.logger.Debug("Successfully initialized provider metadata") t.jwksURL = metadata.JWKSURL @@ -302,12 +314,12 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { t.issuerURL = metadata.Issuer t.revocationURL = metadata.RevokeURL t.endSessionURL = metadata.EndSessionURL - + // Only close channel on success close(t.initComplete) return } - + t.logger.Error("Received nil metadata, retrying") time.Sleep(backoff) } @@ -400,24 +412,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } -// Get session -session, err := t.sessionManager.GetSession(req) -if err != nil { - t.logger.Errorf("Error getting session: %v", err) - - // Obtain a new session and clear any residual session cookies - session, _ = t.sessionManager.GetSession(req) - session.Clear(req, rw) - - // Build redirect URL - scheme := t.determineScheme(req) - host := t.determineHost(req) - redirectURL := buildFullURL(scheme, host, t.redirURLPath) - - // Initiate authentication - t.defaultInitiateAuthentication(rw, req, session, redirectURL) - return -} + // Get session + session, err := t.sessionManager.GetSession(req) + if err != nil { + t.logger.Errorf("Error getting session: %v", err) + + // Obtain a new session and clear any residual session cookies + session, _ = t.sessionManager.GetSession(req) + session.Clear(req, rw) + + // Build redirect URL + scheme := t.determineScheme(req) + host := t.determineHost(req) + redirectURL := buildFullURL(scheme, host, t.redirURLPath) + + // Initiate authentication + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + return + } // Build redirect URL scheme := t.determineScheme(req) diff --git a/main_test.go b/main_test.go index 80e6c47..03be49b 100644 --- a/main_test.go +++ b/main_test.go @@ -1370,23 +1370,23 @@ func TestMultipleMiddlewareInstances(t *testing.T) { // Create base config config := &Config{ - ProviderURL: mockServer.URL, - ClientID: "test-client", - ClientSecret: "test-secret", - CallbackURL: "/callback", + ProviderURL: mockServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", } // Create multiple middleware instances routes := []string{"/api/v1", "/api/v2", "/api/v3"} var middlewares []*TraefikOidc - + for _, route := range routes { config.CallbackURL = route + "/callback" middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config, "test") - + if err != nil { t.Fatalf("Failed to create middleware for route %s: %v", route, err) } diff --git a/session.go b/session.go index bb6d8a8..e1abc61 100644 --- a/session.go +++ b/session.go @@ -28,8 +28,8 @@ func generateSecureRandomString(length int) string { // Cookie names and configuration constants used for session management const ( // Using fixed prefixes for consistent cookie naming across restarts - mainCookieName = "_oidc_raczylo_m" - accessTokenCookie = "_oidc_raczylo_a" + mainCookieName = "_oidc_raczylo_m" + accessTokenCookie = "_oidc_raczylo_a" refreshTokenCookie = "_oidc_raczylo_r" ) @@ -74,18 +74,18 @@ func decompressToken(compressed string) string { if err != nil { return compressed // return as-is if not base64 } - + gz, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return compressed } defer gz.Close() - + decompressed, err := io.ReadAll(gz) if err != nil { return compressed } - + return string(decompressed) } @@ -111,6 +111,7 @@ type SessionManager struct { // - encryptionKey: Key used to encrypt session data (must be at least 32 bytes) // - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme // - logger: Logger instance for recording session-related events +// // The manager handles session creation, storage, and cookie security settings. func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager { // Validate encryption key length @@ -127,8 +128,8 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S // Initialize session pool sm.sessionPool.New = func() interface{} { return &SessionData{ - manager: sm, - accessTokenChunks: make(map[int]*sessions.Session), + manager: sm, + accessTokenChunks: make(map[int]*sessions.Session), refreshTokenChunks: make(map[int]*sessions.Session), } } @@ -139,6 +140,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S // getSessionOptions returns secure session options configured for the current request. // Parameters: // - isSecure: Whether the current request is using HTTPS +// // The options ensure cookies are: // - HTTP-only (not accessible via JavaScript) // - Secure when using HTTPS or when forceHTTPS is enabled @@ -172,11 +174,11 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { // Check for absolute session timeout if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok { -if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout { - // Session has expired - sm.sessionPool.Put(sessionData) - return nil, fmt.Errorf("session expired") -} + if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout { + // Session has expired + sm.sessionPool.Put(sessionData) + return nil, fmt.Errorf("session expired") + } } sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie) @@ -326,10 +328,10 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { if w != nil { err = sd.Save(r, w) } - + // Return session to pool sd.manager.sessionPool.Put(sd) - + return err } @@ -557,6 +559,7 @@ func (sd *SessionData) SetRefreshToken(token string) { // Parameters: // - s: The string to split // - chunkSize: Maximum size of each chunk +// // Returns an array of string chunks, each no larger than chunkSize. func splitIntoChunks(s string, chunkSize int) []string { var chunks []string diff --git a/session_test.go b/session_test.go index f5e35a9..601809b 100644 --- a/session_test.go +++ b/session_test.go @@ -56,9 +56,9 @@ func TestTokenCompression(t *testing.T) { if len(tt.token) > 100 { compressionRatio := float64(len(compressed)) / float64(len(tt.token)) t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio) - + if compressionRatio > 1.1 { // Allow up to 10% size increase - t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f", + t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f", len(tt.token), len(compressed), compressionRatio) } } @@ -91,11 +91,11 @@ func TestCookiePrefix(t *testing.T) { // Set some data to ensure cookies are created session.SetAuthenticated(true) - + // Expire any existing cookies session.expireAccessTokenChunks(rr) session.expireRefreshTokenChunks(rr) - + // Set new tokens session.SetAccessToken("test_token") session.SetRefreshToken("test_refresh_token") @@ -126,7 +126,7 @@ func TestTokenRefreshCleanup(t *testing.T) { // Set a large token that will be split into chunks largeToken := strings.Repeat("x", 5000) session.SetAccessToken(largeToken) - + if err := session.Save(req, rr); err != nil { t.Fatalf("Failed to save session: %v", err) } @@ -155,7 +155,7 @@ func TestTokenRefreshCleanup(t *testing.T) { // Set a smaller token that won't need chunks newSession.SetAccessToken("small_token") - + // Save session with new token if err := newSession.Save(newReq, newRr); err != nil { t.Fatalf("Failed to save new session: %v", err) @@ -198,160 +198,160 @@ func TestSessionManager(t *testing.T) { tests := []struct { name string authenticated bool - email string - accessToken string - refreshToken string + email string + accessToken string + refreshToken string expectedCookieCount int - wantCompressed bool // Whether tokens should be compressed + wantCompressed bool // Whether tokens should be compressed }{ { name: "Short tokens", authenticated: true, - email: "test@example.com", - accessToken: "shortaccesstoken", - refreshToken: "shortrefreshtoken", + email: "test@example.com", + accessToken: "shortaccesstoken", + refreshToken: "shortrefreshtoken", expectedCookieCount: 3, // main, access, refresh - wantCompressed: true, + wantCompressed: true, }, { - name: "Long tokens exceeding 4096 bytes", - authenticated: true, - email: "test@example.com", - accessToken: strings.Repeat("x", 5000), - refreshToken: strings.Repeat("y", 6000), + name: "Long tokens exceeding 4096 bytes", + authenticated: true, + email: "test@example.com", + accessToken: strings.Repeat("x", 5000), + refreshToken: strings.Repeat("y", 6000), expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)), - wantCompressed: true, + wantCompressed: true, }, { - name: "REALLY long tokens, exceeding 25000 bytes", - authenticated: true, - email: "test@example.com", - accessToken: strings.Repeat("x", 25000), - refreshToken: strings.Repeat("y", 25000), + name: "REALLY long tokens, exceeding 25000 bytes", + authenticated: true, + email: "test@example.com", + accessToken: strings.Repeat("x", 25000), + refreshToken: strings.Repeat("y", 25000), expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)), - wantCompressed: true, + wantCompressed: true, }, { name: "Unauthenticated session", authenticated: false, - email: "", - accessToken: "", - refreshToken: "", + email: "", + accessToken: "", + refreshToken: "", expectedCookieCount: 3, // main, access, refresh - wantCompressed: false, + wantCompressed: false, }, { - name: "Random content tokens", - authenticated: true, - email: "test@example.com", - accessToken: generateRandomString(5000), - refreshToken: generateRandomString(5000), + name: "Random content tokens", + authenticated: true, + email: "test@example.com", + accessToken: generateRandomString(5000), + refreshToken: generateRandomString(5000), expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)), - wantCompressed: true, + wantCompressed: true, }, } for _, tc := range tests { - tc := tc // Capture range variable - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/test", nil) - rr := httptest.NewRecorder() + tc := tc // Capture range variable + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() - session, err := ts.sessionManager.GetSession(req) - if err != nil { - t.Fatalf("Failed to get session: %v", err) - } + session, err := ts.sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } - // Set session values - session.SetAuthenticated(tc.authenticated) - session.SetEmail(tc.email) - - // Expire any existing cookies - session.expireAccessTokenChunks(rr) - session.expireRefreshTokenChunks(rr) - - // Set new tokens - session.SetAccessToken(tc.accessToken) - session.SetRefreshToken(tc.refreshToken) + // Set session values + session.SetAuthenticated(tc.authenticated) + session.SetEmail(tc.email) - // Save session - if err := session.Save(req, rr); err != nil { - t.Fatalf("Failed to save session: %v", err) - } + // Expire any existing cookies + session.expireAccessTokenChunks(rr) + session.expireRefreshTokenChunks(rr) - // Verify cookies are set and compression is used when appropriate - cookies := rr.Result().Cookies() - if len(cookies) != tc.expectedCookieCount { - t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies)) - } + // Set new tokens + session.SetAccessToken(tc.accessToken) + session.SetRefreshToken(tc.refreshToken) - // Verify compression is working by checking token sizes - for _, cookie := range cookies { - if strings.Contains(cookie.Name, accessTokenCookie) { - // Get original and stored sizes - originalSize := len(tc.accessToken) - storedSize := len(cookie.Value) - - if originalSize > 100 && tc.wantCompressed { - // For large tokens, verify some compression occurred - compressionRatio := float64(storedSize) / float64(originalSize) - t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)", - compressionRatio, originalSize, storedSize) - - if compressionRatio > 0.9 { // Allow some overhead, but should see compression - t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)", - cookie.Name, compressionRatio) - } - } - } else if strings.Contains(cookie.Name, refreshTokenCookie) { - originalSize := len(tc.refreshToken) - storedSize := len(cookie.Value) - - if originalSize > 100 && tc.wantCompressed { - compressionRatio := float64(storedSize) / float64(originalSize) - t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)", - compressionRatio, originalSize, storedSize) - - if compressionRatio > 0.9 { - t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)", - cookie.Name, compressionRatio) - } - } + // Save session + if err := session.Save(req, rr); err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + // Verify cookies are set and compression is used when appropriate + cookies := rr.Result().Cookies() + if len(cookies) != tc.expectedCookieCount { + t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies)) + } + + // Verify compression is working by checking token sizes + for _, cookie := range cookies { + if strings.Contains(cookie.Name, accessTokenCookie) { + // Get original and stored sizes + originalSize := len(tc.accessToken) + storedSize := len(cookie.Value) + + if originalSize > 100 && tc.wantCompressed { + // For large tokens, verify some compression occurred + compressionRatio := float64(storedSize) / float64(originalSize) + t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)", + compressionRatio, originalSize, storedSize) + + if compressionRatio > 0.9 { // Allow some overhead, but should see compression + t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)", + cookie.Name, compressionRatio) } } + } else if strings.Contains(cookie.Name, refreshTokenCookie) { + originalSize := len(tc.refreshToken) + storedSize := len(cookie.Value) - // Create a new request with the cookies - newReq := httptest.NewRequest("GET", "/test", nil) - for _, cookie := range cookies { - newReq.AddCookie(cookie) - } + if originalSize > 100 && tc.wantCompressed { + compressionRatio := float64(storedSize) / float64(originalSize) + t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)", + compressionRatio, originalSize, storedSize) - // Get the session again and verify values - newSession, err := ts.sessionManager.GetSession(newReq) - if err != nil { - t.Fatalf("Failed to get new session: %v", err) + if compressionRatio > 0.9 { + t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)", + cookie.Name, compressionRatio) + } } + } + } - // Verify session values - if newSession.GetAuthenticated() != tc.authenticated { - t.Errorf("Authentication status not preserved") - } - if email := newSession.GetEmail(); email != tc.email { - t.Errorf("Expected email %s, got %s", tc.email, email) - } - if token := newSession.GetAccessToken(); token != tc.accessToken { - t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken)) - } - if token := newSession.GetRefreshToken(); token != tc.refreshToken { - t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken)) - } + // Create a new request with the cookies + newReq := httptest.NewRequest("GET", "/test", nil) + for _, cookie := range cookies { + newReq.AddCookie(cookie) + } - // Verify session pooling by checking if the session is reused - session2, _ := ts.sessionManager.GetSession(newReq) - if session2 == newSession { - t.Error("Session not properly pooled") - } - }) + // Get the session again and verify values + newSession, err := ts.sessionManager.GetSession(newReq) + if err != nil { + t.Fatalf("Failed to get new session: %v", err) + } + + // Verify session values + if newSession.GetAuthenticated() != tc.authenticated { + t.Errorf("Authentication status not preserved") + } + if email := newSession.GetEmail(); email != tc.email { + t.Errorf("Expected email %s, got %s", tc.email, email) + } + if token := newSession.GetAccessToken(); token != tc.accessToken { + t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken)) + } + if token := newSession.GetRefreshToken(); token != tc.refreshToken { + t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken)) + } + + // Verify session pooling by checking if the session is reused + session2, _ := ts.sessionManager.GetSession(newReq) + if session2 == newSession { + t.Error("Session not properly pooled") + } + }) } } @@ -362,12 +362,12 @@ func calculateExpectedCookieCount(accessToken, refreshToken string) int { calculateChunks := func(token string) int { // Compress token (matching the actual implementation) compressed := compressToken(token) - + // If compressed token fits in one cookie, no additional chunks needed if len(compressed) <= maxCookieSize { return 0 } - + // Calculate chunks needed for compressed token return len(splitIntoChunks(compressed, maxCookieSize)) } diff --git a/settings.go b/settings.go index 700b1c7..2d5380a 100644 --- a/settings.go +++ b/settings.go @@ -221,6 +221,7 @@ type Logger struct { // - "debug": Outputs all messages (debug, info, error) // - "info": Outputs info and error messages // - "error": Outputs only error messages +// // Error messages are always written to stderr, while info and debug // messages are written to stdout when enabled. func NewLogger(logLevel string) *Logger { @@ -229,7 +230,7 @@ func NewLogger(logLevel string) *Logger { logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime) logError.SetOutput(os.Stderr) - + if logLevel == "debug" || logLevel == "info" { logInfo.SetOutput(os.Stdout) }