mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
11 Commits
v0.7.8
...
v0.6.2-beta6
| Author | SHA1 | Date | |
|---|---|---|---|
| 70443f0855 | |||
| 7a443c626c | |||
| 48de8265c5 | |||
| d8d1b74175 | |||
| c233aa92ef | |||
| c400251625 | |||
| 48faf7fadf | |||
| 84d7cd3d76 | |||
| 488264028b | |||
| e23135ded0 | |||
| cd307f88a1 |
+9
-8
@@ -35,11 +35,8 @@ testData:
|
||||
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
|
||||
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
|
||||
|
||||
scopes: # OAuth 2.0 scopes to request (default: ["openid", "email", "profile"])
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Include this to get role information from the provider
|
||||
scopes: # Additional scopes to append to defaults ["openid", "profile", "email"]
|
||||
- roles # Result: ["openid", "profile", "email", "roles"]
|
||||
|
||||
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
|
||||
- company.com
|
||||
@@ -153,11 +150,15 @@ configuration:
|
||||
scopes:
|
||||
type: array
|
||||
description: |
|
||||
The OAuth 2.0 scopes to request from the OIDC provider.
|
||||
Default: ["openid", "profile", "email"]
|
||||
Additional OAuth 2.0 scopes to append to the default scopes.
|
||||
Default scopes are always included: ["openid", "profile", "email"]
|
||||
|
||||
User-provided scopes are appended to defaults with automatic deduplication.
|
||||
For example, specifying ["roles", "custom_scope"] results in:
|
||||
["openid", "profile", "email", "roles", "custom_scope"]
|
||||
|
||||
Include "roles" or similar scope if you need role/group information.
|
||||
Note: For Google OAuth, the middleware automatically handles the
|
||||
Note: For Google OAuth, the middleware automatically handles the
|
||||
proper authentication parameters and does NOT require the "offline_access"
|
||||
scope (which Google rejects as invalid). See documentation for details.
|
||||
required: false
|
||||
|
||||
@@ -67,7 +67,8 @@ The middleware supports the following configuration options:
|
||||
|-----------|-------------|---------|---------|
|
||||
| `logoutURL` | The path for handling logout requests | `callbackURL + "/logout"` | `/oauth2/logout` |
|
||||
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
||||
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
|
||||
| `scopes` | OAuth 2.0 scopes to use for authentication | `["openid", "profile", "email"]` (always included by default) | `["roles", "custom_scope"]` (appended to defaults) |
|
||||
| `overrideScopes` | When true, replaces default scopes with provided scopes instead of appending | `false` | `true` (use only the scopes explicitly provided) |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
@@ -81,6 +82,79 @@ The middleware supports the following configuration options:
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
### Scope Behavior
|
||||
|
||||
The middleware supports two modes for handling OAuth 2.0 scopes, controlled by the `overrideScopes` parameter:
|
||||
|
||||
#### Default Append Mode (`overrideScopes: false`)
|
||||
|
||||
By default, the middleware uses an **append** behavior for OAuth 2.0 scopes:
|
||||
|
||||
- **Default scopes** are always included: `["openid", "profile", "email"]`
|
||||
- **User-provided scopes** are appended to the defaults with automatic deduplication
|
||||
- The final scope list maintains the order: defaults first, then user scopes
|
||||
|
||||
#### Override Mode (`overrideScopes: true`)
|
||||
|
||||
When `overrideScopes` is set to `true`, the middleware uses **replacement** behavior:
|
||||
|
||||
- Default scopes are **not** automatically included
|
||||
- Only the scopes explicitly provided in the `scopes` field are used
|
||||
- You must include all required scopes explicitly, including `openid` if needed
|
||||
|
||||
### Examples:
|
||||
|
||||
**Default behavior (no custom scopes):**
|
||||
```yaml
|
||||
# No scopes field specified
|
||||
# Result: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
**Default append behavior:**
|
||||
```yaml
|
||||
scopes:
|
||||
- roles
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
|
||||
```
|
||||
|
||||
**Overlapping scopes with append (automatic deduplication):**
|
||||
```yaml
|
||||
scopes:
|
||||
- openid # Duplicate - will be deduplicated
|
||||
- roles
|
||||
- profile # Duplicate - will be deduplicated
|
||||
- permissions
|
||||
# Result: ["openid", "profile", "email", "roles", "permissions"]
|
||||
```
|
||||
|
||||
**Using override mode:**
|
||||
```yaml
|
||||
overrideScopes: true
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "custom_scope"]
|
||||
```
|
||||
|
||||
**Empty scopes list with default behavior:**
|
||||
```yaml
|
||||
scopes: []
|
||||
# Result: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
**Empty scopes list with override mode:**
|
||||
```yaml
|
||||
overrideScopes: true
|
||||
scopes: []
|
||||
# Result: [] (Warning: empty scopes may cause authentication to fail)
|
||||
```
|
||||
|
||||
The default append behavior ensures essential OIDC scopes are always present, while the override mode gives you complete control over the exact scopes requested from the provider.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Configuration
|
||||
@@ -101,9 +175,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### With Excluded URLs (Public Access Paths)
|
||||
@@ -124,9 +196,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
excludedURLs:
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /public-data
|
||||
@@ -152,9 +222,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
@@ -178,9 +246,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUsers:
|
||||
- user1@example.com
|
||||
- user2@another.org
|
||||
@@ -204,9 +270,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
allowedUsers:
|
||||
@@ -239,10 +303,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Include this to get role information from the provider
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
@@ -269,9 +330,7 @@ spec:
|
||||
rateLimit: 500 # Requests per second (default: 100)
|
||||
forceHTTPS: false # Default is true for security
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### With Custom Post-Logout Redirect
|
||||
@@ -293,9 +352,7 @@ spec:
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /logged-out-page # Where to redirect after logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### With Templated Headers
|
||||
@@ -316,10 +373,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
@@ -352,9 +406,7 @@ spec:
|
||||
logoutURL: /oauth2/logout
|
||||
enablePKCE: true # Enables PKCE for added security
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### Google OIDC Configuration Example
|
||||
@@ -377,9 +429,7 @@ spec:
|
||||
callbackURL: /oauth2/callback # Adjust if needed
|
||||
logoutURL: /oauth2/logout # Optional: Adjust if needed
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
# Note: DO NOT manually add offline_access scope for Google
|
||||
# The middleware automatically handles Google-specific requirements
|
||||
refreshGracePeriodSeconds: 300 # Optional: Start refresh 5 min before expiry (default 60)
|
||||
@@ -408,9 +458,7 @@ spec:
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
Don't forget to create the secret:
|
||||
@@ -509,9 +557,7 @@ http:
|
||||
postLogoutRedirectURI: /logged-out-page
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
allowedUsers:
|
||||
|
||||
@@ -2,6 +2,57 @@ package traefikoidc
|
||||
|
||||
import "time"
|
||||
|
||||
// BackgroundTask represents a recurring task that runs in the background
|
||||
type BackgroundTask struct {
|
||||
stopChan chan struct{}
|
||||
taskFunc func()
|
||||
logger *Logger
|
||||
name string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewBackgroundTask creates a new background task
|
||||
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger) *BackgroundTask {
|
||||
return &BackgroundTask{
|
||||
name: name,
|
||||
interval: interval,
|
||||
stopChan: make(chan struct{}),
|
||||
taskFunc: taskFunc,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the background task execution
|
||||
func (bt *BackgroundTask) Start() {
|
||||
go bt.run()
|
||||
}
|
||||
|
||||
// Stop terminates the background task
|
||||
func (bt *BackgroundTask) Stop() {
|
||||
close(bt.stopChan)
|
||||
}
|
||||
|
||||
// run is the main execution loop for the background task
|
||||
func (bt *BackgroundTask) run() {
|
||||
ticker := time.NewTicker(bt.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
bt.logger.Debug("Starting background task: %s", bt.name)
|
||||
|
||||
// Run task immediately on startup
|
||||
bt.taskFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
bt.taskFunc()
|
||||
case <-bt.stopChan:
|
||||
bt.logger.Debug("Stopping background task: %s", bt.name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// autoCleanupRoutine periodically calls the provided cleanup function.
|
||||
// It starts a ticker with the given interval and executes the cleanup function
|
||||
// on each tick. The routine stops gracefully when a signal is received on the
|
||||
@@ -12,6 +63,8 @@ import "time"
|
||||
// - interval: The time duration between cleanup calls.
|
||||
// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop.
|
||||
// - cleanup: The function to call periodically for cleanup tasks.
|
||||
//
|
||||
// Deprecated: Use BackgroundTask instead.
|
||||
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// mockTraefikOidc extends TraefikOidc to override JWT verification for testing
|
||||
type mockTraefikOidc struct {
|
||||
*TraefikOidc
|
||||
}
|
||||
|
||||
// Override VerifyToken to avoid JWKS lookup in tests
|
||||
func (m *mockTraefikOidc) VerifyToken(token string) error {
|
||||
// Cache test claims to avoid "claims not found" errors
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
m.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed for testing
|
||||
}
|
||||
|
||||
// Override VerifyJWTSignatureAndClaims to avoid JWKS lookup in tests
|
||||
func (m *mockTraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
// Cache test claims to avoid "claims not found" errors
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
m.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed for testing
|
||||
}
|
||||
|
||||
func TestAzureOIDCRegression(t *testing.T) {
|
||||
// Create a mocked TraefikOidc instance configured for Azure AD
|
||||
mockLogger := NewLogger("debug")
|
||||
|
||||
// Configure for Azure AD provider
|
||||
baseOidc := &TraefikOidc{
|
||||
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
authURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/authorize",
|
||||
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
refreshGracePeriod: 60 * time.Second,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Add rate limiter
|
||||
logger: mockLogger,
|
||||
httpClient: createDefaultHTTPClient(), // Add HTTP client
|
||||
jwkCache: &JWKCache{}, // Add JWK cache
|
||||
tokenCache: NewTokenCache(),
|
||||
tokenBlacklist: NewCache(),
|
||||
allowedUserDomains: make(map[string]struct{}),
|
||||
allowedUsers: make(map[string]struct{}),
|
||||
allowedRolesAndGroups: make(map[string]struct{}),
|
||||
excludedURLs: make(map[string]struct{}),
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
|
||||
// Create the mock wrapper
|
||||
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
|
||||
|
||||
// Initialize session manager
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, mockLogger)
|
||||
tOidc.sessionManager = sessionManager
|
||||
|
||||
// Mock the JWT verification to avoid JWKS lookup issues
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
// For test tokens, always return success and cache claims
|
||||
if strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
|
||||
// Cache test claims for JWT tokens
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil
|
||||
}
|
||||
// For opaque tokens (non-JWT format), return success
|
||||
if !strings.Contains(token, ".") || strings.Count(token, ".") != 2 {
|
||||
return nil
|
||||
}
|
||||
// For JWT tokens, cache basic claims to avoid cache lookup issues
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed for test purposes
|
||||
},
|
||||
}
|
||||
|
||||
// Mock JWT verifier to avoid JWKS lookup
|
||||
tOidc.jwtVerifier = &mockJWTVerifier{
|
||||
verifyFunc: func(jwt *JWT, token string) error {
|
||||
// Also cache claims here to ensure they're available
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Azure provider detection works correctly", func(t *testing.T) {
|
||||
if !tOidc.isAzureProvider() {
|
||||
t.Error("Azure provider should be detected for Azure AD issuer URL")
|
||||
}
|
||||
|
||||
if tOidc.isGoogleProvider() {
|
||||
t.Error("Google provider should not be detected for Azure AD issuer URL")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure auth URL includes correct parameters", func(t *testing.T) {
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that response_mode=query was added for Azure
|
||||
if !strings.Contains(authURL, "response_mode=query") {
|
||||
t.Errorf("response_mode=query not added to Azure auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify offline_access scope is included for Azure providers
|
||||
if !strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("offline_access scope not included in Azure auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify Azure doesn't get Google-specific parameters
|
||||
if strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("access_type=offline incorrectly added to Azure auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
if strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent incorrectly added to Azure auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure access token validation takes priority", func(t *testing.T) {
|
||||
// Create a request and session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
session, _ := tOidc.sessionManager.GetSession(req)
|
||||
|
||||
// Set up session with Azure-style tokens
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
|
||||
// Create a valid JWT access token for testing
|
||||
accessTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
accessToken, _ := createMockJWT(accessTokenClaims)
|
||||
session.SetAccessToken(accessToken)
|
||||
|
||||
// Create an invalid/expired ID token
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
idToken, _ := createMockJWT(idTokenClaims)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
// Mock the token verification to simulate Azure behavior
|
||||
originalTokenVerifier := tOidc.tokenVerifier
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
if token == accessToken {
|
||||
// Access token validation succeeds - cache claims
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil
|
||||
}
|
||||
if token == idToken {
|
||||
// ID token validation fails (expired) - don't cache
|
||||
return newMockError("token has expired")
|
||||
}
|
||||
return newMockError("token validation failed")
|
||||
},
|
||||
}
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test Azure-specific validation
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
|
||||
// Azure should prioritize access token, so even with expired ID token,
|
||||
// user should still be authenticated since access token is valid
|
||||
if !authenticated {
|
||||
t.Error("Azure user should be authenticated when access token is valid, even if ID token is expired")
|
||||
}
|
||||
|
||||
if expired {
|
||||
t.Error("Azure session should not be marked as expired when access token is valid")
|
||||
}
|
||||
|
||||
// May need refresh if we want to get a fresh ID token
|
||||
if !needsRefresh {
|
||||
t.Log("Azure session may not need immediate refresh if access token is still valid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure handles opaque access tokens gracefully", func(t *testing.T) {
|
||||
// Create a request and session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
session, _ := tOidc.sessionManager.GetSession(req)
|
||||
|
||||
// Set up session with opaque access token (non-JWT)
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ValidAccessToken)
|
||||
|
||||
// Create a valid ID token for claims extraction
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
idToken, _ := createMockJWT(idTokenClaims)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
// Mock the token verification
|
||||
originalTokenVerifier := tOidc.tokenVerifier
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
if token == idToken {
|
||||
// ID token is valid - cache claims
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil
|
||||
}
|
||||
return newMockError("token validation failed")
|
||||
},
|
||||
}
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test Azure-specific validation with opaque token
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
|
||||
// Azure should handle opaque access tokens gracefully
|
||||
if !authenticated {
|
||||
t.Error("Azure user should be authenticated with opaque access token")
|
||||
}
|
||||
|
||||
if expired {
|
||||
t.Error("Azure session should not be expired with valid tokens")
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
t.Log("Azure session with opaque token may signal refresh to get JWT tokens")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure CSRF handling during token validation failures", func(t *testing.T) {
|
||||
// Create a request and session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session, _ := tOidc.sessionManager.GetSession(req)
|
||||
|
||||
// Set up session with CSRF token (simulating ongoing auth flow)
|
||||
session.SetCSRF("test-csrf-token-123")
|
||||
session.SetNonce("test-nonce-456")
|
||||
session.SetAuthenticated(false) // Not yet authenticated
|
||||
|
||||
// Save session to simulate real scenario
|
||||
session.Save(req, rw)
|
||||
|
||||
// Mock token verification to always fail (simulating Azure token issues)
|
||||
originalTokenVerifier := tOidc.tokenVerifier
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
return newMockError("azure token validation failed")
|
||||
},
|
||||
}
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test that CSRF is preserved during Azure validation failures
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
|
||||
// Should not be authenticated due to validation failure
|
||||
if authenticated {
|
||||
t.Error("Should not be authenticated when token validation fails")
|
||||
}
|
||||
|
||||
// Should be marked as expired since no tokens work
|
||||
if !expired && !needsRefresh {
|
||||
t.Error("Should be marked as needing refresh or expired when validation fails")
|
||||
}
|
||||
|
||||
// Verify CSRF token is still preserved in session
|
||||
if session.GetCSRF() != "test-csrf-token-123" {
|
||||
t.Error("CSRF token should be preserved during Azure token validation failures")
|
||||
}
|
||||
|
||||
if session.GetNonce() != "test-nonce-456" {
|
||||
t.Error("Nonce should be preserved during Azure token validation failures")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// createMockJWT creates a basic JWT token for testing purposes
|
||||
func createMockJWT(claims map[string]any) (string, error) {
|
||||
// Simple mock JWT - in real tests you'd use a proper JWT library
|
||||
// For this test, we'll create a basic three-part token structure
|
||||
header := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0" // {"alg":"RS256","kid":"test-key-id","typ":"JWT"}
|
||||
|
||||
// Create a simple payload with test claims
|
||||
payload := "eyJpc3MiOiJ0ZXN0LWlzc3VlciIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjoxNjM4MzYwMDAwLCJpYXQiOjE2MzgzNTY0MDAsInN1YiI6InVzZXIxMjMiLCJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ" // Basic claims
|
||||
|
||||
signature := "test-signature"
|
||||
|
||||
return header + "." + payload + "." + signature, nil
|
||||
}
|
||||
|
||||
// Mock error type for testing
|
||||
type mockError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *mockError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func newMockError(message string) error {
|
||||
return &mockError{message: message}
|
||||
}
|
||||
|
||||
// Mock token verifier for testing
|
||||
type mockTokenVerifier struct {
|
||||
verifyFunc func(token string) error
|
||||
}
|
||||
|
||||
func (m *mockTokenVerifier) VerifyToken(token string) error {
|
||||
if m.verifyFunc != nil {
|
||||
return m.verifyFunc(token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mock JWT verifier for testing
|
||||
type mockJWTVerifier struct {
|
||||
verifyFunc func(jwt *JWT, token string) error
|
||||
}
|
||||
|
||||
func (m *mockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
if m.verifyFunc != nil {
|
||||
return m.verifyFunc(jwt, token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||
type CacheItem struct {
|
||||
// Value is the cached data of any type.
|
||||
Value interface{}
|
||||
Value any
|
||||
|
||||
// ExpiresAt is the timestamp when this item should be considered expired.
|
||||
ExpiresAt time.Time
|
||||
@@ -23,42 +23,40 @@ type lruEntry struct {
|
||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
|
||||
type Cache struct {
|
||||
// items stores the cached data with string keys.
|
||||
items map[string]CacheItem
|
||||
|
||||
// order maintains the usage order; most recently used items are at the back.
|
||||
order *list.List
|
||||
|
||||
// elems maps keys to their corresponding list elements for O(1) access.
|
||||
elems map[string]*list.Element
|
||||
|
||||
// mutex protects concurrent access to the cache.
|
||||
mutex sync.RWMutex
|
||||
|
||||
// maxSize is the maximum number of items allowed in the cache.
|
||||
maxSize int
|
||||
// autoCleanupInterval defines how often Cleanup is called automatically.
|
||||
items map[string]CacheItem
|
||||
order *list.List
|
||||
elems map[string]*list.Element
|
||||
cleanupTask *BackgroundTask
|
||||
logger *Logger
|
||||
maxSize int
|
||||
autoCleanupInterval time.Duration
|
||||
// stopCleanup channel to terminate the auto cleanup goroutine.
|
||||
stopCleanup chan struct{}
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 500
|
||||
|
||||
// NewCache creates a new empty cache instance with default settings.
|
||||
// It initializes the internal maps and list, sets the default maximum size,
|
||||
// and starts the automatic cleanup goroutine.
|
||||
// It initializes the internal maps and list and sets the default maximum size.
|
||||
func NewCache() *Cache {
|
||||
return NewCacheWithLogger(nil)
|
||||
}
|
||||
|
||||
// NewCacheWithLogger creates a new cache with a specified logger
|
||||
func NewCacheWithLogger(logger *Logger) *Cache {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
stopCleanup: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
go c.startAutoCleanup()
|
||||
c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -68,7 +66,7 @@ func NewCache() *Cache {
|
||||
// If the key does not exist and the cache is full, the least recently used item is evicted
|
||||
// before adding the new item.
|
||||
// The expiration duration is relative to the time Set is called.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
func (c *Cache) Set(key string, value any, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@@ -106,7 +104,7 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
// Accessing an item moves it to the most recently used position in the LRU list.
|
||||
// If the item does not exist or has expired, nil and false are returned, and the
|
||||
// expired item is removed from the cache.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
func (c *Cache) Get(key string) (any, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@@ -214,15 +212,18 @@ func (c *Cache) removeItem(key string) {
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background goroutine that automatically calls the Cleanup method
|
||||
// startAutoCleanup starts the background task that automatically calls the Cleanup method
|
||||
// at the interval specified by c.autoCleanupInterval.
|
||||
// It uses the autoCleanupRoutine helper function.
|
||||
func (c *Cache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
c.cleanupTask = NewBackgroundTask("cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
|
||||
c.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup goroutine associated with this cache instance.
|
||||
// Close stops the automatic cleanup task associated with this cache instance.
|
||||
// It should be called when the cache is no longer needed to prevent resource leaks.
|
||||
func (c *Cache) Close() {
|
||||
close(c.stopCleanup)
|
||||
if c.cleanupTask != nil {
|
||||
c.cleanupTask.Stop()
|
||||
c.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -49,7 +49,7 @@ func TestCache_SetMaxSize(t *testing.T) {
|
||||
newMaxSize := 3
|
||||
|
||||
// Add more items than the new max size
|
||||
for i := 0; i < originalMaxSize; i++ {
|
||||
for i := range originalMaxSize {
|
||||
key := "key" + string(rune('A'+i))
|
||||
c.Set(key, i, 1*time.Hour)
|
||||
}
|
||||
|
||||
+322
-93
@@ -3,6 +3,7 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
@@ -11,6 +12,122 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorRecoveryMechanism defines the common interface for all error recovery mechanisms
|
||||
type ErrorRecoveryMechanism interface {
|
||||
// ExecuteWithContext executes a function with error recovery
|
||||
ExecuteWithContext(ctx context.Context, fn func() error) error
|
||||
// GetMetrics returns metrics about the error recovery mechanism
|
||||
GetMetrics() map[string]any
|
||||
// Reset resets the state of the error recovery mechanism
|
||||
Reset()
|
||||
// IsAvailable returns whether the mechanism is available for use
|
||||
IsAvailable() bool
|
||||
}
|
||||
|
||||
// BaseRecoveryMechanism provides common functionality for error recovery mechanisms
|
||||
type BaseRecoveryMechanism struct {
|
||||
startTime time.Time
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
logger *Logger
|
||||
name string
|
||||
totalRequests int64
|
||||
totalFailures int64
|
||||
totalSuccesses int64
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBaseRecoveryMechanism creates a new base recovery mechanism
|
||||
func NewBaseRecoveryMechanism(name string, logger *Logger) *BaseRecoveryMechanism {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
return &BaseRecoveryMechanism{
|
||||
name: name,
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordRequest records a request to the error recovery mechanism
|
||||
func (b *BaseRecoveryMechanism) RecordRequest() {
|
||||
atomic.AddInt64(&b.totalRequests, 1)
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
func (b *BaseRecoveryMechanism) RecordSuccess() {
|
||||
atomic.AddInt64(&b.totalSuccesses, 1)
|
||||
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
b.lastSuccessTime = time.Now()
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
func (b *BaseRecoveryMechanism) RecordFailure() {
|
||||
atomic.AddInt64(&b.totalFailures, 1)
|
||||
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
b.lastFailureTime = time.Now()
|
||||
}
|
||||
|
||||
// GetBaseMetrics returns base metrics common to all recovery mechanisms
|
||||
func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]any {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
|
||||
metrics := map[string]any{
|
||||
"total_requests": atomic.LoadInt64(&b.totalRequests),
|
||||
"total_failures": atomic.LoadInt64(&b.totalFailures),
|
||||
"total_successes": atomic.LoadInt64(&b.totalSuccesses),
|
||||
"uptime_seconds": time.Since(b.startTime).Seconds(),
|
||||
"name": b.name,
|
||||
}
|
||||
|
||||
if !b.lastFailureTime.IsZero() {
|
||||
metrics["last_failure_time"] = b.lastFailureTime.Format(time.RFC3339)
|
||||
metrics["seconds_since_last_failure"] = time.Since(b.lastFailureTime).Seconds()
|
||||
}
|
||||
|
||||
if !b.lastSuccessTime.IsZero() {
|
||||
metrics["last_success_time"] = b.lastSuccessTime.Format(time.RFC3339)
|
||||
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
|
||||
}
|
||||
|
||||
// Calculate success rate
|
||||
if metrics["total_requests"].(int64) > 0 {
|
||||
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
|
||||
metrics["success_rate"] = successRate
|
||||
} else {
|
||||
metrics["success_rate"] = 1.0 // Default to 100% if no requests
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// LogInfo logs an informational message
|
||||
func (b *BaseRecoveryMechanism) LogInfo(format string, args ...any) {
|
||||
if b.logger != nil {
|
||||
b.logger.Infof("%s: "+format, append([]any{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error message
|
||||
func (b *BaseRecoveryMechanism) LogError(format string, args ...any) {
|
||||
if b.logger != nil {
|
||||
b.logger.Errorf("%s: "+format, append([]any{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogDebug logs a debug message
|
||||
func (b *BaseRecoveryMechanism) LogDebug(format string, args ...any) {
|
||||
if b.logger != nil {
|
||||
b.logger.Debugf("%s: "+format, append([]any{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerState represents the current state of a circuit breaker
|
||||
type CircuitBreakerState int
|
||||
|
||||
@@ -25,25 +142,12 @@ const (
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for external service calls
|
||||
type CircuitBreaker struct {
|
||||
// Configuration
|
||||
maxFailures int // Maximum failures before opening
|
||||
timeout time.Duration // How long to wait before trying again
|
||||
resetTimeout time.Duration // How long to wait in half-open state
|
||||
|
||||
// State
|
||||
state CircuitBreakerState
|
||||
failures int64
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
mutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalRequests int64
|
||||
totalFailures int64
|
||||
totalSuccesses int64
|
||||
|
||||
// Logger
|
||||
logger *Logger
|
||||
*BaseRecoveryMechanism
|
||||
maxFailures int
|
||||
timeout time.Duration
|
||||
resetTimeout time.Duration
|
||||
state CircuitBreakerState
|
||||
failures int64
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration for circuit breakers
|
||||
@@ -65,17 +169,17 @@ func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
// NewCircuitBreaker creates a new circuit breaker with the given configuration
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig, logger *Logger) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
maxFailures: config.MaxFailures,
|
||||
timeout: config.Timeout,
|
||||
resetTimeout: config.ResetTimeout,
|
||||
state: CircuitBreakerClosed,
|
||||
logger: logger,
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("circuit-breaker", logger),
|
||||
maxFailures: config.MaxFailures,
|
||||
timeout: config.Timeout,
|
||||
resetTimeout: config.ResetTimeout,
|
||||
state: CircuitBreakerClosed,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs the given function with circuit breaker protection
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
atomic.AddInt64(&cb.totalRequests, 1)
|
||||
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
|
||||
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
cb.RecordRequest()
|
||||
|
||||
// Check if circuit breaker allows the request
|
||||
if !cb.allowRequest() {
|
||||
@@ -87,15 +191,20 @@ func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
// Record the result
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
atomic.AddInt64(&cb.totalFailures, 1)
|
||||
cb.RecordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
atomic.AddInt64(&cb.totalSuccesses, 1)
|
||||
cb.RecordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute is the original method for backward compatibility
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
return cb.ExecuteWithContext(context.Background(), fn)
|
||||
}
|
||||
|
||||
// allowRequest checks if the circuit breaker allows the request
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
cb.mutex.Lock()
|
||||
@@ -131,19 +240,18 @@ func (cb *CircuitBreaker) recordFailure() {
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.failures++
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
if cb.failures >= int64(cb.maxFailures) {
|
||||
cb.state = CircuitBreakerOpen
|
||||
cb.logger.Errorf("Circuit breaker opened after %d failures", cb.failures)
|
||||
cb.LogError("Circuit breaker opened after %d failures", cb.failures)
|
||||
}
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Go back to open state on any failure in half-open
|
||||
cb.state = CircuitBreakerOpen
|
||||
cb.logger.Errorf("Circuit breaker returned to open state after failure in half-open")
|
||||
cb.LogError("Circuit breaker returned to open state after failure in half-open")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,14 +260,12 @@ func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.lastSuccessTime = time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Reset failures and close circuit on success in half-open
|
||||
cb.failures = 0
|
||||
cb.state = CircuitBreakerClosed
|
||||
cb.logger.Infof("Circuit breaker closed after successful request in half-open state")
|
||||
cb.LogInfo("Circuit breaker closed after successful request in half-open state")
|
||||
|
||||
case CircuitBreakerClosed:
|
||||
// Reset failure count on success
|
||||
@@ -174,30 +280,58 @@ func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// GetMetrics returns circuit breaker metrics
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
// Reset resets the circuit breaker to its initial state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"state": cb.state,
|
||||
"failures": cb.failures,
|
||||
"total_requests": atomic.LoadInt64(&cb.totalRequests),
|
||||
"total_failures": atomic.LoadInt64(&cb.totalFailures),
|
||||
"total_successes": atomic.LoadInt64(&cb.totalSuccesses),
|
||||
"last_failure": cb.lastFailureTime,
|
||||
"last_success": cb.lastSuccessTime,
|
||||
cb.state = CircuitBreakerClosed
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
cb.LogInfo("Circuit breaker has been reset")
|
||||
}
|
||||
|
||||
// IsAvailable returns whether the circuit breaker is allowing requests
|
||||
func (cb *CircuitBreaker) IsAvailable() bool {
|
||||
return cb.allowRequest()
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the circuit breaker
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]any {
|
||||
cb.mutex.RLock()
|
||||
state := cb.state
|
||||
failures := cb.failures
|
||||
cb.mutex.RUnlock()
|
||||
|
||||
metrics := cb.GetBaseMetrics()
|
||||
|
||||
// Add circuit breaker specific metrics
|
||||
stateStr := "unknown"
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
stateStr = "closed"
|
||||
case CircuitBreakerOpen:
|
||||
stateStr = "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
stateStr = "half-open"
|
||||
}
|
||||
|
||||
metrics["state"] = stateStr
|
||||
metrics["max_failures"] = cb.maxFailures
|
||||
metrics["current_failures"] = failures
|
||||
metrics["timeout_ms"] = cb.timeout.Milliseconds()
|
||||
metrics["reset_timeout_ms"] = cb.resetTimeout.Milliseconds()
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// RetryConfig holds configuration for retry mechanisms
|
||||
type RetryConfig struct {
|
||||
RetryableErrors []string `json:"retryable_errors"`
|
||||
MaxAttempts int `json:"max_attempts"`
|
||||
InitialDelay time.Duration `json:"initial_delay"`
|
||||
MaxDelay time.Duration `json:"max_delay"`
|
||||
BackoffFactor float64 `json:"backoff_factor"`
|
||||
EnableJitter bool `json:"enable_jitter"`
|
||||
RetryableErrors []string `json:"retryable_errors"`
|
||||
}
|
||||
|
||||
// DefaultRetryConfig returns default retry configuration
|
||||
@@ -219,20 +353,21 @@ func DefaultRetryConfig() RetryConfig {
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff
|
||||
type RetryExecutor struct {
|
||||
*BaseRecoveryMechanism
|
||||
config RetryConfig
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewRetryExecutor creates a new retry executor
|
||||
func NewRetryExecutor(config RetryConfig, logger *Logger) *RetryExecutor {
|
||||
return &RetryExecutor{
|
||||
config: config,
|
||||
logger: logger,
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("retry-executor", logger),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs the given function with retry logic
|
||||
func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
|
||||
func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
re.RecordRequest()
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= re.config.MaxAttempts; attempt++ {
|
||||
@@ -240,8 +375,9 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
err := fn()
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
re.logger.Infof("Operation succeeded on attempt %d", attempt)
|
||||
re.LogInfo("Operation succeeded on attempt %d", attempt)
|
||||
}
|
||||
re.RecordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -249,30 +385,39 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
|
||||
// Check if error is retryable
|
||||
if !re.isRetryableError(err) {
|
||||
re.logger.Debugf("Non-retryable error on attempt %d: %v", attempt, err)
|
||||
re.LogDebug("Non-retryable error on attempt %d: %v", attempt, err)
|
||||
re.RecordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't wait after the last attempt
|
||||
if attempt == re.config.MaxAttempts {
|
||||
re.RecordFailure()
|
||||
break
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff
|
||||
delay := re.calculateDelay(attempt)
|
||||
re.logger.Debugf("Retrying operation after %v (attempt %d/%d): %v",
|
||||
re.LogDebug("Retrying operation after %v (attempt %d/%d): %v",
|
||||
delay, attempt, re.config.MaxAttempts, err)
|
||||
|
||||
// Wait with context cancellation support
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
re.RecordFailure()
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
// Continue to next attempt
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("operation failed after %d attempts: %w", re.config.MaxAttempts, lastErr)
|
||||
finalErr := fmt.Errorf("operation failed after %d attempts: %w", re.config.MaxAttempts, lastErr)
|
||||
return finalErr
|
||||
}
|
||||
|
||||
// Execute runs the given function with retry logic (for backward compatibility)
|
||||
func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
return re.ExecuteWithContext(ctx, fn)
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
@@ -341,10 +486,36 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
return time.Duration(delay)
|
||||
}
|
||||
|
||||
// Reset resets the retry executor state
|
||||
func (re *RetryExecutor) Reset() {
|
||||
// Nothing to reset for RetryExecutor
|
||||
re.LogDebug("Retry executor reset")
|
||||
}
|
||||
|
||||
// IsAvailable always returns true for RetryExecutor
|
||||
func (re *RetryExecutor) IsAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the retry executor
|
||||
func (re *RetryExecutor) GetMetrics() map[string]any {
|
||||
metrics := re.GetBaseMetrics()
|
||||
|
||||
// Add retry executor specific metrics
|
||||
metrics["max_attempts"] = re.config.MaxAttempts
|
||||
metrics["initial_delay_ms"] = re.config.InitialDelay.Milliseconds()
|
||||
metrics["max_delay_ms"] = re.config.MaxDelay.Milliseconds()
|
||||
metrics["backoff_factor"] = re.config.BackoffFactor
|
||||
metrics["enable_jitter"] = re.config.EnableJitter
|
||||
metrics["retryable_errors"] = re.config.RetryableErrors
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// HTTPError represents an HTTP error with status code
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
@@ -354,20 +525,12 @@ func (e *HTTPError) Error() string {
|
||||
|
||||
// GracefulDegradation implements graceful degradation patterns
|
||||
type GracefulDegradation struct {
|
||||
// Fallback functions for different operations
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
|
||||
// Health checks for dependencies
|
||||
healthChecks map[string]func() bool
|
||||
|
||||
// Configuration
|
||||
config GracefulDegradationConfig
|
||||
|
||||
// State tracking
|
||||
*BaseRecoveryMechanism
|
||||
fallbacks map[string]func() (any, error)
|
||||
healthChecks map[string]func() bool
|
||||
degradedServices map[string]time.Time
|
||||
config GracefulDegradationConfig
|
||||
mutex sync.RWMutex
|
||||
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// GracefulDegradationConfig holds configuration for graceful degradation
|
||||
@@ -389,11 +552,11 @@ func DefaultGracefulDegradationConfig() GracefulDegradationConfig {
|
||||
// NewGracefulDegradation creates a new graceful degradation manager
|
||||
func NewGracefulDegradation(config GracefulDegradationConfig, logger *Logger) *GracefulDegradation {
|
||||
gd := &GracefulDegradation{
|
||||
fallbacks: make(map[string]func() (interface{}, error)),
|
||||
healthChecks: make(map[string]func() bool),
|
||||
degradedServices: make(map[string]time.Time),
|
||||
config: config,
|
||||
logger: logger,
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("graceful-degradation", logger),
|
||||
fallbacks: make(map[string]func() (any, error)),
|
||||
healthChecks: make(map[string]func() bool),
|
||||
degradedServices: make(map[string]time.Time),
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Start health check routine
|
||||
@@ -403,7 +566,7 @@ func NewGracefulDegradation(config GracefulDegradationConfig, logger *Logger) *G
|
||||
}
|
||||
|
||||
// RegisterFallback registers a fallback function for a service
|
||||
func (gd *GracefulDegradation) RegisterFallback(serviceName string, fallback func() (interface{}, error)) {
|
||||
func (gd *GracefulDegradation) RegisterFallback(serviceName string, fallback func() (any, error)) {
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
gd.fallbacks[serviceName] = fallback
|
||||
@@ -416,10 +579,29 @@ func (gd *GracefulDegradation) RegisterHealthCheck(serviceName string, healthChe
|
||||
gd.healthChecks[serviceName] = healthCheck
|
||||
}
|
||||
|
||||
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
|
||||
func (gd *GracefulDegradation) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
gd.RecordRequest()
|
||||
|
||||
// Execute with a simple wrapper
|
||||
_, err := gd.ExecuteWithFallback("default", func() (any, error) {
|
||||
return nil, fn()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
gd.RecordFailure()
|
||||
} else {
|
||||
gd.RecordSuccess()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecuteWithFallback executes a function with fallback support
|
||||
func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary func() (interface{}, error)) (interface{}, error) {
|
||||
func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary func() (any, error)) (any, error) {
|
||||
// Check if service is degraded
|
||||
if gd.isServiceDegraded(serviceName) {
|
||||
gd.LogInfo("Service %s is degraded, using fallback", serviceName)
|
||||
return gd.executeFallback(serviceName)
|
||||
}
|
||||
|
||||
@@ -428,9 +610,11 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
|
||||
if err != nil {
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded(serviceName)
|
||||
gd.LogError("Service %s failed: %v", serviceName, err)
|
||||
|
||||
// Try fallback if available
|
||||
if gd.config.EnableFallbacks {
|
||||
gd.LogInfo("Using fallback for service %s", serviceName)
|
||||
return gd.executeFallback(serviceName)
|
||||
}
|
||||
|
||||
@@ -465,14 +649,14 @@ func (gd *GracefulDegradation) markServiceDegraded(serviceName string) {
|
||||
defer gd.mutex.Unlock()
|
||||
|
||||
if _, exists := gd.degradedServices[serviceName]; !exists {
|
||||
gd.logger.Errorf("Service %s marked as degraded", serviceName)
|
||||
gd.LogError("Service %s marked as degraded", serviceName)
|
||||
}
|
||||
|
||||
gd.degradedServices[serviceName] = time.Now()
|
||||
}
|
||||
|
||||
// executeFallback executes the fallback function for a service
|
||||
func (gd *GracefulDegradation) executeFallback(serviceName string) (interface{}, error) {
|
||||
func (gd *GracefulDegradation) executeFallback(serviceName string) (any, error) {
|
||||
gd.mutex.RLock()
|
||||
fallback, exists := gd.fallbacks[serviceName]
|
||||
gd.mutex.RUnlock()
|
||||
@@ -481,27 +665,26 @@ func (gd *GracefulDegradation) executeFallback(serviceName string) (interface{},
|
||||
return nil, fmt.Errorf("no fallback available for service %s", serviceName)
|
||||
}
|
||||
|
||||
gd.logger.Infof("Executing fallback for degraded service %s", serviceName)
|
||||
gd.LogInfo("Executing fallback for degraded service %s", serviceName)
|
||||
return fallback()
|
||||
}
|
||||
|
||||
// startHealthCheckRoutine starts the background health check routine
|
||||
func (gd *GracefulDegradation) startHealthCheckRoutine() {
|
||||
ticker := time.NewTicker(gd.config.HealthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
gd.performHealthChecks()
|
||||
}
|
||||
healthCheckTask := NewBackgroundTask(
|
||||
"graceful-degradation-health-check",
|
||||
gd.config.HealthCheckInterval,
|
||||
gd.performHealthChecks,
|
||||
gd.BaseRecoveryMechanism.logger,
|
||||
)
|
||||
healthCheckTask.Start()
|
||||
}
|
||||
|
||||
// performHealthChecks runs health checks for all registered services
|
||||
func (gd *GracefulDegradation) performHealthChecks() {
|
||||
gd.mutex.RLock()
|
||||
healthChecks := make(map[string]func() bool)
|
||||
for name, check := range gd.healthChecks {
|
||||
healthChecks[name] = check
|
||||
}
|
||||
maps.Copy(healthChecks, gd.healthChecks)
|
||||
gd.mutex.RUnlock()
|
||||
|
||||
for serviceName, healthCheck := range healthChecks {
|
||||
@@ -533,13 +716,59 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
|
||||
return degraded
|
||||
}
|
||||
|
||||
// Reset resets the state of all degraded services
|
||||
func (gd *GracefulDegradation) Reset() {
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
|
||||
// Clear degraded services
|
||||
gd.degradedServices = make(map[string]time.Time)
|
||||
gd.LogInfo("Graceful degradation state has been reset")
|
||||
}
|
||||
|
||||
// IsAvailable returns whether the mechanism is available for use
|
||||
func (gd *GracefulDegradation) IsAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the graceful degradation mechanism
|
||||
func (gd *GracefulDegradation) GetMetrics() map[string]any {
|
||||
gd.mutex.RLock()
|
||||
degradedCount := len(gd.degradedServices)
|
||||
|
||||
// Get the names of degraded services
|
||||
degradedServices := make([]string, 0, degradedCount)
|
||||
for service := range gd.degradedServices {
|
||||
degradedServices = append(degradedServices, service)
|
||||
}
|
||||
|
||||
// Get total count of registered fallbacks and health checks
|
||||
fallbackCount := len(gd.fallbacks)
|
||||
healthCheckCount := len(gd.healthChecks)
|
||||
gd.mutex.RUnlock()
|
||||
|
||||
// Get base metrics
|
||||
metrics := gd.GetBaseMetrics()
|
||||
|
||||
// Add graceful degradation specific metrics
|
||||
metrics["degraded_services_count"] = degradedCount
|
||||
metrics["degraded_services"] = degradedServices
|
||||
metrics["registered_fallbacks_count"] = fallbackCount
|
||||
metrics["registered_health_checks_count"] = healthCheckCount
|
||||
metrics["health_check_interval_seconds"] = gd.config.HealthCheckInterval.Seconds()
|
||||
metrics["recovery_timeout_seconds"] = gd.config.RecoveryTimeout.Seconds()
|
||||
metrics["fallbacks_enabled"] = gd.config.EnableFallbacks
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// ErrorRecoveryManager coordinates all error recovery mechanisms
|
||||
type ErrorRecoveryManager struct {
|
||||
circuitBreakers map[string]*CircuitBreaker
|
||||
retryExecutor *RetryExecutor
|
||||
gracefulDegradation *GracefulDegradation
|
||||
mutex sync.RWMutex
|
||||
logger *Logger
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorRecoveryManager creates a new error recovery manager
|
||||
@@ -576,14 +805,14 @@ func (erm *ErrorRecoveryManager) ExecuteWithRecovery(ctx context.Context, servic
|
||||
}
|
||||
|
||||
// GetRecoveryMetrics returns metrics for all recovery mechanisms
|
||||
func (erm *ErrorRecoveryManager) GetRecoveryMetrics() map[string]interface{} {
|
||||
func (erm *ErrorRecoveryManager) GetRecoveryMetrics() map[string]any {
|
||||
erm.mutex.RLock()
|
||||
defer erm.mutex.RUnlock()
|
||||
|
||||
metrics := make(map[string]interface{})
|
||||
metrics := make(map[string]any)
|
||||
|
||||
// Circuit breaker metrics
|
||||
cbMetrics := make(map[string]interface{})
|
||||
cbMetrics := make(map[string]any)
|
||||
for name, cb := range erm.circuitBreakers {
|
||||
cbMetrics[name] = cb.GetMetrics()
|
||||
}
|
||||
|
||||
+6
-11
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -207,7 +208,7 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
}()
|
||||
|
||||
t.Run("Register fallback and health check", func(t *testing.T) {
|
||||
gd.RegisterFallback("test-service", func() (interface{}, error) {
|
||||
gd.RegisterFallback("test-service", func() (any, error) {
|
||||
return "fallback-result", nil
|
||||
})
|
||||
|
||||
@@ -222,12 +223,12 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Execute with fallback on failure", func(t *testing.T) {
|
||||
gd.RegisterFallback("failing-service", func() (interface{}, error) {
|
||||
gd.RegisterFallback("failing-service", func() (any, error) {
|
||||
return "fallback-result", nil
|
||||
})
|
||||
|
||||
// First call should fail and mark service as degraded
|
||||
result, err := gd.ExecuteWithFallback("failing-service", func() (interface{}, error) {
|
||||
result, err := gd.ExecuteWithFallback("failing-service", func() (any, error) {
|
||||
return nil, errors.New("service failure")
|
||||
})
|
||||
if err != nil {
|
||||
@@ -244,7 +245,7 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("No fallback available", func(t *testing.T) {
|
||||
_, err := gd.ExecuteWithFallback("no-fallback-service", func() (interface{}, error) {
|
||||
_, err := gd.ExecuteWithFallback("no-fallback-service", func() (any, error) {
|
||||
return nil, errors.New("service failure")
|
||||
})
|
||||
|
||||
@@ -255,13 +256,7 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
|
||||
t.Run("Get degraded services", func(t *testing.T) {
|
||||
degraded := gd.GetDegradedServices()
|
||||
found := false
|
||||
for _, service := range degraded {
|
||||
if service == "failing-service" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(degraded, "failing-service")
|
||||
if !found {
|
||||
t.Error("Expected failing-service to be in degraded list")
|
||||
}
|
||||
|
||||
+23
-20
@@ -9,6 +9,7 @@ import (
|
||||
"math/big"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -94,7 +95,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetAccessToken("old-access-token")
|
||||
session.SetAccessToken(ValidAccessToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
|
||||
// Create a mock token exchanger that simulates Google's behavior
|
||||
@@ -106,11 +107,15 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// Use standardized test tokens instead of ad-hoc strings
|
||||
testTokens := NewTestTokens()
|
||||
googleTokens := testTokens.GetGoogleTokenSet()
|
||||
|
||||
// Return a simulated Google token response with a new access token
|
||||
// but without a new refresh token (Google doesn't always return a new refresh token)
|
||||
return &TokenResponse{
|
||||
IDToken: "new-id-token-from-google",
|
||||
AccessToken: "new-access-token-from-google",
|
||||
IDToken: googleTokens.IDToken,
|
||||
AccessToken: googleTokens.AccessToken,
|
||||
RefreshToken: "", // Google often doesn't return a new refresh token
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
@@ -127,9 +132,9 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tOidc.extractClaimsFunc = func(token string) (map[string]interface{}, error) {
|
||||
tOidc.extractClaimsFunc = func(token string) (map[string]any, error) {
|
||||
// Return mock claims
|
||||
return map[string]interface{}{
|
||||
return map[string]any{
|
||||
"email": "test@example.com",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
}, nil
|
||||
@@ -149,15 +154,19 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
session.GetRefreshToken())
|
||||
}
|
||||
|
||||
// Use the same test tokens for validation
|
||||
testTokens := NewTestTokens()
|
||||
expectedTokens := testTokens.GetGoogleTokenSet()
|
||||
|
||||
// Check that the tokens were updated correctly
|
||||
if session.GetIDToken() != "new-id-token-from-google" {
|
||||
t.Errorf("ID token not updated: got %s, expected 'new-id-token-from-google'",
|
||||
session.GetIDToken())
|
||||
if session.GetIDToken() != expectedTokens.IDToken {
|
||||
t.Errorf("ID token not updated: got %s, expected %s",
|
||||
session.GetIDToken(), expectedTokens.IDToken)
|
||||
}
|
||||
|
||||
if session.GetAccessToken() != "new-access-token-from-google" {
|
||||
t.Errorf("Access token not updated: got %s, expected 'new-access-token-from-google'",
|
||||
session.GetAccessToken())
|
||||
if session.GetAccessToken() != expectedTokens.AccessToken {
|
||||
t.Errorf("Access token not updated: got %s, expected %s",
|
||||
session.GetAccessToken(), expectedTokens.AccessToken)
|
||||
}
|
||||
})
|
||||
// Test that our fix specifically addresses the reported Google error
|
||||
@@ -295,13 +304,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
scopeList := strings.Split(scope, " ")
|
||||
expectedScopes := []string{"openid", "profile", "email"}
|
||||
for _, expectedScope := range expectedScopes {
|
||||
found := false
|
||||
for _, actualScope := range scopeList {
|
||||
if actualScope == expectedScope {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(scopeList, expectedScope)
|
||||
if !found {
|
||||
t.Errorf("Expected scope '%s' not found in scope parameter: %s", expectedScope, scope)
|
||||
}
|
||||
@@ -377,7 +380,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
nbf := now.Unix()
|
||||
|
||||
// Create initial ID token
|
||||
initialIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
initialIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -393,7 +396,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create refresh ID token
|
||||
refreshedIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
refreshedIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
|
||||
+21
-29
@@ -72,20 +72,11 @@ func deriveCodeChallenge(codeVerifier string) string {
|
||||
// It contains the various tokens and metadata returned after successful
|
||||
// code exchange or token refresh operations.
|
||||
type TokenResponse struct {
|
||||
// IDToken is the OIDC ID token containing user claims
|
||||
IDToken string `json:"id_token"`
|
||||
|
||||
// AccessToken is the OAuth 2.0 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
|
||||
// ExpiresIn is the lifetime in seconds of the access token
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// TokenType is the type of token, typically "Bearer"
|
||||
TokenType string `json:"token_type"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
|
||||
@@ -123,17 +114,13 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
data.Set("refresh_token", codeOrToken)
|
||||
}
|
||||
|
||||
// Use the reusable token HTTP client, fallback to creating one if not initialized
|
||||
client := t.tokenHTTPClient
|
||||
if client == nil {
|
||||
// Fallback for tests or incomplete initialization - create a temporary client
|
||||
// with the same behavior as the original implementation
|
||||
jar, _ := cookiejar.New(nil)
|
||||
client = &http.Client{
|
||||
Transport: t.httpClient.Transport,
|
||||
Timeout: t.httpClient.Timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
@@ -200,7 +187,7 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
// Returns:
|
||||
// - A map representing the JSON claims extracted from the token payload.
|
||||
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
func extractClaims(tokenString string) (map[string]any, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
@@ -211,7 +198,7 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to decode token payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
||||
}
|
||||
@@ -223,15 +210,21 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
// It stores token claims to avoid repeated validation of the
|
||||
// same token, improving performance for frequently used tokens.
|
||||
type TokenCache struct {
|
||||
// cache is the underlying cache implementation
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
const (
|
||||
defaultTokenCacheMaxSize = 1000
|
||||
defaultTokenCacheCleanupInterval = 2 * time.Minute
|
||||
)
|
||||
|
||||
// NewTokenCache creates and initializes a new TokenCache.
|
||||
// It internally creates a new generic Cache instance for storage.
|
||||
func NewTokenCache() *TokenCache {
|
||||
cache := NewCache()
|
||||
cache.SetMaxSize(defaultTokenCacheMaxSize)
|
||||
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +236,7 @@ func NewTokenCache() *TokenCache {
|
||||
// - token: The raw token string (used as the key).
|
||||
// - claims: The map of claims associated with the token.
|
||||
// - expiration: The duration for which the cache entry should be valid.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
func (tc *TokenCache) Set(token string, claims map[string]any, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
@@ -257,13 +250,13 @@ func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiratio
|
||||
// Returns:
|
||||
// - The cached claims map if found and valid.
|
||||
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
func (tc *TokenCache) Get(token string) (map[string]any, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
claims, ok := value.(map[string]interface{})
|
||||
claims, ok := value.(map[string]any)
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
@@ -303,7 +296,6 @@ func (tc *TokenCache) Close() {
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Only include code verifier if PKCE is enabled
|
||||
effectiveCodeVerifier := ""
|
||||
if t.enablePKCE && codeVerifier != "" {
|
||||
effectiveCodeVerifier = codeVerifier
|
||||
@@ -352,7 +344,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
idToken := session.GetIDToken()
|
||||
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
@@ -371,8 +363,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
if t.endSessionURL != "" && accessToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
|
||||
+16
-22
@@ -11,35 +11,29 @@ import (
|
||||
|
||||
// InputValidator provides comprehensive input validation and sanitization
|
||||
type InputValidator struct {
|
||||
// Configuration
|
||||
maxTokenLength int
|
||||
maxURLLength int
|
||||
maxHeaderLength int
|
||||
maxClaimLength int
|
||||
maxEmailLength int
|
||||
maxUsernameLength int
|
||||
|
||||
// Compiled regex patterns
|
||||
emailRegex *regexp.Regexp
|
||||
urlRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
usernameRegex *regexp.Regexp
|
||||
|
||||
// Security patterns to detect
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
xssPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
|
||||
logger *Logger
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of input validation
|
||||
type ValidationResult struct {
|
||||
IsValid bool `json:"is_valid"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
SanitizedValue string `json:"sanitized_value,omitempty"`
|
||||
SecurityRisk string `json:"security_risk,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
IsValid bool `json:"is_valid"`
|
||||
}
|
||||
|
||||
// InputValidationConfig holds configuration for input validation
|
||||
@@ -620,7 +614,7 @@ func (iv *InputValidator) SanitizeInput(input string, maxLength int) string {
|
||||
}
|
||||
|
||||
// ValidateBoundaryValues validates numeric boundary values
|
||||
func (iv *InputValidator) ValidateBoundaryValues(value interface{}, min, max int64) ValidationResult {
|
||||
func (iv *InputValidator) ValidateBoundaryValues(value any, min, max int64) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
var numValue int64
|
||||
|
||||
@@ -204,8 +204,8 @@ func TestSanitizeInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
name: "Normal text",
|
||||
@@ -246,7 +246,7 @@ func TestValidateBoundaryValues(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("Valid boundary values", func(t *testing.T) {
|
||||
validValues := []interface{}{
|
||||
validValues := []any{
|
||||
int(50),
|
||||
int64(100),
|
||||
float64(75.5),
|
||||
@@ -261,7 +261,7 @@ func TestValidateBoundaryValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid boundary values", func(t *testing.T) {
|
||||
invalidValues := []interface{}{
|
||||
invalidValues := []any{
|
||||
int(-1),
|
||||
int64(2000),
|
||||
"not a number",
|
||||
|
||||
@@ -33,13 +33,12 @@ type JWKSet struct {
|
||||
}
|
||||
|
||||
type JWKCache struct {
|
||||
jwks *JWKSet
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
// CacheLifetime is configurable to determine how long the JWKS is cached.
|
||||
expiresAt time.Time
|
||||
jwks *JWKSet
|
||||
internalCache *Cache
|
||||
CacheLifetime time.Duration
|
||||
internalCache *Cache // To hold the closable Cache instance from cache.go
|
||||
maxSize int // Maximum number of items in the cache
|
||||
maxSize int
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type JWKCacheInterface interface {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
@@ -16,37 +17,80 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
replayCacheMu sync.Mutex
|
||||
replayCache *Cache // Replace unbounded map with bounded Cache
|
||||
replayCacheMu sync.RWMutex // Use RWMutex for better read performance
|
||||
replayCache *Cache // Replace unbounded map with bounded Cache
|
||||
replayCacheOnce sync.Once
|
||||
)
|
||||
|
||||
// initReplayCache initializes the global replay cache with size limit
|
||||
func initReplayCache() {
|
||||
if replayCache == nil {
|
||||
replayCacheOnce.Do(func() {
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000) // Set size limit to 10,000 entries
|
||||
replayCache.SetMaxSize(10000)
|
||||
})
|
||||
}
|
||||
|
||||
func cleanupReplayCache() {
|
||||
replayCacheMu.Lock()
|
||||
defer replayCacheMu.Unlock()
|
||||
|
||||
if replayCache != nil {
|
||||
replayCache.Close()
|
||||
replayCache = nil
|
||||
}
|
||||
}
|
||||
|
||||
// STABILITY FIX: Standardize clock skew tolerance usage
|
||||
// ClockSkewToleranceFuture defines the tolerance for future-based claims like 'exp'.
|
||||
// Allows for more leniency with expiration checks.
|
||||
func getReplayCacheStats() (size int, maxSize int) {
|
||||
replayCacheMu.RLock()
|
||||
defer replayCacheMu.RUnlock()
|
||||
|
||||
if replayCache == nil {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
return 0, 10000
|
||||
}
|
||||
|
||||
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
size, maxSize := getReplayCacheStats()
|
||||
if logger != nil {
|
||||
logger.Debugf("Replay cache stats: size=%d, maxSize=%d", size, maxSize)
|
||||
}
|
||||
|
||||
replayCacheMu.RLock()
|
||||
if replayCache != nil {
|
||||
}
|
||||
replayCacheMu.RUnlock()
|
||||
|
||||
case <-ctx.Done():
|
||||
cleanupReplayCache()
|
||||
if logger != nil {
|
||||
logger.Debug("Replay cache cleanup goroutine stopped due to context cancellation")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var ClockSkewToleranceFuture = 2 * time.Minute
|
||||
|
||||
// ClockSkewTolerancePast defines the tolerance for past-based claims like 'iat' and 'nbf'.
|
||||
// A smaller tolerance is typically used here to prevent accepting tokens issued too far in the future.
|
||||
var ClockSkewTolerancePast = 10 * time.Second
|
||||
|
||||
// ClockSkewTolerance is deprecated - use ClockSkewToleranceFuture or ClockSkewTolerancePast
|
||||
// STABILITY FIX: Remove inconsistent usage
|
||||
var ClockSkewTolerance = ClockSkewToleranceFuture
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
type JWT struct {
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
Signature []byte
|
||||
Header map[string]any
|
||||
Claims map[string]any
|
||||
Token string
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature.
|
||||
@@ -75,12 +119,10 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
}
|
||||
// STABILITY FIX: Add comprehensive JSON error handling with panic protection
|
||||
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
|
||||
// Validate header structure
|
||||
if jwt.Header == nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: header is nil after unmarshaling")
|
||||
}
|
||||
@@ -90,12 +132,10 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
}
|
||||
|
||||
// STABILITY FIX: Add comprehensive JSON error handling with panic protection
|
||||
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
// Validate claims structure
|
||||
if jwt.Claims == nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: claims is nil after unmarshaling")
|
||||
}
|
||||
@@ -183,31 +223,23 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
|
||||
}
|
||||
}
|
||||
|
||||
// Implement replay protection by checking the jti (JWT ID)
|
||||
// Skip replay check if explicitly requested (for revalidation scenarios)
|
||||
shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0]
|
||||
|
||||
if jti, ok := claims["jti"].(string); ok && !shouldSkipReplay {
|
||||
// Skip replay detection for tokens that are being verified from the cache
|
||||
if j.Token == "" {
|
||||
// This is a parsed JWT without the original token string,
|
||||
// which means it's likely from a cached token verification
|
||||
return nil
|
||||
}
|
||||
|
||||
// SECURITY FIX: Use bounded Cache with thread-safe operations
|
||||
replayCacheMu.Lock()
|
||||
defer replayCacheMu.Unlock()
|
||||
|
||||
// Initialize cache if not already done
|
||||
initReplayCache()
|
||||
|
||||
// SECURITY FIX: Check for replay attack using Cache API
|
||||
if _, exists := replayCache.Get(jti); exists {
|
||||
return fmt.Errorf("token replay detected")
|
||||
replayCacheMu.RLock()
|
||||
_, exists := replayCache.Get(jti)
|
||||
replayCacheMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return fmt.Errorf("token replay detected (jti: %s)", jti)
|
||||
}
|
||||
|
||||
// Calculate expiration time
|
||||
expFloat, ok := claims["exp"].(float64)
|
||||
var expTime time.Time
|
||||
if ok {
|
||||
@@ -216,10 +248,13 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
|
||||
expTime = time.Now().Add(10 * time.Minute)
|
||||
}
|
||||
|
||||
// SECURITY FIX: Add to replay cache with expiration using Cache API
|
||||
duration := time.Until(expTime)
|
||||
if duration > 0 {
|
||||
replayCache.Set(jti, true, duration)
|
||||
replayCacheMu.Lock()
|
||||
if replayCache != nil {
|
||||
replayCache.Set(jti, true, duration)
|
||||
}
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,13 +276,13 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
|
||||
// Returns:
|
||||
// - nil if the expected audience is found.
|
||||
// - An error if the claim type is invalid or the expected audience is not present.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
func verifyAudience(tokenAudience any, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
case []interface{}:
|
||||
case []any:
|
||||
found := false
|
||||
for _, v := range aud {
|
||||
if str, ok := v.(string); ok && str == expectedAudience {
|
||||
@@ -293,17 +328,15 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
// - An error describing the failure (e.g., "token has expired", "token used before issued").
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now() // Use current time without truncation
|
||||
now := time.Now()
|
||||
|
||||
var err error
|
||||
if future { // 'exp' check
|
||||
// Token is expired if Now is after (ClaimTime + FutureTolerance)
|
||||
if future {
|
||||
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
|
||||
if now.After(allowedExpiry) {
|
||||
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
|
||||
}
|
||||
} else { // 'iat' or 'nbf' check
|
||||
// Token is invalid if Now is before (ClaimTime - PastTolerance)
|
||||
} else {
|
||||
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
|
||||
if now.Before(allowedStart) {
|
||||
reason := "not yet valid"
|
||||
|
||||
+577
-122
File diff suppressed because it is too large
Load Diff
+25
-11
@@ -8,21 +8,31 @@ import (
|
||||
)
|
||||
|
||||
type MetadataCache struct {
|
||||
metadata *ProviderMetadata
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
metadata *ProviderMetadata
|
||||
cleanupTask *BackgroundTask
|
||||
logger *Logger
|
||||
autoCleanupInterval time.Duration
|
||||
stopCleanup chan struct{}
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMetadataCache creates a new MetadataCache instance.
|
||||
// It initializes the cache structure and starts the background cleanup goroutine.
|
||||
// It initializes the cache structure and starts the background cleanup task.
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
return NewMetadataCacheWithLogger(nil)
|
||||
}
|
||||
|
||||
// NewMetadataCacheWithLogger creates a new MetadataCache with a specified logger.
|
||||
func NewMetadataCacheWithLogger(logger *Logger) *MetadataCache {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
stopCleanup: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
go c.startAutoCleanup()
|
||||
c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -92,20 +102,24 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
|
||||
|
||||
c.metadata = metadata
|
||||
// Set a fixed cache lifetime (e.g., 1 hour)
|
||||
// TODO: Consider making this configurable or respecting HTTP cache headers
|
||||
// Consider making this configurable or respecting HTTP cache headers
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
|
||||
// End of GetMetadata
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background goroutine that periodically calls Cleanup
|
||||
// startAutoCleanup starts the background task that periodically calls Cleanup
|
||||
// to remove expired metadata from the cache.
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
c.cleanupTask = NewBackgroundTask("metadata-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
|
||||
c.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup goroutine associated with this metadata cache.
|
||||
// Close stops the automatic cleanup task associated with this metadata cache.
|
||||
func (c *MetadataCache) Close() {
|
||||
close(c.stopCleanup)
|
||||
if c.cleanupTask != nil {
|
||||
c.cleanupTask.Stop()
|
||||
c.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,8 +41,8 @@ func TestGetMetadata_Cached(t *testing.T) {
|
||||
mc := &MetadataCache{
|
||||
metadata: dummyData,
|
||||
expiresAt: time.Now().Add(1 * time.Hour),
|
||||
stopCleanup: make(chan struct{}),
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
logger: newNoOpLogger(),
|
||||
}
|
||||
// Use NewLogger to create a logger that writes errors only.
|
||||
logger := NewLogger("error")
|
||||
@@ -58,10 +58,10 @@ func TestGetMetadata_Cached(t *testing.T) {
|
||||
func TestMetadataCacheAutoCleanup(t *testing.T) {
|
||||
mc := &MetadataCache{
|
||||
autoCleanupInterval: 50 * time.Millisecond,
|
||||
stopCleanup: make(chan struct{}),
|
||||
logger: newNoOpLogger(),
|
||||
}
|
||||
// Start auto cleanup.
|
||||
go mc.startAutoCleanup()
|
||||
mc.startAutoCleanup()
|
||||
mc.mutex.Lock()
|
||||
mc.metadata = &ProviderMetadata{}
|
||||
mc.expiresAt = time.Now().Add(-50 * time.Millisecond)
|
||||
@@ -93,7 +93,7 @@ func TestGetMetadata_FetchError(t *testing.T) {
|
||||
|
||||
// Case 1: Cache is empty.
|
||||
mc := &MetadataCache{
|
||||
stopCleanup: make(chan struct{}),
|
||||
logger: newNoOpLogger(),
|
||||
}
|
||||
logger := NewLogger("error")
|
||||
metadata, err := mc.GetMetadata("http://example.com", errorClient, logger)
|
||||
|
||||
@@ -1,709 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PerformanceMetrics tracks various performance-related metrics
|
||||
type PerformanceMetrics struct {
|
||||
// Cache metrics
|
||||
cacheHits int64
|
||||
cacheMisses int64
|
||||
cacheEvictions int64
|
||||
cacheSize int64
|
||||
|
||||
// Token operation metrics
|
||||
tokenVerifications int64
|
||||
tokenValidations int64
|
||||
tokenRefreshes int64
|
||||
|
||||
// Success/failure tracking
|
||||
successfulVerifications int64
|
||||
successfulValidations int64
|
||||
successfulRefreshes int64
|
||||
failedVerifications int64
|
||||
failedValidations int64
|
||||
failedRefreshes int64
|
||||
|
||||
// Timing metrics
|
||||
avgVerificationTime time.Duration
|
||||
avgValidationTime time.Duration
|
||||
avgRefreshTime time.Duration
|
||||
|
||||
// Resource metrics
|
||||
memoryUsage int64
|
||||
goroutineCount int64
|
||||
memoryPressure int64 // Memory pressure level (0-100)
|
||||
gcPauseTime int64 // Last GC pause time in nanoseconds
|
||||
heapSize int64 // Current heap size
|
||||
heapInUse int64 // Heap memory in use
|
||||
|
||||
// Error metrics (kept for backward compatibility)
|
||||
verificationErrors int64
|
||||
validationErrors int64
|
||||
refreshErrors int64
|
||||
|
||||
// Rate limiting metrics
|
||||
rateLimitedRequests int64
|
||||
|
||||
// Session metrics
|
||||
activeSessions int64
|
||||
sessionCreations int64
|
||||
sessionDeletions int64
|
||||
|
||||
// Timing tracking
|
||||
timingMutex sync.RWMutex
|
||||
verificationTimes []time.Duration
|
||||
validationTimes []time.Duration
|
||||
refreshTimes []time.Duration
|
||||
|
||||
// Start time for uptime calculation
|
||||
startTime time.Time
|
||||
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewPerformanceMetrics creates a new performance metrics tracker
|
||||
func NewPerformanceMetrics(logger *Logger) *PerformanceMetrics {
|
||||
pm := &PerformanceMetrics{
|
||||
startTime: time.Now(),
|
||||
verificationTimes: make([]time.Duration, 0, 1000), // Keep last 1000 measurements
|
||||
validationTimes: make([]time.Duration, 0, 1000),
|
||||
refreshTimes: make([]time.Duration, 0, 1000),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start background metrics collection
|
||||
go pm.startMetricsCollection()
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
// RecordCacheHit records a cache hit
|
||||
func (pm *PerformanceMetrics) RecordCacheHit() {
|
||||
atomic.AddInt64(&pm.cacheHits, 1)
|
||||
}
|
||||
|
||||
// RecordCacheMiss records a cache miss
|
||||
func (pm *PerformanceMetrics) RecordCacheMiss() {
|
||||
atomic.AddInt64(&pm.cacheMisses, 1)
|
||||
}
|
||||
|
||||
// RecordCacheEviction records a cache eviction
|
||||
func (pm *PerformanceMetrics) RecordCacheEviction() {
|
||||
atomic.AddInt64(&pm.cacheEvictions, 1)
|
||||
}
|
||||
|
||||
// UpdateCacheSize updates the current cache size
|
||||
func (pm *PerformanceMetrics) UpdateCacheSize(size int64) {
|
||||
atomic.StoreInt64(&pm.cacheSize, size)
|
||||
}
|
||||
|
||||
// RecordTokenVerification records a token verification operation
|
||||
func (pm *PerformanceMetrics) RecordTokenVerification(duration time.Duration, success bool) {
|
||||
atomic.AddInt64(&pm.tokenVerifications, 1)
|
||||
|
||||
if success {
|
||||
atomic.AddInt64(&pm.successfulVerifications, 1)
|
||||
pm.addVerificationTime(duration)
|
||||
} else {
|
||||
atomic.AddInt64(&pm.failedVerifications, 1)
|
||||
atomic.AddInt64(&pm.verificationErrors, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordTokenValidation records a token validation operation
|
||||
func (pm *PerformanceMetrics) RecordTokenValidation(duration time.Duration, success bool) {
|
||||
atomic.AddInt64(&pm.tokenValidations, 1)
|
||||
|
||||
if success {
|
||||
atomic.AddInt64(&pm.successfulValidations, 1)
|
||||
pm.addValidationTime(duration)
|
||||
} else {
|
||||
atomic.AddInt64(&pm.failedValidations, 1)
|
||||
atomic.AddInt64(&pm.validationErrors, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordTokenRefresh records a token refresh operation
|
||||
func (pm *PerformanceMetrics) RecordTokenRefresh(duration time.Duration, success bool) {
|
||||
atomic.AddInt64(&pm.tokenRefreshes, 1)
|
||||
|
||||
if success {
|
||||
atomic.AddInt64(&pm.successfulRefreshes, 1)
|
||||
pm.addRefreshTime(duration)
|
||||
} else {
|
||||
atomic.AddInt64(&pm.failedRefreshes, 1)
|
||||
atomic.AddInt64(&pm.refreshErrors, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordRateLimitedRequest records a rate-limited request
|
||||
func (pm *PerformanceMetrics) RecordRateLimitedRequest() {
|
||||
atomic.AddInt64(&pm.rateLimitedRequests, 1)
|
||||
}
|
||||
|
||||
// RecordSessionCreation records a session creation
|
||||
func (pm *PerformanceMetrics) RecordSessionCreation() {
|
||||
atomic.AddInt64(&pm.sessionCreations, 1)
|
||||
atomic.AddInt64(&pm.activeSessions, 1)
|
||||
}
|
||||
|
||||
// RecordSessionDeletion records a session deletion
|
||||
func (pm *PerformanceMetrics) RecordSessionDeletion() {
|
||||
atomic.AddInt64(&pm.sessionDeletions, 1)
|
||||
atomic.AddInt64(&pm.activeSessions, -1)
|
||||
}
|
||||
|
||||
// addVerificationTime adds a verification time measurement
|
||||
func (pm *PerformanceMetrics) addVerificationTime(duration time.Duration) {
|
||||
pm.timingMutex.Lock()
|
||||
defer pm.timingMutex.Unlock()
|
||||
|
||||
pm.verificationTimes = append(pm.verificationTimes, duration)
|
||||
if len(pm.verificationTimes) > 1000 {
|
||||
pm.verificationTimes = pm.verificationTimes[1:]
|
||||
}
|
||||
|
||||
pm.updateAverageVerificationTime()
|
||||
}
|
||||
|
||||
// addValidationTime adds a validation time measurement
|
||||
func (pm *PerformanceMetrics) addValidationTime(duration time.Duration) {
|
||||
pm.timingMutex.Lock()
|
||||
defer pm.timingMutex.Unlock()
|
||||
|
||||
pm.validationTimes = append(pm.validationTimes, duration)
|
||||
if len(pm.validationTimes) > 1000 {
|
||||
pm.validationTimes = pm.validationTimes[1:]
|
||||
}
|
||||
|
||||
pm.updateAverageValidationTime()
|
||||
}
|
||||
|
||||
// addRefreshTime adds a refresh time measurement
|
||||
func (pm *PerformanceMetrics) addRefreshTime(duration time.Duration) {
|
||||
pm.timingMutex.Lock()
|
||||
defer pm.timingMutex.Unlock()
|
||||
|
||||
pm.refreshTimes = append(pm.refreshTimes, duration)
|
||||
if len(pm.refreshTimes) > 1000 {
|
||||
pm.refreshTimes = pm.refreshTimes[1:]
|
||||
}
|
||||
|
||||
pm.updateAverageRefreshTime()
|
||||
}
|
||||
|
||||
// updateAverageVerificationTime calculates the average verification time
|
||||
func (pm *PerformanceMetrics) updateAverageVerificationTime() {
|
||||
if len(pm.verificationTimes) == 0 {
|
||||
pm.avgVerificationTime = 0
|
||||
return
|
||||
}
|
||||
|
||||
var total time.Duration
|
||||
for _, t := range pm.verificationTimes {
|
||||
total += t
|
||||
}
|
||||
pm.avgVerificationTime = total / time.Duration(len(pm.verificationTimes))
|
||||
}
|
||||
|
||||
// updateAverageValidationTime calculates the average validation time
|
||||
func (pm *PerformanceMetrics) updateAverageValidationTime() {
|
||||
if len(pm.validationTimes) == 0 {
|
||||
pm.avgValidationTime = 0
|
||||
return
|
||||
}
|
||||
|
||||
var total time.Duration
|
||||
for _, t := range pm.validationTimes {
|
||||
total += t
|
||||
}
|
||||
pm.avgValidationTime = total / time.Duration(len(pm.validationTimes))
|
||||
}
|
||||
|
||||
// updateAverageRefreshTime calculates the average refresh time
|
||||
func (pm *PerformanceMetrics) updateAverageRefreshTime() {
|
||||
if len(pm.refreshTimes) == 0 {
|
||||
pm.avgRefreshTime = 0
|
||||
return
|
||||
}
|
||||
|
||||
var total time.Duration
|
||||
for _, t := range pm.refreshTimes {
|
||||
total += t
|
||||
}
|
||||
pm.avgRefreshTime = total / time.Duration(len(pm.refreshTimes))
|
||||
}
|
||||
|
||||
// startMetricsCollection starts background collection of system metrics
|
||||
func (pm *PerformanceMetrics) startMetricsCollection() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
pm.collectSystemMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
// collectSystemMetrics collects system-level metrics
|
||||
func (pm *PerformanceMetrics) collectSystemMetrics() {
|
||||
// Memory statistics
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
atomic.StoreInt64(&pm.memoryUsage, int64(m.Alloc))
|
||||
atomic.StoreInt64(&pm.heapSize, int64(m.HeapSys))
|
||||
atomic.StoreInt64(&pm.heapInUse, int64(m.HeapInuse))
|
||||
atomic.StoreInt64(&pm.gcPauseTime, int64(m.PauseNs[(m.NumGC+255)%256]))
|
||||
|
||||
// Calculate memory pressure (0-100 scale)
|
||||
// Based on heap utilization and GC frequency
|
||||
heapUtilization := float64(m.HeapInuse) / float64(m.HeapSys)
|
||||
gcFrequency := float64(m.NumGC) / time.Since(pm.startTime).Minutes()
|
||||
|
||||
// Memory pressure calculation
|
||||
pressure := int64(heapUtilization * 50) // 0-50 based on heap utilization
|
||||
if gcFrequency > 10 { // High GC frequency indicates pressure
|
||||
pressure += int64((gcFrequency - 10) * 2) // Add up to 50 more
|
||||
}
|
||||
if pressure > 100 {
|
||||
pressure = 100
|
||||
}
|
||||
atomic.StoreInt64(&pm.memoryPressure, pressure)
|
||||
|
||||
// Goroutine count
|
||||
atomic.StoreInt64(&pm.goroutineCount, int64(runtime.NumGoroutine()))
|
||||
|
||||
// Log memory pressure warnings
|
||||
if pressure > 80 {
|
||||
pm.logger.Errorf("High memory pressure detected: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
|
||||
pressure, heapUtilization*100, gcFrequency)
|
||||
} else if pressure > 60 {
|
||||
pm.logger.Infof("Moderate memory pressure: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
|
||||
pressure, heapUtilization*100, gcFrequency)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns all current performance metrics
|
||||
func (pm *PerformanceMetrics) GetMetrics() map[string]interface{} {
|
||||
pm.timingMutex.RLock()
|
||||
defer pm.timingMutex.RUnlock()
|
||||
|
||||
// Calculate cache hit ratio
|
||||
hits := atomic.LoadInt64(&pm.cacheHits)
|
||||
misses := atomic.LoadInt64(&pm.cacheMisses)
|
||||
var hitRatio float64
|
||||
if hits+misses > 0 {
|
||||
hitRatio = float64(hits) / float64(hits+misses)
|
||||
}
|
||||
|
||||
// Calculate error rates
|
||||
verifications := atomic.LoadInt64(&pm.tokenVerifications)
|
||||
validations := atomic.LoadInt64(&pm.tokenValidations)
|
||||
refreshes := atomic.LoadInt64(&pm.tokenRefreshes)
|
||||
|
||||
var verificationErrorRate, validationErrorRate, refreshErrorRate float64
|
||||
|
||||
if verifications > 0 {
|
||||
verificationErrorRate = float64(atomic.LoadInt64(&pm.verificationErrors)) / float64(verifications)
|
||||
}
|
||||
if validations > 0 {
|
||||
validationErrorRate = float64(atomic.LoadInt64(&pm.validationErrors)) / float64(validations)
|
||||
}
|
||||
if refreshes > 0 {
|
||||
refreshErrorRate = float64(atomic.LoadInt64(&pm.refreshErrors)) / float64(refreshes)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
// Cache metrics
|
||||
"cache_hits": hits,
|
||||
"cache_misses": misses,
|
||||
"cache_hit_ratio": hitRatio,
|
||||
"cache_evictions": atomic.LoadInt64(&pm.cacheEvictions),
|
||||
"cache_size": atomic.LoadInt64(&pm.cacheSize),
|
||||
|
||||
// Token operation metrics
|
||||
"token_verifications": verifications,
|
||||
"token_validations": validations,
|
||||
"token_refreshes": refreshes,
|
||||
"verification_error_rate": verificationErrorRate,
|
||||
"validation_error_rate": validationErrorRate,
|
||||
"refresh_error_rate": refreshErrorRate,
|
||||
|
||||
// Success/failure metrics
|
||||
"successful_verifications": atomic.LoadInt64(&pm.successfulVerifications),
|
||||
"successful_validations": atomic.LoadInt64(&pm.successfulValidations),
|
||||
"successful_refreshes": atomic.LoadInt64(&pm.successfulRefreshes),
|
||||
"failed_verifications": atomic.LoadInt64(&pm.failedVerifications),
|
||||
"failed_validations": atomic.LoadInt64(&pm.failedValidations),
|
||||
"failed_refreshes": atomic.LoadInt64(&pm.failedRefreshes),
|
||||
|
||||
// Timing metrics
|
||||
"avg_verification_time_ms": pm.avgVerificationTime.Milliseconds(),
|
||||
"avg_validation_time_ms": pm.avgValidationTime.Milliseconds(),
|
||||
"avg_refresh_time_ms": pm.avgRefreshTime.Milliseconds(),
|
||||
|
||||
// Resource metrics
|
||||
"memory_usage_bytes": atomic.LoadInt64(&pm.memoryUsage),
|
||||
"memory_pressure": atomic.LoadInt64(&pm.memoryPressure),
|
||||
"heap_size_bytes": atomic.LoadInt64(&pm.heapSize),
|
||||
"heap_inuse_bytes": atomic.LoadInt64(&pm.heapInUse),
|
||||
"gc_pause_time_ns": atomic.LoadInt64(&pm.gcPauseTime),
|
||||
"goroutine_count": atomic.LoadInt64(&pm.goroutineCount),
|
||||
|
||||
// Rate limiting metrics
|
||||
"rate_limited_requests": atomic.LoadInt64(&pm.rateLimitedRequests),
|
||||
|
||||
// Session metrics
|
||||
"active_sessions": atomic.LoadInt64(&pm.activeSessions),
|
||||
"sessions_created": atomic.LoadInt64(&pm.sessionCreations),
|
||||
"sessions_deleted": atomic.LoadInt64(&pm.sessionDeletions),
|
||||
"session_creations": atomic.LoadInt64(&pm.sessionCreations),
|
||||
"session_deletions": atomic.LoadInt64(&pm.sessionDeletions),
|
||||
|
||||
// Uptime
|
||||
"uptime_seconds": time.Since(pm.startTime).Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetDetailedTimingMetrics returns detailed timing statistics
|
||||
func (pm *PerformanceMetrics) GetDetailedTimingMetrics() map[string]interface{} {
|
||||
pm.timingMutex.RLock()
|
||||
defer pm.timingMutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"verification_stats": pm.calculateTimingStats(pm.verificationTimes),
|
||||
"verification_timing": pm.calculateTimingStats(pm.verificationTimes),
|
||||
"validation_stats": pm.calculateTimingStats(pm.validationTimes),
|
||||
"validation_timing": pm.calculateTimingStats(pm.validationTimes),
|
||||
"refresh_stats": pm.calculateTimingStats(pm.refreshTimes),
|
||||
"refresh_timing": pm.calculateTimingStats(pm.refreshTimes),
|
||||
}
|
||||
}
|
||||
|
||||
// calculateTimingStats calculates statistical metrics for timing data
|
||||
func (pm *PerformanceMetrics) calculateTimingStats(times []time.Duration) map[string]interface{} {
|
||||
if len(times) == 0 {
|
||||
return map[string]interface{}{
|
||||
"count": 0,
|
||||
"min_ms": float64(0),
|
||||
"max_ms": float64(0),
|
||||
"avg_ms": float64(0),
|
||||
"average_ms": float64(0),
|
||||
"median_ms": float64(0),
|
||||
"p95_ms": float64(0),
|
||||
"p99_ms": float64(0),
|
||||
}
|
||||
}
|
||||
|
||||
// Sort times for percentile calculations
|
||||
sortedTimes := make([]time.Duration, len(times))
|
||||
copy(sortedTimes, times)
|
||||
|
||||
// Simple bubble sort for small arrays
|
||||
for i := 0; i < len(sortedTimes); i++ {
|
||||
for j := i + 1; j < len(sortedTimes); j++ {
|
||||
if sortedTimes[i] > sortedTimes[j] {
|
||||
sortedTimes[i], sortedTimes[j] = sortedTimes[j], sortedTimes[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate statistics
|
||||
min := sortedTimes[0]
|
||||
max := sortedTimes[len(sortedTimes)-1]
|
||||
|
||||
var total time.Duration
|
||||
for _, t := range sortedTimes {
|
||||
total += t
|
||||
}
|
||||
avg := total / time.Duration(len(sortedTimes))
|
||||
|
||||
median := sortedTimes[len(sortedTimes)/2]
|
||||
p95 := sortedTimes[int(float64(len(sortedTimes))*0.95)]
|
||||
p99 := sortedTimes[int(float64(len(sortedTimes))*0.99)]
|
||||
|
||||
return map[string]interface{}{
|
||||
"count": len(sortedTimes),
|
||||
"min_ms": float64(min.Nanoseconds()) / 1e6,
|
||||
"max_ms": float64(max.Nanoseconds()) / 1e6,
|
||||
"avg_ms": float64(avg.Nanoseconds()) / 1e6,
|
||||
"average_ms": float64(avg.Nanoseconds()) / 1e6,
|
||||
"median_ms": float64(median.Nanoseconds()) / 1e6,
|
||||
"p95_ms": float64(p95.Nanoseconds()) / 1e6,
|
||||
"p99_ms": float64(p99.Nanoseconds()) / 1e6,
|
||||
}
|
||||
}
|
||||
|
||||
// ResourceMonitor tracks resource usage and limits
|
||||
type ResourceMonitor struct {
|
||||
// Memory limits
|
||||
maxMemoryBytes int64
|
||||
|
||||
// Cache limits
|
||||
maxCacheSize int64
|
||||
|
||||
// Session limits
|
||||
maxSessions int64
|
||||
|
||||
// Cache size tracking
|
||||
cacheSizes map[string]int64
|
||||
cacheMutex sync.RWMutex
|
||||
|
||||
// Monitoring state
|
||||
alertThresholds map[string]float64
|
||||
alerts []ResourceAlert
|
||||
alertsMutex sync.RWMutex
|
||||
|
||||
// Performance metrics reference
|
||||
perfMetrics *PerformanceMetrics
|
||||
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// ResourceAlert represents a resource usage alert
|
||||
type ResourceAlert struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
CurrentValue float64 `json:"current_value"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// NewResourceMonitor creates a new resource monitor
|
||||
func NewResourceMonitor(perfMetrics *PerformanceMetrics, logger *Logger) *ResourceMonitor {
|
||||
rm := &ResourceMonitor{
|
||||
maxMemoryBytes: 100 * 1024 * 1024, // 100MB default
|
||||
maxCacheSize: 10000, // 10k items default
|
||||
maxSessions: 1000, // 1k sessions default
|
||||
cacheSizes: make(map[string]int64),
|
||||
alertThresholds: map[string]float64{
|
||||
"memory_usage": 0.8, // 80%
|
||||
"memory_pressure": 0.7, // 70%
|
||||
"cache_usage": 0.9, // 90%
|
||||
"session_usage": 0.85, // 85%
|
||||
"error_rate": 0.1, // 10%
|
||||
},
|
||||
alerts: make([]ResourceAlert, 0),
|
||||
perfMetrics: perfMetrics,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Start monitoring routine
|
||||
go rm.startMonitoring()
|
||||
|
||||
return rm
|
||||
}
|
||||
|
||||
// SetMemoryLimit sets the maximum memory usage limit
|
||||
func (rm *ResourceMonitor) SetMemoryLimit(bytes int64) {
|
||||
rm.maxMemoryBytes = bytes
|
||||
}
|
||||
|
||||
// SetCacheLimit sets the maximum cache size limit
|
||||
func (rm *ResourceMonitor) SetCacheLimit(size int64) {
|
||||
rm.maxCacheSize = size
|
||||
}
|
||||
|
||||
// SetSessionLimit sets the maximum session count limit
|
||||
func (rm *ResourceMonitor) SetSessionLimit(count int64) {
|
||||
rm.maxSessions = count
|
||||
}
|
||||
|
||||
// UpdateCacheSize updates the size of a specific cache
|
||||
func (rm *ResourceMonitor) UpdateCacheSize(cacheName string, size int64) {
|
||||
rm.cacheMutex.Lock()
|
||||
defer rm.cacheMutex.Unlock()
|
||||
rm.cacheSizes[cacheName] = size
|
||||
}
|
||||
|
||||
// GetCacheSizes returns current cache sizes
|
||||
func (rm *ResourceMonitor) GetCacheSizes() map[string]int64 {
|
||||
rm.cacheMutex.RLock()
|
||||
defer rm.cacheMutex.RUnlock()
|
||||
|
||||
sizes := make(map[string]int64)
|
||||
for name, size := range rm.cacheSizes {
|
||||
sizes[name] = size
|
||||
}
|
||||
return sizes
|
||||
}
|
||||
|
||||
// startMonitoring starts the background monitoring routine
|
||||
func (rm *ResourceMonitor) startMonitoring() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rm.checkResourceUsage()
|
||||
}
|
||||
}
|
||||
|
||||
// checkResourceUsage checks current resource usage against limits
|
||||
func (rm *ResourceMonitor) checkResourceUsage() {
|
||||
metrics := rm.perfMetrics.GetMetrics()
|
||||
|
||||
// Check memory usage
|
||||
if memUsage, ok := metrics["memory_usage_bytes"].(int64); ok {
|
||||
memUsageRatio := float64(memUsage) / float64(rm.maxMemoryBytes)
|
||||
if memUsageRatio > rm.alertThresholds["memory_usage"] {
|
||||
rm.addAlert(ResourceAlert{
|
||||
Type: "memory_usage",
|
||||
Message: "Memory usage exceeds threshold",
|
||||
Threshold: rm.alertThresholds["memory_usage"],
|
||||
CurrentValue: memUsageRatio,
|
||||
Timestamp: time.Now(),
|
||||
Severity: rm.getSeverity(memUsageRatio, rm.alertThresholds["memory_usage"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check memory pressure
|
||||
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
|
||||
pressureRatio := float64(memPressure) / 100.0 // Convert to 0-1 scale
|
||||
if pressureRatio > rm.alertThresholds["memory_pressure"] {
|
||||
rm.addAlert(ResourceAlert{
|
||||
Type: "memory_pressure",
|
||||
Message: "Memory pressure exceeds threshold",
|
||||
Threshold: rm.alertThresholds["memory_pressure"],
|
||||
CurrentValue: pressureRatio,
|
||||
Timestamp: time.Now(),
|
||||
Severity: rm.getSeverity(pressureRatio, rm.alertThresholds["memory_pressure"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache usage
|
||||
if cacheSize, ok := metrics["cache_size"].(int64); ok {
|
||||
cacheUsageRatio := float64(cacheSize) / float64(rm.maxCacheSize)
|
||||
if cacheUsageRatio > rm.alertThresholds["cache_usage"] {
|
||||
rm.addAlert(ResourceAlert{
|
||||
Type: "cache_usage",
|
||||
Message: "Cache usage exceeds threshold",
|
||||
Threshold: rm.alertThresholds["cache_usage"],
|
||||
CurrentValue: cacheUsageRatio,
|
||||
Timestamp: time.Now(),
|
||||
Severity: rm.getSeverity(cacheUsageRatio, rm.alertThresholds["cache_usage"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check session usage
|
||||
if activeSessions, ok := metrics["active_sessions"].(int64); ok {
|
||||
sessionUsageRatio := float64(activeSessions) / float64(rm.maxSessions)
|
||||
if sessionUsageRatio > rm.alertThresholds["session_usage"] {
|
||||
rm.addAlert(ResourceAlert{
|
||||
Type: "session_usage",
|
||||
Message: "Active session count exceeds threshold",
|
||||
Threshold: rm.alertThresholds["session_usage"],
|
||||
CurrentValue: sessionUsageRatio,
|
||||
Timestamp: time.Now(),
|
||||
Severity: rm.getSeverity(sessionUsageRatio, rm.alertThresholds["session_usage"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check error rates
|
||||
if errorRate, ok := metrics["verification_error_rate"].(float64); ok {
|
||||
if errorRate > rm.alertThresholds["error_rate"] {
|
||||
rm.addAlert(ResourceAlert{
|
||||
Type: "verification_error_rate",
|
||||
Message: "Token verification error rate exceeds threshold",
|
||||
Threshold: rm.alertThresholds["error_rate"],
|
||||
CurrentValue: errorRate,
|
||||
Timestamp: time.Now(),
|
||||
Severity: rm.getSeverity(errorRate, rm.alertThresholds["error_rate"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getSeverity determines the severity level based on how much the threshold is exceeded
|
||||
func (rm *ResourceMonitor) getSeverity(currentValue, threshold float64) string {
|
||||
ratio := currentValue / threshold
|
||||
if ratio >= 1.5 {
|
||||
return "critical"
|
||||
} else if ratio >= 1.2 {
|
||||
return "high"
|
||||
} else if ratio >= 1.0 {
|
||||
return "medium"
|
||||
}
|
||||
return "low"
|
||||
}
|
||||
|
||||
// addAlert adds a new resource alert
|
||||
func (rm *ResourceMonitor) addAlert(alert ResourceAlert) {
|
||||
rm.alertsMutex.Lock()
|
||||
defer rm.alertsMutex.Unlock()
|
||||
|
||||
// Add alert
|
||||
rm.alerts = append(rm.alerts, alert)
|
||||
|
||||
// Keep only last 100 alerts
|
||||
if len(rm.alerts) > 100 {
|
||||
rm.alerts = rm.alerts[1:]
|
||||
}
|
||||
|
||||
// Log the alert
|
||||
rm.logger.Errorf("Resource Alert [%s/%s]: %s (%.2f%% > %.2f%%)",
|
||||
alert.Type, alert.Severity, alert.Message,
|
||||
alert.CurrentValue*100, alert.Threshold*100)
|
||||
}
|
||||
|
||||
// GetAlerts returns current resource alerts
|
||||
func (rm *ResourceMonitor) GetAlerts() []ResourceAlert {
|
||||
rm.alertsMutex.RLock()
|
||||
defer rm.alertsMutex.RUnlock()
|
||||
|
||||
alerts := make([]ResourceAlert, len(rm.alerts))
|
||||
copy(alerts, rm.alerts)
|
||||
return alerts
|
||||
}
|
||||
|
||||
// GetResourceStatus returns current resource status
|
||||
func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
|
||||
metrics := rm.perfMetrics.GetMetrics()
|
||||
cacheSizes := rm.GetCacheSizes()
|
||||
|
||||
status := map[string]interface{}{
|
||||
"limits": map[string]interface{}{
|
||||
"max_memory_bytes": rm.maxMemoryBytes,
|
||||
"max_cache_size": rm.maxCacheSize,
|
||||
"max_sessions": rm.maxSessions,
|
||||
},
|
||||
"thresholds": rm.alertThresholds,
|
||||
"current": metrics,
|
||||
"cache_sizes": cacheSizes,
|
||||
// Add expected keys for tests
|
||||
"memory_limit": uint64(rm.maxMemoryBytes),
|
||||
"cache_limit": int(rm.maxCacheSize),
|
||||
"session_limit": int(rm.maxSessions),
|
||||
}
|
||||
|
||||
// Calculate usage ratios
|
||||
if memUsage, ok := metrics["memory_usage_bytes"].(int64); ok {
|
||||
status["memory_usage_ratio"] = float64(memUsage) / float64(rm.maxMemoryBytes)
|
||||
}
|
||||
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
|
||||
status["memory_pressure_ratio"] = float64(memPressure) / 100.0
|
||||
}
|
||||
if cacheSize, ok := metrics["cache_size"].(int64); ok {
|
||||
status["cache_usage_ratio"] = float64(cacheSize) / float64(rm.maxCacheSize)
|
||||
}
|
||||
if activeSessions, ok := metrics["active_sessions"].(int64); ok {
|
||||
status["session_usage_ratio"] = float64(activeSessions) / float64(rm.maxSessions)
|
||||
}
|
||||
|
||||
// Calculate total cache size across all caches
|
||||
var totalCacheSize int64
|
||||
for _, size := range cacheSizes {
|
||||
totalCacheSize += size
|
||||
}
|
||||
status["total_cache_size"] = totalCacheSize
|
||||
|
||||
return status
|
||||
}
|
||||
@@ -1,324 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPerformanceMetrics(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger)
|
||||
|
||||
t.Run("Record cache operations", func(t *testing.T) {
|
||||
metrics.RecordCacheHit()
|
||||
metrics.RecordCacheMiss()
|
||||
metrics.RecordCacheEviction()
|
||||
metrics.UpdateCacheSize(100)
|
||||
|
||||
result := metrics.GetMetrics()
|
||||
|
||||
if result["cache_hits"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 cache hit, got %v", result["cache_hits"])
|
||||
}
|
||||
if result["cache_misses"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 cache miss, got %v", result["cache_misses"])
|
||||
}
|
||||
if result["cache_evictions"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 cache eviction, got %v", result["cache_evictions"])
|
||||
}
|
||||
if result["cache_size"].(int64) != 100 {
|
||||
t.Errorf("Expected cache size 100, got %v", result["cache_size"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Record token operations", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
metrics.RecordTokenVerification(time.Since(start), true)
|
||||
|
||||
start = time.Now()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
metrics.RecordTokenValidation(time.Since(start), false)
|
||||
|
||||
start = time.Now()
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
metrics.RecordTokenRefresh(time.Since(start), true)
|
||||
|
||||
result := metrics.GetMetrics()
|
||||
|
||||
if result["token_verifications"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 token verification, got %v", result["token_verifications"])
|
||||
}
|
||||
if result["token_validations"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 token validation, got %v", result["token_validations"])
|
||||
}
|
||||
if result["token_refreshes"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 token refresh, got %v", result["token_refreshes"])
|
||||
}
|
||||
if result["successful_verifications"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 successful verification, got %v", result["successful_verifications"])
|
||||
}
|
||||
if result["failed_validations"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 failed validation, got %v", result["failed_validations"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Record rate limiting and sessions", func(t *testing.T) {
|
||||
metrics.RecordRateLimitedRequest()
|
||||
metrics.RecordSessionCreation()
|
||||
metrics.RecordSessionDeletion()
|
||||
|
||||
result := metrics.GetMetrics()
|
||||
|
||||
if result["rate_limited_requests"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 rate limited request, got %v", result["rate_limited_requests"])
|
||||
}
|
||||
if result["sessions_created"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 session created, got %v", result["sessions_created"])
|
||||
}
|
||||
if result["sessions_deleted"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 session deleted, got %v", result["sessions_deleted"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get detailed timing metrics", func(t *testing.T) {
|
||||
// Add more timing data
|
||||
for i := 0; i < 5; i++ {
|
||||
metrics.RecordTokenVerification(time.Duration(i+1)*time.Millisecond, true)
|
||||
}
|
||||
|
||||
detailed := metrics.GetDetailedTimingMetrics()
|
||||
|
||||
if detailed["verification_stats"] == nil {
|
||||
t.Error("Expected verification stats to be present")
|
||||
}
|
||||
|
||||
verificationStats := detailed["verification_stats"].(map[string]interface{})
|
||||
if verificationStats["count"].(int) != 6 { // 1 from previous test + 5 new
|
||||
t.Errorf("Expected 6 verifications, got %v", verificationStats["count"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResourceMonitor(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger)
|
||||
monitor := NewResourceMonitor(metrics, logger)
|
||||
|
||||
t.Run("Set limits", func(t *testing.T) {
|
||||
monitor.SetMemoryLimit(100 * 1024 * 1024) // 100MB
|
||||
monitor.SetCacheLimit(1000)
|
||||
monitor.SetSessionLimit(500)
|
||||
|
||||
// Should not panic
|
||||
})
|
||||
|
||||
t.Run("Get resource status", func(t *testing.T) {
|
||||
status := monitor.GetResourceStatus()
|
||||
|
||||
if status["memory_limit"] == nil {
|
||||
t.Error("Expected memory limit to be set")
|
||||
}
|
||||
if status["cache_limit"] == nil {
|
||||
t.Error("Expected cache limit to be set")
|
||||
}
|
||||
if status["session_limit"] == nil {
|
||||
t.Error("Expected session limit to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get alerts", func(t *testing.T) {
|
||||
alerts := monitor.GetAlerts()
|
||||
|
||||
// Should return empty slice initially
|
||||
if alerts == nil {
|
||||
t.Error("Expected alerts slice to be initialized")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPerformanceMetricsCalculations(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger)
|
||||
|
||||
t.Run("Average calculation", func(t *testing.T) {
|
||||
// Record multiple operations with known durations
|
||||
durations := []time.Duration{
|
||||
10 * time.Millisecond,
|
||||
20 * time.Millisecond,
|
||||
30 * time.Millisecond,
|
||||
}
|
||||
|
||||
for _, d := range durations {
|
||||
metrics.RecordTokenVerification(d, true)
|
||||
}
|
||||
|
||||
detailed := metrics.GetDetailedTimingMetrics()
|
||||
verificationStats := detailed["verification_stats"].(map[string]interface{})
|
||||
|
||||
// Average should be 20ms
|
||||
avgMs := verificationStats["average_ms"].(float64)
|
||||
if avgMs < 19 || avgMs > 21 { // Allow small variance
|
||||
t.Errorf("Expected average around 20ms, got %f", avgMs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Min/Max calculation", func(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger) // Fresh instance
|
||||
|
||||
durations := []time.Duration{
|
||||
5 * time.Millisecond,
|
||||
50 * time.Millisecond,
|
||||
25 * time.Millisecond,
|
||||
}
|
||||
|
||||
for _, d := range durations {
|
||||
metrics.RecordTokenVerification(d, true)
|
||||
}
|
||||
|
||||
detailed := metrics.GetDetailedTimingMetrics()
|
||||
verificationStats := detailed["verification_stats"].(map[string]interface{})
|
||||
|
||||
minMs := verificationStats["min_ms"].(float64)
|
||||
maxMs := verificationStats["max_ms"].(float64)
|
||||
|
||||
if minMs < 4 || minMs > 6 {
|
||||
t.Errorf("Expected min around 5ms, got %f", minMs)
|
||||
}
|
||||
if maxMs < 49 || maxMs > 51 {
|
||||
t.Errorf("Expected max around 50ms, got %f", maxMs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPerformanceMetricsReset(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger)
|
||||
|
||||
// Record some data
|
||||
metrics.RecordCacheHit()
|
||||
metrics.RecordTokenVerification(10*time.Millisecond, true)
|
||||
|
||||
// Verify data is there
|
||||
result := metrics.GetMetrics()
|
||||
if result["cache_hits"].(int64) != 1 {
|
||||
t.Error("Expected cache hit to be recorded")
|
||||
}
|
||||
|
||||
// Note: The current implementation doesn't have a reset method,
|
||||
// but we can test that metrics accumulate correctly
|
||||
metrics.RecordCacheHit()
|
||||
result = metrics.GetMetrics()
|
||||
if result["cache_hits"].(int64) != 2 {
|
||||
t.Error("Expected cache hits to accumulate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformanceMetricsConcurrency(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger)
|
||||
|
||||
// Test concurrent access
|
||||
done := make(chan bool, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := 0; j < 100; j++ {
|
||||
metrics.RecordCacheHit()
|
||||
metrics.RecordTokenVerification(time.Millisecond, true)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
result := metrics.GetMetrics()
|
||||
|
||||
// Should have 1000 cache hits (10 goroutines * 100 operations)
|
||||
if result["cache_hits"].(int64) != 1000 {
|
||||
t.Errorf("Expected 1000 cache hits, got %v", result["cache_hits"])
|
||||
}
|
||||
|
||||
// Should have 1000 token verifications
|
||||
if result["token_verifications"].(int64) != 1000 {
|
||||
t.Errorf("Expected 1000 token verifications, got %v", result["token_verifications"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceMonitorLimits(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger)
|
||||
monitor := NewResourceMonitor(metrics, logger)
|
||||
|
||||
t.Run("Memory limit validation", func(t *testing.T) {
|
||||
// Set a reasonable memory limit
|
||||
monitor.SetMemoryLimit(50 * 1024 * 1024) // 50MB
|
||||
|
||||
status := monitor.GetResourceStatus()
|
||||
if status["memory_limit"].(uint64) != 50*1024*1024 {
|
||||
t.Error("Memory limit not set correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Cache limit validation", func(t *testing.T) {
|
||||
monitor.SetCacheLimit(2000)
|
||||
|
||||
status := monitor.GetResourceStatus()
|
||||
if status["cache_limit"].(int) != 2000 {
|
||||
t.Error("Cache limit not set correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Session limit validation", func(t *testing.T) {
|
||||
monitor.SetSessionLimit(1000)
|
||||
|
||||
status := monitor.GetResourceStatus()
|
||||
if status["session_limit"].(int) != 1000 {
|
||||
t.Error("Session limit not set correctly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPerformanceMetricsEdgeCases(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
metrics := NewPerformanceMetrics(logger)
|
||||
|
||||
t.Run("Zero duration handling", func(t *testing.T) {
|
||||
metrics.RecordTokenVerification(0, true)
|
||||
|
||||
result := metrics.GetMetrics()
|
||||
if result["token_verifications"].(int64) != 1 {
|
||||
t.Error("Should record verification even with zero duration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Very large duration handling", func(t *testing.T) {
|
||||
largeDuration := time.Hour
|
||||
metrics.RecordTokenVerification(largeDuration, true)
|
||||
|
||||
detailed := metrics.GetDetailedTimingMetrics()
|
||||
verificationStats := detailed["verification_stats"].(map[string]interface{})
|
||||
|
||||
// Should handle large durations without overflow
|
||||
if verificationStats["max_ms"].(float64) <= 0 {
|
||||
t.Error("Should handle large durations correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Negative cache size handling", func(t *testing.T) {
|
||||
// This shouldn't happen in practice, but test robustness
|
||||
metrics.UpdateCacheSize(-1)
|
||||
|
||||
result := metrics.GetMetrics()
|
||||
// Implementation should handle this gracefully
|
||||
if result["cache_size"] == nil {
|
||||
t.Error("Cache size should be present even if negative")
|
||||
}
|
||||
})
|
||||
}
|
||||
+24
-27
@@ -22,8 +22,8 @@ func TestConcurrentTokenVerification(t *testing.T) {
|
||||
|
||||
// Create multiple valid tokens to avoid replay detection
|
||||
tokens := make([]string, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
for i := range 10 {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -71,11 +71,11 @@ func TestConcurrentTokenVerification(t *testing.T) {
|
||||
var errorCount int64
|
||||
errors := make(chan error, numGoroutines*verificationsPerGoroutine)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < verificationsPerGoroutine; j++ {
|
||||
for j := range verificationsPerGoroutine {
|
||||
tokenIndex := (goroutineID*verificationsPerGoroutine + j) % len(tokens)
|
||||
err := tOidc.VerifyToken(tokens[tokenIndex])
|
||||
if err != nil {
|
||||
@@ -144,8 +144,8 @@ func TestCacheMemoryExhaustion(t *testing.T) {
|
||||
const numTokens = 500
|
||||
tokens := make([]string, numTokens)
|
||||
|
||||
for i := 0; i < numTokens; i++ {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
for i := range numTokens {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -161,7 +161,7 @@ func TestCacheMemoryExhaustion(t *testing.T) {
|
||||
tokens[i] = token
|
||||
|
||||
// Add to cache
|
||||
claims := map[string]interface{}{
|
||||
claims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -210,7 +210,7 @@ func TestSessionConcurrencyProtection(t *testing.T) {
|
||||
var successCount int64
|
||||
var errorCount int64
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
@@ -218,7 +218,7 @@ func TestSessionConcurrencyProtection(t *testing.T) {
|
||||
// Each goroutine gets its own request and session
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
for j := range operationsPerGoroutine {
|
||||
// Get a fresh session for each operation
|
||||
s, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -229,7 +229,7 @@ func TestSessionConcurrencyProtection(t *testing.T) {
|
||||
// Perform operations on session
|
||||
s.SetEmail(fmt.Sprintf("user%d-%d@example.com", goroutineID, j))
|
||||
s.SetAuthenticated(true)
|
||||
s.SetAccessToken(fmt.Sprintf("token-%d-%d", goroutineID, j))
|
||||
s.SetAccessToken(ValidAccessToken)
|
||||
|
||||
// Save session
|
||||
testRR := httptest.NewRecorder()
|
||||
@@ -276,11 +276,11 @@ func TestParallelCacheOperations(t *testing.T) {
|
||||
var deleteCount int64
|
||||
|
||||
// Start multiple goroutines performing cache operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
for j := range operationsPerGoroutine {
|
||||
key := fmt.Sprintf("key-%d-%d", goroutineID, j)
|
||||
value := fmt.Sprintf("value-%d-%d", goroutineID, j)
|
||||
|
||||
@@ -377,7 +377,7 @@ func TestOversizedTokenHandling(t *testing.T) {
|
||||
|
||||
// Create an oversized token with large claims
|
||||
largeClaim := strings.Repeat("x", 10000) // 10KB claim
|
||||
oversizedClaims := map[string]interface{}{
|
||||
oversizedClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -411,7 +411,7 @@ func TestOversizedTokenHandling(t *testing.T) {
|
||||
|
||||
// Test extremely long token (beyond reasonable limits)
|
||||
extremelyLongClaim := strings.Repeat("y", 100000) // 100KB claim
|
||||
extremeClaims := map[string]interface{}{
|
||||
extremeClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -528,7 +528,7 @@ func TestMaliciousInputValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify the system is still functional after malicious input
|
||||
validToken, createErr := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validToken, createErr := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -609,7 +609,7 @@ func TestResourceLimits(t *testing.T) {
|
||||
defer cache.Close()
|
||||
|
||||
// Try to overwhelm the cache
|
||||
for i := 0; i < 1000; i++ {
|
||||
for i := range 1000 {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
value := fmt.Sprintf("value-%d", i)
|
||||
cache.Set(key, value, time.Minute)
|
||||
@@ -627,7 +627,7 @@ func TestResourceLimits(t *testing.T) {
|
||||
denied := 0
|
||||
|
||||
// Make many requests quickly
|
||||
for i := 0; i < 100; i++ {
|
||||
for range 100 {
|
||||
if limiter.Allow() {
|
||||
allowed++
|
||||
} else {
|
||||
@@ -651,14 +651,11 @@ func TestErrorRecoveryPatterns(t *testing.T) {
|
||||
|
||||
// Test recovery from cache corruption
|
||||
t.Run("CacheCorruption", func(t *testing.T) {
|
||||
// Corrupt the cache by setting invalid data
|
||||
ts.tOidc.tokenCache.cache.items["corrupted"] = CacheItem{
|
||||
Value: "invalid-data",
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
// Corrupt the cache by using the Set method to avoid data race
|
||||
ts.tOidc.tokenCache.cache.Set("corrupted", "invalid-data", time.Hour)
|
||||
|
||||
// System should handle corrupted cache gracefully
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -684,7 +681,7 @@ func TestErrorRecoveryPatterns(t *testing.T) {
|
||||
ts.tOidc.tokenBlacklist.Set("corrupted-entry", "invalid-data", time.Hour)
|
||||
|
||||
// System should still function
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -716,8 +713,8 @@ func TestPerformanceUnderLoad(t *testing.T) {
|
||||
// Create multiple valid tokens
|
||||
const numTokens = 100
|
||||
tokens := make([]string, numTokens)
|
||||
for i := 0; i < numTokens; i++ {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
for i := range numTokens {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -760,7 +757,7 @@ func TestPerformanceUnderLoad(t *testing.T) {
|
||||
const iterations = 1000
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
for i := range iterations {
|
||||
tokenIndex := i % numTokens
|
||||
err := tOidc.VerifyToken(tokens[tokenIndex])
|
||||
if err != nil {
|
||||
|
||||
+124
@@ -0,0 +1,124 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMergeScopes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
defaultScopes []string
|
||||
userScopes []string
|
||||
expectedScopes []string
|
||||
}{
|
||||
{
|
||||
name: "Empty user scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: []string{},
|
||||
expectedScopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Non-overlapping scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: []string{"roles", "custom_scope"},
|
||||
expectedScopes: []string{"openid", "profile", "email", "roles", "custom_scope"},
|
||||
},
|
||||
{
|
||||
name: "Overlapping scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: []string{"openid", "roles", "profile", "permissions"},
|
||||
expectedScopes: []string{"openid", "profile", "email", "roles", "permissions"},
|
||||
},
|
||||
{
|
||||
name: "Nil user scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: nil,
|
||||
expectedScopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Nil default scopes",
|
||||
defaultScopes: nil,
|
||||
userScopes: []string{"roles", "custom_scope"},
|
||||
expectedScopes: []string{"roles", "custom_scope"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := mergeScopes(tc.defaultScopes, tc.userScopes)
|
||||
if !reflect.DeepEqual(result, tc.expectedScopes) {
|
||||
t.Errorf("Expected %v, got %v", tc.expectedScopes, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopesConfiguration(t *testing.T) {
|
||||
defaultScopes := []string{"openid", "profile", "email"}
|
||||
userScopes := []string{"roles", "custom_scope"}
|
||||
|
||||
t.Run("Default Append Behavior", func(t *testing.T) {
|
||||
// Create config with user scopes but overrideScopes=false
|
||||
config := &Config{
|
||||
Scopes: userScopes,
|
||||
OverrideScopes: false,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect defaultScopes + userScopes with deduplication
|
||||
expectedScopes := []string{"openid", "profile", "email", "roles", "custom_scope"}
|
||||
if !reflect.DeepEqual(result, expectedScopes) {
|
||||
t.Errorf("Expected %v, got %v", expectedScopes, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Override Behavior", func(t *testing.T) {
|
||||
// Create config with user scopes and overrideScopes=true
|
||||
config := &Config{
|
||||
Scopes: userScopes,
|
||||
OverrideScopes: true,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect only userScopes
|
||||
if !reflect.DeepEqual(result, userScopes) {
|
||||
t.Errorf("Expected %v, got %v", userScopes, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty Scopes with Override", func(t *testing.T) {
|
||||
// Create config with empty scopes and overrideScopes=true
|
||||
config := &Config{
|
||||
Scopes: []string{},
|
||||
OverrideScopes: true,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect empty scopes - check length instead of DeepEqual
|
||||
if len(result) != 0 {
|
||||
t.Errorf("Expected empty slice, got %v with length %d", result, len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
+58
-36
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -25,7 +26,7 @@ func TestJWTAlgorithmConfusionAttack(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a standard JWT with RS256 algorithm
|
||||
validRS256JWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validRS256JWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -51,7 +52,7 @@ func TestJWTAlgorithmConfusionAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse header
|
||||
var header map[string]interface{}
|
||||
var header map[string]any
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
t.Fatalf("Failed to unmarshal header: %v", err)
|
||||
}
|
||||
@@ -90,7 +91,7 @@ func TestJWTNoneAlgorithmAttack(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a standard JWT
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -116,7 +117,7 @@ func TestJWTNoneAlgorithmAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse header
|
||||
var header map[string]interface{}
|
||||
var header map[string]any
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
t.Fatalf("Failed to unmarshal header: %v", err)
|
||||
}
|
||||
@@ -154,7 +155,7 @@ func TestJWTTokenTampering(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a standard JWT
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -180,7 +181,7 @@ func TestJWTTokenTampering(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse claims
|
||||
var claims map[string]interface{}
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(claimsBytes, &claims); err != nil {
|
||||
t.Fatalf("Failed to unmarshal claims: %v", err)
|
||||
}
|
||||
@@ -219,7 +220,7 @@ func TestJWTExpiredToken(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT that is already expired
|
||||
expiredJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
expiredJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(-1 * time.Hour).Unix()), // Expired 1 hour ago
|
||||
@@ -252,7 +253,7 @@ func TestJWTFutureToken(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT with a future issuance time
|
||||
futureJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
futureJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(2 * time.Hour).Unix()),
|
||||
@@ -315,7 +316,7 @@ func TestJWTReplayAttack(t *testing.T) {
|
||||
fixedJTI := "fixed-test-jti-for-replay-" + generateRandomString(8)
|
||||
|
||||
// Create a JWT with the fixed JTI
|
||||
replayJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
replayJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -389,8 +390,8 @@ func TestMissingClaims(t *testing.T) {
|
||||
// Test cases for missing claims
|
||||
testCases := []struct {
|
||||
name string
|
||||
omittedClaims []string
|
||||
expectedError string
|
||||
omittedClaims []string
|
||||
}{
|
||||
{
|
||||
name: "Missing Issuer",
|
||||
@@ -422,7 +423,7 @@ func TestMissingClaims(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create standard claims
|
||||
claims := map[string]interface{}{
|
||||
claims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -479,8 +480,8 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
// Set up the attacker's session with malicious data
|
||||
attackerSession.SetAuthenticated(true)
|
||||
attackerSession.SetEmail("attacker@evil.com")
|
||||
attackerSession.SetIDToken("fake-id-token")
|
||||
attackerSession.SetAccessToken("fake-access-token")
|
||||
attackerSession.SetIDToken(ValidIDToken)
|
||||
attackerSession.SetAccessToken(ValidAccessToken)
|
||||
|
||||
// Save the session to get cookies
|
||||
if err := attackerSession.Save(req, resp); err != nil {
|
||||
@@ -510,6 +511,31 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create keys for JWT verification
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
// Create JWK
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // 65537 in bytes
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
// Create mock JWK cache
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
// Create the TraefikOidc middleware
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
@@ -519,6 +545,8 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
@@ -528,7 +556,13 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
|
||||
// Set up the token verifier and JWT verifier
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
close(tOidc.initComplete)
|
||||
|
||||
// Now create a victim's request with the attacker's cookies
|
||||
@@ -565,13 +599,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
// - The response is unauthorized (401), OR
|
||||
// - The token verification failed
|
||||
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized, http.StatusForbidden}
|
||||
codeFound := false
|
||||
for _, code := range expectedCodes {
|
||||
if victimResp.Code == code {
|
||||
codeFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
codeFound := slices.Contains(expectedCodes, victimResp.Code)
|
||||
|
||||
if !codeFound {
|
||||
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, victimResp.Code)
|
||||
@@ -796,7 +824,7 @@ func TestTokenBlacklisting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a valid JWT
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -886,7 +914,7 @@ func TestDifferentSigningAlgorithms(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Define standard claims with unique JTI for each test
|
||||
standardClaims := map[string]interface{}{
|
||||
standardClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -989,9 +1017,9 @@ func TestDifferentSigningAlgorithms(t *testing.T) {
|
||||
}
|
||||
|
||||
// createTestJWTWithECKey creates a JWT signed with an EC private key
|
||||
func createTestJWTWithECKey(privateKey *ecdsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
||||
func createTestJWTWithECKey(privateKey *ecdsa.PrivateKey, alg, kid string, claims map[string]any) (string, error) {
|
||||
// Create the header
|
||||
header := map[string]interface{}{
|
||||
header := map[string]any{
|
||||
"alg": alg,
|
||||
"typ": "JWT",
|
||||
"kid": kid,
|
||||
@@ -1248,7 +1276,7 @@ func TestRateLimiting(t *testing.T) {
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Create a valid JWT token
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1268,7 +1296,7 @@ func TestRateLimiting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Second request should succeed
|
||||
validJWT2, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validJWT2, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1287,7 +1315,7 @@ func TestRateLimiting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
validJWT3, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
validJWT3, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1371,13 +1399,7 @@ func TestAuthorizationHeaderBypass(t *testing.T) {
|
||||
|
||||
// Verify that the response is a redirect to authentication (302) or unauthorized (401)
|
||||
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized}
|
||||
codeFound := false
|
||||
for _, code := range expectedCodes {
|
||||
if resp.Code == code {
|
||||
codeFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
codeFound := slices.Contains(expectedCodes, resp.Code)
|
||||
|
||||
if !codeFound {
|
||||
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, resp.Code)
|
||||
@@ -1390,7 +1412,7 @@ func TestEmptyAudience(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT with empty audience
|
||||
emptyAudJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
emptyAudJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "", // Empty audience
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1423,7 +1445,7 @@ func TestEmptyIssuer(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT with empty issuer
|
||||
emptyIssJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
emptyIssJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "", // Empty issuer
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
|
||||
+170
-176
@@ -6,73 +6,96 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityEventType represents different types of security events
|
||||
type SecurityEventType string
|
||||
|
||||
const (
|
||||
// AuthFailure represents an authentication failure event
|
||||
AuthFailure SecurityEventType = "authentication_failure"
|
||||
// TokenValidFailure represents a token validation failure event
|
||||
TokenValidFailure SecurityEventType = "token_validation_failure"
|
||||
// RateLimitHit represents a rate limit hit event
|
||||
RateLimitHit SecurityEventType = "rate_limit_hit"
|
||||
// SuspiciousActivity represents a suspicious activity event
|
||||
SuspiciousActivity SecurityEventType = "suspicious_activity"
|
||||
)
|
||||
|
||||
// DefaultSeverity returns the default severity level for a security event type
|
||||
func (t SecurityEventType) DefaultSeverity() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "medium"
|
||||
case TokenValidFailure:
|
||||
return "medium"
|
||||
case RateLimitHit:
|
||||
return "low"
|
||||
case SuspiciousActivity:
|
||||
return "high"
|
||||
default:
|
||||
return "medium"
|
||||
}
|
||||
}
|
||||
|
||||
// IPFailureType returns the IP failure tracking type for a security event type
|
||||
func (t SecurityEventType) IPFailureType() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "auth_failure"
|
||||
case TokenValidFailure:
|
||||
return "token_failure"
|
||||
case SuspiciousActivity:
|
||||
return "suspicious"
|
||||
default:
|
||||
return "general"
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityEvent represents a security-related event that should be logged and monitored
|
||||
type SecurityEvent struct {
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SecurityMonitor tracks security events and suspicious activity patterns
|
||||
type SecurityMonitor struct {
|
||||
// Event counters
|
||||
authFailures int64
|
||||
tokenValidationFails int64
|
||||
rateLimitHits int64
|
||||
suspiciousRequests int64
|
||||
|
||||
// IP-based tracking
|
||||
ipFailures map[string]*IPFailureTracker
|
||||
ipMutex sync.RWMutex
|
||||
|
||||
// Pattern detection
|
||||
ipFailures map[string]*IPFailureTracker
|
||||
patternDetector *SuspiciousPatternDetector
|
||||
|
||||
// Event handlers
|
||||
eventHandlers []SecurityEventHandler
|
||||
|
||||
// Configuration
|
||||
config SecurityMonitorConfig
|
||||
|
||||
// Logger
|
||||
logger *Logger
|
||||
logger *Logger
|
||||
eventHandlers []SecurityEventHandler
|
||||
config SecurityMonitorConfig
|
||||
ipMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// IPFailureTracker tracks failures for a specific IP address
|
||||
type IPFailureTracker struct {
|
||||
FailureCount int64
|
||||
LastFailure time.Time
|
||||
FirstFailure time.Time
|
||||
FailureTypes map[string]int64
|
||||
IsBlocked bool
|
||||
BlockedUntil time.Time
|
||||
FailureTypes map[string]int64
|
||||
FailureCount int64
|
||||
mutex sync.RWMutex
|
||||
IsBlocked bool
|
||||
}
|
||||
|
||||
// SuspiciousPatternDetector identifies patterns that may indicate attacks
|
||||
type SuspiciousPatternDetector struct {
|
||||
// Time-based windows for pattern detection
|
||||
shortWindow time.Duration // 1 minute
|
||||
mediumWindow time.Duration // 5 minutes
|
||||
longWindow time.Duration // 15 minutes
|
||||
|
||||
// Pattern thresholds
|
||||
rapidFailureThreshold int // failures in short window
|
||||
distributedAttackThreshold int // failures across IPs in medium window
|
||||
persistentAttackThreshold int // failures in long window
|
||||
|
||||
// Pattern tracking
|
||||
recentEvents []SecurityEvent
|
||||
eventsMutex sync.RWMutex
|
||||
recentEvents []SecurityEvent
|
||||
shortWindow time.Duration
|
||||
mediumWindow time.Duration
|
||||
longWindow time.Duration
|
||||
rapidFailureThreshold int
|
||||
distributedAttackThreshold int
|
||||
persistentAttackThreshold int
|
||||
eventsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// SecurityEventHandler defines the interface for handling security events
|
||||
@@ -82,22 +105,15 @@ type SecurityEventHandler interface {
|
||||
|
||||
// SecurityMonitorConfig contains configuration for the security monitor
|
||||
type SecurityMonitorConfig struct {
|
||||
// Failure thresholds
|
||||
MaxFailuresPerIP int `json:"max_failures_per_ip"`
|
||||
FailureWindowMinutes int `json:"failure_window_minutes"`
|
||||
BlockDurationMinutes int `json:"block_duration_minutes"`
|
||||
|
||||
// Pattern detection settings
|
||||
EnablePatternDetection bool `json:"enable_pattern_detection"`
|
||||
MaxFailuresPerIP int `json:"max_failures_per_ip"`
|
||||
FailureWindowMinutes int `json:"failure_window_minutes"`
|
||||
BlockDurationMinutes int `json:"block_duration_minutes"`
|
||||
RapidFailureThreshold int `json:"rapid_failure_threshold"`
|
||||
|
||||
// Monitoring settings
|
||||
EnableDetailedLogging bool `json:"enable_detailed_logging"`
|
||||
LogSuspiciousOnly bool `json:"log_suspicious_only"`
|
||||
|
||||
// Cleanup settings
|
||||
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
|
||||
RetentionHours int `json:"retention_hours"`
|
||||
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
|
||||
RetentionHours int `json:"retention_hours"`
|
||||
EnablePatternDetection bool `json:"enable_pattern_detection"`
|
||||
EnableDetailedLogging bool `json:"enable_detailed_logging"`
|
||||
LogSuspiciousOnly bool `json:"log_suspicious_only"`
|
||||
}
|
||||
|
||||
// DefaultSecurityMonitorConfig returns a default configuration
|
||||
@@ -115,6 +131,9 @@ func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupTask holds the BackgroundTask for security cleanup
|
||||
var cleanupTask *BackgroundTask
|
||||
|
||||
// NewSecurityMonitor creates a new security monitor instance
|
||||
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
|
||||
sm := &SecurityMonitor{
|
||||
@@ -126,7 +145,7 @@ func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityM
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go sm.startCleanupRoutine()
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
return sm
|
||||
}
|
||||
@@ -144,89 +163,106 @@ func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
|
||||
}
|
||||
}
|
||||
|
||||
// RecordAuthenticationFailure records an authentication failure event
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
|
||||
atomic.AddInt64(&sm.authFailures, 1)
|
||||
// RecordSecurityEvent is a generic method to record any type of security event
|
||||
func (sm *SecurityMonitor) RecordSecurityEvent(
|
||||
eventType SecurityEventType,
|
||||
clientIP, userAgent, requestPath string,
|
||||
message string,
|
||||
details map[string]any,
|
||||
trackIPFailure bool) {
|
||||
|
||||
// Create event with default values for the event type
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
Severity: "medium",
|
||||
Type: string(eventType),
|
||||
Severity: eventType.DefaultSeverity(),
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: fmt.Sprintf("Authentication failed: %s", reason),
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
sm.recordIPFailure(clientIP, "auth_failure")
|
||||
// Track IP failures if requested
|
||||
if trackIPFailure {
|
||||
sm.recordIPFailure(clientIP, eventType.IPFailureType())
|
||||
}
|
||||
|
||||
// Process the event
|
||||
sm.processSecurityEvent(event)
|
||||
}
|
||||
|
||||
// RecordAuthenticationFailure records an authentication failure event
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]any) {
|
||||
if details == nil {
|
||||
details = make(map[string]any)
|
||||
}
|
||||
details["reason"] = reason
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
AuthFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Authentication failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordTokenValidationFailure records a token validation failure
|
||||
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
|
||||
atomic.AddInt64(&sm.tokenValidationFails, 1)
|
||||
|
||||
details := map[string]interface{}{
|
||||
details := map[string]any{
|
||||
"reason": reason,
|
||||
}
|
||||
if tokenPrefix != "" {
|
||||
details["token_prefix"] = tokenPrefix
|
||||
}
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: "token_validation_failure",
|
||||
Severity: "medium",
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: fmt.Sprintf("Token validation failed: %s", reason),
|
||||
Details: details,
|
||||
}
|
||||
|
||||
sm.recordIPFailure(clientIP, "token_failure")
|
||||
sm.processSecurityEvent(event)
|
||||
sm.RecordSecurityEvent(
|
||||
TokenValidFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Token validation failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordRateLimitHit records when rate limiting is triggered
|
||||
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
|
||||
atomic.AddInt64(&sm.rateLimitHits, 1)
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: "rate_limit_hit",
|
||||
Severity: "low",
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: "Rate limit exceeded",
|
||||
Details: map[string]interface{}{
|
||||
"limit_type": "token_verification",
|
||||
},
|
||||
details := map[string]any{
|
||||
"limit_type": "token_verification",
|
||||
}
|
||||
|
||||
sm.recordIPFailure(clientIP, "rate_limit")
|
||||
sm.processSecurityEvent(event)
|
||||
sm.RecordSecurityEvent(
|
||||
RateLimitHit,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
"Rate limit exceeded",
|
||||
details,
|
||||
true, // Track IP failure for rate limiting
|
||||
)
|
||||
}
|
||||
|
||||
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
|
||||
atomic.AddInt64(&sm.suspiciousRequests, 1)
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: "suspicious_activity",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
|
||||
Details: details,
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]any) {
|
||||
if details == nil {
|
||||
details = make(map[string]any)
|
||||
}
|
||||
details["activity_type"] = activityType
|
||||
|
||||
sm.recordIPFailure(clientIP, "suspicious")
|
||||
sm.processSecurityEvent(event)
|
||||
sm.RecordSecurityEvent(
|
||||
SuspiciousActivity,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// recordIPFailure tracks failures for a specific IP address
|
||||
@@ -266,7 +302,7 @@ func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
|
||||
Details: map[string]interface{}{
|
||||
Details: map[string]any{
|
||||
"failure_count": tracker.FailureCount,
|
||||
"failure_types": tracker.FailureTypes,
|
||||
"blocked_until": tracker.BlockedUntil,
|
||||
@@ -319,7 +355,7 @@ func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
|
||||
Details: map[string]interface{}{
|
||||
Details: map[string]any{
|
||||
"pattern_type": pattern,
|
||||
"trigger_event": event,
|
||||
},
|
||||
@@ -351,30 +387,11 @@ func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
|
||||
sm.eventHandlers = append(sm.eventHandlers, handler)
|
||||
}
|
||||
|
||||
// GetSecurityMetrics returns current security metrics
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
|
||||
sm.ipMutex.RLock()
|
||||
defer sm.ipMutex.RUnlock()
|
||||
|
||||
blockedIPs := 0
|
||||
totalTrackedIPs := len(sm.ipFailures)
|
||||
|
||||
for _, tracker := range sm.ipFailures {
|
||||
tracker.mutex.RLock()
|
||||
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
|
||||
blockedIPs++
|
||||
}
|
||||
tracker.mutex.RUnlock()
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"auth_failures": atomic.LoadInt64(&sm.authFailures),
|
||||
"token_validation_fails": atomic.LoadInt64(&sm.tokenValidationFails),
|
||||
"rate_limit_hits": atomic.LoadInt64(&sm.rateLimitHits),
|
||||
"suspicious_requests": atomic.LoadInt64(&sm.suspiciousRequests),
|
||||
"blocked_ips": blockedIPs,
|
||||
"tracked_ips": totalTrackedIPs,
|
||||
"uptime_hours": time.Since(time.Now().Add(-24 * time.Hour)).Hours(), // Placeholder
|
||||
// GetSecurityMetrics returns minimal security metrics
|
||||
// This is kept for API compatibility but doesn't collect actual metrics
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]any {
|
||||
return map[string]any{
|
||||
"tracked_ips": 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -456,11 +473,20 @@ func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
|
||||
|
||||
// startCleanupRoutine starts the background cleanup routine
|
||||
func (sm *SecurityMonitor) startCleanupRoutine() {
|
||||
ticker := time.NewTicker(time.Duration(sm.config.CleanupIntervalMinutes) * time.Minute)
|
||||
defer ticker.Stop()
|
||||
// Use BackgroundTask abstraction for consistent management
|
||||
cleanupTask = NewBackgroundTask(
|
||||
"security-monitor-cleanup",
|
||||
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
|
||||
sm.cleanup,
|
||||
sm.logger)
|
||||
cleanupTask.Start()
|
||||
}
|
||||
|
||||
for range ticker.C {
|
||||
sm.cleanup()
|
||||
// StopCleanupRoutine stops the background cleanup routine
|
||||
func (sm *SecurityMonitor) StopCleanupRoutine() {
|
||||
if cleanupTask != nil {
|
||||
cleanupTask.Stop()
|
||||
cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -537,36 +563,4 @@ func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsSecurityEventHandler tracks security metrics
|
||||
type MetricsSecurityEventHandler struct {
|
||||
eventCounts map[string]int64
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMetricsSecurityEventHandler creates a new metrics event handler
|
||||
func NewMetricsSecurityEventHandler() *MetricsSecurityEventHandler {
|
||||
return &MetricsSecurityEventHandler{
|
||||
eventCounts: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSecurityEvent implements SecurityEventHandler
|
||||
func (h *MetricsSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.eventCounts[event.Type]++
|
||||
h.eventCounts[fmt.Sprintf("%s_%s", event.Type, event.Severity)]++
|
||||
}
|
||||
|
||||
// GetMetrics returns the current metrics
|
||||
func (h *MetricsSecurityEventHandler) GetMetrics() map[string]int64 {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
metrics := make(map[string]int64)
|
||||
for k, v := range h.eventCounts {
|
||||
metrics[k] = v
|
||||
}
|
||||
return metrics
|
||||
}
|
||||
// Note: MetricsSecurityEventHandler has been removed as part of metrics cleanup
|
||||
|
||||
+14
-77
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -42,42 +43,19 @@ func TestSecurityMonitor(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Token validation failure", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123")
|
||||
|
||||
metrics := monitor.GetSecurityMetrics()
|
||||
if metrics["token_validation_fails"].(int64) == 0 {
|
||||
t.Error("Expected token validation failures to be recorded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Rate limit hit", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api")
|
||||
|
||||
metrics := monitor.GetSecurityMetrics()
|
||||
if metrics["rate_limit_hits"].(int64) == 0 {
|
||||
t.Error("Expected rate limit hits to be recorded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Suspicious activity", func(t *testing.T) {
|
||||
details := map[string]interface{}{"pattern": "unusual"}
|
||||
details := map[string]any{"pattern": "unusual"}
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
|
||||
|
||||
metrics := monitor.GetSecurityMetrics()
|
||||
if metrics["suspicious_requests"].(int64) == 0 {
|
||||
t.Error("Expected suspicious activities to be recorded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get security metrics", func(t *testing.T) {
|
||||
metrics := monitor.GetSecurityMetrics()
|
||||
|
||||
if metrics["auth_failures"].(int64) == 0 {
|
||||
t.Error("Expected some authentication failures")
|
||||
}
|
||||
if metrics["blocked_ips"] == nil {
|
||||
t.Error("Expected blocked IPs count to be present")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -86,7 +64,7 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
t.Run("Add events and detect patterns", func(t *testing.T) {
|
||||
// Add multiple events from same IP
|
||||
for i := 0; i < 10; i++ {
|
||||
for range 10 {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.100",
|
||||
@@ -97,13 +75,7 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := false
|
||||
for _, pattern := range patterns {
|
||||
if pattern == "rapid_failures_from_ip_192.168.1.100" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(patterns, "rapid_failures_from_ip_192.168.1.100")
|
||||
if !found {
|
||||
t.Error("Expected to detect rapid failure pattern")
|
||||
}
|
||||
@@ -111,7 +83,7 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
t.Run("Detect distributed attack pattern", func(t *testing.T) {
|
||||
// Add failures from many different IPs
|
||||
for i := 0; i < 25; i++ {
|
||||
for i := range 25 {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1." + strconv.Itoa(100+i),
|
||||
@@ -122,13 +94,7 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := false
|
||||
for _, pattern := range patterns {
|
||||
if pattern == "distributed_attack_pattern" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
found := slices.Contains(patterns, "distributed_attack_pattern")
|
||||
if !found {
|
||||
t.Error("Expected to detect distributed attack pattern")
|
||||
}
|
||||
@@ -204,24 +170,7 @@ func TestSecurityEventHandlers(t *testing.T) {
|
||||
handler.HandleSecurityEvent(event)
|
||||
})
|
||||
|
||||
t.Run("Metrics security event handler", func(t *testing.T) {
|
||||
handler := NewMetricsSecurityEventHandler()
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.1",
|
||||
Timestamp: time.Now(),
|
||||
Message: "Test failure",
|
||||
Severity: "medium",
|
||||
}
|
||||
|
||||
handler.HandleSecurityEvent(event)
|
||||
|
||||
metrics := handler.GetMetrics()
|
||||
if metrics["authentication_failure"] != 1 {
|
||||
t.Errorf("Expected 1 authentication failure, got %v", metrics["authentication_failure"])
|
||||
}
|
||||
})
|
||||
// Metrics security event handler test removed as part of metrics cleanup
|
||||
}
|
||||
|
||||
func TestSecurityMonitorEventHandlers(t *testing.T) {
|
||||
@@ -312,26 +261,14 @@ func TestSecurityEventTypes(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Test different event types
|
||||
// Test different event types - just verify they don't panic
|
||||
monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil)
|
||||
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
|
||||
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
|
||||
|
||||
details := map[string]interface{}{"pattern": "test"}
|
||||
details := map[string]any{"pattern": "test"}
|
||||
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
|
||||
|
||||
metrics := monitor.GetSecurityMetrics()
|
||||
|
||||
if metrics["auth_failures"].(int64) == 0 {
|
||||
t.Error("Expected authentication failures to be recorded")
|
||||
}
|
||||
if metrics["token_validation_fails"].(int64) == 0 {
|
||||
t.Error("Expected token validation failures to be recorded")
|
||||
}
|
||||
if metrics["rate_limit_hits"].(int64) == 0 {
|
||||
t.Error("Expected rate limit hits to be recorded")
|
||||
}
|
||||
if metrics["suspicious_requests"].(int64) == 0 {
|
||||
t.Error("Expected suspicious activities to be recorded")
|
||||
}
|
||||
// Just verify GetSecurityMetrics doesn't panic
|
||||
_ = monitor.GetSecurityMetrics()
|
||||
}
|
||||
|
||||
+1266
-232
File diff suppressed because it is too large
Load Diff
+473
-3
@@ -1,12 +1,16 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func TestSessionPoolMemoryLeak(t *testing.T) {
|
||||
@@ -144,7 +148,7 @@ func getPooledObjects(sm *SessionManager) int {
|
||||
var objects []*SessionData
|
||||
maxAttempts := 100 // Safety limit to prevent infinite loops
|
||||
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
for range maxAttempts {
|
||||
obj := sm.sessionPool.Get()
|
||||
if obj == nil {
|
||||
break
|
||||
@@ -191,7 +195,7 @@ func TestSessionObjectTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create and discard 5 sessions
|
||||
for i := 0; i < 5; i++ {
|
||||
for range 5 {
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
@@ -218,4 +222,470 @@ func TestSessionObjectTracking(t *testing.T) {
|
||||
t.Log("Session pool handling verified")
|
||||
}
|
||||
|
||||
// This is intentionally left empty to remove unused code
|
||||
// TestTokenCompressionIntegrity tests that token compression and decompression maintains JWT integrity
|
||||
func TestTokenCompressionIntegrity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantFail bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT - Small",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.signature",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT - Large",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + strings.Repeat("eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9", 100) + ".signature",
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - Wrong dot count",
|
||||
token: "invalid.token",
|
||||
wantFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - No dots",
|
||||
token: "invalidtoken",
|
||||
wantFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - Too many dots",
|
||||
token: "part1.part2.part3.part4",
|
||||
wantFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty token",
|
||||
token: "",
|
||||
wantFail: false, // Empty tokens are handled gracefully
|
||||
},
|
||||
{
|
||||
name: "Oversized token (>50KB)",
|
||||
token: "part1." + strings.Repeat("A", 51*1024) + ".part3",
|
||||
wantFail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
compressed := compressToken(tt.token)
|
||||
|
||||
if tt.wantFail {
|
||||
// For invalid tokens, compression should return original
|
||||
if compressed != tt.token {
|
||||
t.Errorf("Expected compression to return original for invalid token, got different result")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For valid tokens, test round-trip integrity
|
||||
decompressed := decompressToken(compressed)
|
||||
if decompressed != tt.token {
|
||||
t.Errorf("Token integrity lost: original=%q, compressed=%q, decompressed=%q",
|
||||
tt.token, compressed, decompressed)
|
||||
}
|
||||
|
||||
// Test that decompression is idempotent
|
||||
decompressed2 := decompressToken(decompressed)
|
||||
if decompressed2 != tt.token {
|
||||
t.Errorf("Decompression not idempotent: %q != %q", decompressed2, tt.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenCompressionCorruptionDetection tests that gzip corruption is detected and handled
|
||||
func TestTokenCompressionCorruptionDetection(t *testing.T) {
|
||||
validJWT := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.signature"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
corruptedInput string
|
||||
expectOriginal bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid base64",
|
||||
corruptedInput: "!@#$%^&*()",
|
||||
expectOriginal: true,
|
||||
},
|
||||
{
|
||||
name: "Valid base64 but invalid gzip",
|
||||
corruptedInput: base64.StdEncoding.EncodeToString([]byte("not gzip data")),
|
||||
expectOriginal: true,
|
||||
},
|
||||
{
|
||||
name: "Truncated gzip data",
|
||||
corruptedInput: "H4sI", // Incomplete gzip header
|
||||
expectOriginal: true,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
corruptedInput: "",
|
||||
expectOriginal: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := decompressToken(tt.corruptedInput)
|
||||
if tt.expectOriginal && result != tt.corruptedInput {
|
||||
t.Errorf("Expected decompression to return original corrupted input, got: %q", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test that valid compression still works
|
||||
compressed := compressToken(validJWT)
|
||||
decompressed := decompressToken(compressed)
|
||||
if decompressed != validJWT {
|
||||
t.Errorf("Valid compression/decompression failed: %q != %q", decompressed, validJWT)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenChunkingIntegrity tests that large tokens are properly chunked and reassembled
|
||||
func TestTokenChunkingIntegrity(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create tokens of various sizes to test chunking
|
||||
testTokens := NewTestTokens()
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenSize int
|
||||
expectChunked bool
|
||||
}{
|
||||
{
|
||||
name: "Small token (no chunking)",
|
||||
tokenSize: 100,
|
||||
expectChunked: false,
|
||||
},
|
||||
{
|
||||
name: "Medium token (no chunking)",
|
||||
tokenSize: 800, // FIXED: Reduced further to account for new conservative chunk size (1200 bytes)
|
||||
expectChunked: false,
|
||||
},
|
||||
{
|
||||
name: "Large token (chunking required)",
|
||||
tokenSize: 5000,
|
||||
expectChunked: true,
|
||||
},
|
||||
{
|
||||
name: "Very large token (multiple chunks)",
|
||||
tokenSize: 10000,
|
||||
expectChunked: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// FIXED: Use incompressible tokens to ensure chunking occurs
|
||||
var token string
|
||||
if tt.expectChunked {
|
||||
token = testTokens.CreateIncompressibleToken(tt.tokenSize)
|
||||
} else {
|
||||
token = testTokens.CreateLargeValidJWT(tt.tokenSize)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Store the token
|
||||
session.SetAccessToken(token)
|
||||
|
||||
// Retrieve the token
|
||||
retrievedToken := session.GetAccessToken()
|
||||
|
||||
// Verify integrity
|
||||
if retrievedToken != token {
|
||||
t.Errorf("Token integrity lost:\nOriginal: %q\nRetrieved: %q", token, retrievedToken)
|
||||
}
|
||||
|
||||
// Check if chunking occurred as expected
|
||||
hasChunks := len(session.accessTokenChunks) > 0
|
||||
if tt.expectChunked != hasChunks {
|
||||
t.Errorf("Chunking expectation mismatch: expected chunked=%v, has chunks=%v", tt.expectChunked, hasChunks)
|
||||
}
|
||||
|
||||
session.ReturnToPool()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenChunkingCorruptionResistance tests handling of corrupted chunks
|
||||
func TestTokenChunkingCorruptionResistance(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a large token that will be chunked
|
||||
largeToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
|
||||
base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, `{"sub":"test","data":"%s"}`, strings.Repeat("A", 5000))) +
|
||||
".signature"
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Store the token (this should create chunks)
|
||||
session.SetAccessToken(largeToken)
|
||||
if len(session.accessTokenChunks) == 0 {
|
||||
t.Skip("Token was not chunked, skipping corruption test")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
corruptChunk func(chunks map[int]*sessions.Session)
|
||||
name string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Missing chunk in sequence",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Remove a middle chunk
|
||||
if len(chunks) > 1 {
|
||||
delete(chunks, 1)
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Empty chunk data",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Set first chunk to empty
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = ""
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Wrong data type in chunk",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Set chunk data to wrong type
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = 123 // Should be string
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Oversized chunk",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Set chunk to oversized data
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = strings.Repeat("A", maxCookieSize+200)
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Get a fresh session
|
||||
freshSession, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get fresh session: %v", err)
|
||||
}
|
||||
|
||||
// Store the token again
|
||||
freshSession.SetAccessToken(largeToken)
|
||||
|
||||
// Apply corruption
|
||||
tt.corruptChunk(freshSession.accessTokenChunks)
|
||||
|
||||
// Try to retrieve the token
|
||||
retrievedToken := freshSession.GetAccessToken()
|
||||
|
||||
if tt.expectEmpty {
|
||||
if retrievedToken != "" {
|
||||
t.Errorf("Expected empty token due to corruption, got: %q", retrievedToken)
|
||||
}
|
||||
} else {
|
||||
if retrievedToken != largeToken {
|
||||
t.Errorf("Expected original token despite corruption, got: %q", retrievedToken)
|
||||
}
|
||||
}
|
||||
|
||||
freshSession.ReturnToPool()
|
||||
})
|
||||
}
|
||||
|
||||
session.ReturnToPool()
|
||||
}
|
||||
|
||||
// TestTokenSizeLimits tests that token size limits are enforced
|
||||
func TestTokenSizeLimits(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
testTokens := NewTestTokens()
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenSize int
|
||||
expectStored bool
|
||||
}{
|
||||
{
|
||||
name: "Normal size token",
|
||||
tokenSize: 1000,
|
||||
expectStored: true,
|
||||
},
|
||||
{
|
||||
name: "Large but acceptable token",
|
||||
tokenSize: 30000, // FIXED: 30KB to ensure final size < 100KB limit
|
||||
expectStored: true,
|
||||
},
|
||||
{
|
||||
name: "Oversized token (>100KB)",
|
||||
tokenSize: 120000, // FIXED: 120KB to ensure rejection after compression
|
||||
expectStored: false, // Should be rejected
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// FIXED: Use proper token generation that accounts for base64 encoding
|
||||
var token string
|
||||
if tt.expectStored {
|
||||
token = testTokens.CreateLargeValidJWT(tt.tokenSize)
|
||||
} else {
|
||||
token = testTokens.CreateIncompressibleToken(tt.tokenSize)
|
||||
}
|
||||
|
||||
// Store the token
|
||||
session.SetAccessToken(token)
|
||||
|
||||
// Try to retrieve it
|
||||
retrievedToken := session.GetAccessToken()
|
||||
|
||||
if tt.expectStored {
|
||||
if retrievedToken != token {
|
||||
t.Errorf("Expected token to be stored and retrieved, but got different token")
|
||||
}
|
||||
} else {
|
||||
if retrievedToken == token {
|
||||
t.Errorf("Expected oversized token to be rejected, but it was stored")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentTokenOperations tests thread safety of token operations
|
||||
func TestConcurrentTokenOperations(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
const numGoroutines = 10
|
||||
const numOperations = 100
|
||||
|
||||
// Test concurrent access and refresh token operations
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := range numOperations {
|
||||
// Create unique tokens for each goroutine/operation
|
||||
accessToken := ValidAccessToken
|
||||
refreshToken := fmt.Sprintf("refresh_token_%d_%d", id, j)
|
||||
|
||||
// Concurrent operations
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetRefreshToken(refreshToken)
|
||||
|
||||
retrievedAccess := session.GetAccessToken()
|
||||
retrievedRefresh := session.GetRefreshToken()
|
||||
|
||||
// Verify tokens are still valid (should be one of the tokens set by any goroutine)
|
||||
if retrievedAccess != "" && strings.Count(retrievedAccess, ".") != 2 {
|
||||
t.Errorf("Retrieved access token has invalid format: %q", retrievedAccess)
|
||||
}
|
||||
if retrievedRefresh != "" && len(retrievedRefresh) < 10 {
|
||||
t.Errorf("Retrieved refresh token is too short: %q", retrievedRefresh)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for range numGoroutines {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionValidationAndCleanup tests session validation and orphan cleanup
|
||||
func TestSessionValidationAndCleanup(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set tokens that will create chunks
|
||||
largeToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
|
||||
base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat(`{"data":"large"}`, 500))) +
|
||||
".signature"
|
||||
|
||||
session.SetAccessToken(largeToken)
|
||||
session.SetRefreshToken("refresh_token_test")
|
||||
|
||||
// Save session to create cookies
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify chunks were created
|
||||
if len(session.accessTokenChunks) == 0 {
|
||||
t.Log("No chunks created, large token test may not be applicable")
|
||||
}
|
||||
|
||||
// Test cleanup by clearing session
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.Logf("Clear returned error (may be expected): %v", err)
|
||||
}
|
||||
|
||||
// Verify tokens are cleared
|
||||
if token := session.GetAccessToken(); token != "" {
|
||||
t.Errorf("Access token should be empty after clear, got: %q", token)
|
||||
}
|
||||
if token := session.GetRefreshToken(); token != "" {
|
||||
t.Errorf("Refresh token should be empty after clear, got: %q", token)
|
||||
}
|
||||
}
|
||||
|
||||
+40
-96
@@ -26,96 +26,28 @@ type TemplatedHeader struct {
|
||||
// It provides all necessary settings to configure OpenID Connect authentication
|
||||
// with various providers like Auth0, Logto, or any standard OIDC provider.
|
||||
type Config struct {
|
||||
// ProviderURL is the base URL of the OIDC provider (required)
|
||||
// Example: https://accounts.google.com
|
||||
ProviderURL string `json:"providerURL"`
|
||||
|
||||
// RevocationURL is the endpoint for revoking tokens (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
|
||||
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
|
||||
// This enhances security but might not be supported by all OIDC providers
|
||||
// Default: false
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
|
||||
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
||||
// Example: /oauth2/callback
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
|
||||
// LogoutURL is the path for handling logout requests (optional)
|
||||
// If not provided, it will be set to CallbackURL + "/logout"
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
|
||||
// ClientID is the OAuth 2.0 client identifier (required)
|
||||
ClientID string `json:"clientID"`
|
||||
|
||||
// ClientSecret is the OAuth 2.0 client secret (required)
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
|
||||
// Scopes defines the OAuth 2.0 scopes to request (optional)
|
||||
// Defaults to ["openid", "profile", "email"] if not provided
|
||||
Scopes []string `json:"scopes"`
|
||||
|
||||
// LogLevel sets the logging verbosity (optional)
|
||||
// Valid values: "debug", "info", "error"
|
||||
// Default: "info"
|
||||
LogLevel string `json:"logLevel"`
|
||||
|
||||
// SessionEncryptionKey is used to encrypt session data (required)
|
||||
// Must be a secure random string
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
|
||||
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
|
||||
// Default: false
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
|
||||
// RateLimit sets the maximum number of requests per second (optional)
|
||||
// Default: 100
|
||||
RateLimit int `json:"rateLimit"`
|
||||
|
||||
// ExcludedURLs lists paths that bypass authentication (optional)
|
||||
// Example: ["/health", "/metrics"]
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
|
||||
// AllowedUserDomains restricts access to specific email domains (optional)
|
||||
// Example: ["company.com", "subsidiary.com"]
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
|
||||
// AllowedUsers restricts access to specific email addresses (optional)
|
||||
// Example: ["user1@example.com", "user2@example.com"]
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
|
||||
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
|
||||
// Example: ["admin", "developer"]
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
|
||||
// OIDCEndSessionURL is the provider's end session endpoint (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
|
||||
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
|
||||
// Default: "/"
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
|
||||
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
|
||||
HTTPClient *http.Client
|
||||
|
||||
// RefreshGracePeriodSeconds defines how many seconds before a token expires
|
||||
// the plugin should attempt to refresh it proactively (optional)
|
||||
// Default: 60
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
// Headers defines custom HTTP headers to set with templated values (optional)
|
||||
// Values can reference tokens and claims using Go templates with the following variables:
|
||||
// - {{.AccessToken}} - The access token (ID token)
|
||||
// - {{.IdToken}} - Same as AccessToken (for consistency)
|
||||
// - {{.RefreshToken}} - The refresh token
|
||||
// - {{.Claims.email}} - Access token claims (use proper case for claim names)
|
||||
// Examples:
|
||||
//
|
||||
// [{Name: "X-Forwarded-Email", Value: "{{.Claims.email}}"}]
|
||||
// [{Name: "Authorization", Value: "Bearer {{.AccessToken}}"}]
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
HTTPClient *http.Client
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -156,6 +88,7 @@ func CreateConfig() *Config {
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
OverrideScopes: false, // Default to appending scopes, not overriding
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
}
|
||||
|
||||
@@ -481,7 +414,7 @@ func NewLogger(logLevel string) *Logger {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -491,7 +424,7 @@ func (l *Logger) Info(format string, args ...interface{}) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -501,7 +434,7 @@ func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -512,7 +445,7 @@ func (l *Logger) Error(format string, args ...interface{}) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
func (l *Logger) Infof(format string, args ...any) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -523,7 +456,7 @@ func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
func (l *Logger) Debugf(format string, args ...any) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -534,10 +467,21 @@ func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
func (l *Logger) Errorf(format string, args ...any) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// newNoOpLogger creates a silent logger that doesn't output anything.
|
||||
// This is useful for internal components that need a logger instance
|
||||
// but should not produce any output by default.
|
||||
func newNoOpLogger() *Logger {
|
||||
return &Logger{
|
||||
logError: log.New(io.Discard, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
}
|
||||
|
||||
// handleError logs an error message using the provided logger and sends an HTTP error
|
||||
// response to the client with the specified message and status code.
|
||||
//
|
||||
|
||||
+32
-10
@@ -7,6 +7,19 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Helper function to compare string slices
|
||||
func equalSlices(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestCreateConfig(t *testing.T) {
|
||||
t.Run("Default Values", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
@@ -36,27 +49,36 @@ func TestCreateConfig(t *testing.T) {
|
||||
if !config.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true by default")
|
||||
}
|
||||
|
||||
// Check OverrideScopes default
|
||||
if config.OverrideScopes {
|
||||
t.Error("Expected OverrideScopes to be false by default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Custom Values Preserved", func(t *testing.T) {
|
||||
t.Run("Config Can Hold Custom Values", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.Scopes = []string{"custom_scope"}
|
||||
config.LogLevel = "debug"
|
||||
config.RateLimit = 50
|
||||
config.ForceHTTPS = false
|
||||
config.OverrideScopes = true
|
||||
|
||||
// Verify custom values are not overwritten
|
||||
// Verify config struct can hold custom values
|
||||
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
|
||||
t.Error("Custom scopes were overwritten")
|
||||
t.Error("Config struct cannot hold custom scopes")
|
||||
}
|
||||
if config.LogLevel != "debug" {
|
||||
t.Error("Custom log level was overwritten")
|
||||
t.Error("Config struct cannot hold custom log level")
|
||||
}
|
||||
if config.RateLimit != 50 {
|
||||
t.Error("Custom rate limit was overwritten")
|
||||
t.Error("Config struct cannot hold custom rate limit")
|
||||
}
|
||||
if config.ForceHTTPS {
|
||||
t.Error("Custom ForceHTTPS value was overwritten")
|
||||
t.Error("Config struct cannot hold custom ForceHTTPS value")
|
||||
}
|
||||
if !config.OverrideScopes {
|
||||
t.Error("Config struct cannot hold custom OverrideScopes value")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -241,10 +263,10 @@ func TestLogger(t *testing.T) {
|
||||
var debugBuf, infoBuf, errorBuf bytes.Buffer
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logLevel string
|
||||
testFunc func(*Logger)
|
||||
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
|
||||
name string
|
||||
logLevel string
|
||||
}{
|
||||
{
|
||||
name: "Debug Level",
|
||||
@@ -392,9 +414,9 @@ func TestHandleError(t *testing.T) {
|
||||
|
||||
// Test helper types
|
||||
type testResponseRecorder struct {
|
||||
statusCode int
|
||||
body string
|
||||
headers map[string][]string
|
||||
body string
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Header() http.Header {
|
||||
|
||||
@@ -11,15 +11,15 @@ func TestTemplateExecution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
data map[string]any
|
||||
expectedValue string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "String Claim",
|
||||
templateText: "{{.Claims.email}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
@@ -29,8 +29,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Number Claim",
|
||||
templateText: "{{.Claims.age}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"age": 30,
|
||||
},
|
||||
},
|
||||
@@ -40,8 +40,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Boolean Claim",
|
||||
templateText: "{{.Claims.admin}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
@@ -51,8 +51,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Array Claim",
|
||||
templateText: "{{index .Claims.roles 0}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"roles": []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
@@ -62,9 +62,9 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Nested Object Claim",
|
||||
templateText: "{{.Claims.user.name}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"user": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"user": map[string]any{
|
||||
"name": "John Doe",
|
||||
},
|
||||
},
|
||||
@@ -75,7 +75,7 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Access Token",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"AccessToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
},
|
||||
expectedValue: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
@@ -84,7 +84,7 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "ID Token",
|
||||
templateText: "{{.IdToken}}",
|
||||
data: map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"IdToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
},
|
||||
expectedValue: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
@@ -93,7 +93,7 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Refresh Token",
|
||||
templateText: "{{.RefreshToken}}",
|
||||
data: map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"RefreshToken": "refresh-token-value",
|
||||
},
|
||||
expectedValue: "refresh-token-value",
|
||||
@@ -102,8 +102,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Conditional Template",
|
||||
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
@@ -113,8 +113,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Multiple Claims",
|
||||
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"email": "john.doe@example.com",
|
||||
@@ -126,8 +126,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Missing Claim",
|
||||
templateText: "{{.Claims.missing}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{},
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{},
|
||||
},
|
||||
expectedValue: "<no value>",
|
||||
expectError: false, // Go templates don't error on missing values
|
||||
@@ -135,14 +135,67 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Invalid Template Syntax",
|
||||
templateText: "{{.Claims.email",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectError: true, // Parsing should fail
|
||||
},
|
||||
{
|
||||
name: "Custom Claims",
|
||||
templateText: "Role: {{.Claims.role}}, Department: {{.Claims.department}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"department": "engineering",
|
||||
},
|
||||
},
|
||||
expectedValue: "Role: admin, Department: engineering",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nested Custom Claims",
|
||||
templateText: "Org: {{.Claims.metadata.organization}}, Team: {{.Claims.metadata.team}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"metadata": map[string]any{
|
||||
"organization": "company-name",
|
||||
"team": "platform",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "Org: company-name, Team: platform",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Email Claims",
|
||||
templateText: "Email: {{.Claims.email}}, Verified: {{.Claims.email_verified}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "Email: user@example.com, Verified: true",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "User Identity Claims",
|
||||
templateText: "Name: {{.Claims.name}}, Subject: {{.Claims.sub}}, Username: {{.Claims.preferred_username}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"name": "John Doe",
|
||||
"sub": "user123",
|
||||
"preferred_username": "johndoe",
|
||||
},
|
||||
},
|
||||
expectedValue: "Name: John Doe, Subject: user123, Username: johndoe",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -176,46 +229,203 @@ func TestTemplateExecution(t *testing.T) {
|
||||
|
||||
// TestTemplateExecutionContext tests the specific template data context used in processAuthorizedRequest
|
||||
func TestTemplateExecutionContext(t *testing.T) {
|
||||
// Define a test struct that matches the one used in processAuthorizedRequest
|
||||
// Test cases for map-based template data, matching the new implementation
|
||||
mapTests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]any
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction with map",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{},
|
||||
"RefreshToken": "refresh-token-value",
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims with map",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token",
|
||||
"IdToken": "id-token",
|
||||
"Claims": map[string]any{
|
||||
"sub": "user123",
|
||||
},
|
||||
"RefreshToken": "refresh-token",
|
||||
},
|
||||
expectedValue: "User: user123 Token: access-token",
|
||||
},
|
||||
{
|
||||
name: "Authorization header with Bearer token",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "jwt-access-token",
|
||||
"IdToken": "id-token",
|
||||
"Claims": map[string]any{},
|
||||
},
|
||||
expectedValue: "Bearer jwt-access-token",
|
||||
},
|
||||
{
|
||||
name: "Boolean template data with AccessToken",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": true, // Test boolean values to ensure they render correctly
|
||||
},
|
||||
expectedValue: "Bearer true",
|
||||
},
|
||||
{
|
||||
name: "Custom non-standard claims in ID token",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"permissions": "read:all,write:own",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
|
||||
},
|
||||
{
|
||||
name: "Deeply nested custom claims",
|
||||
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"Claims": map[string]any{
|
||||
"app_metadata": map[string]any{
|
||||
"organization": map[string]any{
|
||||
"name": "acme-corp",
|
||||
"id": "org-123",
|
||||
},
|
||||
"team": "platform",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Organization: acme-corp, X-Team: platform",
|
||||
},
|
||||
{
|
||||
name: "Email in claims",
|
||||
templateText: "X-User-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Email: user@example.com, X-Email-Verified: true",
|
||||
},
|
||||
{
|
||||
name: "User info from claims",
|
||||
templateText: "X-User-ID: {{.Claims.sub}}, X-User-Name: {{.Claims.name}}, X-Username: {{.Claims.preferred_username}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"sub": "user123456",
|
||||
"name": "Jane Doe",
|
||||
"preferred_username": "jane.doe",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-ID: user123456, X-User-Name: Jane Doe, X-Username: jane.doe",
|
||||
},
|
||||
}
|
||||
|
||||
// Run map-based tests (matching the new implementation)
|
||||
for _, tc := range mapTests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// For backward compatibility, also test the original struct-based implementation
|
||||
type templateData struct {
|
||||
Claims map[string]any
|
||||
AccessToken string
|
||||
IdToken string
|
||||
RefreshToken string
|
||||
Claims map[string]interface{}
|
||||
}
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
// Test cases for struct-based template data (original implementation)
|
||||
structTests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data templateData
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction",
|
||||
name: "Access and ID token distinction with struct",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token-value",
|
||||
IdToken: "id-token-value", // Now these should be distinct values
|
||||
Claims: map[string]interface{}{},
|
||||
Claims: map[string]any{},
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims",
|
||||
name: "Combining tokens and claims with struct",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "access-token",
|
||||
Claims: map[string]interface{}{
|
||||
Claims: map[string]any{
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
expectedValue: "User: user123 Token: access-token",
|
||||
},
|
||||
{
|
||||
name: "Custom claims with struct",
|
||||
templateText: "X-Custom: {{.Claims.custom_field}}, X-Group: {{.Claims.group}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "id-token",
|
||||
Claims: map[string]any{
|
||||
"sub": "user123",
|
||||
"custom_field": "custom-value",
|
||||
"group": "admins",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Custom: custom-value, X-Group: admins",
|
||||
},
|
||||
{
|
||||
name: "Email claim in struct context",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "id-token",
|
||||
Claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "John Smith",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Email: user@example.com, X-Name: John Smith",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
for _, tc := range structTests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
@@ -235,3 +445,162 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegressionBooleanAccessToken specifically tests the regression case where
|
||||
// a boolean value was causing "can't evaluate field AccessToken in type bool" error
|
||||
func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
// Test the specific case where we execute a template referencing AccessToken
|
||||
// using a boolean context value
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
dataContext any
|
||||
expectedValue string
|
||||
expectError bool // Added to skip the test that demonstrates the error
|
||||
}{
|
||||
{
|
||||
name: "Map with boolean as root",
|
||||
templateText: "{{.AccessToken}}",
|
||||
dataContext: map[string]any{"AccessToken": "token-value"},
|
||||
expectedValue: "token-value",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Boolean as root context",
|
||||
templateText: "{{.AccessToken}}",
|
||||
dataContext: true,
|
||||
expectedValue: "<no value>",
|
||||
expectError: true, // Skip this test as it demonstrates the error we're fixing
|
||||
},
|
||||
{
|
||||
name: "Bearer with map context",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
dataContext: map[string]any{"AccessToken": "token-value"},
|
||||
expectedValue: "Bearer token-value",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Complex nesting with authorization",
|
||||
templateText: "Authorization: Bearer {{.AccessToken}}",
|
||||
dataContext: map[string]any{
|
||||
"AccessToken": "jwt-token-123",
|
||||
"something": true,
|
||||
"anotherField": map[string]any{
|
||||
"nested": "value",
|
||||
},
|
||||
},
|
||||
expectedValue: "Authorization: Bearer jwt-token-123",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Custom claims access",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Groups: {{.Claims.groups}}",
|
||||
dataContext: map[string]any{
|
||||
"AccessToken": "jwt-token-xyz",
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"groups": "group1,group2,group3",
|
||||
"custom_data": map[string]any{
|
||||
"organization": "company-name",
|
||||
"department": "engineering",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Role: admin, X-User-Groups: group1,group2,group3",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nested custom claims access",
|
||||
templateText: "X-Organization: {{.Claims.custom_data.organization}}, X-Department: {{.Claims.custom_data.department}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"custom_data": map[string]any{
|
||||
"organization": "company-name",
|
||||
"department": "engineering",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Organization: company-name, X-Department: engineering",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Azure AD specific claims",
|
||||
templateText: "X-TenantID: {{.Claims.tid}}, X-Roles: {{.Claims.roles}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"tid": "tenant-id-12345",
|
||||
"roles": "User,Admin,Developer",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-TenantID: tenant-id-12345, X-Roles: User,Admin,Developer",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Auth0 specific claims",
|
||||
templateText: "X-Permissions: {{.Claims.permissions}}, X-AppMetadata: {{.Claims.app_metadata.plan}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"permissions": "read:products,write:orders",
|
||||
"app_metadata": map[string]any{
|
||||
"plan": "premium",
|
||||
"status": "active",
|
||||
"trial_ended": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Permissions: read:products,write:orders, X-AppMetadata: premium",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Standard claims with email",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}, X-Subject: {{.Claims.sub}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "John Doe",
|
||||
"sub": "auth0|12345",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Email: user@example.com, X-Name: John Doe, X-Subject: auth0|12345",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Verified email claim",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Email: user@example.com, X-Email-Verified: true",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
// Skip tests that demonstrate the error
|
||||
if tc.expectError {
|
||||
t.Skip("Skipping test that demonstrates the error we're fixing")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.dataContext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -19,19 +20,19 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
sessionSetup func(*SessionData)
|
||||
claims map[string]interface{}
|
||||
claims map[string]any
|
||||
expectedHeaders map[string]string
|
||||
interceptedHeaders map[string]string
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
}{
|
||||
{
|
||||
name: "Basic Email Header",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -45,7 +46,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
{Name: "X-User-Name", Value: "{{.Claims.given_name}} {{.Claims.family_name}}"},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"sub": "user123",
|
||||
"given_name": "John",
|
||||
@@ -94,7 +95,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Role", Value: "{{.Claims.role}}"},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
// role claim is missing
|
||||
},
|
||||
@@ -107,7 +108,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Admin", Value: "{{if .Claims.is_admin}}true{{else}}false{{end}}"},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
claims: map[string]any{
|
||||
"email": "admin@example.com",
|
||||
"is_admin": true,
|
||||
},
|
||||
@@ -120,7 +121,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Auth-Info", Value: "User={{.Claims.email}}, Token={{.AccessToken}}"},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -133,7 +134,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-AccessToken", Value: "{{.AccessToken}}"},
|
||||
},
|
||||
claims: map[string]interface{}{ // For ID Token
|
||||
claims: map[string]any{ // For ID Token
|
||||
"email": "opaque_user@example.com",
|
||||
"sub": "opaque_sub_for_id_token",
|
||||
},
|
||||
@@ -149,7 +150,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
token := ts.token
|
||||
if len(tc.claims) > 0 {
|
||||
var err error
|
||||
baseClaims := map[string]interface{}{
|
||||
baseClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000), // Far future timestamp
|
||||
@@ -161,9 +162,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add the test-specific claims
|
||||
for k, v := range tc.claims {
|
||||
baseClaims[k] = v
|
||||
}
|
||||
maps.Copy(baseClaims, tc.claims)
|
||||
|
||||
token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", baseClaims)
|
||||
if err != nil {
|
||||
@@ -267,7 +266,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
if tc.name == "ID Token Header" || tc.name == "Both Token Types" {
|
||||
idTokenClaims := map[string]interface{}{
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject",
|
||||
"nonce": "test-nonce", "jti": generateRandomString(16), "type": "id_token",
|
||||
@@ -285,7 +284,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", idErr)
|
||||
}
|
||||
|
||||
accessTokenClaims := map[string]interface{}{
|
||||
accessTokenClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject",
|
||||
"jti": generateRandomString(16), "type": "access_token", "scope": "openid email profile",
|
||||
@@ -316,15 +315,13 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
tc.expectedHeaders["X-Access-Token"] = accessTokenForSession
|
||||
}
|
||||
} else if tc.name == "Opaque Access Token with AccessTokenField" {
|
||||
idTokenClaims := map[string]interface{}{
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject", // Default sub
|
||||
"nonce": "test-nonce", "jti": generateRandomString(16), "type": "id_token",
|
||||
}
|
||||
// Populate ID token claims from tc.claims
|
||||
for k, v := range tc.claims {
|
||||
idTokenClaims[k] = v
|
||||
}
|
||||
maps.Copy(idTokenClaims, tc.claims)
|
||||
// Ensure email from tc.claims is used for the ID token
|
||||
session.SetEmail(tc.claims["email"].(string)) // Also set it directly for initial session state
|
||||
|
||||
@@ -426,9 +423,9 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
claims map[string]interface{}
|
||||
shouldExecuteCheck bool
|
||||
}{
|
||||
{
|
||||
@@ -447,8 +444,8 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Roles", Value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"roles": []interface{}{"admin", "user", "manager"},
|
||||
claims: map[string]any{
|
||||
"roles": []any{"admin", "user", "manager"},
|
||||
},
|
||||
shouldExecuteCheck: true,
|
||||
},
|
||||
@@ -457,7 +454,7 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with the test claims
|
||||
claims := map[string]interface{}{
|
||||
claims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000), // Far future timestamp
|
||||
@@ -469,9 +466,7 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add the test-specific claims
|
||||
for k, v := range tc.claims {
|
||||
claims[k] = v
|
||||
}
|
||||
maps.Copy(claims, tc.claims)
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
@@ -576,7 +571,7 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
// createLargeTemplate creates a template with many variable references
|
||||
func createLargeTemplate(size int) string {
|
||||
template := "{{with .Claims}}"
|
||||
for i := 0; i < size; i++ {
|
||||
for i := range size {
|
||||
if i > 0 {
|
||||
template += ","
|
||||
}
|
||||
@@ -587,9 +582,10 @@ func createLargeTemplate(size int) string {
|
||||
}
|
||||
|
||||
// createLargeClaims creates a map with many claims for testing large templates
|
||||
func createLargeClaims(size int) map[string]interface{} {
|
||||
claims := make(map[string]interface{})
|
||||
for i := 0; i < size; i++ {
|
||||
func createLargeClaims(size int) map[string]any {
|
||||
claims := make(map[string]any)
|
||||
for i := range size {
|
||||
claims["email"] = "largeclaimsuser@example.com" // Add email claim
|
||||
key := "field" + string(rune('a'+i%26)) + string(rune('0'+i%10))
|
||||
claims[key] = "value" + string(rune('a'+i%26)) + string(rune('0'+i%10))
|
||||
}
|
||||
|
||||
+391
@@ -0,0 +1,391 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestTokens provides a comprehensive set of standardized test tokens
|
||||
// for consistent testing across the entire codebase.
|
||||
type TestTokens struct{}
|
||||
|
||||
// NewTestTokens creates a new TestTokens instance
|
||||
func NewTestTokens() *TestTokens {
|
||||
return &TestTokens{}
|
||||
}
|
||||
|
||||
// Valid JWT tokens for testing
|
||||
const (
|
||||
// ValidAccessToken - A properly formatted JWT access token for testing
|
||||
ValidAccessToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiJlNDcxN2RhZDBmZjAyOTNkIiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.bmwp-vk0B7Ir9UiUkzib8L7yJbebJ00o3U9QrB6gP2H9-RfqyCbN8M9Rkx7Rb8Vdh3YzqkBBoLS_G0i414rs2I9uABnTC4E6-63qkGdUrLB7p-XbjcRW2RoIBwXHk7lfumi8eX0uWzBsJ9CY0__UECVsex5XORfBb4Bcqj0LK4y-glxkpI51I7BPySfciWC_PkdaQ1Qe5pCAlxeNs2E9NMGXp-Ox6vAufUzoC2cws1LswGPPP6icQ-Zlzd5WMCIWhdIkN4yTxk8FMqsTC52k2zskRHNSSd4DDVETonfzawZNqDcMpnTyN53sCJ9UHiQTl9mCm61ttYW-W9Gc-ze4Xw"
|
||||
|
||||
// ValidIDToken - A properly formatted JWT ID token for testing
|
||||
ValidIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiI2YzBjZTZmMTM4Y2EzMzc2IiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.RBQYejA9vP4lnh2EhFqWerePWaCyDTF0ZE1jlU2xm4g2wWVeaEHpv5SNg92_gwk633N9xx7ugS0UrlEu4qbT7wSb1HBDR00q_andyYnyFk4OoxPpD0AqHkVr-pjS-Z7UCGF3sLgQ4ECmU9695PIys3XvgUGMzEn_mK-PHcpY5AnbBGFsbj7epUld_sb6WfjjjwAa8kKfKObPvaIpuJ4TlxI1Uf0wYOoIA0zh5ipeAn-i8Ud-GErxis1Hp8UQK7IRolXpToiXnFcnf3vI3eCS7Yu3oPl7LRxTxKMCI9h0MCwu25ZNsOg2C9ohyebpU0jbURX9Q74GNOaphv-Lz9rCRA"
|
||||
|
||||
// ValidRefreshToken - A properly formatted refresh token for testing
|
||||
ValidRefreshToken = "valid-refresh-token-12345"
|
||||
|
||||
// MinimalValidJWT - The shortest valid JWT for testing
|
||||
MinimalValidJWT = "h.p.s"
|
||||
|
||||
// ValidRefreshTokenGoogle - A Google-style refresh token for testing
|
||||
ValidRefreshTokenGoogle = "google_refresh_token_12345"
|
||||
)
|
||||
|
||||
// Invalid tokens for testing validation
|
||||
const (
|
||||
// InvalidTokenNoDots - Token with no dots (invalid JWT format)
|
||||
InvalidTokenNoDots = "notajwttoken"
|
||||
|
||||
// InvalidTokenOneDot - Token with one dot (invalid JWT format)
|
||||
InvalidTokenOneDot = "header.payload"
|
||||
|
||||
// InvalidTokenThreeDots - Token with three dots (invalid JWT format)
|
||||
InvalidTokenThreeDots = "header.payload.signature.extra"
|
||||
|
||||
// EmptyToken - Empty token
|
||||
EmptyToken = ""
|
||||
|
||||
// CorruptedBase64Token - Token with invalid base64 data for chunking tests
|
||||
CorruptedBase64Token = "corrupted_base64_!@#$"
|
||||
)
|
||||
|
||||
// CreateLargeValidJWT creates a JWT of approximately the specified size
|
||||
// This replaces the ad-hoc createLargeValidJWT function in tests
|
||||
func (tt *TestTokens) CreateLargeValidJWT(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "signature_" + tt.generateRandomString(32)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 50)
|
||||
|
||||
// Create a payload with realistic JWT claims
|
||||
claims := map[string]any{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": 9999999999,
|
||||
"iat": 1000000000,
|
||||
}
|
||||
|
||||
// FIXED: Calculate data size safely
|
||||
dataSize := max(
|
||||
// Account for other claims and base64 encoding
|
||||
payloadSize-100,
|
||||
// Minimum data size
|
||||
10)
|
||||
|
||||
claims["data"] = tt.generateRandomString(dataSize)
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// CreateLargeRefreshToken creates a refresh token of approximately the specified size
|
||||
func (tt *TestTokens) CreateLargeRefreshToken(targetSize int) string {
|
||||
baseToken := "refresh_token_"
|
||||
padding := tt.generateRandomString(targetSize - len(baseToken))
|
||||
return baseToken + padding
|
||||
}
|
||||
|
||||
// CreateExpiredJWT creates an expired JWT token for testing
|
||||
func (tt *TestTokens) CreateExpiredJWT() string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
// Create claims with expired timestamp
|
||||
claims := map[string]any{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": time.Now().Unix() - 3600, // Expired 1 hour ago
|
||||
"iat": time.Now().Unix() - 7200, // Issued 2 hours ago
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := "expired_signature"
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// CreateUniqueValidJWT creates a unique valid JWT for concurrent testing
|
||||
func (tt *TestTokens) CreateUniqueValidJWT(id string) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
claims := map[string]any{
|
||||
"sub": "user_" + id,
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": 9999999999,
|
||||
"iat": 1000000000,
|
||||
"jti": id,
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := "sig_" + id
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// CreateIncompressibleToken creates a token that cannot be compressed effectively
|
||||
// This is useful for testing chunking scenarios where compression doesn't help
|
||||
func (tt *TestTokens) CreateIncompressibleToken(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "incompressible_signature_" + tt.generateRandomString(32)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 100)
|
||||
|
||||
// Generate multiple random fields to prevent compression
|
||||
randomFields := make(map[string]any)
|
||||
randomFields["sub"] = "user123"
|
||||
randomFields["iss"] = "https://example.com"
|
||||
randomFields["aud"] = "client123"
|
||||
randomFields["exp"] = 9999999999
|
||||
randomFields["iat"] = 1000000000
|
||||
|
||||
// Add many random fields with random data to prevent compression
|
||||
remainingSize := payloadSize - 200 // Account for base64 encoding and other fields
|
||||
fieldCount := max(
|
||||
// ~100 bytes per field
|
||||
remainingSize/100, 1)
|
||||
|
||||
for i := range fieldCount {
|
||||
// Generate truly random data for each field
|
||||
randomBytes := make([]byte, 50)
|
||||
rand.Read(randomBytes)
|
||||
fieldName := fmt.Sprintf("random_field_%d_%s", i, tt.generateRandomString(8))
|
||||
randomFields[fieldName] = base64.StdEncoding.EncodeToString(randomBytes)
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(randomFields)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
token := fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
|
||||
// If still too small, pad with more random data
|
||||
if len(token) < targetSize {
|
||||
padding := targetSize - len(token)
|
||||
extraRandomBytes := make([]byte, padding/2)
|
||||
rand.Read(extraRandomBytes)
|
||||
randomFields["padding"] = base64.StdEncoding.EncodeToString(extraRandomBytes)
|
||||
claimsJSON, _ = json.Marshal(randomFields)
|
||||
payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
token = fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// GetValidTokenSet returns a complete set of valid tokens for testing
|
||||
func (tt *TestTokens) GetValidTokenSet() TokenSet {
|
||||
return TokenSet{
|
||||
AccessToken: ValidAccessToken,
|
||||
IDToken: ValidIDToken,
|
||||
RefreshToken: ValidRefreshToken,
|
||||
}
|
||||
}
|
||||
|
||||
// GetGoogleTokenSet returns tokens that simulate Google OIDC provider responses
|
||||
func (tt *TestTokens) GetGoogleTokenSet() TokenSet {
|
||||
return TokenSet{
|
||||
AccessToken: ValidAccessToken,
|
||||
IDToken: ValidIDToken,
|
||||
RefreshToken: ValidRefreshTokenGoogle,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLargeTokenSet returns a set of large tokens for chunking tests
|
||||
func (tt *TestTokens) GetLargeTokenSet() TokenSet {
|
||||
return TokenSet{
|
||||
AccessToken: tt.CreateLargeValidJWT(5000),
|
||||
IDToken: tt.CreateLargeValidJWT(2000),
|
||||
RefreshToken: tt.CreateLargeRefreshToken(3000),
|
||||
}
|
||||
}
|
||||
|
||||
// GetInvalidTokens returns various invalid tokens for validation testing
|
||||
func (tt *TestTokens) GetInvalidTokens() InvalidTokenSet {
|
||||
return InvalidTokenSet{
|
||||
NoDots: InvalidTokenNoDots,
|
||||
OneDot: InvalidTokenOneDot,
|
||||
ThreeDots: InvalidTokenThreeDots,
|
||||
Empty: EmptyToken,
|
||||
Corrupted: CorruptedBase64Token,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomString creates a random string of the specified length
|
||||
func (tt *TestTokens) generateRandomString(length int) string {
|
||||
// FIXED: Handle negative or zero lengths safely
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
randomByte := make([]byte, 1)
|
||||
rand.Read(randomByte)
|
||||
b[i] = charset[int(randomByte[0])%len(charset)]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// TokenSet represents a complete set of tokens for testing
|
||||
type TokenSet struct {
|
||||
AccessToken string
|
||||
IDToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// InvalidTokenSet represents various invalid tokens for validation testing
|
||||
type InvalidTokenSet struct {
|
||||
NoDots string // Token with 0 dots
|
||||
OneDot string // Token with 1 dot
|
||||
ThreeDots string // Token with 3 dots
|
||||
Empty string // Empty token
|
||||
Corrupted string // Corrupted/invalid characters
|
||||
}
|
||||
|
||||
// TestScenarios provides predefined test scenarios
|
||||
type TestScenarios struct {
|
||||
tokens *TestTokens
|
||||
}
|
||||
|
||||
// NewTestScenarios creates a new TestScenarios instance
|
||||
func NewTestScenarios() *TestScenarios {
|
||||
return &TestScenarios{
|
||||
tokens: NewTestTokens(),
|
||||
}
|
||||
}
|
||||
|
||||
// NormalFlow returns tokens for normal authentication flow testing
|
||||
func (ts *TestScenarios) NormalFlow() TokenSet {
|
||||
return ts.tokens.GetValidTokenSet()
|
||||
}
|
||||
|
||||
// GoogleFlow returns tokens simulating Google OIDC provider
|
||||
func (ts *TestScenarios) GoogleFlow() TokenSet {
|
||||
return ts.tokens.GetGoogleTokenSet()
|
||||
}
|
||||
|
||||
// ChunkingRequired returns large tokens that require chunking
|
||||
func (ts *TestScenarios) ChunkingRequired() TokenSet {
|
||||
return ts.tokens.GetLargeTokenSet()
|
||||
}
|
||||
|
||||
// CorruptionTest returns tokens and corruption scenarios for testing
|
||||
func (ts *TestScenarios) CorruptionTest() CorruptionTestSet {
|
||||
return CorruptionTestSet{
|
||||
ValidTokens: ts.tokens.GetValidTokenSet(),
|
||||
InvalidTokens: ts.tokens.GetInvalidTokens(),
|
||||
LargeTokens: ts.tokens.GetLargeTokenSet(),
|
||||
CorruptedToken: CorruptedBase64Token,
|
||||
}
|
||||
}
|
||||
|
||||
// ConcurrentTest returns unique tokens for concurrent testing
|
||||
func (ts *TestScenarios) ConcurrentTest(count int) []TokenSet {
|
||||
sets := make([]TokenSet, count)
|
||||
for i := range count {
|
||||
sets[i] = TokenSet{
|
||||
AccessToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("concurrent_%d", i)),
|
||||
IDToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("id_%d", i)),
|
||||
RefreshToken: fmt.Sprintf("refresh_concurrent_%d", i),
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
// CorruptionTestSet represents tokens and scenarios for corruption testing
|
||||
type CorruptionTestSet struct {
|
||||
ValidTokens TokenSet
|
||||
InvalidTokens InvalidTokenSet
|
||||
LargeTokens TokenSet
|
||||
CorruptedToken string
|
||||
}
|
||||
|
||||
// TokenValidationTestCases returns test cases for token validation
|
||||
func (tt *TestTokens) TokenValidationTestCases() []ValidationTestCase {
|
||||
return []ValidationTestCase{
|
||||
{
|
||||
Name: "Empty token",
|
||||
Token: EmptyToken,
|
||||
ExpectStored: true, // Empty tokens are allowed for clearing
|
||||
ExpectRetrieved: false, // But return as empty
|
||||
},
|
||||
{
|
||||
Name: "Single dot",
|
||||
Token: InvalidTokenOneDot,
|
||||
ExpectStored: false, // Invalid JWT format
|
||||
ExpectRetrieved: false,
|
||||
},
|
||||
{
|
||||
Name: "No dots",
|
||||
Token: InvalidTokenNoDots,
|
||||
ExpectStored: false, // Invalid JWT format
|
||||
ExpectRetrieved: false,
|
||||
},
|
||||
{
|
||||
Name: "Too many dots",
|
||||
Token: InvalidTokenThreeDots,
|
||||
ExpectStored: false, // Invalid JWT format
|
||||
ExpectRetrieved: false,
|
||||
},
|
||||
{
|
||||
Name: "Valid minimal JWT",
|
||||
Token: MinimalValidJWT,
|
||||
ExpectStored: true,
|
||||
ExpectRetrieved: true,
|
||||
},
|
||||
{
|
||||
Name: "Valid standard JWT",
|
||||
Token: ValidAccessToken,
|
||||
ExpectStored: true,
|
||||
ExpectRetrieved: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidationTestCase represents a single token validation test case
|
||||
type ValidationTestCase struct {
|
||||
Name string
|
||||
Token string
|
||||
ExpectStored bool
|
||||
ExpectRetrieved bool
|
||||
}
|
||||
|
||||
// Helper functions for common test patterns
|
||||
|
||||
// AssertValidTokenStorage verifies that a valid token can be stored and retrieved
|
||||
func AssertValidTokenStorage(t TestingInterface, session *SessionData, token string) {
|
||||
session.SetAccessToken(token)
|
||||
retrieved := session.GetAccessToken()
|
||||
if retrieved != token {
|
||||
t.Errorf("Token storage failed: expected %q, got %q", token, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertInvalidTokenRejection verifies that an invalid token is rejected
|
||||
func AssertInvalidTokenRejection(t TestingInterface, session *SessionData, token string) {
|
||||
original := session.GetAccessToken()
|
||||
session.SetAccessToken(token)
|
||||
after := session.GetAccessToken()
|
||||
if after != original {
|
||||
t.Errorf("Invalid token was not rejected: expected %q, got %q", original, after)
|
||||
}
|
||||
}
|
||||
|
||||
// TestingInterface provides the minimal interface needed for testing
|
||||
type TestingInterface interface {
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
@@ -0,0 +1,502 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// TestTokenCorruptionScenario reproduces the exact failure pattern from GitHub issue #53:
|
||||
// Token verified successfully multiple times, then fails with "signature verification failed"
|
||||
func TestTokenCorruptionScenario(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a valid JWT token
|
||||
validJWT := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImV4cCI6OTk5OTk5OTk5OX0.signature"
|
||||
|
||||
tests := []struct {
|
||||
corruptionScenario func(*SessionData)
|
||||
name string
|
||||
tokenSize int
|
||||
iterations int
|
||||
expectConsistent bool
|
||||
}{
|
||||
{
|
||||
name: "Small token - multiple retrievals",
|
||||
tokenSize: len(validJWT),
|
||||
iterations: 10,
|
||||
expectConsistent: true,
|
||||
},
|
||||
{
|
||||
name: "Large chunked token - multiple retrievals",
|
||||
tokenSize: 5000,
|
||||
iterations: 10,
|
||||
expectConsistent: true,
|
||||
},
|
||||
{
|
||||
name: "Compression corruption simulation",
|
||||
tokenSize: 2000,
|
||||
iterations: 5,
|
||||
expectConsistent: false, // Will be corrupted intentionally
|
||||
corruptionScenario: func(session *SessionData) {
|
||||
// Simulate corruption by directly modifying session values
|
||||
if session.accessSession != nil {
|
||||
// Simulate corrupted compressed data
|
||||
session.accessSession.Values["token"] = "corrupted_base64_!@#$"
|
||||
session.accessSession.Values["compressed"] = true
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Chunk reassembly corruption simulation",
|
||||
tokenSize: 25000, // Large enough to force chunking even after compression
|
||||
iterations: 5,
|
||||
expectConsistent: false, // Will be corrupted intentionally
|
||||
corruptionScenario: func(session *SessionData) {
|
||||
// Simulate chunk corruption with invalid base64 characters
|
||||
if len(session.accessTokenChunks) > 0 {
|
||||
if chunk, exists := session.accessTokenChunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = "invalid_base64_!@#$%"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Create token of specified size
|
||||
token := createTokenOfSize(validJWT, tt.tokenSize)
|
||||
|
||||
// 1. Store the token
|
||||
session.SetAccessToken(token)
|
||||
t.Logf("Stored token of size %d bytes", len(token))
|
||||
|
||||
// 2. Verify token can be retrieved multiple times successfully
|
||||
var retrievedTokens []string
|
||||
for i := 0; i < tt.iterations; i++ {
|
||||
retrieved := session.GetAccessToken()
|
||||
retrievedTokens = append(retrievedTokens, retrieved)
|
||||
|
||||
if tt.expectConsistent && retrieved != token {
|
||||
t.Errorf("Iteration %d: Token mismatch, expected consistency", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Apply corruption scenario if specified
|
||||
if tt.corruptionScenario != nil {
|
||||
tt.corruptionScenario(session)
|
||||
}
|
||||
|
||||
// 4. Retrieve token after potential corruption
|
||||
finalRetrieved := session.GetAccessToken()
|
||||
|
||||
if tt.expectConsistent {
|
||||
// With fixes, token should still be retrievable correctly
|
||||
if finalRetrieved != token {
|
||||
t.Errorf("Final retrieval failed - corruption not handled correctly")
|
||||
t.Logf("Expected: %q", token)
|
||||
t.Logf("Got: %q", finalRetrieved)
|
||||
}
|
||||
} else {
|
||||
// For corruption scenarios, expect empty string (graceful failure)
|
||||
if finalRetrieved != "" {
|
||||
t.Errorf("Expected corruption to result in empty token, got: %q", finalRetrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Verify all previous retrievals were consistent (if expected)
|
||||
if tt.expectConsistent {
|
||||
for i, retrieved := range retrievedTokens {
|
||||
if retrieved != token {
|
||||
t.Errorf("Iteration %d produced inconsistent result", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompressionIntegrityFailure tests scenarios where compression fails integrity checks
|
||||
func TestCompressionIntegrityFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectSame bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig",
|
||||
expectSame: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - wrong dots",
|
||||
token: "invalid.token",
|
||||
expectSame: true, // Should return unchanged
|
||||
},
|
||||
{
|
||||
name: "Oversized token",
|
||||
token: "header." + strings.Repeat("A", 60000) + ".sig",
|
||||
expectSame: true, // Should return unchanged due to size limit
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
compressed := compressToken(tt.token)
|
||||
|
||||
if tt.expectSame && compressed != tt.token {
|
||||
// If we expect the token to remain the same but it was compressed,
|
||||
// verify round-trip integrity
|
||||
decompressed := decompressToken(compressed)
|
||||
if decompressed != tt.token {
|
||||
t.Errorf("Compression integrity failed: original=%q, decompressed=%q", tt.token, decompressed)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChunkReassemblyEdgeCases tests edge cases in chunk reassembly that could cause corruption
|
||||
func TestChunkReassemblyEdgeCases(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Create a large token that will definitely be chunked
|
||||
largeToken := createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig", 8000)
|
||||
|
||||
// Store the token to create chunks
|
||||
session.SetAccessToken(largeToken)
|
||||
|
||||
if len(session.accessTokenChunks) == 0 {
|
||||
t.Skip("Token was not chunked, skipping reassembly tests")
|
||||
}
|
||||
|
||||
t.Logf("Token was split into %d chunks", len(session.accessTokenChunks))
|
||||
|
||||
// Test various corruption scenarios
|
||||
corruptionTests := []struct {
|
||||
corruption func(map[int]*sessions.Session)
|
||||
name string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Gap in chunk sequence",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
// Remove chunk 1 if it exists
|
||||
delete(chunks, 1)
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Chunk with nil value",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = nil
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Chunk with wrong type",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = 12345 // Should be string
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Empty chunk data",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = ""
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Excessive chunk count",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
// This test simulates having too many chunks (>50 limit)
|
||||
// We'll create a scenario by adding many fake chunks
|
||||
for i := range 60 {
|
||||
fakeSession := &sessions.Session{Values: make(map[any]any)}
|
||||
fakeSession.Values["token_chunk"] = "fake_chunk_data"
|
||||
chunks[i] = fakeSession
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, ct := range corruptionTests {
|
||||
t.Run(ct.name, func(t *testing.T) {
|
||||
// Get a fresh session for each test
|
||||
freshReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
freshSession, err := sm.GetSession(freshReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get fresh session: %v", err)
|
||||
}
|
||||
defer freshSession.ReturnToPool()
|
||||
|
||||
// Store the large token again
|
||||
freshSession.SetAccessToken(largeToken)
|
||||
|
||||
// Apply corruption
|
||||
ct.corruption(freshSession.accessTokenChunks)
|
||||
|
||||
// Try to retrieve the token
|
||||
retrieved := freshSession.GetAccessToken()
|
||||
|
||||
if ct.expectEmpty {
|
||||
if retrieved != "" {
|
||||
t.Errorf("Expected empty token due to corruption, got: %q", retrieved)
|
||||
}
|
||||
} else {
|
||||
if retrieved != largeToken {
|
||||
t.Errorf("Expected original token, got: %q", retrieved)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRaceConditionProtection tests that concurrent access doesn't cause corruption
|
||||
func TestRaceConditionProtection(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
const numGoroutines = 20
|
||||
const numOperations = 50
|
||||
|
||||
// Create tokens of different sizes
|
||||
tokens := []string{
|
||||
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig1",
|
||||
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig2", 3000),
|
||||
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig3", 6000),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := range numOperations {
|
||||
tokenIndex := (goroutineID + j) % len(tokens)
|
||||
expectedToken := tokens[tokenIndex]
|
||||
|
||||
// Set token
|
||||
session.SetAccessToken(expectedToken)
|
||||
|
||||
// Retrieve token
|
||||
retrieved := session.GetAccessToken()
|
||||
|
||||
// Verify it's a valid JWT (should have exactly 2 dots)
|
||||
if retrieved != "" && strings.Count(retrieved, ".") != 2 {
|
||||
errChan <- fmt.Errorf("goroutine %d, op %d: invalid JWT format in retrieved token: %q",
|
||||
goroutineID, j, retrieved)
|
||||
continue
|
||||
}
|
||||
|
||||
// The retrieved token should be one of the valid tokens we set
|
||||
// (due to concurrent access, it might not be the exact one we just set)
|
||||
isValidToken := slices.Contains(tokens, retrieved)
|
||||
|
||||
if retrieved != "" && !isValidToken {
|
||||
errChan <- fmt.Errorf("goroutine %d, op %d: retrieved unknown token: %q",
|
||||
goroutineID, j, retrieved)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errChan {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMemoryExhaustionProtection tests protection against memory exhaustion attacks
|
||||
func TestMemoryExhaustionProtection(t *testing.T) {
|
||||
tests := []struct {
|
||||
setupCorruption func() string
|
||||
name string
|
||||
expectRejection bool
|
||||
}{
|
||||
{
|
||||
name: "Extremely large compressed data",
|
||||
setupCorruption: func() string {
|
||||
return base64.StdEncoding.EncodeToString(bytes.Repeat([]byte("A"), 200*1024)) // 200KB
|
||||
},
|
||||
expectRejection: true,
|
||||
},
|
||||
{
|
||||
name: "Malformed gzip bomb attempt",
|
||||
setupCorruption: func() string {
|
||||
// Create data that looks like gzip but would decompress to huge size
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
gz.Write(bytes.Repeat([]byte("A"), 10*1024)) // 10KB that compresses well
|
||||
gz.Close()
|
||||
|
||||
compressed := buf.Bytes()
|
||||
// Modify to make it potentially dangerous
|
||||
return base64.StdEncoding.EncodeToString(compressed)
|
||||
},
|
||||
expectRejection: false, // Our decompression has size limits
|
||||
},
|
||||
{
|
||||
name: "Token with excessive chunk simulation",
|
||||
setupCorruption: func() string {
|
||||
// This will be tested in the session layer
|
||||
return strings.Repeat("chunk.", 100) + "final"
|
||||
},
|
||||
expectRejection: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
corruptedData := tt.setupCorruption()
|
||||
|
||||
result := decompressToken(corruptedData)
|
||||
|
||||
if tt.expectRejection {
|
||||
// Should return original corrupted data, not attempt decompression
|
||||
if result != corruptedData {
|
||||
t.Errorf("Expected rejection of dangerous data, but decompression was attempted")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no excessive memory was used (this test would catch OOM in practice)
|
||||
// The fact that we reach this point means memory limits were effective
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackwardCompatibility ensures that sessions created before the fixes still work
|
||||
func TestBackwardCompatibility(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Simulate old-style session data (without new validation fields)
|
||||
oldStyleToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.oldsig"
|
||||
|
||||
// Manually set token without going through new SetAccessToken validation
|
||||
session.accessSession.Values["token"] = oldStyleToken
|
||||
session.accessSession.Values["compressed"] = false
|
||||
|
||||
// Should still be retrievable
|
||||
retrieved := session.GetAccessToken()
|
||||
if retrieved != oldStyleToken {
|
||||
t.Errorf("Backward compatibility failed: expected %q, got %q", oldStyleToken, retrieved)
|
||||
}
|
||||
|
||||
// Test with simulated old compressed token
|
||||
oldCompressed := compressToken(oldStyleToken)
|
||||
session.accessSession.Values["token"] = oldCompressed
|
||||
session.accessSession.Values["compressed"] = true
|
||||
|
||||
retrieved2 := session.GetAccessToken()
|
||||
if retrieved2 != oldStyleToken {
|
||||
t.Errorf("Backward compatibility with compression failed: expected %q, got %q", oldStyleToken, retrieved2)
|
||||
}
|
||||
}
|
||||
|
||||
// createTokenOfSize creates a JWT token of approximately the specified size
|
||||
func createTokenOfSize(baseToken string, targetSize int) string {
|
||||
parts := strings.Split(baseToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return baseToken
|
||||
}
|
||||
|
||||
header, payload, signature := parts[0], parts[1], parts[2]
|
||||
currentSize := len(baseToken)
|
||||
|
||||
if currentSize >= targetSize {
|
||||
return baseToken
|
||||
}
|
||||
|
||||
// Expand the payload to reach target size
|
||||
paddingNeeded := targetSize - len(header) - len(signature) - 2 // Account for dots
|
||||
if paddingNeeded > 0 {
|
||||
// Decode current payload, add padding, re-encode
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// If we can't decode, just pad with random base64-safe characters to resist compression
|
||||
randomBytes := make([]byte, paddingNeeded)
|
||||
rand.Read(randomBytes)
|
||||
// Encode as base64 to make it base64-safe
|
||||
padData := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||||
payload = payload + padData
|
||||
} else {
|
||||
// Add padding to the JSON - use random data to resist compression
|
||||
randomBytes := make([]byte, paddingNeeded/2)
|
||||
rand.Read(randomBytes)
|
||||
// Encode as base64 to make it JSON-safe
|
||||
padData := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
newPayload := fmt.Sprintf(`{"original":%s,"padding":"%s"}`, string(decoded), padData)
|
||||
payload = base64.RawURLEncoding.EncodeToString([]byte(newPayload))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
+368
-8
@@ -2,8 +2,12 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
"time"
|
||||
@@ -15,17 +19,17 @@ import (
|
||||
func TestTokenTypeDistinction(t *testing.T) {
|
||||
// Define test data where AccessToken and IdToken are deliberately different
|
||||
type templateData struct {
|
||||
Claims map[string]any
|
||||
AccessToken string
|
||||
IdToken string
|
||||
RefreshToken string
|
||||
Claims map[string]interface{}
|
||||
}
|
||||
|
||||
testData := templateData{
|
||||
AccessToken: "test-access-token-abc123",
|
||||
IdToken: "test-id-token-xyz789",
|
||||
RefreshToken: "test-refresh-token",
|
||||
Claims: map[string]interface{}{
|
||||
Claims: map[string]any{
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
@@ -87,7 +91,7 @@ func TestTokenTypeIntegration(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create different tokens for ID and access tokens
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
@@ -103,7 +107,7 @@ func TestTokenTypeIntegration(t *testing.T) {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", err)
|
||||
}
|
||||
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
@@ -257,10 +261,10 @@ func TestSessionIDTokenAccessToken(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set test tokens
|
||||
idToken := "test-id-token-123"
|
||||
accessToken := "test-access-token-456"
|
||||
refreshToken := "test-refresh-token-789"
|
||||
// Set test tokens using standardized tokens
|
||||
idToken := ValidIDToken
|
||||
accessToken := ValidAccessToken
|
||||
refreshToken := ValidRefreshToken
|
||||
|
||||
// Store tokens in session
|
||||
session.SetIDToken(idToken)
|
||||
@@ -309,3 +313,359 @@ func TestSessionIDTokenAccessToken(t *testing.T) {
|
||||
t.Errorf("ID token and Access token should be different, but both are %q", retrievedIDToken)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenCorruptionIntegrationFlows tests the complete token handling flow with corruption scenarios
|
||||
func TestTokenCorruptionIntegrationFlows(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
corruptAction func(*SessionData)
|
||||
name string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
idToken string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "Normal flow - small tokens",
|
||||
accessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.access_sig",
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "Normal flow - large tokens (chunked)",
|
||||
accessToken: createLargeValidJWT(5000),
|
||||
refreshToken: createLargeRefreshToken(3000),
|
||||
idToken: createLargeValidJWT(2000),
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "Corrupted access token compression",
|
||||
accessToken: createLargeValidJWT(3000),
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
expectSuccess: false,
|
||||
corruptAction: func(session *SessionData) {
|
||||
// Corrupt compressed access token
|
||||
if session.accessSession != nil {
|
||||
session.accessSession.Values["token"] = "corrupted_compressed_data_!@#"
|
||||
session.accessSession.Values["compressed"] = true
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Corrupted chunk in large token",
|
||||
accessToken: createLargeValidJWT(8000), // Force chunking
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
expectSuccess: false,
|
||||
corruptAction: func(session *SessionData) {
|
||||
// Corrupt first chunk
|
||||
if len(session.accessTokenChunks) > 0 {
|
||||
if chunk, exists := session.accessTokenChunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = "corrupted_chunk_data"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get session
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Store tokens
|
||||
session.SetAccessToken(tt.accessToken)
|
||||
session.SetRefreshToken(tt.refreshToken)
|
||||
session.SetIDToken(tt.idToken)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Apply corruption if specified
|
||||
if tt.corruptAction != nil {
|
||||
tt.corruptAction(session)
|
||||
}
|
||||
|
||||
// Test token retrieval after corruption
|
||||
retrievedAccess := session.GetAccessToken()
|
||||
retrievedRefresh := session.GetRefreshToken()
|
||||
retrievedID := session.GetIDToken()
|
||||
|
||||
if tt.expectSuccess {
|
||||
if retrievedAccess != tt.accessToken {
|
||||
t.Errorf("Access token corruption: expected %q, got %q", tt.accessToken, retrievedAccess)
|
||||
}
|
||||
if retrievedRefresh != tt.refreshToken {
|
||||
t.Errorf("Refresh token corruption: expected %q, got %q", tt.refreshToken, retrievedRefresh)
|
||||
}
|
||||
if retrievedID != tt.idToken {
|
||||
t.Errorf("ID token corruption: expected %q, got %q", tt.idToken, retrievedID)
|
||||
}
|
||||
} else {
|
||||
// For corruption scenarios, access token should be empty (graceful failure)
|
||||
if retrievedAccess != "" {
|
||||
t.Errorf("Expected corrupted access token to return empty, got: %q", retrievedAccess)
|
||||
}
|
||||
// Other tokens should still work
|
||||
if retrievedRefresh != tt.refreshToken {
|
||||
t.Errorf("Refresh token should not be affected by access token corruption: expected %q, got %q",
|
||||
tt.refreshToken, retrievedRefresh)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPersistenceWithCorruption tests that session corruption is handled across requests
|
||||
func TestSessionPersistenceWithCorruption(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// First request - store tokens
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
rr1 := httptest.NewRecorder()
|
||||
|
||||
session1, err := sm.GetSession(req1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
largeToken := createLargeValidJWT(6000)
|
||||
session1.SetAccessToken(largeToken)
|
||||
session1.SetAuthenticated(true)
|
||||
|
||||
if err := session1.Save(req1, rr1); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies from first response
|
||||
cookies := rr1.Result().Cookies()
|
||||
session1.ReturnToPool()
|
||||
|
||||
// Second request - retrieve tokens with cookies
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
session2, err := sm.GetSession(req2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session from cookies: %v", err)
|
||||
}
|
||||
defer session2.ReturnToPool()
|
||||
|
||||
// Verify token can be retrieved
|
||||
retrieved := session2.GetAccessToken()
|
||||
if retrieved != largeToken {
|
||||
t.Errorf("Token persistence failed: expected %q, got %q", largeToken, retrieved)
|
||||
}
|
||||
|
||||
// Simulate corruption by modifying chunks
|
||||
if len(session2.accessTokenChunks) > 0 {
|
||||
// Corrupt a middle chunk
|
||||
chunkIndex := len(session2.accessTokenChunks) / 2
|
||||
if chunk, exists := session2.accessTokenChunks[chunkIndex]; exists {
|
||||
chunk.Values["token_chunk"] = "corrupted"
|
||||
}
|
||||
|
||||
// Try to retrieve again - should detect corruption and return empty
|
||||
retrievedAfterCorruption := session2.GetAccessToken()
|
||||
if retrievedAfterCorruption != "" {
|
||||
t.Errorf("Expected corruption to be detected, but got token: %q", retrievedAfterCorruption)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentTokenOperationsWithCorruption tests concurrent access with intentional corruption
|
||||
func TestConcurrentTokenOperationsWithCorruption(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
const numGoroutines = 10
|
||||
const numOperations = 20
|
||||
|
||||
done := make(chan bool, numGoroutines)
|
||||
errorChan := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Start concurrent operations
|
||||
for i := range numGoroutines {
|
||||
go func(goroutineID int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := range numOperations {
|
||||
// Create a unique valid token for each operation
|
||||
token := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwib3AiOiIxMjMifQ.sig_%d_%d",
|
||||
goroutineID, j)
|
||||
|
||||
// Store token
|
||||
session.SetAccessToken(token)
|
||||
|
||||
// Retrieve token
|
||||
retrieved := session.GetAccessToken()
|
||||
|
||||
// Validate retrieved token format
|
||||
if retrieved != "" {
|
||||
if strings.Count(retrieved, ".") != 2 {
|
||||
errorChan <- fmt.Errorf("goroutine %d, op %d: invalid JWT format: %q",
|
||||
goroutineID, j, retrieved)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a reasonable length
|
||||
if len(retrieved) < 10 || len(retrieved) > 100000 {
|
||||
errorChan <- fmt.Errorf("goroutine %d, op %d: suspicious token length %d: %q",
|
||||
goroutineID, j, len(retrieved), retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// Occasionally simulate corruption to test error handling
|
||||
if j%5 == 0 && len(session.accessTokenChunks) > 0 {
|
||||
// Intentionally corrupt a random chunk
|
||||
for chunkID, chunk := range session.accessTokenChunks {
|
||||
if chunkID%2 == 0 {
|
||||
chunk.Values["token_chunk"] = "intentionally_corrupted"
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for range numGoroutines {
|
||||
<-done
|
||||
}
|
||||
close(errorChan)
|
||||
|
||||
// Check for any unexpected errors
|
||||
errorCount := 0
|
||||
for err := range errorChan {
|
||||
t.Logf("Concurrent operation error: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
// We expect some corruption-related "errors" due to intentional corruption,
|
||||
// but not format-related errors which would indicate actual corruption bugs
|
||||
if errorCount > numGoroutines*numOperations/4 { // Allow up to 25% corruption-related issues
|
||||
t.Errorf("Too many errors during concurrent operations: %d", errorCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenValidationEdgeCases tests edge cases in token validation
|
||||
func TestTokenValidationEdgeCases(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Use standardized test tokens
|
||||
testTokens := NewTestTokens()
|
||||
edgeCases := testTokens.TokenValidationTestCases()
|
||||
|
||||
for _, ec := range edgeCases {
|
||||
t.Run(ec.Name, func(t *testing.T) {
|
||||
// Clear any previous token
|
||||
session.SetAccessToken("")
|
||||
|
||||
// Store the test token
|
||||
originalToken := session.GetAccessToken()
|
||||
session.SetAccessToken(ec.Token)
|
||||
afterStoreToken := session.GetAccessToken()
|
||||
|
||||
if ec.ExpectStored {
|
||||
if afterStoreToken != ec.Token {
|
||||
t.Errorf("Expected token to be stored, but got different value")
|
||||
}
|
||||
} else {
|
||||
if afterStoreToken != originalToken {
|
||||
t.Errorf("Expected invalid token to be rejected, but it was stored")
|
||||
}
|
||||
}
|
||||
|
||||
// Test retrieval
|
||||
finalToken := session.GetAccessToken()
|
||||
if ec.ExpectRetrieved {
|
||||
if finalToken != ec.Token {
|
||||
t.Errorf("Expected token to be retrievable: %q, got: %q", ec.Token, finalToken)
|
||||
}
|
||||
} else {
|
||||
if finalToken != "" {
|
||||
t.Errorf("Expected empty token due to invalid format, got: %q", finalToken)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for test data creation
|
||||
|
||||
// createLargeValidJWT creates a JWT of approximately the specified size
|
||||
func createLargeValidJWT(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "signature_" + generateRandomString(32)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 50)
|
||||
|
||||
// Create a payload with realistic JWT claims
|
||||
claims := map[string]any{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": 9999999999,
|
||||
"iat": 1000000000,
|
||||
"data": generateRandomString(payloadSize - 100), // Account for other claims
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// createLargeRefreshToken creates a refresh token of approximately the specified size
|
||||
func createLargeRefreshToken(targetSize int) string {
|
||||
baseToken := "refresh_token_"
|
||||
padding := generateRandomString(targetSize - len(baseToken))
|
||||
return baseToken + padding
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user