Compare commits

...

12 Commits

Author SHA1 Message Date
lukaszraczylo 667b4213fe Well.. that escalated quickly.
Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ).
2025-06-20 19:23:50 +01:00
lukaszraczylo 70443f0855 Add ability to overwrite the default scopes in the settings file 2025-06-20 11:31:29 +01:00
lukaszraczylo 7a443c626c Fix claims issue. 2025-06-20 09:47:22 +01:00
lukaszraczylo 48de8265c5 Multiple changes to improve performance and reduce complexity.
- Optimise the errors and recovery.
- Deduplicate code in metadata cache.
- Remove unused performance monitoring code.
- Simplify session management and settings handling.
2025-06-20 09:00:02 +01:00
lukaszraczylo d8d1b74175 Fieldalignment 2025-06-20 08:07:03 +01:00
lukaszraczylo c233aa92ef Modernize run 2025-06-20 08:02:26 +01:00
lukaszraczylo c400251625 Refactoring code to fix the issues identified by the users. 2025-06-19 10:10:54 +01:00
lukaszraczylo 48faf7fadf Additional fixes and cleanup 2025-06-18 01:09:14 +01:00
lukaszraczylo 84d7cd3d76 Improvements targetting possible memory usage spikes. 2025-06-18 00:50:12 +01:00
lukaszraczylo 488264028b Ensure that appended roles are unique. Update the documentation. 2025-06-18 00:19:26 +01:00
lukaszraczylo e23135ded0 Fixes issue #51 2025-06-18 00:04:10 +01:00
lukaszraczylo cd307f88a1 Fix bug affecting Azure OIDC authentication ( and most likely others ) 2025-06-17 20:21:36 +01:00
33 changed files with 5944 additions and 2239 deletions
+9 -8
View File
@@ -35,11 +35,8 @@ testData:
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
scopes: # OAuth 2.0 scopes to request (default: ["openid", "email", "profile"])
- openid
- email
- profile
- roles # Include this to get role information from the provider
scopes: # Additional scopes to append to defaults ["openid", "profile", "email"]
- roles # Result: ["openid", "profile", "email", "roles"]
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
- company.com
@@ -153,11 +150,15 @@ configuration:
scopes:
type: array
description: |
The OAuth 2.0 scopes to request from the OIDC provider.
Default: ["openid", "profile", "email"]
Additional OAuth 2.0 scopes to append to the default scopes.
Default scopes are always included: ["openid", "profile", "email"]
User-provided scopes are appended to defaults with automatic deduplication.
For example, specifying ["roles", "custom_scope"] results in:
["openid", "profile", "email", "roles", "custom_scope"]
Include "roles" or similar scope if you need role/group information.
Note: For Google OAuth, the middleware automatically handles the
Note: For Google OAuth, the middleware automatically handles the
proper authentication parameters and does NOT require the "offline_access"
scope (which Google rejects as invalid). See documentation for details.
required: false
+88 -42
View File
@@ -67,7 +67,8 @@ The middleware supports the following configuration options:
|-----------|-------------|---------|---------|
| `logoutURL` | The path for handling logout requests | `callbackURL + "/logout"` | `/oauth2/logout` |
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
| `scopes` | OAuth 2.0 scopes to use for authentication | `["openid", "profile", "email"]` (always included by default) | `["roles", "custom_scope"]` (appended to defaults) |
| `overrideScopes` | When true, replaces default scopes with provided scopes instead of appending | `false` | `true` (use only the scopes explicitly provided) |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
@@ -81,6 +82,79 @@ The middleware supports the following configuration options:
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
## Scope Configuration
### Scope Behavior
The middleware supports two modes for handling OAuth 2.0 scopes, controlled by the `overrideScopes` parameter:
#### Default Append Mode (`overrideScopes: false`)
By default, the middleware uses an **append** behavior for OAuth 2.0 scopes:
- **Default scopes** are always included: `["openid", "profile", "email"]`
- **User-provided scopes** are appended to the defaults with automatic deduplication
- The final scope list maintains the order: defaults first, then user scopes
#### Override Mode (`overrideScopes: true`)
When `overrideScopes` is set to `true`, the middleware uses **replacement** behavior:
- Default scopes are **not** automatically included
- Only the scopes explicitly provided in the `scopes` field are used
- You must include all required scopes explicitly, including `openid` if needed
### Examples:
**Default behavior (no custom scopes):**
```yaml
# No scopes field specified
# Result: ["openid", "profile", "email"]
```
**Default append behavior:**
```yaml
scopes:
- roles
- custom_scope
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
```
**Overlapping scopes with append (automatic deduplication):**
```yaml
scopes:
- openid # Duplicate - will be deduplicated
- roles
- profile # Duplicate - will be deduplicated
- permissions
# Result: ["openid", "profile", "email", "roles", "permissions"]
```
**Using override mode:**
```yaml
overrideScopes: true
scopes:
- openid
- profile
- custom_scope
# Result: ["openid", "profile", "custom_scope"]
```
**Empty scopes list with default behavior:**
```yaml
scopes: []
# Result: ["openid", "profile", "email"]
```
**Empty scopes list with override mode:**
```yaml
overrideScopes: true
scopes: []
# Result: [] (Warning: empty scopes may cause authentication to fail)
```
The default append behavior ensures essential OIDC scopes are always present, while the override mode gives you complete control over the exact scopes requested from the provider.
## Usage Examples
### Basic Configuration
@@ -101,9 +175,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
### With Excluded URLs (Public Access Paths)
@@ -124,9 +196,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
excludedURLs:
- /login # covers /login, /login/me, /login/reminder etc.
- /public-data
@@ -152,9 +222,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
allowedUserDomains:
- company.com
- subsidiary.com
@@ -178,9 +246,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
allowedUsers:
- user1@example.com
- user2@another.org
@@ -204,9 +270,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
allowedUserDomains:
- company.com
allowedUsers:
@@ -239,10 +303,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Include this to get role information from the provider
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
allowedRolesAndGroups:
- admin
- developer
@@ -269,9 +330,7 @@ spec:
rateLimit: 500 # Requests per second (default: 100)
forceHTTPS: false # Default is true for security
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
### With Custom Post-Logout Redirect
@@ -293,9 +352,7 @@ spec:
logoutURL: /oauth2/logout
postLogoutRedirectURI: /logged-out-page # Where to redirect after logout
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
### With Templated Headers
@@ -316,10 +373,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
headers:
- name: "X-User-Email"
value: "{{.Claims.email}}"
@@ -352,9 +406,7 @@ spec:
logoutURL: /oauth2/logout
enablePKCE: true # Enables PKCE for added security
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
### Google OIDC Configuration Example
@@ -377,9 +429,7 @@ spec:
callbackURL: /oauth2/callback # Adjust if needed
logoutURL: /oauth2/logout # Optional: Adjust if needed
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
# Note: DO NOT manually add offline_access scope for Google
# The middleware automatically handles Google-specific requirements
refreshGracePeriodSeconds: 300 # Optional: Start refresh 5 min before expiry (default 60)
@@ -408,9 +458,7 @@ spec:
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
Don't forget to create the secret:
@@ -509,9 +557,7 @@ http:
postLogoutRedirectURI: /logged-out-page
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
scopes:
- openid
- email
- profile
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
allowedUserDomains:
- company.com
allowedUsers:
+53
View File
@@ -2,6 +2,57 @@ package traefikoidc
import "time"
// BackgroundTask represents a recurring task that runs in the background
type BackgroundTask struct {
stopChan chan struct{}
taskFunc func()
logger *Logger
name string
interval time.Duration
}
// NewBackgroundTask creates a new background task
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger) *BackgroundTask {
return &BackgroundTask{
name: name,
interval: interval,
stopChan: make(chan struct{}),
taskFunc: taskFunc,
logger: logger,
}
}
// Start begins the background task execution
func (bt *BackgroundTask) Start() {
go bt.run()
}
// Stop terminates the background task
func (bt *BackgroundTask) Stop() {
close(bt.stopChan)
}
// run is the main execution loop for the background task
func (bt *BackgroundTask) run() {
ticker := time.NewTicker(bt.interval)
defer ticker.Stop()
bt.logger.Debug("Starting background task: %s", bt.name)
// Run task immediately on startup
bt.taskFunc()
for {
select {
case <-ticker.C:
bt.taskFunc()
case <-bt.stopChan:
bt.logger.Debug("Stopping background task: %s", bt.name)
return
}
}
}
// autoCleanupRoutine periodically calls the provided cleanup function.
// It starts a ticker with the given interval and executes the cleanup function
// on each tick. The routine stops gracefully when a signal is received on the
@@ -12,6 +63,8 @@ import "time"
// - interval: The time duration between cleanup calls.
// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop.
// - cleanup: The function to call periodically for cleanup tasks.
//
// Deprecated: Use BackgroundTask instead.
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
+387
View File
@@ -0,0 +1,387 @@
package traefikoidc
import (
"net/http/httptest"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
// mockTraefikOidc extends TraefikOidc to override JWT verification for testing
type mockTraefikOidc struct {
*TraefikOidc
}
// Override VerifyToken to avoid JWKS lookup in tests
func (m *mockTraefikOidc) VerifyToken(token string) error {
// Cache test claims to avoid "claims not found" errors
testClaims := map[string]interface{}{
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
}
m.tokenCache.Set(token, testClaims, time.Hour)
return nil // Always succeed for testing
}
// Override VerifyJWTSignatureAndClaims to avoid JWKS lookup in tests
func (m *mockTraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
// Cache test claims to avoid "claims not found" errors
testClaims := map[string]interface{}{
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
}
m.tokenCache.Set(token, testClaims, time.Hour)
return nil // Always succeed for testing
}
func TestAzureOIDCRegression(t *testing.T) {
// Create a mocked TraefikOidc instance configured for Azure AD
mockLogger := NewLogger("debug")
// Configure for Azure AD provider
baseOidc := &TraefikOidc{
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
authURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/authorize",
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
clientID: "test-client-id",
clientSecret: "test-client-secret",
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60 * time.Second,
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Add rate limiter
logger: mockLogger,
httpClient: createDefaultHTTPClient(), // Add HTTP client
jwkCache: &JWKCache{}, // Add JWK cache
tokenCache: NewTokenCache(),
tokenBlacklist: NewCache(),
allowedUserDomains: make(map[string]struct{}),
allowedUsers: make(map[string]struct{}),
allowedRolesAndGroups: make(map[string]struct{}),
excludedURLs: make(map[string]struct{}),
extractClaimsFunc: extractClaims,
}
// Create the mock wrapper
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
// Initialize session manager
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, mockLogger)
tOidc.sessionManager = sessionManager
// Mock the JWT verification to avoid JWKS lookup issues
tOidc.tokenVerifier = &mockTokenVerifier{
verifyFunc: func(token string) error {
// For test tokens, always return success and cache claims
if strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
// Cache test claims for JWT tokens
testClaims := map[string]interface{}{
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
}
tOidc.tokenCache.Set(token, testClaims, time.Hour)
return nil
}
// For opaque tokens (non-JWT format), return success
if !strings.Contains(token, ".") || strings.Count(token, ".") != 2 {
return nil
}
// For JWT tokens, cache basic claims to avoid cache lookup issues
testClaims := map[string]interface{}{
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
}
tOidc.tokenCache.Set(token, testClaims, time.Hour)
return nil // Always succeed for test purposes
},
}
// Mock JWT verifier to avoid JWKS lookup
tOidc.jwtVerifier = &mockJWTVerifier{
verifyFunc: func(jwt *JWT, token string) error {
// Also cache claims here to ensure they're available
testClaims := map[string]interface{}{
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
}
tOidc.tokenCache.Set(token, testClaims, time.Hour)
return nil // Always succeed
},
}
t.Run("Azure provider detection works correctly", func(t *testing.T) {
if !tOidc.isAzureProvider() {
t.Error("Azure provider should be detected for Azure AD issuer URL")
}
if tOidc.isGoogleProvider() {
t.Error("Google provider should not be detected for Azure AD issuer URL")
}
})
t.Run("Azure auth URL includes correct parameters", func(t *testing.T) {
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that response_mode=query was added for Azure
if !strings.Contains(authURL, "response_mode=query") {
t.Errorf("response_mode=query not added to Azure auth URL: %s", authURL)
}
// Verify offline_access scope is included for Azure providers
if !strings.Contains(authURL, "offline_access") {
t.Errorf("offline_access scope not included in Azure auth URL: %s", authURL)
}
// Verify Azure doesn't get Google-specific parameters
if strings.Contains(authURL, "access_type=offline") {
t.Errorf("access_type=offline incorrectly added to Azure auth URL: %s", authURL)
}
if strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent incorrectly added to Azure auth URL: %s", authURL)
}
})
t.Run("Azure access token validation takes priority", func(t *testing.T) {
// Create a request and session
req := httptest.NewRequest("GET", "/protected", nil)
session, _ := tOidc.sessionManager.GetSession(req)
// Set up session with Azure-style tokens
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
// Create a valid JWT access token for testing
accessTokenClaims := map[string]interface{}{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user123",
"email": "user@example.com",
}
accessToken, _ := createMockJWT(accessTokenClaims)
session.SetAccessToken(accessToken)
// Create an invalid/expired ID token
idTokenClaims := map[string]interface{}{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired
"iat": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "user123",
"email": "user@example.com",
}
idToken, _ := createMockJWT(idTokenClaims)
session.SetIDToken(idToken)
// Mock the token verification to simulate Azure behavior
originalTokenVerifier := tOidc.tokenVerifier
tOidc.tokenVerifier = &mockTokenVerifier{
verifyFunc: func(token string) error {
if token == accessToken {
// Access token validation succeeds - cache claims
testClaims := map[string]interface{}{
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
}
tOidc.tokenCache.Set(token, testClaims, time.Hour)
return nil
}
if token == idToken {
// ID token validation fails (expired) - don't cache
return newMockError("token has expired")
}
return newMockError("token validation failed")
},
}
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
// Test Azure-specific validation
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
// Azure should prioritize access token, so even with expired ID token,
// user should still be authenticated since access token is valid
if !authenticated {
t.Error("Azure user should be authenticated when access token is valid, even if ID token is expired")
}
if expired {
t.Error("Azure session should not be marked as expired when access token is valid")
}
// May need refresh if we want to get a fresh ID token
if !needsRefresh {
t.Log("Azure session may not need immediate refresh if access token is still valid")
}
})
t.Run("Azure handles opaque access tokens gracefully", func(t *testing.T) {
// Create a request and session
req := httptest.NewRequest("GET", "/protected", nil)
session, _ := tOidc.sessionManager.GetSession(req)
// Set up session with opaque access token (non-JWT)
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ValidAccessToken)
// Create a valid ID token for claims extraction
idTokenClaims := map[string]interface{}{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user123",
"email": "user@example.com",
}
idToken, _ := createMockJWT(idTokenClaims)
session.SetIDToken(idToken)
// Mock the token verification
originalTokenVerifier := tOidc.tokenVerifier
tOidc.tokenVerifier = &mockTokenVerifier{
verifyFunc: func(token string) error {
if token == idToken {
// ID token is valid - cache claims
testClaims := map[string]interface{}{
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
"sub": "test-user",
"email": "test@example.com",
}
tOidc.tokenCache.Set(token, testClaims, time.Hour)
return nil
}
return newMockError("token validation failed")
},
}
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
// Test Azure-specific validation with opaque token
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
// Azure should handle opaque access tokens gracefully
if !authenticated {
t.Error("Azure user should be authenticated with opaque access token")
}
if expired {
t.Error("Azure session should not be expired with valid tokens")
}
if needsRefresh {
t.Log("Azure session with opaque token may signal refresh to get JWT tokens")
}
})
t.Run("Azure CSRF handling during token validation failures", func(t *testing.T) {
// Create a request and session
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, _ := tOidc.sessionManager.GetSession(req)
// Set up session with CSRF token (simulating ongoing auth flow)
session.SetCSRF("test-csrf-token-123")
session.SetNonce("test-nonce-456")
session.SetAuthenticated(false) // Not yet authenticated
// Save session to simulate real scenario
session.Save(req, rw)
// Mock token verification to always fail (simulating Azure token issues)
originalTokenVerifier := tOidc.tokenVerifier
tOidc.tokenVerifier = &mockTokenVerifier{
verifyFunc: func(token string) error {
return newMockError("azure token validation failed")
},
}
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
// Test that CSRF is preserved during Azure validation failures
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
// Should not be authenticated due to validation failure
if authenticated {
t.Error("Should not be authenticated when token validation fails")
}
// Should be marked as expired since no tokens work
if !expired && !needsRefresh {
t.Error("Should be marked as needing refresh or expired when validation fails")
}
// Verify CSRF token is still preserved in session
if session.GetCSRF() != "test-csrf-token-123" {
t.Error("CSRF token should be preserved during Azure token validation failures")
}
if session.GetNonce() != "test-nonce-456" {
t.Error("Nonce should be preserved during Azure token validation failures")
}
})
}
// createMockJWT creates a basic JWT token for testing purposes
func createMockJWT(claims map[string]interface{}) (string, error) {
// Simple mock JWT - in real tests you'd use a proper JWT library
// For this test, we'll create a basic three-part token structure
header := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0" // {"alg":"RS256","kid":"test-key-id","typ":"JWT"}
// Create a simple payload with test claims
payload := "eyJpc3MiOiJ0ZXN0LWlzc3VlciIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjoxNjM4MzYwMDAwLCJpYXQiOjE2MzgzNTY0MDAsInN1YiI6InVzZXIxMjMiLCJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ" // Basic claims
signature := "test-signature"
return header + "." + payload + "." + signature, nil
}
// Mock error type for testing
type mockError struct {
message string
}
func (e *mockError) Error() string {
return e.message
}
func newMockError(message string) error {
return &mockError{message: message}
}
// Mock token verifier for testing
type mockTokenVerifier struct {
verifyFunc func(token string) error
}
func (m *mockTokenVerifier) VerifyToken(token string) error {
if m.verifyFunc != nil {
return m.verifyFunc(token)
}
return nil
}
// Mock JWT verifier for testing
type mockJWTVerifier struct {
verifyFunc func(jwt *JWT, token string) error
}
func (m *mockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
if m.verifyFunc != nil {
return m.verifyFunc(jwt, token)
}
return nil
}
+27 -26
View File
@@ -23,42 +23,40 @@ type lruEntry struct {
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
type Cache struct {
// items stores the cached data with string keys.
items map[string]CacheItem
// order maintains the usage order; most recently used items are at the back.
order *list.List
// elems maps keys to their corresponding list elements for O(1) access.
elems map[string]*list.Element
// mutex protects concurrent access to the cache.
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache.
maxSize int
// autoCleanupInterval defines how often Cleanup is called automatically.
items map[string]CacheItem
order *list.List
elems map[string]*list.Element
cleanupTask *BackgroundTask
logger *Logger
maxSize int
autoCleanupInterval time.Duration
// stopCleanup channel to terminate the auto cleanup goroutine.
stopCleanup chan struct{}
mutex sync.RWMutex
}
// DefaultMaxSize is the default maximum number of items in the cache.
const DefaultMaxSize = 500
// NewCache creates a new empty cache instance with default settings.
// It initializes the internal maps and list, sets the default maximum size,
// and starts the automatic cleanup goroutine.
// It initializes the internal maps and list and sets the default maximum size.
func NewCache() *Cache {
return NewCacheWithLogger(nil)
}
// NewCacheWithLogger creates a new cache with a specified logger
func NewCacheWithLogger(logger *Logger) *Cache {
if logger == nil {
logger = newNoOpLogger()
}
c := &Cache{
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
autoCleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
logger: logger,
}
go c.startAutoCleanup()
c.startAutoCleanup()
return c
}
@@ -214,15 +212,18 @@ func (c *Cache) removeItem(key string) {
}
}
// startAutoCleanup starts the background goroutine that automatically calls the Cleanup method
// startAutoCleanup starts the background task that automatically calls the Cleanup method
// at the interval specified by c.autoCleanupInterval.
// It uses the autoCleanupRoutine helper function.
func (c *Cache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
c.cleanupTask = NewBackgroundTask("cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
c.cleanupTask.Start()
}
// Close stops the automatic cleanup goroutine associated with this cache instance.
// Close stops the automatic cleanup task associated with this cache instance.
// It should be called when the cache is no longer needed to prevent resource leaks.
func (c *Cache) Close() {
close(c.stopCleanup)
if c.cleanupTask != nil {
c.cleanupTask.Stop()
c.cleanupTask = nil
}
}
+314 -84
View File
@@ -11,6 +11,122 @@ import (
"time"
)
// ErrorRecoveryMechanism defines the common interface for all error recovery mechanisms
type ErrorRecoveryMechanism interface {
// ExecuteWithContext executes a function with error recovery
ExecuteWithContext(ctx context.Context, fn func() error) error
// GetMetrics returns metrics about the error recovery mechanism
GetMetrics() map[string]interface{}
// Reset resets the state of the error recovery mechanism
Reset()
// IsAvailable returns whether the mechanism is available for use
IsAvailable() bool
}
// BaseRecoveryMechanism provides common functionality for error recovery mechanisms
type BaseRecoveryMechanism struct {
startTime time.Time
lastFailureTime time.Time
lastSuccessTime time.Time
logger *Logger
name string
totalRequests int64
totalFailures int64
totalSuccesses int64
mutex sync.RWMutex
}
// NewBaseRecoveryMechanism creates a new base recovery mechanism
func NewBaseRecoveryMechanism(name string, logger *Logger) *BaseRecoveryMechanism {
if logger == nil {
logger = newNoOpLogger()
}
return &BaseRecoveryMechanism{
name: name,
logger: logger,
startTime: time.Now(),
}
}
// RecordRequest records a request to the error recovery mechanism
func (b *BaseRecoveryMechanism) RecordRequest() {
atomic.AddInt64(&b.totalRequests, 1)
}
// RecordSuccess records a successful operation
func (b *BaseRecoveryMechanism) RecordSuccess() {
atomic.AddInt64(&b.totalSuccesses, 1)
b.mutex.Lock()
defer b.mutex.Unlock()
b.lastSuccessTime = time.Now()
}
// RecordFailure records a failed operation
func (b *BaseRecoveryMechanism) RecordFailure() {
atomic.AddInt64(&b.totalFailures, 1)
b.mutex.Lock()
defer b.mutex.Unlock()
b.lastFailureTime = time.Now()
}
// GetBaseMetrics returns base metrics common to all recovery mechanisms
func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
b.mutex.RLock()
defer b.mutex.RUnlock()
metrics := map[string]interface{}{
"total_requests": atomic.LoadInt64(&b.totalRequests),
"total_failures": atomic.LoadInt64(&b.totalFailures),
"total_successes": atomic.LoadInt64(&b.totalSuccesses),
"uptime_seconds": time.Since(b.startTime).Seconds(),
"name": b.name,
}
if !b.lastFailureTime.IsZero() {
metrics["last_failure_time"] = b.lastFailureTime.Format(time.RFC3339)
metrics["seconds_since_last_failure"] = time.Since(b.lastFailureTime).Seconds()
}
if !b.lastSuccessTime.IsZero() {
metrics["last_success_time"] = b.lastSuccessTime.Format(time.RFC3339)
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
}
// Calculate success rate
if metrics["total_requests"].(int64) > 0 {
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
metrics["success_rate"] = successRate
} else {
metrics["success_rate"] = 1.0 // Default to 100% if no requests
}
return metrics
}
// LogInfo logs an informational message
func (b *BaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
if b.logger != nil {
b.logger.Infof("%s: "+format, append([]interface{}{b.name}, args...)...)
}
}
// LogError logs an error message
func (b *BaseRecoveryMechanism) LogError(format string, args ...interface{}) {
if b.logger != nil {
b.logger.Errorf("%s: "+format, append([]interface{}{b.name}, args...)...)
}
}
// LogDebug logs a debug message
func (b *BaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
if b.logger != nil {
b.logger.Debugf("%s: "+format, append([]interface{}{b.name}, args...)...)
}
}
// CircuitBreakerState represents the current state of a circuit breaker
type CircuitBreakerState int
@@ -25,25 +141,12 @@ const (
// CircuitBreaker implements the circuit breaker pattern for external service calls
type CircuitBreaker struct {
// Configuration
maxFailures int // Maximum failures before opening
timeout time.Duration // How long to wait before trying again
resetTimeout time.Duration // How long to wait in half-open state
// State
state CircuitBreakerState
failures int64
lastFailureTime time.Time
lastSuccessTime time.Time
mutex sync.RWMutex
// Metrics
totalRequests int64
totalFailures int64
totalSuccesses int64
// Logger
logger *Logger
*BaseRecoveryMechanism
maxFailures int
timeout time.Duration
resetTimeout time.Duration
state CircuitBreakerState
failures int64
}
// CircuitBreakerConfig holds configuration for circuit breakers
@@ -65,17 +168,17 @@ func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
// NewCircuitBreaker creates a new circuit breaker with the given configuration
func NewCircuitBreaker(config CircuitBreakerConfig, logger *Logger) *CircuitBreaker {
return &CircuitBreaker{
maxFailures: config.MaxFailures,
timeout: config.Timeout,
resetTimeout: config.ResetTimeout,
state: CircuitBreakerClosed,
logger: logger,
BaseRecoveryMechanism: NewBaseRecoveryMechanism("circuit-breaker", logger),
maxFailures: config.MaxFailures,
timeout: config.Timeout,
resetTimeout: config.ResetTimeout,
state: CircuitBreakerClosed,
}
}
// Execute runs the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
atomic.AddInt64(&cb.totalRequests, 1)
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
cb.RecordRequest()
// Check if circuit breaker allows the request
if !cb.allowRequest() {
@@ -87,15 +190,20 @@ func (cb *CircuitBreaker) Execute(fn func() error) error {
// Record the result
if err != nil {
cb.recordFailure()
atomic.AddInt64(&cb.totalFailures, 1)
cb.RecordFailure()
return err
}
cb.recordSuccess()
atomic.AddInt64(&cb.totalSuccesses, 1)
cb.RecordSuccess()
return nil
}
// Execute is the original method for backward compatibility
func (cb *CircuitBreaker) Execute(fn func() error) error {
return cb.ExecuteWithContext(context.Background(), fn)
}
// allowRequest checks if the circuit breaker allows the request
func (cb *CircuitBreaker) allowRequest() bool {
cb.mutex.Lock()
@@ -131,19 +239,18 @@ func (cb *CircuitBreaker) recordFailure() {
defer cb.mutex.Unlock()
cb.failures++
cb.lastFailureTime = time.Now()
switch cb.state {
case CircuitBreakerClosed:
if cb.failures >= int64(cb.maxFailures) {
cb.state = CircuitBreakerOpen
cb.logger.Errorf("Circuit breaker opened after %d failures", cb.failures)
cb.LogError("Circuit breaker opened after %d failures", cb.failures)
}
case CircuitBreakerHalfOpen:
// Go back to open state on any failure in half-open
cb.state = CircuitBreakerOpen
cb.logger.Errorf("Circuit breaker returned to open state after failure in half-open")
cb.LogError("Circuit breaker returned to open state after failure in half-open")
}
}
@@ -152,14 +259,12 @@ func (cb *CircuitBreaker) recordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.lastSuccessTime = time.Now()
switch cb.state {
case CircuitBreakerHalfOpen:
// Reset failures and close circuit on success in half-open
cb.failures = 0
cb.state = CircuitBreakerClosed
cb.logger.Infof("Circuit breaker closed after successful request in half-open state")
cb.LogInfo("Circuit breaker closed after successful request in half-open state")
case CircuitBreakerClosed:
// Reset failure count on success
@@ -174,30 +279,58 @@ func (cb *CircuitBreaker) GetState() CircuitBreakerState {
return cb.state
}
// GetMetrics returns circuit breaker metrics
// Reset resets the circuit breaker to its initial state
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.state = CircuitBreakerClosed
atomic.StoreInt64(&cb.failures, 0)
cb.LogInfo("Circuit breaker has been reset")
}
// IsAvailable returns whether the circuit breaker is allowing requests
func (cb *CircuitBreaker) IsAvailable() bool {
return cb.allowRequest()
}
// GetMetrics returns metrics about the circuit breaker
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
state := cb.state
failures := cb.failures
cb.mutex.RUnlock()
return map[string]interface{}{
"state": cb.state,
"failures": cb.failures,
"total_requests": atomic.LoadInt64(&cb.totalRequests),
"total_failures": atomic.LoadInt64(&cb.totalFailures),
"total_successes": atomic.LoadInt64(&cb.totalSuccesses),
"last_failure": cb.lastFailureTime,
"last_success": cb.lastSuccessTime,
metrics := cb.GetBaseMetrics()
// Add circuit breaker specific metrics
stateStr := "unknown"
switch state {
case CircuitBreakerClosed:
stateStr = "closed"
case CircuitBreakerOpen:
stateStr = "open"
case CircuitBreakerHalfOpen:
stateStr = "half-open"
}
metrics["state"] = stateStr
metrics["max_failures"] = cb.maxFailures
metrics["current_failures"] = failures
metrics["timeout_ms"] = cb.timeout.Milliseconds()
metrics["reset_timeout_ms"] = cb.resetTimeout.Milliseconds()
return metrics
}
// RetryConfig holds configuration for retry mechanisms
type RetryConfig struct {
RetryableErrors []string `json:"retryable_errors"`
MaxAttempts int `json:"max_attempts"`
InitialDelay time.Duration `json:"initial_delay"`
MaxDelay time.Duration `json:"max_delay"`
BackoffFactor float64 `json:"backoff_factor"`
EnableJitter bool `json:"enable_jitter"`
RetryableErrors []string `json:"retryable_errors"`
}
// DefaultRetryConfig returns default retry configuration
@@ -219,20 +352,21 @@ func DefaultRetryConfig() RetryConfig {
// RetryExecutor implements retry logic with exponential backoff
type RetryExecutor struct {
*BaseRecoveryMechanism
config RetryConfig
logger *Logger
}
// NewRetryExecutor creates a new retry executor
func NewRetryExecutor(config RetryConfig, logger *Logger) *RetryExecutor {
return &RetryExecutor{
config: config,
logger: logger,
BaseRecoveryMechanism: NewBaseRecoveryMechanism("retry-executor", logger),
config: config,
}
}
// Execute runs the given function with retry logic
func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error) error {
re.RecordRequest()
var lastErr error
for attempt := 1; attempt <= re.config.MaxAttempts; attempt++ {
@@ -240,8 +374,9 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
err := fn()
if err == nil {
if attempt > 1 {
re.logger.Infof("Operation succeeded on attempt %d", attempt)
re.LogInfo("Operation succeeded on attempt %d", attempt)
}
re.RecordSuccess()
return nil
}
@@ -249,30 +384,39 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
// Check if error is retryable
if !re.isRetryableError(err) {
re.logger.Debugf("Non-retryable error on attempt %d: %v", attempt, err)
re.LogDebug("Non-retryable error on attempt %d: %v", attempt, err)
re.RecordFailure()
return err
}
// Don't wait after the last attempt
if attempt == re.config.MaxAttempts {
re.RecordFailure()
break
}
// Calculate delay with exponential backoff
delay := re.calculateDelay(attempt)
re.logger.Debugf("Retrying operation after %v (attempt %d/%d): %v",
re.LogDebug("Retrying operation after %v (attempt %d/%d): %v",
delay, attempt, re.config.MaxAttempts, err)
// Wait with context cancellation support
select {
case <-ctx.Done():
re.RecordFailure()
return ctx.Err()
case <-time.After(delay):
// Continue to next attempt
}
}
return fmt.Errorf("operation failed after %d attempts: %w", re.config.MaxAttempts, lastErr)
finalErr := fmt.Errorf("operation failed after %d attempts: %w", re.config.MaxAttempts, lastErr)
return finalErr
}
// Execute runs the given function with retry logic (for backward compatibility)
func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
return re.ExecuteWithContext(ctx, fn)
}
// isRetryableError checks if an error should trigger a retry
@@ -341,10 +485,36 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
return time.Duration(delay)
}
// Reset resets the retry executor state
func (re *RetryExecutor) Reset() {
// Nothing to reset for RetryExecutor
re.LogDebug("Retry executor reset")
}
// IsAvailable always returns true for RetryExecutor
func (re *RetryExecutor) IsAvailable() bool {
return true
}
// GetMetrics returns metrics about the retry executor
func (re *RetryExecutor) GetMetrics() map[string]interface{} {
metrics := re.GetBaseMetrics()
// Add retry executor specific metrics
metrics["max_attempts"] = re.config.MaxAttempts
metrics["initial_delay_ms"] = re.config.InitialDelay.Milliseconds()
metrics["max_delay_ms"] = re.config.MaxDelay.Milliseconds()
metrics["backoff_factor"] = re.config.BackoffFactor
metrics["enable_jitter"] = re.config.EnableJitter
metrics["retryable_errors"] = re.config.RetryableErrors
return metrics
}
// HTTPError represents an HTTP error with status code
type HTTPError struct {
StatusCode int
Message string
StatusCode int
}
// Error implements the error interface
@@ -354,20 +524,12 @@ func (e *HTTPError) Error() string {
// GracefulDegradation implements graceful degradation patterns
type GracefulDegradation struct {
// Fallback functions for different operations
fallbacks map[string]func() (interface{}, error)
// Health checks for dependencies
healthChecks map[string]func() bool
// Configuration
config GracefulDegradationConfig
// State tracking
*BaseRecoveryMechanism
fallbacks map[string]func() (interface{}, error)
healthChecks map[string]func() bool
degradedServices map[string]time.Time
config GracefulDegradationConfig
mutex sync.RWMutex
logger *Logger
}
// GracefulDegradationConfig holds configuration for graceful degradation
@@ -389,11 +551,11 @@ func DefaultGracefulDegradationConfig() GracefulDegradationConfig {
// NewGracefulDegradation creates a new graceful degradation manager
func NewGracefulDegradation(config GracefulDegradationConfig, logger *Logger) *GracefulDegradation {
gd := &GracefulDegradation{
fallbacks: make(map[string]func() (interface{}, error)),
healthChecks: make(map[string]func() bool),
degradedServices: make(map[string]time.Time),
config: config,
logger: logger,
BaseRecoveryMechanism: NewBaseRecoveryMechanism("graceful-degradation", logger),
fallbacks: make(map[string]func() (interface{}, error)),
healthChecks: make(map[string]func() bool),
degradedServices: make(map[string]time.Time),
config: config,
}
// Start health check routine
@@ -416,10 +578,29 @@ func (gd *GracefulDegradation) RegisterHealthCheck(serviceName string, healthChe
gd.healthChecks[serviceName] = healthCheck
}
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
func (gd *GracefulDegradation) ExecuteWithContext(ctx context.Context, fn func() error) error {
gd.RecordRequest()
// Execute with a simple wrapper
_, err := gd.ExecuteWithFallback("default", func() (interface{}, error) {
return nil, fn()
})
if err != nil {
gd.RecordFailure()
} else {
gd.RecordSuccess()
}
return err
}
// ExecuteWithFallback executes a function with fallback support
func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary func() (interface{}, error)) (interface{}, error) {
// Check if service is degraded
if gd.isServiceDegraded(serviceName) {
gd.LogInfo("Service %s is degraded, using fallback", serviceName)
return gd.executeFallback(serviceName)
}
@@ -428,9 +609,11 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
if err != nil {
// Mark service as degraded
gd.markServiceDegraded(serviceName)
gd.LogError("Service %s failed: %v", serviceName, err)
// Try fallback if available
if gd.config.EnableFallbacks {
gd.LogInfo("Using fallback for service %s", serviceName)
return gd.executeFallback(serviceName)
}
@@ -465,7 +648,7 @@ func (gd *GracefulDegradation) markServiceDegraded(serviceName string) {
defer gd.mutex.Unlock()
if _, exists := gd.degradedServices[serviceName]; !exists {
gd.logger.Errorf("Service %s marked as degraded", serviceName)
gd.LogError("Service %s marked as degraded", serviceName)
}
gd.degradedServices[serviceName] = time.Now()
@@ -481,26 +664,27 @@ func (gd *GracefulDegradation) executeFallback(serviceName string) (interface{},
return nil, fmt.Errorf("no fallback available for service %s", serviceName)
}
gd.logger.Infof("Executing fallback for degraded service %s", serviceName)
gd.LogInfo("Executing fallback for degraded service %s", serviceName)
return fallback()
}
// startHealthCheckRoutine starts the background health check routine
func (gd *GracefulDegradation) startHealthCheckRoutine() {
ticker := time.NewTicker(gd.config.HealthCheckInterval)
defer ticker.Stop()
for range ticker.C {
gd.performHealthChecks()
}
healthCheckTask := NewBackgroundTask(
"graceful-degradation-health-check",
gd.config.HealthCheckInterval,
gd.performHealthChecks,
gd.BaseRecoveryMechanism.logger,
)
healthCheckTask.Start()
}
// performHealthChecks runs health checks for all registered services
func (gd *GracefulDegradation) performHealthChecks() {
gd.mutex.RLock()
healthChecks := make(map[string]func() bool)
for name, check := range gd.healthChecks {
healthChecks[name] = check
for k, v := range gd.healthChecks {
healthChecks[k] = v
}
gd.mutex.RUnlock()
@@ -533,13 +717,59 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
return degraded
}
// Reset resets the state of all degraded services
func (gd *GracefulDegradation) Reset() {
gd.mutex.Lock()
defer gd.mutex.Unlock()
// Clear degraded services
gd.degradedServices = make(map[string]time.Time)
gd.LogInfo("Graceful degradation state has been reset")
}
// IsAvailable returns whether the mechanism is available for use
func (gd *GracefulDegradation) IsAvailable() bool {
return true
}
// GetMetrics returns metrics about the graceful degradation mechanism
func (gd *GracefulDegradation) GetMetrics() map[string]interface{} {
gd.mutex.RLock()
degradedCount := len(gd.degradedServices)
// Get the names of degraded services
degradedServices := make([]string, 0, degradedCount)
for service := range gd.degradedServices {
degradedServices = append(degradedServices, service)
}
// Get total count of registered fallbacks and health checks
fallbackCount := len(gd.fallbacks)
healthCheckCount := len(gd.healthChecks)
gd.mutex.RUnlock()
// Get base metrics
metrics := gd.GetBaseMetrics()
// Add graceful degradation specific metrics
metrics["degraded_services_count"] = degradedCount
metrics["degraded_services"] = degradedServices
metrics["registered_fallbacks_count"] = fallbackCount
metrics["registered_health_checks_count"] = healthCheckCount
metrics["health_check_interval_seconds"] = gd.config.HealthCheckInterval.Seconds()
metrics["recovery_timeout_seconds"] = gd.config.RecoveryTimeout.Seconds()
metrics["fallbacks_enabled"] = gd.config.EnableFallbacks
return metrics
}
// ErrorRecoveryManager coordinates all error recovery mechanisms
type ErrorRecoveryManager struct {
circuitBreakers map[string]*CircuitBreaker
retryExecutor *RetryExecutor
gracefulDegradation *GracefulDegradation
mutex sync.RWMutex
logger *Logger
mutex sync.RWMutex
}
// NewErrorRecoveryManager creates a new error recovery manager
+2 -2
View File
@@ -256,8 +256,8 @@ func TestGracefulDegradation(t *testing.T) {
t.Run("Get degraded services", func(t *testing.T) {
degraded := gd.GetDegradedServices()
found := false
for _, service := range degraded {
if service == "failing-service" {
for _, s := range degraded {
if s == "failing-service" {
found = true
break
}
+19 -11
View File
@@ -94,7 +94,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
session, _ := sessionManager.GetSession(req)
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetAccessToken("old-access-token")
session.SetAccessToken(ValidAccessToken)
session.SetRefreshToken("valid-refresh-token")
// Create a mock token exchanger that simulates Google's behavior
@@ -106,11 +106,15 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
return nil, fmt.Errorf("invalid token")
}
// Use standardized test tokens instead of ad-hoc strings
testTokens := NewTestTokens()
googleTokens := testTokens.GetGoogleTokenSet()
// Return a simulated Google token response with a new access token
// but without a new refresh token (Google doesn't always return a new refresh token)
return &TokenResponse{
IDToken: "new-id-token-from-google",
AccessToken: "new-access-token-from-google",
IDToken: googleTokens.IDToken,
AccessToken: googleTokens.AccessToken,
RefreshToken: "", // Google often doesn't return a new refresh token
ExpiresIn: 3600,
}, nil
@@ -149,15 +153,19 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
session.GetRefreshToken())
}
// Use the same test tokens for validation
testTokens := NewTestTokens()
expectedTokens := testTokens.GetGoogleTokenSet()
// Check that the tokens were updated correctly
if session.GetIDToken() != "new-id-token-from-google" {
t.Errorf("ID token not updated: got %s, expected 'new-id-token-from-google'",
session.GetIDToken())
if session.GetIDToken() != expectedTokens.IDToken {
t.Errorf("ID token not updated: got %s, expected %s",
session.GetIDToken(), expectedTokens.IDToken)
}
if session.GetAccessToken() != "new-access-token-from-google" {
t.Errorf("Access token not updated: got %s, expected 'new-access-token-from-google'",
session.GetAccessToken())
if session.GetAccessToken() != expectedTokens.AccessToken {
t.Errorf("Access token not updated: got %s, expected %s",
session.GetAccessToken(), expectedTokens.AccessToken)
}
})
// Test that our fix specifically addresses the reported Google error
@@ -296,8 +304,8 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
expectedScopes := []string{"openid", "profile", "email"}
for _, expectedScope := range expectedScopes {
found := false
for _, actualScope := range scopeList {
if actualScope == expectedScope {
for _, s := range scopeList {
if s == expectedScope {
found = true
break
}
+16 -24
View File
@@ -72,20 +72,11 @@ func deriveCodeChallenge(codeVerifier string) string {
// It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations.
type TokenResponse struct {
// IDToken is the OIDC ID token containing user claims
IDToken string `json:"id_token"`
// AccessToken is the OAuth 2.0 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the lifetime in seconds of the access token
ExpiresIn int `json:"expires_in"`
// TokenType is the type of token, typically "Bearer"
TokenType string `json:"token_type"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
@@ -123,17 +114,13 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
data.Set("refresh_token", codeOrToken)
}
// Use the reusable token HTTP client, fallback to creating one if not initialized
client := t.tokenHTTPClient
if client == nil {
// Fallback for tests or incomplete initialization - create a temporary client
// with the same behavior as the original implementation
jar, _ := cookiejar.New(nil)
client = &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
@@ -223,15 +210,21 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
type TokenCache struct {
// cache is the underlying cache implementation
cache *Cache
}
const (
defaultTokenCacheMaxSize = 1000
defaultTokenCacheCleanupInterval = 2 * time.Minute
)
// NewTokenCache creates and initializes a new TokenCache.
// It internally creates a new generic Cache instance for storage.
func NewTokenCache() *TokenCache {
cache := NewCache()
cache.SetMaxSize(defaultTokenCacheMaxSize)
return &TokenCache{
cache: NewCache(),
cache: cache,
}
}
@@ -303,7 +296,6 @@ func (tc *TokenCache) Close() {
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
ctx := context.Background()
// Only include code verifier if PKCE is enabled
effectiveCodeVerifier := ""
if t.enablePKCE && codeVerifier != "" {
effectiveCodeVerifier = codeVerifier
@@ -352,7 +344,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
accessToken := session.GetAccessToken()
idToken := session.GetIDToken()
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
@@ -371,8 +363,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
+15 -21
View File
@@ -11,35 +11,29 @@ import (
// InputValidator provides comprehensive input validation and sanitization
type InputValidator struct {
// Configuration
maxTokenLength int
maxURLLength int
maxHeaderLength int
maxClaimLength int
maxEmailLength int
maxUsernameLength int
// Compiled regex patterns
emailRegex *regexp.Regexp
urlRegex *regexp.Regexp
tokenRegex *regexp.Regexp
usernameRegex *regexp.Regexp
// Security patterns to detect
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
xssPatterns []string
pathTraversalPatterns []string
logger *Logger
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
}
// ValidationResult represents the result of input validation
type ValidationResult struct {
IsValid bool `json:"is_valid"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
SanitizedValue string `json:"sanitized_value,omitempty"`
SecurityRisk string `json:"security_risk,omitempty"`
Errors []string `json:"errors,omitempty"`
Warnings []string `json:"warnings,omitempty"`
IsValid bool `json:"is_valid"`
}
// InputValidationConfig holds configuration for input validation
+1 -1
View File
@@ -204,8 +204,8 @@ func TestSanitizeInput(t *testing.T) {
tests := []struct {
name string
input string
maxLen int
expected string
maxLen int
}{
{
name: "Normal text",
+5 -6
View File
@@ -33,13 +33,12 @@ type JWKSet struct {
}
type JWKCache struct {
jwks *JWKSet
expiresAt time.Time
mutex sync.RWMutex
// CacheLifetime is configurable to determine how long the JWKS is cached.
expiresAt time.Time
jwks *JWKSet
internalCache *Cache
CacheLifetime time.Duration
internalCache *Cache // To hold the closable Cache instance from cache.go
maxSize int // Maximum number of items in the cache
maxSize int
mutex sync.RWMutex
}
type JWKCacheInterface interface {
+71 -38
View File
@@ -1,6 +1,7 @@
package traefikoidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
@@ -16,37 +17,80 @@ import (
)
var (
replayCacheMu sync.Mutex
replayCache *Cache // Replace unbounded map with bounded Cache
replayCacheMu sync.RWMutex // Use RWMutex for better read performance
replayCache *Cache // Replace unbounded map with bounded Cache
replayCacheOnce sync.Once
)
// initReplayCache initializes the global replay cache with size limit
func initReplayCache() {
if replayCache == nil {
replayCacheOnce.Do(func() {
replayCache = NewCache()
replayCache.SetMaxSize(10000) // Set size limit to 10,000 entries
replayCache.SetMaxSize(10000)
})
}
func cleanupReplayCache() {
replayCacheMu.Lock()
defer replayCacheMu.Unlock()
if replayCache != nil {
replayCache.Close()
replayCache = nil
}
}
// STABILITY FIX: Standardize clock skew tolerance usage
// ClockSkewToleranceFuture defines the tolerance for future-based claims like 'exp'.
// Allows for more leniency with expiration checks.
func getReplayCacheStats() (size int, maxSize int) {
replayCacheMu.RLock()
defer replayCacheMu.RUnlock()
if replayCache == nil {
return 0, 0
}
return 0, 10000
}
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
size, maxSize := getReplayCacheStats()
if logger != nil {
logger.Debugf("Replay cache stats: size=%d, maxSize=%d", size, maxSize)
}
replayCacheMu.RLock()
if replayCache != nil {
}
replayCacheMu.RUnlock()
case <-ctx.Done():
cleanupReplayCache()
if logger != nil {
logger.Debug("Replay cache cleanup goroutine stopped due to context cancellation")
}
return
}
}
}()
}
var ClockSkewToleranceFuture = 2 * time.Minute
// ClockSkewTolerancePast defines the tolerance for past-based claims like 'iat' and 'nbf'.
// A smaller tolerance is typically used here to prevent accepting tokens issued too far in the future.
var ClockSkewTolerancePast = 10 * time.Second
// ClockSkewTolerance is deprecated - use ClockSkewToleranceFuture or ClockSkewTolerancePast
// STABILITY FIX: Remove inconsistent usage
var ClockSkewTolerance = ClockSkewToleranceFuture
// JWT represents a JSON Web Token as defined in RFC 7519.
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
Signature []byte
Token string
Signature []byte
}
// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature.
@@ -75,12 +119,10 @@ func parseJWT(tokenString string) (*JWT, error) {
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
}
// STABILITY FIX: Add comprehensive JSON error handling with panic protection
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
}
// Validate header structure
if jwt.Header == nil {
return nil, fmt.Errorf("invalid JWT format: header is nil after unmarshaling")
}
@@ -90,12 +132,10 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
}
// STABILITY FIX: Add comprehensive JSON error handling with panic protection
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Validate claims structure
if jwt.Claims == nil {
return nil, fmt.Errorf("invalid JWT format: claims is nil after unmarshaling")
}
@@ -183,31 +223,23 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
}
}
// Implement replay protection by checking the jti (JWT ID)
// Skip replay check if explicitly requested (for revalidation scenarios)
shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0]
if jti, ok := claims["jti"].(string); ok && !shouldSkipReplay {
// Skip replay detection for tokens that are being verified from the cache
if j.Token == "" {
// This is a parsed JWT without the original token string,
// which means it's likely from a cached token verification
return nil
}
// SECURITY FIX: Use bounded Cache with thread-safe operations
replayCacheMu.Lock()
defer replayCacheMu.Unlock()
// Initialize cache if not already done
initReplayCache()
// SECURITY FIX: Check for replay attack using Cache API
if _, exists := replayCache.Get(jti); exists {
return fmt.Errorf("token replay detected")
replayCacheMu.RLock()
_, exists := replayCache.Get(jti)
replayCacheMu.RUnlock()
if exists {
return fmt.Errorf("token replay detected (jti: %s)", jti)
}
// Calculate expiration time
expFloat, ok := claims["exp"].(float64)
var expTime time.Time
if ok {
@@ -216,10 +248,13 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
expTime = time.Now().Add(10 * time.Minute)
}
// SECURITY FIX: Add to replay cache with expiration using Cache API
duration := time.Until(expTime)
if duration > 0 {
replayCache.Set(jti, true, duration)
replayCacheMu.Lock()
if replayCache != nil {
replayCache.Set(jti, true, duration)
}
replayCacheMu.Unlock()
}
}
@@ -293,17 +328,15 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
// - An error describing the failure (e.g., "token has expired", "token used before issued").
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
claimTime := time.Unix(int64(unixTime), 0)
now := time.Now() // Use current time without truncation
now := time.Now()
var err error
if future { // 'exp' check
// Token is expired if Now is after (ClaimTime + FutureTolerance)
if future {
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
if now.After(allowedExpiry) {
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
}
} else { // 'iat' or 'nbf' check
// Token is invalid if Now is before (ClaimTime - PastTolerance)
} else {
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
if now.Before(allowedStart) {
reason := "not yet valid"
+617 -299
View File
File diff suppressed because it is too large Load Diff
+507 -52
View File
@@ -30,8 +30,8 @@ type TestSuite struct {
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
token string
sessionManager *SessionManager
token string
}
// Setup initializes the test suite
@@ -410,15 +410,15 @@ func TestServeHTTP(t *testing.T) {
}
tests := []struct {
name string
requestPath string
sessionValues map[interface{}]interface{}
expectedStatus int
expectedBody string
setupSession func(*SessionData)
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks
requestHeaders map[string]string // Added for setting headers like Accept
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager)
requestHeaders map[string]string
name string
requestPath string
expectedBody string
expectedStatus int
}{
{
name: "Excluded URL",
@@ -503,7 +503,13 @@ func TestServeHTTP(t *testing.T) {
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken()) // Set expired token
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken) // Set expired token
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
@@ -572,7 +578,13 @@ func TestServeHTTP(t *testing.T) {
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken()) // Expired access token
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
@@ -594,7 +606,13 @@ func TestServeHTTP(t *testing.T) {
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken()) // Expired access token
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
@@ -855,10 +873,10 @@ func TestJWKToPEM(t *testing.T) {
ts.Setup()
tests := []struct {
name string
jwk *JWK
expectError bool
name string
errorContains string
expectError bool
}{
{
name: "Unsupported Key Type",
@@ -910,8 +928,8 @@ func TestParseJWT(t *testing.T) {
tests := []struct {
name string
token string
expectError bool
errorContains string
expectError bool
}{
{
name: "Invalid Format",
@@ -971,11 +989,11 @@ func TestHandleCallback(t *testing.T) {
redirectURL := "http://example.com/"
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(*SessionData)
name string
queryParams string
expectedStatus int
}{
{
@@ -1141,7 +1159,7 @@ func TestHandleCallback(t *testing.T) {
}
for _, tc := range tests {
tc := tc // Capture range variable
// Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Clear the global replay cache before each test run
replayCacheMu.Lock()
@@ -1238,12 +1256,12 @@ func TestIsAllowedDomain(t *testing.T) {
ts.Setup()
tests := []struct {
name string
email string
allowedDomains map[string]struct{}
allowedUsers map[string]struct{}
name string
email string
expectedLogOutput string
allowed bool
expectedLogOutput string // For testing log messages
}{
{
name: "Allowed domain",
@@ -1325,11 +1343,11 @@ func TestOIDCHandler(t *testing.T) {
ts.token = "valid.jwt.token"
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
name string
queryParams string
expectedStatus int
blacklist bool
rateLimit bool
@@ -1433,7 +1451,7 @@ func TestOIDCHandler(t *testing.T) {
}
for _, tc := range tests {
tc := tc // Capture range variable
// Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
@@ -1486,31 +1504,33 @@ func TestHandleLogout(t *testing.T) {
defer mockRevocationServer.Close()
tests := []struct {
name string
setupSession func(*SessionData)
name string
endSessionURL string
expectedStatus int
expectedURL string
host string
expectedStatus int
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
session.SetAccessToken(ValidAccessToken)
session.SetIDToken(ValidIDToken)
session.SetRefreshToken(ValidRefreshToken)
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectedURL: "https://provider/end-session?id_token_hint=" + url.QueryEscape(ValidIDToken) + "&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
host: "test-host",
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
session.SetAccessToken(ValidAccessToken)
session.SetIDToken(ValidIDToken)
session.SetRefreshToken(ValidRefreshToken)
},
endSessionURL: "",
expectedStatus: http.StatusFound,
@@ -1528,8 +1548,9 @@ func TestHandleLogout(t *testing.T) {
name: "Logout with invalid end session URL",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
session.SetAccessToken(ValidAccessToken)
session.SetIDToken(ValidIDToken)
session.SetRefreshToken(ValidRefreshToken)
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
@@ -1811,7 +1832,13 @@ func TestHandleExpiredToken(t *testing.T) {
name: "Basic expired token",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.SetEmail("test@example.com")
},
expectedPath: "/original/path",
@@ -1820,7 +1847,13 @@ func TestHandleExpiredToken(t *testing.T) {
name: "Session with additional values",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.mainSession.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
@@ -2071,12 +2104,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
tests := []struct {
name string
allowedRolesAndGroups map[string]struct{}
claims map[string]interface{}
setupSession func(*SessionData)
expectedStatus int
expectedHeaders map[string]string
name string
expectedStatus int
}{
{
name: "User with allowed role",
@@ -2280,10 +2313,10 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
ts.Setup()
tests := []struct {
name string
setupServer func() *httptest.Server
expectError bool
name string
errorContains string
expectError bool
}{
{
name: "Successful token exchange with redirects",
@@ -2307,7 +2340,7 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := 0; i < 3; i++ {
for i := range 3 {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
@@ -2391,9 +2424,9 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL string
state string
nonce string
enablePKCE bool
codeChallenge string
expectedPrefix string
enablePKCE bool
checkPKCE bool
}{
{
@@ -2541,10 +2574,10 @@ func TestExchangeCodeForToken(t *testing.T) {
ts.Setup()
tests := []struct {
name string
enablePKCE bool
codeVerifier string
setupMock func(t *testing.T) *httptest.Server
name string
codeVerifier string
enablePKCE bool
}{
{
name: "With PKCE Enabled and Code Verifier",
@@ -2850,10 +2883,10 @@ func TestJWTVerifyWithSkipReplayCheck(t *testing.T) {
tests := []struct {
name string
errorContains string
skipReplayCheck bool
firstCall bool
expectError bool
errorContains string
}{
{
name: "First verification with skipReplayCheck=false should succeed",
@@ -3083,7 +3116,7 @@ func TestAuthenticationFlowReplayDetection(t *testing.T) {
// Step 2: Subsequent requests (simulate normal request processing)
// These should use the token cache and skip replay detection
for i := 0; i < 3; i++ {
for i := range 3 {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Subsequent request %d should succeed: %v", i+1, err)
@@ -3204,7 +3237,7 @@ func TestConcurrentTokenValidation(t *testing.T) {
iat := now.Unix()
nbf := now.Unix()
for i := 0; i < 10; i++ {
for i := range 10 {
jti := generateRandomString(16)
jtis = append(jtis, jti)
@@ -3231,9 +3264,9 @@ func TestConcurrentTokenValidation(t *testing.T) {
results := make(chan error, numGoroutines*numIterations)
for g := 0; g < numGoroutines; g++ {
for g := range numGoroutines {
go func(goroutineID int) {
for i := 0; i < numIterations; i++ {
for i := range numIterations {
tokenIndex := (goroutineID + i) % len(tokens)
token := tokens[tokenIndex]
@@ -3250,7 +3283,7 @@ func TestConcurrentTokenValidation(t *testing.T) {
// Collect results
var errors []error
for i := 0; i < numGoroutines*numIterations*2; i++ {
for range numGoroutines * numIterations * 2 {
if err := <-results; err != nil {
errors = append(errors, err)
}
@@ -3306,10 +3339,10 @@ func TestJTIBlacklistBehavior(t *testing.T) {
// Test JTI blacklist behavior
tests := []struct {
name string
action func() error
expectError bool
name string
description string
expectError bool
}{
{
name: "Initial verification adds JTI to blacklist",
@@ -3415,7 +3448,7 @@ func TestSessionBasedTokenRevalidation(t *testing.T) {
// Step 2: Multiple session-based requests (normal request processing)
// These should not trigger replay detection false positives
for i := 0; i < 5; i++ {
for i := range 5 {
err = ts.tOidc.VerifyToken(token)
if err != nil {
t.Errorf("Session request %d should succeed: %v", i+1, err)
@@ -3462,9 +3495,9 @@ func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
nbf := now.Unix()
tests := []struct {
claims map[string]interface{}
name string
tokenType string
claims map[string]interface{}
expectError bool
}{
{
@@ -3564,3 +3597,425 @@ func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
})
}
}
// TestScopeMerging tests the scope append functionality
func TestScopeMerging(t *testing.T) {
// Helper function to compare string slices
equalSlices := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
tests := []struct {
name string
defaultScopes []string
userScopes []string
expectedScopes []string
}{
{
name: "Empty user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{},
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "Nil user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: nil,
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "New scopes are appended",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"custom_scope", "another_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Deduplication - user scope already in defaults",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"openid", "custom_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope"},
},
{
name: "Duplicate user scopes are removed",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"custom_scope", "custom_scope", "another_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Multiple overlapping scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"profile", "custom_scope", "email", "another_scope", "profile"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Only custom scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"read:users", "write:users", "admin"},
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"},
},
{
name: "Empty defaults",
defaultScopes: []string{},
userScopes: []string{"custom1", "custom2"},
expectedScopes: []string{"custom1", "custom2"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Test the mergeScopes function directly
result := mergeScopes(tc.defaultScopes, tc.userScopes)
if !equalSlices(result, tc.expectedScopes) {
t.Errorf("Expected %v, got %v", tc.expectedScopes, result)
}
})
}
}
// TestScopeMergingEdgeCases tests additional edge cases for scope deduplication
func TestScopeMergingEdgeCases(t *testing.T) {
// Helper function to compare string slices
equalSlices := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
tests := []struct {
name string
description string
defaultScopes []string
userScopes []string
expectedScopes []string
}{
{
name: "Case sensitivity preserved",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"OpenID", "PROFILE", "custom"},
expectedScopes: []string{"openid", "profile", "email", "OpenID", "PROFILE", "custom"},
description: "OAuth scopes are case-sensitive, so different cases should be preserved",
},
{
name: "Empty strings in user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"", "custom", "", "another"},
expectedScopes: []string{"openid", "profile", "email", "", "custom", "another"},
description: "Empty strings should be preserved (though invalid in OAuth)",
},
{
name: "Whitespace scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{" ", "custom", " ", "another"},
expectedScopes: []string{"openid", "profile", "email", " ", "custom", " ", "another"},
description: "Whitespace-only scopes should be preserved as distinct",
},
{
name: "Large number of scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: generateLargeUserScopes(),
expectedScopes: func() []string {
// Manually calculate expected result with proper deduplication
defaults := []string{"openid", "profile", "email"}
userScopes := generateLargeUserScopes()
return mergeScopes(defaults, userScopes)
}(),
description: "Performance test with larger scope lists",
},
{
name: "Complex OAuth scopes with special characters",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"},
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"},
description: "Real-world OAuth scopes with colons, slashes, and special characters",
},
{
name: "Duplicate defaults in user scopes multiple times",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"openid", "profile", "openid", "custom", "email", "profile", "custom"},
expectedScopes: []string{"openid", "profile", "email", "custom"},
description: "Multiple duplicates of default scopes should be completely deduplicated",
},
{
name: "All user scopes are duplicates of defaults",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"email", "openid", "profile", "openid"},
expectedScopes: []string{"openid", "profile", "email"},
description: "When all user scopes duplicate defaults, result should be just defaults",
},
{
name: "Single scope scenarios",
defaultScopes: []string{"openid"},
userScopes: []string{"custom"},
expectedScopes: []string{"openid", "custom"},
description: "Minimal case with single scopes",
},
{
name: "Identical scopes in same order",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"openid", "profile", "email"},
expectedScopes: []string{"openid", "profile", "email"},
description: "When user scopes exactly match defaults, no duplication",
},
{
name: "Identical scopes in different order",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"email", "profile", "openid"},
expectedScopes: []string{"openid", "profile", "email"},
description: "Order of defaults is preserved when user scopes are reordered duplicates",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Test the mergeScopes function directly
result := mergeScopes(tc.defaultScopes, tc.userScopes)
if !equalSlices(result, tc.expectedScopes) {
t.Errorf("Expected %v, got %v\nDescription: %s", tc.expectedScopes, result, tc.description)
}
})
}
}
// generateLargeUserScopes creates a large list of user scopes for performance testing
func generateLargeUserScopes() []string {
scopes := make([]string, 100)
for i := range 100 {
scopes[i] = fmt.Sprintf("scope_%d", i)
}
// Add some duplicates to test deduplication performance
scopes = append(scopes, "scope_1", "scope_5", "scope_10", "openid") // Include a default duplicate
return scopes
}
// TestScopeMergingPerformance tests performance with large scope lists
func TestScopeMergingPerformance(t *testing.T) {
// Create large scope lists
defaultScopes := []string{"openid", "profile", "email"}
// Create 1000 user scopes with some duplicates
userScopes := make([]string, 1000)
for i := range 1000 {
if i%10 == 0 {
// Add some duplicates of defaults
userScopes[i] = defaultScopes[i%len(defaultScopes)]
} else if i%7 == 0 {
// Add some internal duplicates
userScopes[i] = fmt.Sprintf("scope_%d", i%50)
} else {
userScopes[i] = fmt.Sprintf("scope_%d", i)
}
}
// Measure performance
start := time.Now()
result := mergeScopes(defaultScopes, userScopes)
duration := time.Since(start)
// Verify result correctness
if len(result) < len(defaultScopes) {
t.Errorf("Result should contain at least the default scopes")
}
// Verify no duplicates exist
seen := make(map[string]bool)
for _, scope := range result {
if seen[scope] {
t.Errorf("Duplicate scope found in result: %s", scope)
}
seen[scope] = true
}
// Performance assertion (should be very fast)
if duration > time.Millisecond*10 {
t.Logf("Performance note: mergeScopes took %v for 1000+ scopes (still acceptable)", duration)
}
t.Logf("Performance: processed %d user scopes in %v, result has %d unique scopes",
len(userScopes), duration, len(result))
}
// TestScopeMergingMemoryEfficiency tests memory efficiency of the mergeScopes function
func TestScopeMergingMemoryEfficiency(t *testing.T) {
defaultScopes := []string{"openid", "profile", "email"}
userScopes := []string{"custom1", "custom2"}
// Test that the function doesn't modify input slices
originalDefaults := make([]string, len(defaultScopes))
copy(originalDefaults, defaultScopes)
originalUser := make([]string, len(userScopes))
copy(originalUser, userScopes)
result := mergeScopes(defaultScopes, userScopes)
// Verify input slices are unchanged
for i, scope := range defaultScopes {
if scope != originalDefaults[i] {
t.Errorf("Default scopes were modified: expected %s, got %s", originalDefaults[i], scope)
}
}
for i, scope := range userScopes {
if scope != originalUser[i] {
t.Errorf("User scopes were modified: expected %s, got %s", originalUser[i], scope)
}
}
// Verify result is independent
result[0] = "modified"
if defaultScopes[0] == "modified" {
t.Error("Modifying result affected input defaults")
}
expectedLength := len(defaultScopes) + len(userScopes)
if len(result) != expectedLength {
t.Errorf("Expected result length %d, got %d", expectedLength, len(result))
}
}
// TestNewWithScopeAppending tests that the New function properly merges scopes
func TestNewWithScopeAppending(t *testing.T) {
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
defer mockServer.Close()
tests := []struct {
name string
configScopes []string
expectedScopes []string
}{
{
name: "Default scopes only",
configScopes: []string{},
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "Custom scopes appended",
configScopes: []string{"custom_scope", "another_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
},
{
name: "Overlapping scopes deduplicated",
configScopes: []string{"openid", "custom_scope"},
expectedScopes: []string{"openid", "profile", "email", "custom_scope"},
},
{
name: "OAuth scopes",
configScopes: []string{"read:users", "write:users", "admin"},
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create config with test scopes
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
Scopes: tc.configScopes,
}
// Create middleware instance
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware: %v", err)
}
// Wait for initialization
if m, ok := middleware.(*TraefikOidc); ok {
select {
case <-m.initComplete:
case <-time.After(5 * time.Second):
t.Fatalf("Middleware failed to initialize")
}
// Check that scopes were properly merged
if !equalSlices(m.scopes, tc.expectedScopes) {
t.Errorf("Expected scopes %v, got %v", tc.expectedScopes, m.scopes)
}
} else {
t.Fatalf("Middleware is not of type *TraefikOidc")
}
})
}
}
// TestBuildAuthURLWithMergedScopes tests that the auth URL includes the properly merged scopes
func TestBuildAuthURLWithMergedScopes(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
expectedScopes string
scopes []string
}{
{
name: "Default scopes only",
scopes: []string{"openid", "profile", "email"},
expectedScopes: "openid profile email offline_access",
},
{
name: "Custom scopes appended",
scopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
expectedScopes: "openid profile email custom_scope another_scope offline_access",
},
{
name: "OAuth scopes",
scopes: []string{"openid", "profile", "email", "read:users", "write:users"},
expectedScopes: "openid profile email read:users write:users offline_access",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance with specific scopes
tOidc := ts.tOidc
tOidc.scopes = tc.scopes
tOidc.authURL = "https://auth.example.com/oauth/authorize"
tOidc.issuerURL = "https://auth.example.com"
// Build auth URL
result := tOidc.buildAuthURL("https://app.example.com/callback", "test-state", "test-nonce", "")
// Parse the resulting URL to verify scopes
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
actualScopes := query.Get("scope")
if actualScopes != tc.expectedScopes {
t.Errorf("Expected scopes %q, got %q", tc.expectedScopes, actualScopes)
}
})
}
}
+25 -11
View File
@@ -8,21 +8,31 @@ import (
)
type MetadataCache struct {
metadata *ProviderMetadata
expiresAt time.Time
mutex sync.RWMutex
metadata *ProviderMetadata
cleanupTask *BackgroundTask
logger *Logger
autoCleanupInterval time.Duration
stopCleanup chan struct{}
mutex sync.RWMutex
}
// NewMetadataCache creates a new MetadataCache instance.
// It initializes the cache structure and starts the background cleanup goroutine.
// It initializes the cache structure and starts the background cleanup task.
func NewMetadataCache() *MetadataCache {
return NewMetadataCacheWithLogger(nil)
}
// NewMetadataCacheWithLogger creates a new MetadataCache with a specified logger.
func NewMetadataCacheWithLogger(logger *Logger) *MetadataCache {
if logger == nil {
logger = newNoOpLogger()
}
c := &MetadataCache{
autoCleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
logger: logger,
}
go c.startAutoCleanup()
c.startAutoCleanup()
return c
}
@@ -92,20 +102,24 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
c.metadata = metadata
// Set a fixed cache lifetime (e.g., 1 hour)
// TODO: Consider making this configurable or respecting HTTP cache headers
// Consider making this configurable or respecting HTTP cache headers
c.expiresAt = time.Now().Add(1 * time.Hour)
// End of GetMetadata
return metadata, nil
}
// startAutoCleanup starts the background goroutine that periodically calls Cleanup
// startAutoCleanup starts the background task that periodically calls Cleanup
// to remove expired metadata from the cache.
func (c *MetadataCache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
c.cleanupTask = NewBackgroundTask("metadata-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
c.cleanupTask.Start()
}
// Close stops the automatic cleanup goroutine associated with this metadata cache.
// Close stops the automatic cleanup task associated with this metadata cache.
func (c *MetadataCache) Close() {
close(c.stopCleanup)
if c.cleanupTask != nil {
c.cleanupTask.Stop()
c.cleanupTask = nil
}
}
+4 -4
View File
@@ -41,8 +41,8 @@ func TestGetMetadata_Cached(t *testing.T) {
mc := &MetadataCache{
metadata: dummyData,
expiresAt: time.Now().Add(1 * time.Hour),
stopCleanup: make(chan struct{}),
autoCleanupInterval: 5 * time.Minute,
logger: newNoOpLogger(),
}
// Use NewLogger to create a logger that writes errors only.
logger := NewLogger("error")
@@ -58,10 +58,10 @@ func TestGetMetadata_Cached(t *testing.T) {
func TestMetadataCacheAutoCleanup(t *testing.T) {
mc := &MetadataCache{
autoCleanupInterval: 50 * time.Millisecond,
stopCleanup: make(chan struct{}),
logger: newNoOpLogger(),
}
// Start auto cleanup.
go mc.startAutoCleanup()
mc.startAutoCleanup()
mc.mutex.Lock()
mc.metadata = &ProviderMetadata{}
mc.expiresAt = time.Now().Add(-50 * time.Millisecond)
@@ -93,7 +93,7 @@ func TestGetMetadata_FetchError(t *testing.T) {
// Case 1: Cache is empty.
mc := &MetadataCache{
stopCleanup: make(chan struct{}),
logger: newNoOpLogger(),
}
logger := NewLogger("error")
metadata, err := mc.GetMetadata("http://example.com", errorClient, logger)
-709
View File
@@ -1,709 +0,0 @@
package traefikoidc
import (
"runtime"
"sync"
"sync/atomic"
"time"
)
// PerformanceMetrics tracks various performance-related metrics
type PerformanceMetrics struct {
// Cache metrics
cacheHits int64
cacheMisses int64
cacheEvictions int64
cacheSize int64
// Token operation metrics
tokenVerifications int64
tokenValidations int64
tokenRefreshes int64
// Success/failure tracking
successfulVerifications int64
successfulValidations int64
successfulRefreshes int64
failedVerifications int64
failedValidations int64
failedRefreshes int64
// Timing metrics
avgVerificationTime time.Duration
avgValidationTime time.Duration
avgRefreshTime time.Duration
// Resource metrics
memoryUsage int64
goroutineCount int64
memoryPressure int64 // Memory pressure level (0-100)
gcPauseTime int64 // Last GC pause time in nanoseconds
heapSize int64 // Current heap size
heapInUse int64 // Heap memory in use
// Error metrics (kept for backward compatibility)
verificationErrors int64
validationErrors int64
refreshErrors int64
// Rate limiting metrics
rateLimitedRequests int64
// Session metrics
activeSessions int64
sessionCreations int64
sessionDeletions int64
// Timing tracking
timingMutex sync.RWMutex
verificationTimes []time.Duration
validationTimes []time.Duration
refreshTimes []time.Duration
// Start time for uptime calculation
startTime time.Time
logger *Logger
}
// NewPerformanceMetrics creates a new performance metrics tracker
func NewPerformanceMetrics(logger *Logger) *PerformanceMetrics {
pm := &PerformanceMetrics{
startTime: time.Now(),
verificationTimes: make([]time.Duration, 0, 1000), // Keep last 1000 measurements
validationTimes: make([]time.Duration, 0, 1000),
refreshTimes: make([]time.Duration, 0, 1000),
logger: logger,
}
// Start background metrics collection
go pm.startMetricsCollection()
return pm
}
// RecordCacheHit records a cache hit
func (pm *PerformanceMetrics) RecordCacheHit() {
atomic.AddInt64(&pm.cacheHits, 1)
}
// RecordCacheMiss records a cache miss
func (pm *PerformanceMetrics) RecordCacheMiss() {
atomic.AddInt64(&pm.cacheMisses, 1)
}
// RecordCacheEviction records a cache eviction
func (pm *PerformanceMetrics) RecordCacheEviction() {
atomic.AddInt64(&pm.cacheEvictions, 1)
}
// UpdateCacheSize updates the current cache size
func (pm *PerformanceMetrics) UpdateCacheSize(size int64) {
atomic.StoreInt64(&pm.cacheSize, size)
}
// RecordTokenVerification records a token verification operation
func (pm *PerformanceMetrics) RecordTokenVerification(duration time.Duration, success bool) {
atomic.AddInt64(&pm.tokenVerifications, 1)
if success {
atomic.AddInt64(&pm.successfulVerifications, 1)
pm.addVerificationTime(duration)
} else {
atomic.AddInt64(&pm.failedVerifications, 1)
atomic.AddInt64(&pm.verificationErrors, 1)
}
}
// RecordTokenValidation records a token validation operation
func (pm *PerformanceMetrics) RecordTokenValidation(duration time.Duration, success bool) {
atomic.AddInt64(&pm.tokenValidations, 1)
if success {
atomic.AddInt64(&pm.successfulValidations, 1)
pm.addValidationTime(duration)
} else {
atomic.AddInt64(&pm.failedValidations, 1)
atomic.AddInt64(&pm.validationErrors, 1)
}
}
// RecordTokenRefresh records a token refresh operation
func (pm *PerformanceMetrics) RecordTokenRefresh(duration time.Duration, success bool) {
atomic.AddInt64(&pm.tokenRefreshes, 1)
if success {
atomic.AddInt64(&pm.successfulRefreshes, 1)
pm.addRefreshTime(duration)
} else {
atomic.AddInt64(&pm.failedRefreshes, 1)
atomic.AddInt64(&pm.refreshErrors, 1)
}
}
// RecordRateLimitedRequest records a rate-limited request
func (pm *PerformanceMetrics) RecordRateLimitedRequest() {
atomic.AddInt64(&pm.rateLimitedRequests, 1)
}
// RecordSessionCreation records a session creation
func (pm *PerformanceMetrics) RecordSessionCreation() {
atomic.AddInt64(&pm.sessionCreations, 1)
atomic.AddInt64(&pm.activeSessions, 1)
}
// RecordSessionDeletion records a session deletion
func (pm *PerformanceMetrics) RecordSessionDeletion() {
atomic.AddInt64(&pm.sessionDeletions, 1)
atomic.AddInt64(&pm.activeSessions, -1)
}
// addVerificationTime adds a verification time measurement
func (pm *PerformanceMetrics) addVerificationTime(duration time.Duration) {
pm.timingMutex.Lock()
defer pm.timingMutex.Unlock()
pm.verificationTimes = append(pm.verificationTimes, duration)
if len(pm.verificationTimes) > 1000 {
pm.verificationTimes = pm.verificationTimes[1:]
}
pm.updateAverageVerificationTime()
}
// addValidationTime adds a validation time measurement
func (pm *PerformanceMetrics) addValidationTime(duration time.Duration) {
pm.timingMutex.Lock()
defer pm.timingMutex.Unlock()
pm.validationTimes = append(pm.validationTimes, duration)
if len(pm.validationTimes) > 1000 {
pm.validationTimes = pm.validationTimes[1:]
}
pm.updateAverageValidationTime()
}
// addRefreshTime adds a refresh time measurement
func (pm *PerformanceMetrics) addRefreshTime(duration time.Duration) {
pm.timingMutex.Lock()
defer pm.timingMutex.Unlock()
pm.refreshTimes = append(pm.refreshTimes, duration)
if len(pm.refreshTimes) > 1000 {
pm.refreshTimes = pm.refreshTimes[1:]
}
pm.updateAverageRefreshTime()
}
// updateAverageVerificationTime calculates the average verification time
func (pm *PerformanceMetrics) updateAverageVerificationTime() {
if len(pm.verificationTimes) == 0 {
pm.avgVerificationTime = 0
return
}
var total time.Duration
for _, t := range pm.verificationTimes {
total += t
}
pm.avgVerificationTime = total / time.Duration(len(pm.verificationTimes))
}
// updateAverageValidationTime calculates the average validation time
func (pm *PerformanceMetrics) updateAverageValidationTime() {
if len(pm.validationTimes) == 0 {
pm.avgValidationTime = 0
return
}
var total time.Duration
for _, t := range pm.validationTimes {
total += t
}
pm.avgValidationTime = total / time.Duration(len(pm.validationTimes))
}
// updateAverageRefreshTime calculates the average refresh time
func (pm *PerformanceMetrics) updateAverageRefreshTime() {
if len(pm.refreshTimes) == 0 {
pm.avgRefreshTime = 0
return
}
var total time.Duration
for _, t := range pm.refreshTimes {
total += t
}
pm.avgRefreshTime = total / time.Duration(len(pm.refreshTimes))
}
// startMetricsCollection starts background collection of system metrics
func (pm *PerformanceMetrics) startMetricsCollection() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
pm.collectSystemMetrics()
}
}
// collectSystemMetrics collects system-level metrics
func (pm *PerformanceMetrics) collectSystemMetrics() {
// Memory statistics
var m runtime.MemStats
runtime.ReadMemStats(&m)
atomic.StoreInt64(&pm.memoryUsage, int64(m.Alloc))
atomic.StoreInt64(&pm.heapSize, int64(m.HeapSys))
atomic.StoreInt64(&pm.heapInUse, int64(m.HeapInuse))
atomic.StoreInt64(&pm.gcPauseTime, int64(m.PauseNs[(m.NumGC+255)%256]))
// Calculate memory pressure (0-100 scale)
// Based on heap utilization and GC frequency
heapUtilization := float64(m.HeapInuse) / float64(m.HeapSys)
gcFrequency := float64(m.NumGC) / time.Since(pm.startTime).Minutes()
// Memory pressure calculation
pressure := int64(heapUtilization * 50) // 0-50 based on heap utilization
if gcFrequency > 10 { // High GC frequency indicates pressure
pressure += int64((gcFrequency - 10) * 2) // Add up to 50 more
}
if pressure > 100 {
pressure = 100
}
atomic.StoreInt64(&pm.memoryPressure, pressure)
// Goroutine count
atomic.StoreInt64(&pm.goroutineCount, int64(runtime.NumGoroutine()))
// Log memory pressure warnings
if pressure > 80 {
pm.logger.Errorf("High memory pressure detected: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
pressure, heapUtilization*100, gcFrequency)
} else if pressure > 60 {
pm.logger.Infof("Moderate memory pressure: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
pressure, heapUtilization*100, gcFrequency)
}
}
// GetMetrics returns all current performance metrics
func (pm *PerformanceMetrics) GetMetrics() map[string]interface{} {
pm.timingMutex.RLock()
defer pm.timingMutex.RUnlock()
// Calculate cache hit ratio
hits := atomic.LoadInt64(&pm.cacheHits)
misses := atomic.LoadInt64(&pm.cacheMisses)
var hitRatio float64
if hits+misses > 0 {
hitRatio = float64(hits) / float64(hits+misses)
}
// Calculate error rates
verifications := atomic.LoadInt64(&pm.tokenVerifications)
validations := atomic.LoadInt64(&pm.tokenValidations)
refreshes := atomic.LoadInt64(&pm.tokenRefreshes)
var verificationErrorRate, validationErrorRate, refreshErrorRate float64
if verifications > 0 {
verificationErrorRate = float64(atomic.LoadInt64(&pm.verificationErrors)) / float64(verifications)
}
if validations > 0 {
validationErrorRate = float64(atomic.LoadInt64(&pm.validationErrors)) / float64(validations)
}
if refreshes > 0 {
refreshErrorRate = float64(atomic.LoadInt64(&pm.refreshErrors)) / float64(refreshes)
}
return map[string]interface{}{
// Cache metrics
"cache_hits": hits,
"cache_misses": misses,
"cache_hit_ratio": hitRatio,
"cache_evictions": atomic.LoadInt64(&pm.cacheEvictions),
"cache_size": atomic.LoadInt64(&pm.cacheSize),
// Token operation metrics
"token_verifications": verifications,
"token_validations": validations,
"token_refreshes": refreshes,
"verification_error_rate": verificationErrorRate,
"validation_error_rate": validationErrorRate,
"refresh_error_rate": refreshErrorRate,
// Success/failure metrics
"successful_verifications": atomic.LoadInt64(&pm.successfulVerifications),
"successful_validations": atomic.LoadInt64(&pm.successfulValidations),
"successful_refreshes": atomic.LoadInt64(&pm.successfulRefreshes),
"failed_verifications": atomic.LoadInt64(&pm.failedVerifications),
"failed_validations": atomic.LoadInt64(&pm.failedValidations),
"failed_refreshes": atomic.LoadInt64(&pm.failedRefreshes),
// Timing metrics
"avg_verification_time_ms": pm.avgVerificationTime.Milliseconds(),
"avg_validation_time_ms": pm.avgValidationTime.Milliseconds(),
"avg_refresh_time_ms": pm.avgRefreshTime.Milliseconds(),
// Resource metrics
"memory_usage_bytes": atomic.LoadInt64(&pm.memoryUsage),
"memory_pressure": atomic.LoadInt64(&pm.memoryPressure),
"heap_size_bytes": atomic.LoadInt64(&pm.heapSize),
"heap_inuse_bytes": atomic.LoadInt64(&pm.heapInUse),
"gc_pause_time_ns": atomic.LoadInt64(&pm.gcPauseTime),
"goroutine_count": atomic.LoadInt64(&pm.goroutineCount),
// Rate limiting metrics
"rate_limited_requests": atomic.LoadInt64(&pm.rateLimitedRequests),
// Session metrics
"active_sessions": atomic.LoadInt64(&pm.activeSessions),
"sessions_created": atomic.LoadInt64(&pm.sessionCreations),
"sessions_deleted": atomic.LoadInt64(&pm.sessionDeletions),
"session_creations": atomic.LoadInt64(&pm.sessionCreations),
"session_deletions": atomic.LoadInt64(&pm.sessionDeletions),
// Uptime
"uptime_seconds": time.Since(pm.startTime).Seconds(),
}
}
// GetDetailedTimingMetrics returns detailed timing statistics
func (pm *PerformanceMetrics) GetDetailedTimingMetrics() map[string]interface{} {
pm.timingMutex.RLock()
defer pm.timingMutex.RUnlock()
return map[string]interface{}{
"verification_stats": pm.calculateTimingStats(pm.verificationTimes),
"verification_timing": pm.calculateTimingStats(pm.verificationTimes),
"validation_stats": pm.calculateTimingStats(pm.validationTimes),
"validation_timing": pm.calculateTimingStats(pm.validationTimes),
"refresh_stats": pm.calculateTimingStats(pm.refreshTimes),
"refresh_timing": pm.calculateTimingStats(pm.refreshTimes),
}
}
// calculateTimingStats calculates statistical metrics for timing data
func (pm *PerformanceMetrics) calculateTimingStats(times []time.Duration) map[string]interface{} {
if len(times) == 0 {
return map[string]interface{}{
"count": 0,
"min_ms": float64(0),
"max_ms": float64(0),
"avg_ms": float64(0),
"average_ms": float64(0),
"median_ms": float64(0),
"p95_ms": float64(0),
"p99_ms": float64(0),
}
}
// Sort times for percentile calculations
sortedTimes := make([]time.Duration, len(times))
copy(sortedTimes, times)
// Simple bubble sort for small arrays
for i := 0; i < len(sortedTimes); i++ {
for j := i + 1; j < len(sortedTimes); j++ {
if sortedTimes[i] > sortedTimes[j] {
sortedTimes[i], sortedTimes[j] = sortedTimes[j], sortedTimes[i]
}
}
}
// Calculate statistics
min := sortedTimes[0]
max := sortedTimes[len(sortedTimes)-1]
var total time.Duration
for _, t := range sortedTimes {
total += t
}
avg := total / time.Duration(len(sortedTimes))
median := sortedTimes[len(sortedTimes)/2]
p95 := sortedTimes[int(float64(len(sortedTimes))*0.95)]
p99 := sortedTimes[int(float64(len(sortedTimes))*0.99)]
return map[string]interface{}{
"count": len(sortedTimes),
"min_ms": float64(min.Nanoseconds()) / 1e6,
"max_ms": float64(max.Nanoseconds()) / 1e6,
"avg_ms": float64(avg.Nanoseconds()) / 1e6,
"average_ms": float64(avg.Nanoseconds()) / 1e6,
"median_ms": float64(median.Nanoseconds()) / 1e6,
"p95_ms": float64(p95.Nanoseconds()) / 1e6,
"p99_ms": float64(p99.Nanoseconds()) / 1e6,
}
}
// ResourceMonitor tracks resource usage and limits
type ResourceMonitor struct {
// Memory limits
maxMemoryBytes int64
// Cache limits
maxCacheSize int64
// Session limits
maxSessions int64
// Cache size tracking
cacheSizes map[string]int64
cacheMutex sync.RWMutex
// Monitoring state
alertThresholds map[string]float64
alerts []ResourceAlert
alertsMutex sync.RWMutex
// Performance metrics reference
perfMetrics *PerformanceMetrics
logger *Logger
}
// ResourceAlert represents a resource usage alert
type ResourceAlert struct {
Type string `json:"type"`
Message string `json:"message"`
Threshold float64 `json:"threshold"`
CurrentValue float64 `json:"current_value"`
Timestamp time.Time `json:"timestamp"`
Severity string `json:"severity"`
}
// NewResourceMonitor creates a new resource monitor
func NewResourceMonitor(perfMetrics *PerformanceMetrics, logger *Logger) *ResourceMonitor {
rm := &ResourceMonitor{
maxMemoryBytes: 100 * 1024 * 1024, // 100MB default
maxCacheSize: 10000, // 10k items default
maxSessions: 1000, // 1k sessions default
cacheSizes: make(map[string]int64),
alertThresholds: map[string]float64{
"memory_usage": 0.8, // 80%
"memory_pressure": 0.7, // 70%
"cache_usage": 0.9, // 90%
"session_usage": 0.85, // 85%
"error_rate": 0.1, // 10%
},
alerts: make([]ResourceAlert, 0),
perfMetrics: perfMetrics,
logger: logger,
}
// Start monitoring routine
go rm.startMonitoring()
return rm
}
// SetMemoryLimit sets the maximum memory usage limit
func (rm *ResourceMonitor) SetMemoryLimit(bytes int64) {
rm.maxMemoryBytes = bytes
}
// SetCacheLimit sets the maximum cache size limit
func (rm *ResourceMonitor) SetCacheLimit(size int64) {
rm.maxCacheSize = size
}
// SetSessionLimit sets the maximum session count limit
func (rm *ResourceMonitor) SetSessionLimit(count int64) {
rm.maxSessions = count
}
// UpdateCacheSize updates the size of a specific cache
func (rm *ResourceMonitor) UpdateCacheSize(cacheName string, size int64) {
rm.cacheMutex.Lock()
defer rm.cacheMutex.Unlock()
rm.cacheSizes[cacheName] = size
}
// GetCacheSizes returns current cache sizes
func (rm *ResourceMonitor) GetCacheSizes() map[string]int64 {
rm.cacheMutex.RLock()
defer rm.cacheMutex.RUnlock()
sizes := make(map[string]int64)
for name, size := range rm.cacheSizes {
sizes[name] = size
}
return sizes
}
// startMonitoring starts the background monitoring routine
func (rm *ResourceMonitor) startMonitoring() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
rm.checkResourceUsage()
}
}
// checkResourceUsage checks current resource usage against limits
func (rm *ResourceMonitor) checkResourceUsage() {
metrics := rm.perfMetrics.GetMetrics()
// Check memory usage
if memUsage, ok := metrics["memory_usage_bytes"].(int64); ok {
memUsageRatio := float64(memUsage) / float64(rm.maxMemoryBytes)
if memUsageRatio > rm.alertThresholds["memory_usage"] {
rm.addAlert(ResourceAlert{
Type: "memory_usage",
Message: "Memory usage exceeds threshold",
Threshold: rm.alertThresholds["memory_usage"],
CurrentValue: memUsageRatio,
Timestamp: time.Now(),
Severity: rm.getSeverity(memUsageRatio, rm.alertThresholds["memory_usage"]),
})
}
}
// Check memory pressure
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
pressureRatio := float64(memPressure) / 100.0 // Convert to 0-1 scale
if pressureRatio > rm.alertThresholds["memory_pressure"] {
rm.addAlert(ResourceAlert{
Type: "memory_pressure",
Message: "Memory pressure exceeds threshold",
Threshold: rm.alertThresholds["memory_pressure"],
CurrentValue: pressureRatio,
Timestamp: time.Now(),
Severity: rm.getSeverity(pressureRatio, rm.alertThresholds["memory_pressure"]),
})
}
}
// Check cache usage
if cacheSize, ok := metrics["cache_size"].(int64); ok {
cacheUsageRatio := float64(cacheSize) / float64(rm.maxCacheSize)
if cacheUsageRatio > rm.alertThresholds["cache_usage"] {
rm.addAlert(ResourceAlert{
Type: "cache_usage",
Message: "Cache usage exceeds threshold",
Threshold: rm.alertThresholds["cache_usage"],
CurrentValue: cacheUsageRatio,
Timestamp: time.Now(),
Severity: rm.getSeverity(cacheUsageRatio, rm.alertThresholds["cache_usage"]),
})
}
}
// Check session usage
if activeSessions, ok := metrics["active_sessions"].(int64); ok {
sessionUsageRatio := float64(activeSessions) / float64(rm.maxSessions)
if sessionUsageRatio > rm.alertThresholds["session_usage"] {
rm.addAlert(ResourceAlert{
Type: "session_usage",
Message: "Active session count exceeds threshold",
Threshold: rm.alertThresholds["session_usage"],
CurrentValue: sessionUsageRatio,
Timestamp: time.Now(),
Severity: rm.getSeverity(sessionUsageRatio, rm.alertThresholds["session_usage"]),
})
}
}
// Check error rates
if errorRate, ok := metrics["verification_error_rate"].(float64); ok {
if errorRate > rm.alertThresholds["error_rate"] {
rm.addAlert(ResourceAlert{
Type: "verification_error_rate",
Message: "Token verification error rate exceeds threshold",
Threshold: rm.alertThresholds["error_rate"],
CurrentValue: errorRate,
Timestamp: time.Now(),
Severity: rm.getSeverity(errorRate, rm.alertThresholds["error_rate"]),
})
}
}
}
// getSeverity determines the severity level based on how much the threshold is exceeded
func (rm *ResourceMonitor) getSeverity(currentValue, threshold float64) string {
ratio := currentValue / threshold
if ratio >= 1.5 {
return "critical"
} else if ratio >= 1.2 {
return "high"
} else if ratio >= 1.0 {
return "medium"
}
return "low"
}
// addAlert adds a new resource alert
func (rm *ResourceMonitor) addAlert(alert ResourceAlert) {
rm.alertsMutex.Lock()
defer rm.alertsMutex.Unlock()
// Add alert
rm.alerts = append(rm.alerts, alert)
// Keep only last 100 alerts
if len(rm.alerts) > 100 {
rm.alerts = rm.alerts[1:]
}
// Log the alert
rm.logger.Errorf("Resource Alert [%s/%s]: %s (%.2f%% > %.2f%%)",
alert.Type, alert.Severity, alert.Message,
alert.CurrentValue*100, alert.Threshold*100)
}
// GetAlerts returns current resource alerts
func (rm *ResourceMonitor) GetAlerts() []ResourceAlert {
rm.alertsMutex.RLock()
defer rm.alertsMutex.RUnlock()
alerts := make([]ResourceAlert, len(rm.alerts))
copy(alerts, rm.alerts)
return alerts
}
// GetResourceStatus returns current resource status
func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
metrics := rm.perfMetrics.GetMetrics()
cacheSizes := rm.GetCacheSizes()
status := map[string]interface{}{
"limits": map[string]interface{}{
"max_memory_bytes": rm.maxMemoryBytes,
"max_cache_size": rm.maxCacheSize,
"max_sessions": rm.maxSessions,
},
"thresholds": rm.alertThresholds,
"current": metrics,
"cache_sizes": cacheSizes,
// Add expected keys for tests
"memory_limit": uint64(rm.maxMemoryBytes),
"cache_limit": int(rm.maxCacheSize),
"session_limit": int(rm.maxSessions),
}
// Calculate usage ratios
if memUsage, ok := metrics["memory_usage_bytes"].(int64); ok {
status["memory_usage_ratio"] = float64(memUsage) / float64(rm.maxMemoryBytes)
}
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
status["memory_pressure_ratio"] = float64(memPressure) / 100.0
}
if cacheSize, ok := metrics["cache_size"].(int64); ok {
status["cache_usage_ratio"] = float64(cacheSize) / float64(rm.maxCacheSize)
}
if activeSessions, ok := metrics["active_sessions"].(int64); ok {
status["session_usage_ratio"] = float64(activeSessions) / float64(rm.maxSessions)
}
// Calculate total cache size across all caches
var totalCacheSize int64
for _, size := range cacheSizes {
totalCacheSize += size
}
status["total_cache_size"] = totalCacheSize
return status
}
-324
View File
@@ -1,324 +0,0 @@
package traefikoidc
import (
"testing"
"time"
)
func TestPerformanceMetrics(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger)
t.Run("Record cache operations", func(t *testing.T) {
metrics.RecordCacheHit()
metrics.RecordCacheMiss()
metrics.RecordCacheEviction()
metrics.UpdateCacheSize(100)
result := metrics.GetMetrics()
if result["cache_hits"].(int64) != 1 {
t.Errorf("Expected 1 cache hit, got %v", result["cache_hits"])
}
if result["cache_misses"].(int64) != 1 {
t.Errorf("Expected 1 cache miss, got %v", result["cache_misses"])
}
if result["cache_evictions"].(int64) != 1 {
t.Errorf("Expected 1 cache eviction, got %v", result["cache_evictions"])
}
if result["cache_size"].(int64) != 100 {
t.Errorf("Expected cache size 100, got %v", result["cache_size"])
}
})
t.Run("Record token operations", func(t *testing.T) {
start := time.Now()
time.Sleep(10 * time.Millisecond)
metrics.RecordTokenVerification(time.Since(start), true)
start = time.Now()
time.Sleep(5 * time.Millisecond)
metrics.RecordTokenValidation(time.Since(start), false)
start = time.Now()
time.Sleep(15 * time.Millisecond)
metrics.RecordTokenRefresh(time.Since(start), true)
result := metrics.GetMetrics()
if result["token_verifications"].(int64) != 1 {
t.Errorf("Expected 1 token verification, got %v", result["token_verifications"])
}
if result["token_validations"].(int64) != 1 {
t.Errorf("Expected 1 token validation, got %v", result["token_validations"])
}
if result["token_refreshes"].(int64) != 1 {
t.Errorf("Expected 1 token refresh, got %v", result["token_refreshes"])
}
if result["successful_verifications"].(int64) != 1 {
t.Errorf("Expected 1 successful verification, got %v", result["successful_verifications"])
}
if result["failed_validations"].(int64) != 1 {
t.Errorf("Expected 1 failed validation, got %v", result["failed_validations"])
}
})
t.Run("Record rate limiting and sessions", func(t *testing.T) {
metrics.RecordRateLimitedRequest()
metrics.RecordSessionCreation()
metrics.RecordSessionDeletion()
result := metrics.GetMetrics()
if result["rate_limited_requests"].(int64) != 1 {
t.Errorf("Expected 1 rate limited request, got %v", result["rate_limited_requests"])
}
if result["sessions_created"].(int64) != 1 {
t.Errorf("Expected 1 session created, got %v", result["sessions_created"])
}
if result["sessions_deleted"].(int64) != 1 {
t.Errorf("Expected 1 session deleted, got %v", result["sessions_deleted"])
}
})
t.Run("Get detailed timing metrics", func(t *testing.T) {
// Add more timing data
for i := 0; i < 5; i++ {
metrics.RecordTokenVerification(time.Duration(i+1)*time.Millisecond, true)
}
detailed := metrics.GetDetailedTimingMetrics()
if detailed["verification_stats"] == nil {
t.Error("Expected verification stats to be present")
}
verificationStats := detailed["verification_stats"].(map[string]interface{})
if verificationStats["count"].(int) != 6 { // 1 from previous test + 5 new
t.Errorf("Expected 6 verifications, got %v", verificationStats["count"])
}
})
}
func TestResourceMonitor(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger)
monitor := NewResourceMonitor(metrics, logger)
t.Run("Set limits", func(t *testing.T) {
monitor.SetMemoryLimit(100 * 1024 * 1024) // 100MB
monitor.SetCacheLimit(1000)
monitor.SetSessionLimit(500)
// Should not panic
})
t.Run("Get resource status", func(t *testing.T) {
status := monitor.GetResourceStatus()
if status["memory_limit"] == nil {
t.Error("Expected memory limit to be set")
}
if status["cache_limit"] == nil {
t.Error("Expected cache limit to be set")
}
if status["session_limit"] == nil {
t.Error("Expected session limit to be set")
}
})
t.Run("Get alerts", func(t *testing.T) {
alerts := monitor.GetAlerts()
// Should return empty slice initially
if alerts == nil {
t.Error("Expected alerts slice to be initialized")
}
})
}
func TestPerformanceMetricsCalculations(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger)
t.Run("Average calculation", func(t *testing.T) {
// Record multiple operations with known durations
durations := []time.Duration{
10 * time.Millisecond,
20 * time.Millisecond,
30 * time.Millisecond,
}
for _, d := range durations {
metrics.RecordTokenVerification(d, true)
}
detailed := metrics.GetDetailedTimingMetrics()
verificationStats := detailed["verification_stats"].(map[string]interface{})
// Average should be 20ms
avgMs := verificationStats["average_ms"].(float64)
if avgMs < 19 || avgMs > 21 { // Allow small variance
t.Errorf("Expected average around 20ms, got %f", avgMs)
}
})
t.Run("Min/Max calculation", func(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger) // Fresh instance
durations := []time.Duration{
5 * time.Millisecond,
50 * time.Millisecond,
25 * time.Millisecond,
}
for _, d := range durations {
metrics.RecordTokenVerification(d, true)
}
detailed := metrics.GetDetailedTimingMetrics()
verificationStats := detailed["verification_stats"].(map[string]interface{})
minMs := verificationStats["min_ms"].(float64)
maxMs := verificationStats["max_ms"].(float64)
if minMs < 4 || minMs > 6 {
t.Errorf("Expected min around 5ms, got %f", minMs)
}
if maxMs < 49 || maxMs > 51 {
t.Errorf("Expected max around 50ms, got %f", maxMs)
}
})
}
func TestPerformanceMetricsReset(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger)
// Record some data
metrics.RecordCacheHit()
metrics.RecordTokenVerification(10*time.Millisecond, true)
// Verify data is there
result := metrics.GetMetrics()
if result["cache_hits"].(int64) != 1 {
t.Error("Expected cache hit to be recorded")
}
// Note: The current implementation doesn't have a reset method,
// but we can test that metrics accumulate correctly
metrics.RecordCacheHit()
result = metrics.GetMetrics()
if result["cache_hits"].(int64) != 2 {
t.Error("Expected cache hits to accumulate")
}
}
func TestPerformanceMetricsConcurrency(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger)
// Test concurrent access
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
defer func() { done <- true }()
for j := 0; j < 100; j++ {
metrics.RecordCacheHit()
metrics.RecordTokenVerification(time.Millisecond, true)
}
}()
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
result := metrics.GetMetrics()
// Should have 1000 cache hits (10 goroutines * 100 operations)
if result["cache_hits"].(int64) != 1000 {
t.Errorf("Expected 1000 cache hits, got %v", result["cache_hits"])
}
// Should have 1000 token verifications
if result["token_verifications"].(int64) != 1000 {
t.Errorf("Expected 1000 token verifications, got %v", result["token_verifications"])
}
}
func TestResourceMonitorLimits(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger)
monitor := NewResourceMonitor(metrics, logger)
t.Run("Memory limit validation", func(t *testing.T) {
// Set a reasonable memory limit
monitor.SetMemoryLimit(50 * 1024 * 1024) // 50MB
status := monitor.GetResourceStatus()
if status["memory_limit"].(uint64) != 50*1024*1024 {
t.Error("Memory limit not set correctly")
}
})
t.Run("Cache limit validation", func(t *testing.T) {
monitor.SetCacheLimit(2000)
status := monitor.GetResourceStatus()
if status["cache_limit"].(int) != 2000 {
t.Error("Cache limit not set correctly")
}
})
t.Run("Session limit validation", func(t *testing.T) {
monitor.SetSessionLimit(1000)
status := monitor.GetResourceStatus()
if status["session_limit"].(int) != 1000 {
t.Error("Session limit not set correctly")
}
})
}
func TestPerformanceMetricsEdgeCases(t *testing.T) {
logger := NewLogger("debug")
metrics := NewPerformanceMetrics(logger)
t.Run("Zero duration handling", func(t *testing.T) {
metrics.RecordTokenVerification(0, true)
result := metrics.GetMetrics()
if result["token_verifications"].(int64) != 1 {
t.Error("Should record verification even with zero duration")
}
})
t.Run("Very large duration handling", func(t *testing.T) {
largeDuration := time.Hour
metrics.RecordTokenVerification(largeDuration, true)
detailed := metrics.GetDetailedTimingMetrics()
verificationStats := detailed["verification_stats"].(map[string]interface{})
// Should handle large durations without overflow
if verificationStats["max_ms"].(float64) <= 0 {
t.Error("Should handle large durations correctly")
}
})
t.Run("Negative cache size handling", func(t *testing.T) {
// This shouldn't happen in practice, but test robustness
metrics.UpdateCacheSize(-1)
result := metrics.GetMetrics()
// Implementation should handle this gracefully
if result["cache_size"] == nil {
t.Error("Cache size should be present even if negative")
}
})
}
+3 -6
View File
@@ -229,7 +229,7 @@ func TestSessionConcurrencyProtection(t *testing.T) {
// Perform operations on session
s.SetEmail(fmt.Sprintf("user%d-%d@example.com", goroutineID, j))
s.SetAuthenticated(true)
s.SetAccessToken(fmt.Sprintf("token-%d-%d", goroutineID, j))
s.SetAccessToken(ValidAccessToken)
// Save session
testRR := httptest.NewRecorder()
@@ -651,11 +651,8 @@ func TestErrorRecoveryPatterns(t *testing.T) {
// Test recovery from cache corruption
t.Run("CacheCorruption", func(t *testing.T) {
// Corrupt the cache by setting invalid data
ts.tOidc.tokenCache.cache.items["corrupted"] = CacheItem{
Value: "invalid-data",
ExpiresAt: time.Now().Add(time.Hour),
}
// Corrupt the cache by using the Set method to avoid data race
ts.tOidc.tokenCache.cache.Set("corrupted", "invalid-data", time.Hour)
// System should handle corrupted cache gracefully
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
+124
View File
@@ -0,0 +1,124 @@
package traefikoidc
import (
"reflect"
"testing"
)
func TestMergeScopes(t *testing.T) {
testCases := []struct {
name string
defaultScopes []string
userScopes []string
expectedScopes []string
}{
{
name: "Empty user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{},
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "Non-overlapping scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"roles", "custom_scope"},
expectedScopes: []string{"openid", "profile", "email", "roles", "custom_scope"},
},
{
name: "Overlapping scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: []string{"openid", "roles", "profile", "permissions"},
expectedScopes: []string{"openid", "profile", "email", "roles", "permissions"},
},
{
name: "Nil user scopes",
defaultScopes: []string{"openid", "profile", "email"},
userScopes: nil,
expectedScopes: []string{"openid", "profile", "email"},
},
{
name: "Nil default scopes",
defaultScopes: nil,
userScopes: []string{"roles", "custom_scope"},
expectedScopes: []string{"roles", "custom_scope"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := mergeScopes(tc.defaultScopes, tc.userScopes)
if !reflect.DeepEqual(result, tc.expectedScopes) {
t.Errorf("Expected %v, got %v", tc.expectedScopes, result)
}
})
}
}
func TestScopesConfiguration(t *testing.T) {
defaultScopes := []string{"openid", "profile", "email"}
userScopes := []string{"roles", "custom_scope"}
t.Run("Default Append Behavior", func(t *testing.T) {
// Create config with user scopes but overrideScopes=false
config := &Config{
Scopes: userScopes,
OverrideScopes: false,
}
// Simulate middleware initialization
var result []string
if config.OverrideScopes {
result = append([]string(nil), config.Scopes...)
} else {
result = mergeScopes(defaultScopes, config.Scopes)
}
// Expect defaultScopes + userScopes with deduplication
expectedScopes := []string{"openid", "profile", "email", "roles", "custom_scope"}
if !reflect.DeepEqual(result, expectedScopes) {
t.Errorf("Expected %v, got %v", expectedScopes, result)
}
})
t.Run("Override Behavior", func(t *testing.T) {
// Create config with user scopes and overrideScopes=true
config := &Config{
Scopes: userScopes,
OverrideScopes: true,
}
// Simulate middleware initialization
var result []string
if config.OverrideScopes {
result = append([]string(nil), config.Scopes...)
} else {
result = mergeScopes(defaultScopes, config.Scopes)
}
// Expect only userScopes
if !reflect.DeepEqual(result, userScopes) {
t.Errorf("Expected %v, got %v", userScopes, result)
}
})
t.Run("Empty Scopes with Override", func(t *testing.T) {
// Create config with empty scopes and overrideScopes=true
config := &Config{
Scopes: []string{},
OverrideScopes: true,
}
// Simulate middleware initialization
var result []string
if config.OverrideScopes {
result = append([]string(nil), config.Scopes...)
} else {
result = mergeScopes(defaultScopes, config.Scopes)
}
// Expect empty scopes - check length instead of DeepEqual
if len(result) != 0 {
t.Errorf("Expected empty slice, got %v with length %d", result, len(result))
}
})
}
+36 -3
View File
@@ -389,8 +389,8 @@ func TestMissingClaims(t *testing.T) {
// Test cases for missing claims
testCases := []struct {
name string
omittedClaims []string
expectedError string
omittedClaims []string
}{
{
name: "Missing Issuer",
@@ -479,8 +479,8 @@ func TestSessionFixationAttack(t *testing.T) {
// Set up the attacker's session with malicious data
attackerSession.SetAuthenticated(true)
attackerSession.SetEmail("attacker@evil.com")
attackerSession.SetIDToken("fake-id-token")
attackerSession.SetAccessToken("fake-access-token")
attackerSession.SetIDToken(ValidIDToken)
attackerSession.SetAccessToken(ValidAccessToken)
// Save the session to get cookies
if err := attackerSession.Save(req, resp); err != nil {
@@ -510,6 +510,31 @@ func TestSessionFixationAttack(t *testing.T) {
w.WriteHeader(http.StatusOK)
})
// Create keys for JWT verification
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
// Create JWK
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
// Create mock JWK cache
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
// Create the TraefikOidc middleware
tOidc := &TraefikOidc{
next: nextHandler,
@@ -519,6 +544,8 @@ func TestSessionFixationAttack(t *testing.T) {
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: NewCache(),
tokenCache: NewTokenCache(),
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
@@ -528,7 +555,13 @@ func TestSessionFixationAttack(t *testing.T) {
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: sm,
extractClaimsFunc: extractClaims,
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
close(tOidc.initComplete)
// Now create a victim's request with the attacker's cookies
+158 -164
View File
@@ -6,73 +6,96 @@ import (
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
)
// SecurityEventType represents different types of security events
type SecurityEventType string
const (
// AuthFailure represents an authentication failure event
AuthFailure SecurityEventType = "authentication_failure"
// TokenValidFailure represents a token validation failure event
TokenValidFailure SecurityEventType = "token_validation_failure"
// RateLimitHit represents a rate limit hit event
RateLimitHit SecurityEventType = "rate_limit_hit"
// SuspiciousActivity represents a suspicious activity event
SuspiciousActivity SecurityEventType = "suspicious_activity"
)
// DefaultSeverity returns the default severity level for a security event type
func (t SecurityEventType) DefaultSeverity() string {
switch t {
case AuthFailure:
return "medium"
case TokenValidFailure:
return "medium"
case RateLimitHit:
return "low"
case SuspiciousActivity:
return "high"
default:
return "medium"
}
}
// IPFailureType returns the IP failure tracking type for a security event type
func (t SecurityEventType) IPFailureType() string {
switch t {
case AuthFailure:
return "auth_failure"
case TokenValidFailure:
return "token_failure"
case SuspiciousActivity:
return "suspicious"
default:
return "general"
}
}
// SecurityEvent represents a security-related event that should be logged and monitored
type SecurityEvent struct {
Timestamp time.Time `json:"timestamp"`
Details map[string]interface{} `json:"details,omitempty"`
Type string `json:"type"`
Severity string `json:"severity"`
Timestamp time.Time `json:"timestamp"`
ClientIP string `json:"client_ip"`
UserAgent string `json:"user_agent"`
RequestPath string `json:"request_path"`
Message string `json:"message"`
Details map[string]interface{} `json:"details,omitempty"`
}
// SecurityMonitor tracks security events and suspicious activity patterns
type SecurityMonitor struct {
// Event counters
authFailures int64
tokenValidationFails int64
rateLimitHits int64
suspiciousRequests int64
// IP-based tracking
ipFailures map[string]*IPFailureTracker
ipMutex sync.RWMutex
// Pattern detection
ipFailures map[string]*IPFailureTracker
patternDetector *SuspiciousPatternDetector
// Event handlers
eventHandlers []SecurityEventHandler
// Configuration
config SecurityMonitorConfig
// Logger
logger *Logger
logger *Logger
eventHandlers []SecurityEventHandler
config SecurityMonitorConfig
ipMutex sync.RWMutex
}
// IPFailureTracker tracks failures for a specific IP address
type IPFailureTracker struct {
FailureCount int64
LastFailure time.Time
FirstFailure time.Time
FailureTypes map[string]int64
IsBlocked bool
BlockedUntil time.Time
FailureTypes map[string]int64
FailureCount int64
mutex sync.RWMutex
IsBlocked bool
}
// SuspiciousPatternDetector identifies patterns that may indicate attacks
type SuspiciousPatternDetector struct {
// Time-based windows for pattern detection
shortWindow time.Duration // 1 minute
mediumWindow time.Duration // 5 minutes
longWindow time.Duration // 15 minutes
// Pattern thresholds
rapidFailureThreshold int // failures in short window
distributedAttackThreshold int // failures across IPs in medium window
persistentAttackThreshold int // failures in long window
// Pattern tracking
recentEvents []SecurityEvent
eventsMutex sync.RWMutex
recentEvents []SecurityEvent
shortWindow time.Duration
mediumWindow time.Duration
longWindow time.Duration
rapidFailureThreshold int
distributedAttackThreshold int
persistentAttackThreshold int
eventsMutex sync.RWMutex
}
// SecurityEventHandler defines the interface for handling security events
@@ -82,22 +105,15 @@ type SecurityEventHandler interface {
// SecurityMonitorConfig contains configuration for the security monitor
type SecurityMonitorConfig struct {
// Failure thresholds
MaxFailuresPerIP int `json:"max_failures_per_ip"`
FailureWindowMinutes int `json:"failure_window_minutes"`
BlockDurationMinutes int `json:"block_duration_minutes"`
// Pattern detection settings
EnablePatternDetection bool `json:"enable_pattern_detection"`
MaxFailuresPerIP int `json:"max_failures_per_ip"`
FailureWindowMinutes int `json:"failure_window_minutes"`
BlockDurationMinutes int `json:"block_duration_minutes"`
RapidFailureThreshold int `json:"rapid_failure_threshold"`
// Monitoring settings
EnableDetailedLogging bool `json:"enable_detailed_logging"`
LogSuspiciousOnly bool `json:"log_suspicious_only"`
// Cleanup settings
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
RetentionHours int `json:"retention_hours"`
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
RetentionHours int `json:"retention_hours"`
EnablePatternDetection bool `json:"enable_pattern_detection"`
EnableDetailedLogging bool `json:"enable_detailed_logging"`
LogSuspiciousOnly bool `json:"log_suspicious_only"`
}
// DefaultSecurityMonitorConfig returns a default configuration
@@ -115,6 +131,9 @@ func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
}
}
// cleanupTask holds the BackgroundTask for security cleanup
var cleanupTask *BackgroundTask
// NewSecurityMonitor creates a new security monitor instance
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
sm := &SecurityMonitor{
@@ -126,7 +145,7 @@ func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityM
}
// Start cleanup routine
go sm.startCleanupRoutine()
sm.startCleanupRoutine()
return sm
}
@@ -144,29 +163,55 @@ func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
}
}
// RecordAuthenticationFailure records an authentication failure event
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
atomic.AddInt64(&sm.authFailures, 1)
// RecordSecurityEvent is a generic method to record any type of security event
func (sm *SecurityMonitor) RecordSecurityEvent(
eventType SecurityEventType,
clientIP, userAgent, requestPath string,
message string,
details map[string]interface{},
trackIPFailure bool) {
// Create event with default values for the event type
event := SecurityEvent{
Type: "authentication_failure",
Severity: "medium",
Type: string(eventType),
Severity: eventType.DefaultSeverity(),
Timestamp: time.Now(),
ClientIP: clientIP,
UserAgent: userAgent,
RequestPath: requestPath,
Message: fmt.Sprintf("Authentication failed: %s", reason),
Message: message,
Details: details,
}
sm.recordIPFailure(clientIP, "auth_failure")
// Track IP failures if requested
if trackIPFailure {
sm.recordIPFailure(clientIP, eventType.IPFailureType())
}
// Process the event
sm.processSecurityEvent(event)
}
// RecordAuthenticationFailure records an authentication failure event
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
if details == nil {
details = make(map[string]interface{})
}
details["reason"] = reason
sm.RecordSecurityEvent(
AuthFailure,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Authentication failed: %s", reason),
details,
true,
)
}
// RecordTokenValidationFailure records a token validation failure
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
atomic.AddInt64(&sm.tokenValidationFails, 1)
details := map[string]interface{}{
"reason": reason,
}
@@ -174,59 +219,50 @@ func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, req
details["token_prefix"] = tokenPrefix
}
event := SecurityEvent{
Type: "token_validation_failure",
Severity: "medium",
Timestamp: time.Now(),
ClientIP: clientIP,
UserAgent: userAgent,
RequestPath: requestPath,
Message: fmt.Sprintf("Token validation failed: %s", reason),
Details: details,
}
sm.recordIPFailure(clientIP, "token_failure")
sm.processSecurityEvent(event)
sm.RecordSecurityEvent(
TokenValidFailure,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Token validation failed: %s", reason),
details,
true,
)
}
// RecordRateLimitHit records when rate limiting is triggered
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
atomic.AddInt64(&sm.rateLimitHits, 1)
event := SecurityEvent{
Type: "rate_limit_hit",
Severity: "low",
Timestamp: time.Now(),
ClientIP: clientIP,
UserAgent: userAgent,
RequestPath: requestPath,
Message: "Rate limit exceeded",
Details: map[string]interface{}{
"limit_type": "token_verification",
},
details := map[string]interface{}{
"limit_type": "token_verification",
}
sm.recordIPFailure(clientIP, "rate_limit")
sm.processSecurityEvent(event)
sm.RecordSecurityEvent(
RateLimitHit,
clientIP,
userAgent,
requestPath,
"Rate limit exceeded",
details,
true, // Track IP failure for rate limiting
)
}
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
atomic.AddInt64(&sm.suspiciousRequests, 1)
event := SecurityEvent{
Type: "suspicious_activity",
Severity: "high",
Timestamp: time.Now(),
ClientIP: clientIP,
UserAgent: userAgent,
RequestPath: requestPath,
Message: fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
Details: details,
if details == nil {
details = make(map[string]interface{})
}
details["activity_type"] = activityType
sm.recordIPFailure(clientIP, "suspicious")
sm.processSecurityEvent(event)
sm.RecordSecurityEvent(
SuspiciousActivity,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
details,
true,
)
}
// recordIPFailure tracks failures for a specific IP address
@@ -351,30 +387,11 @@ func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
sm.eventHandlers = append(sm.eventHandlers, handler)
}
// GetSecurityMetrics returns current security metrics
// GetSecurityMetrics returns minimal security metrics
// This is kept for API compatibility but doesn't collect actual metrics
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
sm.ipMutex.RLock()
defer sm.ipMutex.RUnlock()
blockedIPs := 0
totalTrackedIPs := len(sm.ipFailures)
for _, tracker := range sm.ipFailures {
tracker.mutex.RLock()
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
blockedIPs++
}
tracker.mutex.RUnlock()
}
return map[string]interface{}{
"auth_failures": atomic.LoadInt64(&sm.authFailures),
"token_validation_fails": atomic.LoadInt64(&sm.tokenValidationFails),
"rate_limit_hits": atomic.LoadInt64(&sm.rateLimitHits),
"suspicious_requests": atomic.LoadInt64(&sm.suspiciousRequests),
"blocked_ips": blockedIPs,
"tracked_ips": totalTrackedIPs,
"uptime_hours": time.Since(time.Now().Add(-24 * time.Hour)).Hours(), // Placeholder
"tracked_ips": 0,
}
}
@@ -456,11 +473,20 @@ func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
// startCleanupRoutine starts the background cleanup routine
func (sm *SecurityMonitor) startCleanupRoutine() {
ticker := time.NewTicker(time.Duration(sm.config.CleanupIntervalMinutes) * time.Minute)
defer ticker.Stop()
// Use BackgroundTask abstraction for consistent management
cleanupTask = NewBackgroundTask(
"security-monitor-cleanup",
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
sm.cleanup,
sm.logger)
cleanupTask.Start()
}
for range ticker.C {
sm.cleanup()
// StopCleanupRoutine stops the background cleanup routine
func (sm *SecurityMonitor) StopCleanupRoutine() {
if cleanupTask != nil {
cleanupTask.Stop()
cleanupTask = nil
}
}
@@ -537,36 +563,4 @@ func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
}
}
// MetricsSecurityEventHandler tracks security metrics
type MetricsSecurityEventHandler struct {
eventCounts map[string]int64
mutex sync.RWMutex
}
// NewMetricsSecurityEventHandler creates a new metrics event handler
func NewMetricsSecurityEventHandler() *MetricsSecurityEventHandler {
return &MetricsSecurityEventHandler{
eventCounts: make(map[string]int64),
}
}
// HandleSecurityEvent implements SecurityEventHandler
func (h *MetricsSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.eventCounts[event.Type]++
h.eventCounts[fmt.Sprintf("%s_%s", event.Type, event.Severity)]++
}
// GetMetrics returns the current metrics
func (h *MetricsSecurityEventHandler) GetMetrics() map[string]int64 {
h.mutex.RLock()
defer h.mutex.RUnlock()
metrics := make(map[string]int64)
for k, v := range h.eventCounts {
metrics[k] = v
}
return metrics
}
// Note: MetricsSecurityEventHandler has been removed as part of metrics cleanup
+11 -63
View File
@@ -42,42 +42,19 @@ func TestSecurityMonitor(t *testing.T) {
})
t.Run("Token validation failure", func(t *testing.T) {
// Just verify the method doesn't panic
monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123")
metrics := monitor.GetSecurityMetrics()
if metrics["token_validation_fails"].(int64) == 0 {
t.Error("Expected token validation failures to be recorded")
}
})
t.Run("Rate limit hit", func(t *testing.T) {
// Just verify the method doesn't panic
monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api")
metrics := monitor.GetSecurityMetrics()
if metrics["rate_limit_hits"].(int64) == 0 {
t.Error("Expected rate limit hits to be recorded")
}
})
t.Run("Suspicious activity", func(t *testing.T) {
details := map[string]interface{}{"pattern": "unusual"}
// Just verify the method doesn't panic
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
metrics := monitor.GetSecurityMetrics()
if metrics["suspicious_requests"].(int64) == 0 {
t.Error("Expected suspicious activities to be recorded")
}
})
t.Run("Get security metrics", func(t *testing.T) {
metrics := monitor.GetSecurityMetrics()
if metrics["auth_failures"].(int64) == 0 {
t.Error("Expected some authentication failures")
}
if metrics["blocked_ips"] == nil {
t.Error("Expected blocked IPs count to be present")
}
})
}
@@ -98,8 +75,8 @@ func TestSuspiciousPatternDetector(t *testing.T) {
patterns := detector.DetectSuspiciousPatterns()
found := false
for _, pattern := range patterns {
if pattern == "rapid_failures_from_ip_192.168.1.100" {
for _, p := range patterns {
if p == "rapid_failures_from_ip_192.168.1.100" {
found = true
break
}
@@ -123,8 +100,8 @@ func TestSuspiciousPatternDetector(t *testing.T) {
patterns := detector.DetectSuspiciousPatterns()
found := false
for _, pattern := range patterns {
if pattern == "distributed_attack_pattern" {
for _, p := range patterns {
if p == "distributed_attack_pattern" {
found = true
break
}
@@ -204,24 +181,7 @@ func TestSecurityEventHandlers(t *testing.T) {
handler.HandleSecurityEvent(event)
})
t.Run("Metrics security event handler", func(t *testing.T) {
handler := NewMetricsSecurityEventHandler()
event := SecurityEvent{
Type: "authentication_failure",
ClientIP: "192.168.1.1",
Timestamp: time.Now(),
Message: "Test failure",
Severity: "medium",
}
handler.HandleSecurityEvent(event)
metrics := handler.GetMetrics()
if metrics["authentication_failure"] != 1 {
t.Errorf("Expected 1 authentication failure, got %v", metrics["authentication_failure"])
}
})
// Metrics security event handler test removed as part of metrics cleanup
}
func TestSecurityMonitorEventHandlers(t *testing.T) {
@@ -312,7 +272,7 @@ func TestSecurityEventTypes(t *testing.T) {
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
// Test different event types
// Test different event types - just verify they don't panic
monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil)
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
@@ -320,18 +280,6 @@ func TestSecurityEventTypes(t *testing.T) {
details := map[string]interface{}{"pattern": "test"}
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
metrics := monitor.GetSecurityMetrics()
if metrics["auth_failures"].(int64) == 0 {
t.Error("Expected authentication failures to be recorded")
}
if metrics["token_validation_fails"].(int64) == 0 {
t.Error("Expected token validation failures to be recorded")
}
if metrics["rate_limit_hits"].(int64) == 0 {
t.Error("Expected rate limit hits to be recorded")
}
if metrics["suspicious_requests"].(int64) == 0 {
t.Error("Expected suspicious activities to be recorded")
}
// Just verify GetSecurityMetrics doesn't panic
_ = monitor.GetSecurityMetrics()
}
+1259 -225
View File
File diff suppressed because it is too large Load Diff
+471 -1
View File
@@ -1,12 +1,16 @@
package traefikoidc
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
"time"
"github.com/gorilla/sessions"
)
func TestSessionPoolMemoryLeak(t *testing.T) {
@@ -218,4 +222,470 @@ func TestSessionObjectTracking(t *testing.T) {
t.Log("Session pool handling verified")
}
// This is intentionally left empty to remove unused code
// TestTokenCompressionIntegrity tests that token compression and decompression maintains JWT integrity
func TestTokenCompressionIntegrity(t *testing.T) {
tests := []struct {
name string
token string
wantFail bool
}{
{
name: "Valid JWT - Small",
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.signature",
},
{
name: "Valid JWT - Large",
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + strings.Repeat("eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9", 100) + ".signature",
},
{
name: "Invalid JWT - Wrong dot count",
token: "invalid.token",
wantFail: true,
},
{
name: "Invalid JWT - No dots",
token: "invalidtoken",
wantFail: true,
},
{
name: "Invalid JWT - Too many dots",
token: "part1.part2.part3.part4",
wantFail: true,
},
{
name: "Empty token",
token: "",
wantFail: false, // Empty tokens are handled gracefully
},
{
name: "Oversized token (>50KB)",
token: "part1." + strings.Repeat("A", 51*1024) + ".part3",
wantFail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compressed := compressToken(tt.token)
if tt.wantFail {
// For invalid tokens, compression should return original
if compressed != tt.token {
t.Errorf("Expected compression to return original for invalid token, got different result")
}
return
}
// For valid tokens, test round-trip integrity
decompressed := decompressToken(compressed)
if decompressed != tt.token {
t.Errorf("Token integrity lost: original=%q, compressed=%q, decompressed=%q",
tt.token, compressed, decompressed)
}
// Test that decompression is idempotent
decompressed2 := decompressToken(decompressed)
if decompressed2 != tt.token {
t.Errorf("Decompression not idempotent: %q != %q", decompressed2, tt.token)
}
})
}
}
// TestTokenCompressionCorruptionDetection tests that gzip corruption is detected and handled
func TestTokenCompressionCorruptionDetection(t *testing.T) {
validJWT := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.signature"
tests := []struct {
name string
corruptedInput string
expectOriginal bool
}{
{
name: "Invalid base64",
corruptedInput: "!@#$%^&*()",
expectOriginal: true,
},
{
name: "Valid base64 but invalid gzip",
corruptedInput: base64.StdEncoding.EncodeToString([]byte("not gzip data")),
expectOriginal: true,
},
{
name: "Truncated gzip data",
corruptedInput: "H4sI", // Incomplete gzip header
expectOriginal: true,
},
{
name: "Empty string",
corruptedInput: "",
expectOriginal: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := decompressToken(tt.corruptedInput)
if tt.expectOriginal && result != tt.corruptedInput {
t.Errorf("Expected decompression to return original corrupted input, got: %q", result)
}
})
}
// Test that valid compression still works
compressed := compressToken(validJWT)
decompressed := decompressToken(compressed)
if decompressed != validJWT {
t.Errorf("Valid compression/decompression failed: %q != %q", decompressed, validJWT)
}
}
// TestTokenChunkingIntegrity tests that large tokens are properly chunked and reassembled
func TestTokenChunkingIntegrity(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create tokens of various sizes to test chunking
testTokens := NewTestTokens()
tests := []struct {
name string
tokenSize int
expectChunked bool
}{
{
name: "Small token (no chunking)",
tokenSize: 100,
expectChunked: false,
},
{
name: "Medium token (no chunking)",
tokenSize: 800, // FIXED: Reduced further to account for new conservative chunk size (1200 bytes)
expectChunked: false,
},
{
name: "Large token (chunking required)",
tokenSize: 5000,
expectChunked: true,
},
{
name: "Very large token (multiple chunks)",
tokenSize: 10000,
expectChunked: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// FIXED: Use incompressible tokens to ensure chunking occurs
var token string
if tt.expectChunked {
token = testTokens.CreateIncompressibleToken(tt.tokenSize)
} else {
token = testTokens.CreateLargeValidJWT(tt.tokenSize)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Store the token
session.SetAccessToken(token)
// Retrieve the token
retrievedToken := session.GetAccessToken()
// Verify integrity
if retrievedToken != token {
t.Errorf("Token integrity lost:\nOriginal: %q\nRetrieved: %q", token, retrievedToken)
}
// Check if chunking occurred as expected
hasChunks := len(session.accessTokenChunks) > 0
if tt.expectChunked != hasChunks {
t.Errorf("Chunking expectation mismatch: expected chunked=%v, has chunks=%v", tt.expectChunked, hasChunks)
}
session.ReturnToPool()
})
}
}
// TestTokenChunkingCorruptionResistance tests handling of corrupted chunks
func TestTokenChunkingCorruptionResistance(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create a large token that will be chunked
largeToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, `{"sub":"test","data":"%s"}`, strings.Repeat("A", 5000))) +
".signature"
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Store the token (this should create chunks)
session.SetAccessToken(largeToken)
if len(session.accessTokenChunks) == 0 {
t.Skip("Token was not chunked, skipping corruption test")
}
tests := []struct {
corruptChunk func(chunks map[int]*sessions.Session)
name string
expectEmpty bool
}{
{
name: "Missing chunk in sequence",
corruptChunk: func(chunks map[int]*sessions.Session) {
// Remove a middle chunk
if len(chunks) > 1 {
delete(chunks, 1)
}
},
expectEmpty: true,
},
{
name: "Empty chunk data",
corruptChunk: func(chunks map[int]*sessions.Session) {
// Set first chunk to empty
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = ""
}
},
expectEmpty: true,
},
{
name: "Wrong data type in chunk",
corruptChunk: func(chunks map[int]*sessions.Session) {
// Set chunk data to wrong type
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = 123 // Should be string
}
},
expectEmpty: true,
},
{
name: "Oversized chunk",
corruptChunk: func(chunks map[int]*sessions.Session) {
// Set chunk to oversized data
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = strings.Repeat("A", maxCookieSize+200)
}
},
expectEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Get a fresh session
freshSession, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get fresh session: %v", err)
}
// Store the token again
freshSession.SetAccessToken(largeToken)
// Apply corruption
tt.corruptChunk(freshSession.accessTokenChunks)
// Try to retrieve the token
retrievedToken := freshSession.GetAccessToken()
if tt.expectEmpty {
if retrievedToken != "" {
t.Errorf("Expected empty token due to corruption, got: %q", retrievedToken)
}
} else {
if retrievedToken != largeToken {
t.Errorf("Expected original token despite corruption, got: %q", retrievedToken)
}
}
freshSession.ReturnToPool()
})
}
session.ReturnToPool()
}
// TestTokenSizeLimits tests that token size limits are enforced
func TestTokenSizeLimits(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
testTokens := NewTestTokens()
tests := []struct {
name string
tokenSize int
expectStored bool
}{
{
name: "Normal size token",
tokenSize: 1000,
expectStored: true,
},
{
name: "Large but acceptable token",
tokenSize: 30000, // FIXED: 30KB to ensure final size < 100KB limit
expectStored: true,
},
{
name: "Oversized token (>100KB)",
tokenSize: 120000, // FIXED: 120KB to ensure rejection after compression
expectStored: false, // Should be rejected
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// FIXED: Use proper token generation that accounts for base64 encoding
var token string
if tt.expectStored {
token = testTokens.CreateLargeValidJWT(tt.tokenSize)
} else {
token = testTokens.CreateIncompressibleToken(tt.tokenSize)
}
// Store the token
session.SetAccessToken(token)
// Try to retrieve it
retrievedToken := session.GetAccessToken()
if tt.expectStored {
if retrievedToken != token {
t.Errorf("Expected token to be stored and retrieved, but got different token")
}
} else {
if retrievedToken == token {
t.Errorf("Expected oversized token to be rejected, but it was stored")
}
}
})
}
}
// TestConcurrentTokenOperations tests thread safety of token operations
func TestConcurrentTokenOperations(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
const numGoroutines = 10
const numOperations = 100
// Test concurrent access and refresh token operations
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
for j := 0; j < numOperations; j++ {
// Create unique tokens for each goroutine/operation
accessToken := ValidAccessToken
refreshToken := fmt.Sprintf("refresh_token_%d_%d", id, j)
// Concurrent operations
session.SetAccessToken(accessToken)
session.SetRefreshToken(refreshToken)
retrievedAccess := session.GetAccessToken()
retrievedRefresh := session.GetRefreshToken()
// Verify tokens are still valid (should be one of the tokens set by any goroutine)
if retrievedAccess != "" && strings.Count(retrievedAccess, ".") != 2 {
t.Errorf("Retrieved access token has invalid format: %q", retrievedAccess)
}
if retrievedRefresh != "" && len(retrievedRefresh) < 10 {
t.Errorf("Retrieved refresh token is too short: %q", retrievedRefresh)
}
}
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
<-done
}
}
// TestSessionValidationAndCleanup tests session validation and orphan cleanup
func TestSessionValidationAndCleanup(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set tokens that will create chunks
largeToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat(`{"data":"large"}`, 500))) +
".signature"
session.SetAccessToken(largeToken)
session.SetRefreshToken("refresh_token_test")
// Save session to create cookies
if err := session.Save(req, rw); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify chunks were created
if len(session.accessTokenChunks) == 0 {
t.Log("No chunks created, large token test may not be applicable")
}
// Test cleanup by clearing session
if err := session.Clear(req, rw); err != nil {
t.Logf("Clear returned error (may be expected): %v", err)
}
// Verify tokens are cleared
if token := session.GetAccessToken(); token != "" {
t.Errorf("Access token should be empty after clear, got: %q", token)
}
if token := session.GetRefreshToken(); token != "" {
t.Errorf("Refresh token should be empty after clear, got: %q", token)
}
}
+34 -90
View File
@@ -26,96 +26,28 @@ type TemplatedHeader struct {
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
type Config struct {
// ProviderURL is the base URL of the OIDC provider (required)
// Example: https://accounts.google.com
ProviderURL string `json:"providerURL"`
// RevocationURL is the endpoint for revoking tokens (optional)
// If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"`
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
// This enhances security but might not be supported by all OIDC providers
// Default: false
EnablePKCE bool `json:"enablePKCE"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback
CallbackURL string `json:"callbackURL"`
// LogoutURL is the path for handling logout requests (optional)
// If not provided, it will be set to CallbackURL + "/logout"
LogoutURL string `json:"logoutURL"`
// ClientID is the OAuth 2.0 client identifier (required)
ClientID string `json:"clientID"`
// ClientSecret is the OAuth 2.0 client secret (required)
ClientSecret string `json:"clientSecret"`
// Scopes defines the OAuth 2.0 scopes to request (optional)
// Defaults to ["openid", "profile", "email"] if not provided
Scopes []string `json:"scopes"`
// LogLevel sets the logging verbosity (optional)
// Valid values: "debug", "info", "error"
// Default: "info"
LogLevel string `json:"logLevel"`
// SessionEncryptionKey is used to encrypt session data (required)
// Must be a secure random string
SessionEncryptionKey string `json:"sessionEncryptionKey"`
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
// Default: false
ForceHTTPS bool `json:"forceHTTPS"`
// RateLimit sets the maximum number of requests per second (optional)
// Default: 100
RateLimit int `json:"rateLimit"`
// ExcludedURLs lists paths that bypass authentication (optional)
// Example: ["/health", "/metrics"]
ExcludedURLs []string `json:"excludedURLs"`
// AllowedUserDomains restricts access to specific email domains (optional)
// Example: ["company.com", "subsidiary.com"]
AllowedUserDomains []string `json:"allowedUserDomains"`
// AllowedUsers restricts access to specific email addresses (optional)
// Example: ["user1@example.com", "user2@example.com"]
AllowedUsers []string `json:"allowedUsers"`
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
// Example: ["admin", "developer"]
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
// OIDCEndSessionURL is the provider's end session endpoint (optional)
// If not provided, it will be discovered from provider metadata
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
// Default: "/"
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
HTTPClient *http.Client
// RefreshGracePeriodSeconds defines how many seconds before a token expires
// the plugin should attempt to refresh it proactively (optional)
// Default: 60
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
// Headers defines custom HTTP headers to set with templated values (optional)
// Values can reference tokens and claims using Go templates with the following variables:
// - {{.AccessToken}} - The access token (ID token)
// - {{.IdToken}} - Same as AccessToken (for consistency)
// - {{.RefreshToken}} - The refresh token
// - {{.Claims.email}} - Access token claims (use proper case for claim names)
// Examples:
//
// [{Name: "X-Forwarded-Email", Value: "{{.Claims.email}}"}]
// [{Name: "Authorization", Value: "Bearer {{.AccessToken}}"}]
Headers []TemplatedHeader `json:"headers"`
HTTPClient *http.Client
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
LogLevel string `json:"logLevel"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedUsers []string `json:"allowedUsers"`
Scopes []string `json:"scopes"`
Headers []TemplatedHeader `json:"headers"`
RateLimit int `json:"rateLimit"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
ForceHTTPS bool `json:"forceHTTPS"`
EnablePKCE bool `json:"enablePKCE"`
OverrideScopes bool `json:"overrideScopes"`
}
const (
@@ -156,6 +88,7 @@ func CreateConfig() *Config {
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
EnablePKCE: false, // PKCE is opt-in
OverrideScopes: false, // Default to appending scopes, not overriding
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
}
@@ -538,6 +471,17 @@ func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// newNoOpLogger creates a silent logger that doesn't output anything.
// This is useful for internal components that need a logger instance
// but should not produce any output by default.
func newNoOpLogger() *Logger {
return &Logger{
logError: log.New(io.Discard, "", 0),
logInfo: log.New(io.Discard, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
}
// handleError logs an error message using the provided logger and sends an HTTP error
// response to the client with the specified message and status code.
//
+32 -10
View File
@@ -7,6 +7,19 @@ import (
"testing"
)
// Helper function to compare string slices
func equalSlices(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func TestCreateConfig(t *testing.T) {
t.Run("Default Values", func(t *testing.T) {
config := CreateConfig()
@@ -36,27 +49,36 @@ func TestCreateConfig(t *testing.T) {
if !config.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true by default")
}
// Check OverrideScopes default
if config.OverrideScopes {
t.Error("Expected OverrideScopes to be false by default")
}
})
t.Run("Custom Values Preserved", func(t *testing.T) {
t.Run("Config Can Hold Custom Values", func(t *testing.T) {
config := CreateConfig()
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
config.ForceHTTPS = false
config.OverrideScopes = true
// Verify custom values are not overwritten
// Verify config struct can hold custom values
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
t.Error("Custom scopes were overwritten")
t.Error("Config struct cannot hold custom scopes")
}
if config.LogLevel != "debug" {
t.Error("Custom log level was overwritten")
t.Error("Config struct cannot hold custom log level")
}
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
t.Error("Config struct cannot hold custom rate limit")
}
if config.ForceHTTPS {
t.Error("Custom ForceHTTPS value was overwritten")
t.Error("Config struct cannot hold custom ForceHTTPS value")
}
if !config.OverrideScopes {
t.Error("Config struct cannot hold custom OverrideScopes value")
}
})
}
@@ -241,10 +263,10 @@ func TestLogger(t *testing.T) {
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
name string
logLevel string
}{
{
name: "Debug Level",
@@ -392,9 +414,9 @@ func TestHandleError(t *testing.T) {
// Test helper types
type testResponseRecorder struct {
statusCode int
body string
headers map[string][]string
body string
statusCode int
}
func (r *testResponseRecorder) Header() http.Header {
+376 -7
View File
@@ -143,6 +143,59 @@ func TestTemplateExecution(t *testing.T) {
expectedValue: "",
expectError: true, // Parsing should fail
},
{
name: "Custom Claims",
templateText: "Role: {{.Claims.role}}, Department: {{.Claims.department}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
"role": "admin",
"department": "engineering",
},
},
expectedValue: "Role: admin, Department: engineering",
expectError: false,
},
{
name: "Nested Custom Claims",
templateText: "Org: {{.Claims.metadata.organization}}, Team: {{.Claims.metadata.team}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
"metadata": map[string]interface{}{
"organization": "company-name",
"team": "platform",
},
},
},
expectedValue: "Org: company-name, Team: platform",
expectError: false,
},
{
name: "Email Claims",
templateText: "Email: {{.Claims.email}}, Verified: {{.Claims.email_verified}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
"email_verified": true,
},
},
expectedValue: "Email: user@example.com, Verified: true",
expectError: false,
},
{
name: "User Identity Claims",
templateText: "Name: {{.Claims.name}}, Subject: {{.Claims.sub}}, Username: {{.Claims.preferred_username}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"name": "John Doe",
"sub": "user123",
"preferred_username": "johndoe",
},
},
expectedValue: "Name: John Doe, Subject: user123, Username: johndoe",
expectError: false,
},
}
for _, tc := range tests {
@@ -176,23 +229,153 @@ func TestTemplateExecution(t *testing.T) {
// TestTemplateExecutionContext tests the specific template data context used in processAuthorizedRequest
func TestTemplateExecutionContext(t *testing.T) {
// Define a test struct that matches the one used in processAuthorizedRequest
// Test cases for map-based template data, matching the new implementation
mapTests := []struct {
name string
templateText string
data map[string]interface{}
expectedValue string
}{
{
name: "Access and ID token distinction with map",
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IdToken": "id-token-value",
"Claims": map[string]interface{}{},
"RefreshToken": "refresh-token-value",
},
expectedValue: "Access: access-token-value ID: id-token-value",
},
{
name: "Combining tokens and claims with map",
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
data: map[string]interface{}{
"AccessToken": "access-token",
"IdToken": "id-token",
"Claims": map[string]interface{}{
"sub": "user123",
},
"RefreshToken": "refresh-token",
},
expectedValue: "User: user123 Token: access-token",
},
{
name: "Authorization header with Bearer token",
templateText: "Bearer {{.AccessToken}}",
data: map[string]interface{}{
"AccessToken": "jwt-access-token",
"IdToken": "id-token",
"Claims": map[string]interface{}{},
},
expectedValue: "Bearer jwt-access-token",
},
{
name: "Boolean template data with AccessToken",
templateText: "Bearer {{.AccessToken}}",
data: map[string]interface{}{
"AccessToken": true, // Test boolean values to ensure they render correctly
},
expectedValue: "Bearer true",
},
{
name: "Custom non-standard claims in ID token",
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IdToken": "id-token-value",
"Claims": map[string]interface{}{
"email": "user@example.com",
"role": "admin",
"permissions": "read:all,write:own",
},
},
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
},
{
name: "Deeply nested custom claims",
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"Claims": map[string]interface{}{
"app_metadata": map[string]interface{}{
"organization": map[string]interface{}{
"name": "acme-corp",
"id": "org-123",
},
"team": "platform",
},
},
},
expectedValue: "X-Organization: acme-corp, X-Team: platform",
},
{
name: "Email in claims",
templateText: "X-User-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IdToken": "id-token-value",
"Claims": map[string]interface{}{
"email": "user@example.com",
"email_verified": true,
},
},
expectedValue: "X-User-Email: user@example.com, X-Email-Verified: true",
},
{
name: "User info from claims",
templateText: "X-User-ID: {{.Claims.sub}}, X-User-Name: {{.Claims.name}}, X-Username: {{.Claims.preferred_username}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IdToken": "id-token-value",
"Claims": map[string]interface{}{
"sub": "user123456",
"name": "Jane Doe",
"preferred_username": "jane.doe",
},
},
expectedValue: "X-User-ID: user123456, X-User-Name: Jane Doe, X-Username: jane.doe",
},
}
// Run map-based tests (matching the new implementation)
for _, tc := range mapTests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
if err != nil {
t.Fatalf("Failed to execute template: %v", err)
}
result := buf.String()
if result != tc.expectedValue {
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
}
})
}
// For backward compatibility, also test the original struct-based implementation
type templateData struct {
Claims map[string]interface{}
AccessToken string
IdToken string
RefreshToken string
Claims map[string]interface{}
}
// Test cases
tests := []struct {
// Test cases for struct-based template data (original implementation)
structTests := []struct {
name string
templateText string
data templateData
expectedValue string
}{
{
name: "Access and ID token distinction",
name: "Access and ID token distinction with struct",
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
data: templateData{
AccessToken: "access-token-value",
@@ -202,7 +385,7 @@ func TestTemplateExecutionContext(t *testing.T) {
expectedValue: "Access: access-token-value ID: id-token-value",
},
{
name: "Combining tokens and claims",
name: "Combining tokens and claims with struct",
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
data: templateData{
AccessToken: "access-token",
@@ -213,9 +396,36 @@ func TestTemplateExecutionContext(t *testing.T) {
},
expectedValue: "User: user123 Token: access-token",
},
{
name: "Custom claims with struct",
templateText: "X-Custom: {{.Claims.custom_field}}, X-Group: {{.Claims.group}}",
data: templateData{
AccessToken: "access-token",
IdToken: "id-token",
Claims: map[string]interface{}{
"sub": "user123",
"custom_field": "custom-value",
"group": "admins",
},
},
expectedValue: "X-Custom: custom-value, X-Group: admins",
},
{
name: "Email claim in struct context",
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}",
data: templateData{
AccessToken: "access-token",
IdToken: "id-token",
Claims: map[string]interface{}{
"email": "user@example.com",
"name": "John Smith",
},
},
expectedValue: "X-Email: user@example.com, X-Name: John Smith",
},
}
for _, tc := range tests {
for _, tc := range structTests {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
if err != nil {
@@ -235,3 +445,162 @@ func TestTemplateExecutionContext(t *testing.T) {
})
}
}
// TestRegressionBooleanAccessToken specifically tests the regression case where
// a boolean value was causing "can't evaluate field AccessToken in type bool" error
func TestRegressionBooleanAccessToken(t *testing.T) {
// Test the specific case where we execute a template referencing AccessToken
// using a boolean context value
testCases := []struct {
name string
templateText string
dataContext interface{}
expectedValue string
expectError bool // Added to skip the test that demonstrates the error
}{
{
name: "Map with boolean as root",
templateText: "{{.AccessToken}}",
dataContext: map[string]interface{}{"AccessToken": "token-value"},
expectedValue: "token-value",
expectError: false,
},
{
name: "Boolean as root context",
templateText: "{{.AccessToken}}",
dataContext: true,
expectedValue: "<no value>",
expectError: true, // Skip this test as it demonstrates the error we're fixing
},
{
name: "Bearer with map context",
templateText: "Bearer {{.AccessToken}}",
dataContext: map[string]interface{}{"AccessToken": "token-value"},
expectedValue: "Bearer token-value",
expectError: false,
},
{
name: "Complex nesting with authorization",
templateText: "Authorization: Bearer {{.AccessToken}}",
dataContext: map[string]interface{}{
"AccessToken": "jwt-token-123",
"something": true,
"anotherField": map[string]interface{}{
"nested": "value",
},
},
expectedValue: "Authorization: Bearer jwt-token-123",
expectError: false,
},
{
name: "Custom claims access",
templateText: "X-User-Role: {{.Claims.role}}, X-User-Groups: {{.Claims.groups}}",
dataContext: map[string]interface{}{
"AccessToken": "jwt-token-xyz",
"Claims": map[string]interface{}{
"email": "user@example.com",
"role": "admin",
"groups": "group1,group2,group3",
"custom_data": map[string]interface{}{
"organization": "company-name",
"department": "engineering",
},
},
},
expectedValue: "X-User-Role: admin, X-User-Groups: group1,group2,group3",
expectError: false,
},
{
name: "Nested custom claims access",
templateText: "X-Organization: {{.Claims.custom_data.organization}}, X-Department: {{.Claims.custom_data.department}}",
dataContext: map[string]interface{}{
"Claims": map[string]interface{}{
"custom_data": map[string]interface{}{
"organization": "company-name",
"department": "engineering",
},
},
},
expectedValue: "X-Organization: company-name, X-Department: engineering",
expectError: false,
},
{
name: "Azure AD specific claims",
templateText: "X-TenantID: {{.Claims.tid}}, X-Roles: {{.Claims.roles}}",
dataContext: map[string]interface{}{
"Claims": map[string]interface{}{
"tid": "tenant-id-12345",
"roles": "User,Admin,Developer",
},
},
expectedValue: "X-TenantID: tenant-id-12345, X-Roles: User,Admin,Developer",
expectError: false,
},
{
name: "Auth0 specific claims",
templateText: "X-Permissions: {{.Claims.permissions}}, X-AppMetadata: {{.Claims.app_metadata.plan}}",
dataContext: map[string]interface{}{
"Claims": map[string]interface{}{
"permissions": "read:products,write:orders",
"app_metadata": map[string]interface{}{
"plan": "premium",
"status": "active",
"trial_ended": false,
},
},
},
expectedValue: "X-Permissions: read:products,write:orders, X-AppMetadata: premium",
expectError: false,
},
{
name: "Standard claims with email",
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}, X-Subject: {{.Claims.sub}}",
dataContext: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
"name": "John Doe",
"sub": "auth0|12345",
},
},
expectedValue: "X-Email: user@example.com, X-Name: John Doe, X-Subject: auth0|12345",
expectError: false,
},
{
name: "Verified email claim",
templateText: "X-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
dataContext: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
"email_verified": true,
},
},
expectedValue: "X-Email: user@example.com, X-Email-Verified: true",
expectError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
// Skip tests that demonstrate the error
if tc.expectError {
t.Skip("Skipping test that demonstrates the error we're fixing")
}
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.dataContext)
if err != nil {
t.Fatalf("Failed to execute template: %v", err)
}
result := buf.String()
if result != tc.expectedValue {
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
}
})
}
}
+5 -3
View File
@@ -19,12 +19,12 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
ts.Setup()
tests := []struct {
name string
headers []TemplatedHeader
sessionSetup func(*SessionData)
claims map[string]interface{}
expectedHeaders map[string]string
interceptedHeaders map[string]string
name string
headers []TemplatedHeader
}{
{
name: "Basic Email Header",
@@ -426,9 +426,9 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
ts.Setup()
tests := []struct {
claims map[string]interface{}
name string
headers []TemplatedHeader
claims map[string]interface{}
shouldExecuteCheck bool
}{
{
@@ -577,6 +577,7 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
func createLargeTemplate(size int) string {
template := "{{with .Claims}}"
for i := 0; i < size; i++ {
if i > 0 {
template += ","
}
@@ -590,6 +591,7 @@ func createLargeTemplate(size int) string {
func createLargeClaims(size int) map[string]interface{} {
claims := make(map[string]interface{})
for i := 0; i < size; i++ {
claims["email"] = "largeclaimsuser@example.com" // Add email claim
key := "field" + string(rune('a'+i%26)) + string(rune('0'+i%10))
claims[key] = "value" + string(rune('a'+i%26)) + string(rune('0'+i%10))
}
+400
View File
@@ -0,0 +1,400 @@
package traefikoidc
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"time"
)
// TestTokens provides a comprehensive set of standardized test tokens
// for consistent testing across the entire codebase.
type TestTokens struct{}
// NewTestTokens creates a new TestTokens instance
func NewTestTokens() *TestTokens {
return &TestTokens{}
}
// Valid JWT tokens for testing
const (
// ValidAccessToken - A properly formatted JWT access token for testing
ValidAccessToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiJlNDcxN2RhZDBmZjAyOTNkIiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.bmwp-vk0B7Ir9UiUkzib8L7yJbebJ00o3U9QrB6gP2H9-RfqyCbN8M9Rkx7Rb8Vdh3YzqkBBoLS_G0i414rs2I9uABnTC4E6-63qkGdUrLB7p-XbjcRW2RoIBwXHk7lfumi8eX0uWzBsJ9CY0__UECVsex5XORfBb4Bcqj0LK4y-glxkpI51I7BPySfciWC_PkdaQ1Qe5pCAlxeNs2E9NMGXp-Ox6vAufUzoC2cws1LswGPPP6icQ-Zlzd5WMCIWhdIkN4yTxk8FMqsTC52k2zskRHNSSd4DDVETonfzawZNqDcMpnTyN53sCJ9UHiQTl9mCm61ttYW-W9Gc-ze4Xw"
// ValidIDToken - A properly formatted JWT ID token for testing
ValidIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiI2YzBjZTZmMTM4Y2EzMzc2IiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.RBQYejA9vP4lnh2EhFqWerePWaCyDTF0ZE1jlU2xm4g2wWVeaEHpv5SNg92_gwk633N9xx7ugS0UrlEu4qbT7wSb1HBDR00q_andyYnyFk4OoxPpD0AqHkVr-pjS-Z7UCGF3sLgQ4ECmU9695PIys3XvgUGMzEn_mK-PHcpY5AnbBGFsbj7epUld_sb6WfjjjwAa8kKfKObPvaIpuJ4TlxI1Uf0wYOoIA0zh5ipeAn-i8Ud-GErxis1Hp8UQK7IRolXpToiXnFcnf3vI3eCS7Yu3oPl7LRxTxKMCI9h0MCwu25ZNsOg2C9ohyebpU0jbURX9Q74GNOaphv-Lz9rCRA"
// ValidRefreshToken - A properly formatted refresh token for testing
ValidRefreshToken = "valid-refresh-token-12345"
// MinimalValidJWT - The shortest valid JWT for testing
MinimalValidJWT = "h.p.s"
// ValidRefreshTokenGoogle - A Google-style refresh token for testing
ValidRefreshTokenGoogle = "google_refresh_token_12345"
)
// Invalid tokens for testing validation
const (
// InvalidTokenNoDots - Token with no dots (invalid JWT format)
InvalidTokenNoDots = "notajwttoken"
// InvalidTokenOneDot - Token with one dot (invalid JWT format)
InvalidTokenOneDot = "header.payload"
// InvalidTokenThreeDots - Token with three dots (invalid JWT format)
InvalidTokenThreeDots = "header.payload.signature.extra"
// EmptyToken - Empty token
EmptyToken = ""
// CorruptedBase64Token - Token with invalid base64 data for chunking tests
CorruptedBase64Token = "corrupted_base64_!@#$"
)
// CreateLargeValidJWT creates a JWT of approximately the specified size
// This replaces the ad-hoc createLargeValidJWT function in tests
func (tt *TestTokens) CreateLargeValidJWT(targetSize int) string {
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
signature := "signature_" + tt.generateRandomString(32)
// Calculate required payload size
usedSize := len(header) + len(signature) + 2 // account for dots
payloadSize := targetSize - usedSize
if payloadSize < 50 {
payloadSize = 50
}
// Create a payload with realistic JWT claims
claims := map[string]interface{}{
"sub": "user123",
"iss": "https://example.com",
"aud": "client123",
"exp": 9999999999,
"iat": 1000000000,
}
dataSize := payloadSize - 100 // Account for other claims and base64 encoding
if dataSize < 10 {
dataSize = 10 // Minimum data size
}
claims["data"] = tt.generateRandomString(dataSize)
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
}
// CreateLargeRefreshToken creates a refresh token of approximately the specified size
func (tt *TestTokens) CreateLargeRefreshToken(targetSize int) string {
baseToken := "refresh_token_"
padding := tt.generateRandomString(targetSize - len(baseToken))
return baseToken + padding
}
// CreateExpiredJWT creates an expired JWT token for testing
func (tt *TestTokens) CreateExpiredJWT() string {
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
// Create claims with expired timestamp
claims := map[string]interface{}{
"sub": "user123",
"iss": "https://example.com",
"aud": "client123",
"exp": time.Now().Unix() - 3600, // Expired 1 hour ago
"iat": time.Now().Unix() - 7200, // Issued 2 hours ago
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
signature := "expired_signature"
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
}
// CreateUniqueValidJWT creates a unique valid JWT for concurrent testing
func (tt *TestTokens) CreateUniqueValidJWT(id string) string {
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
claims := map[string]interface{}{
"sub": "user_" + id,
"iss": "https://example.com",
"aud": "client123",
"exp": 9999999999,
"iat": 1000000000,
"jti": id,
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
signature := "sig_" + id
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
}
// CreateIncompressibleToken creates a token that cannot be compressed effectively
// This is useful for testing chunking scenarios where compression doesn't help
func (tt *TestTokens) CreateIncompressibleToken(targetSize int) string {
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
signature := "incompressible_signature_" + tt.generateRandomString(32)
// Calculate required payload size
usedSize := len(header) + len(signature) + 2 // account for dots
payloadSize := max(targetSize-usedSize, 100)
// Generate multiple random fields to prevent compression
randomFields := make(map[string]interface{})
randomFields["sub"] = "user123"
randomFields["iss"] = "https://example.com"
randomFields["aud"] = "client123"
randomFields["exp"] = 9999999999
randomFields["iat"] = 1000000000
// Add many random fields with random data to prevent compression
remainingSize := payloadSize - 200 // Account for base64 encoding and other fields
fieldCount := remainingSize / 100 // ~100 bytes per field
if fieldCount < 1 {
fieldCount = 1
}
for i := 0; i < fieldCount; i++ {
// Generate truly random data for each field
randomBytes := make([]byte, 50)
rand.Read(randomBytes)
fieldName := fmt.Sprintf("random_field_%d_%s", i, tt.generateRandomString(8))
randomFields[fieldName] = base64.StdEncoding.EncodeToString(randomBytes)
}
claimsJSON, _ := json.Marshal(randomFields)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
token := fmt.Sprintf("%s.%s.%s", header, payload, signature)
// If still too small, pad with more random data
if len(token) < targetSize {
padding := targetSize - len(token)
extraRandomBytes := make([]byte, padding/2)
rand.Read(extraRandomBytes)
randomFields["padding"] = base64.StdEncoding.EncodeToString(extraRandomBytes)
claimsJSON, _ = json.Marshal(randomFields)
payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
token = fmt.Sprintf("%s.%s.%s", header, payload, signature)
}
return token
}
// GetValidTokenSet returns a complete set of valid tokens for testing
func (tt *TestTokens) GetValidTokenSet() TokenSet {
return TokenSet{
AccessToken: ValidAccessToken,
IDToken: ValidIDToken,
RefreshToken: ValidRefreshToken,
}
}
// GetGoogleTokenSet returns tokens that simulate Google OIDC provider responses
func (tt *TestTokens) GetGoogleTokenSet() TokenSet {
return TokenSet{
AccessToken: ValidAccessToken,
IDToken: ValidIDToken,
RefreshToken: ValidRefreshTokenGoogle,
}
}
// GetLargeTokenSet returns a set of large tokens for chunking tests
func (tt *TestTokens) GetLargeTokenSet() TokenSet {
return TokenSet{
AccessToken: tt.CreateLargeValidJWT(5000),
IDToken: tt.CreateLargeValidJWT(2000),
RefreshToken: tt.CreateLargeRefreshToken(3000),
}
}
// GetInvalidTokens returns various invalid tokens for validation testing
func (tt *TestTokens) GetInvalidTokens() InvalidTokenSet {
return InvalidTokenSet{
NoDots: InvalidTokenNoDots,
OneDot: InvalidTokenOneDot,
ThreeDots: InvalidTokenThreeDots,
Empty: EmptyToken,
Corrupted: CorruptedBase64Token,
}
}
// generateRandomString creates a random string of the specified length
func (tt *TestTokens) generateRandomString(length int) string {
// FIXED: Handle negative or zero lengths safely
if length <= 0 {
return ""
}
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := 0; i < length; i++ {
randomByte := make([]byte, 1)
rand.Read(randomByte)
b[i] = charset[int(randomByte[0])%len(charset)]
}
return string(b)
}
// TokenSet represents a complete set of tokens for testing
type TokenSet struct {
AccessToken string
IDToken string
RefreshToken string
}
// InvalidTokenSet represents various invalid tokens for validation testing
type InvalidTokenSet struct {
NoDots string // Token with 0 dots
OneDot string // Token with 1 dot
ThreeDots string // Token with 3 dots
Empty string // Empty token
Corrupted string // Corrupted/invalid characters
}
// TestScenarios provides predefined test scenarios
type TestScenarios struct {
tokens *TestTokens
}
// NewTestScenarios creates a new TestScenarios instance
func NewTestScenarios() *TestScenarios {
return &TestScenarios{
tokens: NewTestTokens(),
}
}
// NormalFlow returns tokens for normal authentication flow testing
func (ts *TestScenarios) NormalFlow() TokenSet {
return ts.tokens.GetValidTokenSet()
}
// GoogleFlow returns tokens simulating Google OIDC provider
func (ts *TestScenarios) GoogleFlow() TokenSet {
return ts.tokens.GetGoogleTokenSet()
}
// ChunkingRequired returns large tokens that require chunking
func (ts *TestScenarios) ChunkingRequired() TokenSet {
return ts.tokens.GetLargeTokenSet()
}
// CorruptionTest returns tokens and corruption scenarios for testing
func (ts *TestScenarios) CorruptionTest() CorruptionTestSet {
return CorruptionTestSet{
ValidTokens: ts.tokens.GetValidTokenSet(),
InvalidTokens: ts.tokens.GetInvalidTokens(),
LargeTokens: ts.tokens.GetLargeTokenSet(),
CorruptedToken: CorruptedBase64Token,
}
}
// ConcurrentTest returns unique tokens for concurrent testing
func (ts *TestScenarios) ConcurrentTest(count int) []TokenSet {
sets := make([]TokenSet, count)
for i := 0; i < count; i++ {
sets[i] = TokenSet{
AccessToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("concurrent_%d", i)),
IDToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("id_%d", i)),
RefreshToken: fmt.Sprintf("refresh_concurrent_%d", i),
}
}
return sets
}
// CorruptionTestSet represents tokens and scenarios for corruption testing
type CorruptionTestSet struct {
ValidTokens TokenSet
InvalidTokens InvalidTokenSet
LargeTokens TokenSet
CorruptedToken string
}
// TokenValidationTestCases returns test cases for token validation
func (tt *TestTokens) TokenValidationTestCases() []ValidationTestCase {
return []ValidationTestCase{
{
Name: "Empty token",
Token: EmptyToken,
ExpectStored: true, // Empty tokens are allowed for clearing
ExpectRetrieved: false, // But return as empty
},
{
Name: "Single dot",
Token: InvalidTokenOneDot,
ExpectStored: false, // Invalid JWT format
ExpectRetrieved: false,
},
{
Name: "No dots",
Token: InvalidTokenNoDots,
ExpectStored: false, // Invalid JWT format
ExpectRetrieved: false,
},
{
Name: "Too many dots",
Token: InvalidTokenThreeDots,
ExpectStored: false, // Invalid JWT format
ExpectRetrieved: false,
},
{
Name: "Valid minimal JWT",
Token: MinimalValidJWT,
ExpectStored: true,
ExpectRetrieved: true,
},
{
Name: "Valid standard JWT",
Token: ValidAccessToken,
ExpectStored: true,
ExpectRetrieved: true,
},
}
}
// ValidationTestCase represents a single token validation test case
type ValidationTestCase struct {
Name string
Token string
ExpectStored bool
ExpectRetrieved bool
}
// Helper functions for common test patterns
// AssertValidTokenStorage verifies that a valid token can be stored and retrieved
func AssertValidTokenStorage(t TestingInterface, session *SessionData, token string) {
session.SetAccessToken(token)
retrieved := session.GetAccessToken()
if retrieved != token {
t.Errorf("Token storage failed: expected %q, got %q", token, retrieved)
}
}
// AssertInvalidTokenRejection verifies that an invalid token is rejected
func AssertInvalidTokenRejection(t TestingInterface, session *SessionData, token string) {
original := session.GetAccessToken()
session.SetAccessToken(token)
after := session.GetAccessToken()
if after != original {
t.Errorf("Invalid token was not rejected: expected %q, got %q", original, after)
}
}
// TestingInterface provides the minimal interface needed for testing
type TestingInterface interface {
Errorf(format string, args ...interface{})
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
+507
View File
@@ -0,0 +1,507 @@
package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/gorilla/sessions"
)
// TestTokenCorruptionScenario reproduces the exact failure pattern from GitHub issue #53:
// Token verified successfully multiple times, then fails with "signature verification failed"
func TestTokenCorruptionScenario(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// Create a valid JWT token
validJWT := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImV4cCI6OTk5OTk5OTk5OX0.signature"
tests := []struct {
corruptionScenario func(*SessionData)
name string
tokenSize int
iterations int
expectConsistent bool
}{
{
name: "Small token - multiple retrievals",
tokenSize: len(validJWT),
iterations: 10,
expectConsistent: true,
},
{
name: "Large chunked token - multiple retrievals",
tokenSize: 5000,
iterations: 10,
expectConsistent: true,
},
{
name: "Compression corruption simulation",
tokenSize: 2000,
iterations: 5,
expectConsistent: false, // Will be corrupted intentionally
corruptionScenario: func(session *SessionData) {
// Simulate corruption by directly modifying session values
if session.accessSession != nil {
// Simulate corrupted compressed data
session.accessSession.Values["token"] = "corrupted_base64_!@#$"
session.accessSession.Values["compressed"] = true
}
},
},
{
name: "Chunk reassembly corruption simulation",
tokenSize: 25000, // Large enough to force chunking even after compression
iterations: 5,
expectConsistent: false, // Will be corrupted intentionally
corruptionScenario: func(session *SessionData) {
// Simulate chunk corruption with invalid base64 characters
if len(session.accessTokenChunks) > 0 {
if chunk, exists := session.accessTokenChunks[0]; exists {
chunk.Values["token_chunk"] = "invalid_base64_!@#$%"
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Create token of specified size
token := createTokenOfSize(validJWT, tt.tokenSize)
// 1. Store the token
session.SetAccessToken(token)
t.Logf("Stored token of size %d bytes", len(token))
// 2. Verify token can be retrieved multiple times successfully
var retrievedTokens []string
for i := 0; i < tt.iterations; i++ {
retrieved := session.GetAccessToken()
retrievedTokens = append(retrievedTokens, retrieved)
if tt.expectConsistent && retrieved != token {
t.Errorf("Iteration %d: Token mismatch, expected consistency", i)
break
}
}
// 3. Apply corruption scenario if specified
if tt.corruptionScenario != nil {
tt.corruptionScenario(session)
}
// 4. Retrieve token after potential corruption
finalRetrieved := session.GetAccessToken()
if tt.expectConsistent {
// With fixes, token should still be retrievable correctly
if finalRetrieved != token {
t.Errorf("Final retrieval failed - corruption not handled correctly")
t.Logf("Expected: %q", token)
t.Logf("Got: %q", finalRetrieved)
}
} else {
// For corruption scenarios, expect empty string (graceful failure)
if finalRetrieved != "" {
t.Errorf("Expected corruption to result in empty token, got: %q", finalRetrieved)
}
}
// 5. Verify all previous retrievals were consistent (if expected)
if tt.expectConsistent {
for i, retrieved := range retrievedTokens {
if retrieved != token {
t.Errorf("Iteration %d produced inconsistent result", i)
}
}
}
})
}
}
// TestCompressionIntegrityFailure tests scenarios where compression fails integrity checks
func TestCompressionIntegrityFailure(t *testing.T) {
tests := []struct {
name string
token string
expectSame bool
}{
{
name: "Valid JWT",
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig",
expectSame: true,
},
{
name: "Invalid JWT - wrong dots",
token: "invalid.token",
expectSame: true, // Should return unchanged
},
{
name: "Oversized token",
token: "header." + strings.Repeat("A", 60000) + ".sig",
expectSame: true, // Should return unchanged due to size limit
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compressed := compressToken(tt.token)
if tt.expectSame && compressed != tt.token {
// If we expect the token to remain the same but it was compressed,
// verify round-trip integrity
decompressed := decompressToken(compressed)
if decompressed != tt.token {
t.Errorf("Compression integrity failed: original=%q, decompressed=%q", tt.token, decompressed)
}
}
})
}
}
// TestChunkReassemblyEdgeCases tests edge cases in chunk reassembly that could cause corruption
func TestChunkReassemblyEdgeCases(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Create a large token that will definitely be chunked
largeToken := createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig", 8000)
// Store the token to create chunks
session.SetAccessToken(largeToken)
if len(session.accessTokenChunks) == 0 {
t.Skip("Token was not chunked, skipping reassembly tests")
}
t.Logf("Token was split into %d chunks", len(session.accessTokenChunks))
// Test various corruption scenarios
corruptionTests := []struct {
corruption func(map[int]*sessions.Session)
name string
expectEmpty bool
}{
{
name: "Gap in chunk sequence",
corruption: func(chunks map[int]*sessions.Session) {
// Remove chunk 1 if it exists
delete(chunks, 1)
},
expectEmpty: true,
},
{
name: "Chunk with nil value",
corruption: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = nil
}
},
expectEmpty: true,
},
{
name: "Chunk with wrong type",
corruption: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = 12345 // Should be string
}
},
expectEmpty: true,
},
{
name: "Empty chunk data",
corruption: func(chunks map[int]*sessions.Session) {
if chunk, exists := chunks[0]; exists {
chunk.Values["token_chunk"] = ""
}
},
expectEmpty: true,
},
{
name: "Excessive chunk count",
corruption: func(chunks map[int]*sessions.Session) {
// This test simulates having too many chunks (>50 limit)
// We'll create a scenario by adding many fake chunks
for i := 0; i < 60; i++ {
fakeSession := &sessions.Session{Values: make(map[interface{}]interface{})}
fakeSession.Values["token_chunk"] = "fake_chunk_data"
chunks[i] = fakeSession
}
},
expectEmpty: true,
},
}
for _, ct := range corruptionTests {
t.Run(ct.name, func(t *testing.T) {
// Get a fresh session for each test
freshReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
freshSession, err := sm.GetSession(freshReq)
if err != nil {
t.Fatalf("Failed to get fresh session: %v", err)
}
defer freshSession.ReturnToPool()
// Store the large token again
freshSession.SetAccessToken(largeToken)
// Apply corruption
ct.corruption(freshSession.accessTokenChunks)
// Try to retrieve the token
retrieved := freshSession.GetAccessToken()
if ct.expectEmpty {
if retrieved != "" {
t.Errorf("Expected empty token due to corruption, got: %q", retrieved)
}
} else {
if retrieved != largeToken {
t.Errorf("Expected original token, got: %q", retrieved)
}
}
})
}
}
// TestRaceConditionProtection tests that concurrent access doesn't cause corruption
func TestRaceConditionProtection(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
const numGoroutines = 20
const numOperations = 50
// Create tokens of different sizes
tokens := []string{
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig1",
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig2", 3000),
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig3", 6000),
}
var wg sync.WaitGroup
errChan := make(chan error, numGoroutines*numOperations)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
tokenIndex := (goroutineID + j) % len(tokens)
expectedToken := tokens[tokenIndex]
// Set token
session.SetAccessToken(expectedToken)
// Retrieve token
retrieved := session.GetAccessToken()
// Verify it's a valid JWT (should have exactly 2 dots)
if retrieved != "" && strings.Count(retrieved, ".") != 2 {
errChan <- fmt.Errorf("goroutine %d, op %d: invalid JWT format in retrieved token: %q",
goroutineID, j, retrieved)
continue
}
// The retrieved token should be one of the valid tokens we set
// (due to concurrent access, it might not be the exact one we just set)
isValidToken := false
for _, validToken := range tokens {
if retrieved == validToken {
isValidToken = true
break
}
}
if retrieved != "" && !isValidToken {
errChan <- fmt.Errorf("goroutine %d, op %d: retrieved unknown token: %q",
goroutineID, j, retrieved)
}
}
}(i)
}
wg.Wait()
close(errChan)
// Check for any errors
for err := range errChan {
t.Error(err)
}
}
// TestMemoryExhaustionProtection tests protection against memory exhaustion attacks
func TestMemoryExhaustionProtection(t *testing.T) {
tests := []struct {
setupCorruption func() string
name string
expectRejection bool
}{
{
name: "Extremely large compressed data",
setupCorruption: func() string {
return base64.StdEncoding.EncodeToString(bytes.Repeat([]byte("A"), 200*1024)) // 200KB
},
expectRejection: true,
},
{
name: "Malformed gzip bomb attempt",
setupCorruption: func() string {
// Create data that looks like gzip but would decompress to huge size
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
gz.Write(bytes.Repeat([]byte("A"), 10*1024)) // 10KB that compresses well
gz.Close()
compressed := buf.Bytes()
// Modify to make it potentially dangerous
return base64.StdEncoding.EncodeToString(compressed)
},
expectRejection: false, // Our decompression has size limits
},
{
name: "Token with excessive chunk simulation",
setupCorruption: func() string {
// This will be tested in the session layer
return strings.Repeat("chunk.", 100) + "final"
},
expectRejection: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
corruptedData := tt.setupCorruption()
result := decompressToken(corruptedData)
if tt.expectRejection {
// Should return original corrupted data, not attempt decompression
if result != corruptedData {
t.Errorf("Expected rejection of dangerous data, but decompression was attempted")
}
}
// Verify no excessive memory was used (this test would catch OOM in practice)
// The fact that we reach this point means memory limits were effective
})
}
}
// TestBackwardCompatibility ensures that sessions created before the fixes still work
func TestBackwardCompatibility(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Simulate old-style session data (without new validation fields)
oldStyleToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.oldsig"
// Manually set token without going through new SetAccessToken validation
session.accessSession.Values["token"] = oldStyleToken
session.accessSession.Values["compressed"] = false
// Should still be retrievable
retrieved := session.GetAccessToken()
if retrieved != oldStyleToken {
t.Errorf("Backward compatibility failed: expected %q, got %q", oldStyleToken, retrieved)
}
// Test with simulated old compressed token
oldCompressed := compressToken(oldStyleToken)
session.accessSession.Values["token"] = oldCompressed
session.accessSession.Values["compressed"] = true
retrieved2 := session.GetAccessToken()
if retrieved2 != oldStyleToken {
t.Errorf("Backward compatibility with compression failed: expected %q, got %q", oldStyleToken, retrieved2)
}
}
// createTokenOfSize creates a JWT token of approximately the specified size
func createTokenOfSize(baseToken string, targetSize int) string {
parts := strings.Split(baseToken, ".")
if len(parts) != 3 {
return baseToken
}
header, payload, signature := parts[0], parts[1], parts[2]
currentSize := len(baseToken)
if currentSize >= targetSize {
return baseToken
}
// Expand the payload to reach target size
paddingNeeded := targetSize - len(header) - len(signature) - 2 // Account for dots
if paddingNeeded > 0 {
// Decode current payload, add padding, re-encode
decoded, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
// If we can't decode, just pad with random base64-safe characters to resist compression
randomBytes := make([]byte, paddingNeeded)
rand.Read(randomBytes)
// Encode as base64 to make it base64-safe
padData := base64.RawURLEncoding.EncodeToString(randomBytes)
payload = payload + padData
} else {
// Add padding to the JSON - use random data to resist compression
randomBytes := make([]byte, paddingNeeded/2)
rand.Read(randomBytes)
// Encode as base64 to make it JSON-safe
padData := base64.StdEncoding.EncodeToString(randomBytes)
newPayload := fmt.Sprintf(`{"original":%s,"padding":"%s"}`, string(decoded), padData)
payload = base64.RawURLEncoding.EncodeToString([]byte(newPayload))
}
}
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
}
+368 -5
View File
@@ -2,8 +2,12 @@ package traefikoidc
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"text/template"
"time"
@@ -15,10 +19,10 @@ import (
func TestTokenTypeDistinction(t *testing.T) {
// Define test data where AccessToken and IdToken are deliberately different
type templateData struct {
Claims map[string]interface{}
AccessToken string
IdToken string
RefreshToken string
Claims map[string]interface{}
}
testData := templateData{
@@ -257,10 +261,10 @@ func TestSessionIDTokenAccessToken(t *testing.T) {
t.Fatalf("Failed to get session: %v", err)
}
// Set test tokens
idToken := "test-id-token-123"
accessToken := "test-access-token-456"
refreshToken := "test-refresh-token-789"
// Set test tokens using standardized tokens
idToken := ValidIDToken
accessToken := ValidAccessToken
refreshToken := ValidRefreshToken
// Store tokens in session
session.SetIDToken(idToken)
@@ -309,3 +313,362 @@ func TestSessionIDTokenAccessToken(t *testing.T) {
t.Errorf("ID token and Access token should be different, but both are %q", retrievedIDToken)
}
}
// TestTokenCorruptionIntegrationFlows tests the complete token handling flow with corruption scenarios
func TestTokenCorruptionIntegrationFlows(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
tests := []struct {
corruptAction func(*SessionData)
name string
accessToken string
refreshToken string
idToken string
expectSuccess bool
}{
{
name: "Normal flow - small tokens",
accessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.access_sig",
refreshToken: "refresh_token_12345",
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
expectSuccess: true,
},
{
name: "Normal flow - large tokens (chunked)",
accessToken: createLargeValidJWT(5000),
refreshToken: createLargeRefreshToken(3000),
idToken: createLargeValidJWT(2000),
expectSuccess: true,
},
{
name: "Corrupted access token compression",
accessToken: createLargeValidJWT(3000),
refreshToken: "refresh_token_12345",
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
expectSuccess: false,
corruptAction: func(session *SessionData) {
// Corrupt compressed access token
if session.accessSession != nil {
session.accessSession.Values["token"] = "corrupted_compressed_data_!@#"
session.accessSession.Values["compressed"] = true
}
},
},
{
name: "Corrupted chunk in large token",
accessToken: createLargeValidJWT(8000), // Force chunking
refreshToken: "refresh_token_12345",
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
expectSuccess: false,
corruptAction: func(session *SessionData) {
// Corrupt first chunk
if len(session.accessTokenChunks) > 0 {
if chunk, exists := session.accessTokenChunks[0]; exists {
chunk.Values["token_chunk"] = "corrupted_chunk_data"
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
// Get session
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Store tokens
session.SetAccessToken(tt.accessToken)
session.SetRefreshToken(tt.refreshToken)
session.SetIDToken(tt.idToken)
session.SetAuthenticated(true)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Apply corruption if specified
if tt.corruptAction != nil {
tt.corruptAction(session)
}
// Test token retrieval after corruption
retrievedAccess := session.GetAccessToken()
retrievedRefresh := session.GetRefreshToken()
retrievedID := session.GetIDToken()
if tt.expectSuccess {
if retrievedAccess != tt.accessToken {
t.Errorf("Access token corruption: expected %q, got %q", tt.accessToken, retrievedAccess)
}
if retrievedRefresh != tt.refreshToken {
t.Errorf("Refresh token corruption: expected %q, got %q", tt.refreshToken, retrievedRefresh)
}
if retrievedID != tt.idToken {
t.Errorf("ID token corruption: expected %q, got %q", tt.idToken, retrievedID)
}
} else {
// For corruption scenarios, access token should be empty (graceful failure)
if retrievedAccess != "" {
t.Errorf("Expected corrupted access token to return empty, got: %q", retrievedAccess)
}
// Other tokens should still work
if retrievedRefresh != tt.refreshToken {
t.Errorf("Refresh token should not be affected by access token corruption: expected %q, got %q",
tt.refreshToken, retrievedRefresh)
}
}
})
}
}
// TestSessionPersistenceWithCorruption tests that session corruption is handled across requests
func TestSessionPersistenceWithCorruption(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
// First request - store tokens
req1 := httptest.NewRequest("GET", "/test", nil)
rr1 := httptest.NewRecorder()
session1, err := sm.GetSession(req1)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
largeToken := createLargeValidJWT(6000)
session1.SetAccessToken(largeToken)
session1.SetAuthenticated(true)
if err := session1.Save(req1, rr1); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies from first response
cookies := rr1.Result().Cookies()
session1.ReturnToPool()
// Second request - retrieve tokens with cookies
req2 := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
req2.AddCookie(cookie)
}
session2, err := sm.GetSession(req2)
if err != nil {
t.Fatalf("Failed to get session from cookies: %v", err)
}
defer session2.ReturnToPool()
// Verify token can be retrieved
retrieved := session2.GetAccessToken()
if retrieved != largeToken {
t.Errorf("Token persistence failed: expected %q, got %q", largeToken, retrieved)
}
// Simulate corruption by modifying chunks
if len(session2.accessTokenChunks) > 0 {
// Corrupt a middle chunk
chunkIndex := len(session2.accessTokenChunks) / 2
if chunk, exists := session2.accessTokenChunks[chunkIndex]; exists {
chunk.Values["token_chunk"] = "corrupted"
}
// Try to retrieve again - should detect corruption and return empty
retrievedAfterCorruption := session2.GetAccessToken()
if retrievedAfterCorruption != "" {
t.Errorf("Expected corruption to be detected, but got token: %q", retrievedAfterCorruption)
}
}
}
// TestConcurrentTokenOperationsWithCorruption tests concurrent access with intentional corruption
func TestConcurrentTokenOperationsWithCorruption(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
const numGoroutines = 10
const numOperations = 20
done := make(chan bool, numGoroutines)
errorChan := make(chan error, numGoroutines*numOperations)
// Start concurrent operations
for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) {
defer func() { done <- true }()
for j := 0; j < numOperations; j++ {
// Create a unique valid token for each operation
token := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwib3AiOiIxMjMifQ.sig_%d_%d",
goroutineID, j)
// Store token
session.SetAccessToken(token)
// Retrieve token
retrieved := session.GetAccessToken()
// Validate retrieved token format
if retrieved != "" {
if strings.Count(retrieved, ".") != 2 {
errorChan <- fmt.Errorf("goroutine %d, op %d: invalid JWT format: %q",
goroutineID, j, retrieved)
continue
}
// Check if it's a reasonable length
if len(retrieved) < 10 || len(retrieved) > 100000 {
errorChan <- fmt.Errorf("goroutine %d, op %d: suspicious token length %d: %q",
goroutineID, j, len(retrieved), retrieved)
}
}
// Occasionally simulate corruption to test error handling
if j%5 == 0 && len(session.accessTokenChunks) > 0 {
// Intentionally corrupt a random chunk
for chunkID, chunk := range session.accessTokenChunks {
if chunkID%2 == 0 {
chunk.Values["token_chunk"] = "intentionally_corrupted"
break
}
}
}
}
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
<-done
}
close(errorChan)
// Check for any unexpected errors
errorCount := 0
for err := range errorChan {
t.Logf("Concurrent operation error: %v", err)
errorCount++
}
// We expect some corruption-related "errors" due to intentional corruption,
// but not format-related errors which would indicate actual corruption bugs
if errorCount > numGoroutines*numOperations/4 { // Allow up to 25% corruption-related issues
t.Errorf("Too many errors during concurrent operations: %d", errorCount)
}
}
// TestTokenValidationEdgeCases tests edge cases in token validation
func TestTokenValidationEdgeCases(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
req := httptest.NewRequest("GET", "/test", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
defer session.ReturnToPool()
// Use standardized test tokens
testTokens := NewTestTokens()
edgeCases := testTokens.TokenValidationTestCases()
for _, ec := range edgeCases {
t.Run(ec.Name, func(t *testing.T) {
// Clear any previous token
session.SetAccessToken("")
// Store the test token
originalToken := session.GetAccessToken()
session.SetAccessToken(ec.Token)
afterStoreToken := session.GetAccessToken()
if ec.ExpectStored {
if afterStoreToken != ec.Token {
t.Errorf("Expected token to be stored, but got different value")
}
} else {
if afterStoreToken != originalToken {
t.Errorf("Expected invalid token to be rejected, but it was stored")
}
}
// Test retrieval
finalToken := session.GetAccessToken()
if ec.ExpectRetrieved {
if finalToken != ec.Token {
t.Errorf("Expected token to be retrievable: %q, got: %q", ec.Token, finalToken)
}
} else {
if finalToken != "" {
t.Errorf("Expected empty token due to invalid format, got: %q", finalToken)
}
}
})
}
}
// Helper functions for test data creation
// createLargeValidJWT creates a JWT of approximately the specified size
func createLargeValidJWT(targetSize int) string {
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
signature := "signature_" + generateRandomString(32)
// Calculate required payload size
usedSize := len(header) + len(signature) + 2 // account for dots
payloadSize := targetSize - usedSize
if payloadSize < 50 {
payloadSize = 50
}
// Create a payload with realistic JWT claims
claims := map[string]interface{}{
"sub": "user123",
"iss": "https://example.com",
"aud": "client123",
"exp": 9999999999,
"iat": 1000000000,
"data": generateRandomString(payloadSize - 100), // Account for other claims
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
}
// createLargeRefreshToken creates a refresh token of approximately the specified size
func createLargeRefreshToken(targetSize int) string {
baseToken := "refresh_token_"
padding := generateRandomString(targetSize - len(baseToken))
return baseToken + padding
}