mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 46745f5b54 | |||
| a54ae71279 | |||
| ae2a2877e9 | |||
| c2a81bc2df | |||
| dbe3455f49 | |||
| 0dfc252c95 | |||
| 71574090bf | |||
| de91edb514 | |||
| 667b4213fe |
@@ -0,0 +1,2 @@
|
||||
docker/
|
||||
.claude/
|
||||
@@ -76,6 +76,99 @@ testData:
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
|
||||
|
||||
# --- Provider Specific Configuration Examples ---
|
||||
#
|
||||
# Below are example configurations tailored for specific OIDC providers.
|
||||
# Uncomment and adapt the relevant section for your provider.
|
||||
# Remember to replace placeholder values (like client IDs, secrets, domains)
|
||||
# with your actual credentials and settings.
|
||||
#
|
||||
# For all providers, ensure claims like email, roles, and groups are
|
||||
# configured to be included in the ID TOKEN. This plugin validates ID tokens.
|
||||
|
||||
# --- Keycloak Example ---
|
||||
# testDataKeycloak:
|
||||
# providerURL: https://your-keycloak-domain/realms/your-realm # e.g., http://localhost:8080/realms/master
|
||||
# clientID: your-keycloak-client-id
|
||||
# clientSecret: your-keycloak-client-secret # Store securely, e.g., urn:k8s:secret:namespace:secret-name:key
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-keycloak"
|
||||
# scopes: # Default ["openid", "profile", "email"] are usually sufficient. Add others if mappers depend on them.
|
||||
# - roles # Example: if you mapped Keycloak roles to a 'roles' claim in the ID token
|
||||
# - groups # Example: if you mapped Keycloak groups to a 'groups' claim in the ID token
|
||||
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
|
||||
# - admin
|
||||
# - editor
|
||||
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
|
||||
# # See README.md "Provider Configuration Recommendations" for Keycloak.
|
||||
|
||||
# --- Azure AD (Microsoft Entra ID) Example ---
|
||||
# testDataAzureAD:
|
||||
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0 # Replace your-tenant-id
|
||||
# clientID: your-azure-ad-client-id
|
||||
# clientSecret: your-azure-ad-client-secret # Store securely
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
|
||||
# scopes: # Defaults ["openid", "profile", "email"] are good.
|
||||
# # Azure AD may require specific scopes for certain graph API permissions if you were to use the access token,
|
||||
# # but for ID token claims, defaults are often enough.
|
||||
# # Group claims need to be configured in Azure AD App Registration -> Token Configuration -> Add groups claim.
|
||||
# allowedUserDomains:
|
||||
# - yourcompany.com
|
||||
# allowedRolesAndGroups: # If you configured group claims (typically 'groups') or app roles in Azure AD
|
||||
# - "group-object-id-1" # Azure AD group claims can be Object IDs by default
|
||||
# - "AppRoleName"
|
||||
# # See README.md "Provider Configuration Recommendations" for Azure AD.
|
||||
|
||||
# --- Google Workspace / Google Cloud Identity Example ---
|
||||
# testDataGoogle:
|
||||
# providerURL: https://accounts.google.com # This is standard for Google
|
||||
# clientID: your-google-client-id.apps.googleusercontent.com
|
||||
# clientSecret: your-google-client-secret # Store securely
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-google"
|
||||
# scopes: # Defaults ["openid", "profile", "email"] are handled. Plugin manages Google-specifics.
|
||||
# # Do NOT add 'offline_access' - plugin handles this.
|
||||
# allowedUserDomains: # Useful for Google Workspace users
|
||||
# - your-gsuite-domain.com
|
||||
# # Google includes 'hd' (hosted domain) claim which can be used with allowedUserDomains.
|
||||
# # Other claims like 'email', 'sub', 'name' are standard.
|
||||
# # See README.md "Provider Configuration Recommendations" for Google.
|
||||
|
||||
# --- Auth0 Example ---
|
||||
# testDataAuth0:
|
||||
# providerURL: https://your-auth0-domain.auth0.com # Replace with your Auth0 domain
|
||||
# clientID: your-auth0-client-id
|
||||
# clientSecret: your-auth0-client-secret # Store securely
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-auth0"
|
||||
# scopes: # Defaults ["openid", "profile", "email"]. Add custom scopes if your Auth0 Rules/Actions require them.
|
||||
# - read:custom_data # Example custom scope
|
||||
# allowedRolesAndGroups: # Based on claims added via Auth0 Rules or Actions (e.g. namespaced claims)
|
||||
# - "https://your-app.com/roles:admin"
|
||||
# - editor
|
||||
# # Use Auth0 Rules or Actions to add custom claims (roles, permissions) to the ID Token.
|
||||
# # Ensure postLogoutRedirectURI is in Auth0 app's "Allowed Logout URLs".
|
||||
# # See README.md "Provider Configuration Recommendations" for Auth0.
|
||||
|
||||
# --- Generic OIDC Provider Example ---
|
||||
# testDataGenericOIDC:
|
||||
# providerURL: https://your-generic-oidc-provider.com/oidc # Issuer URL for your provider
|
||||
# clientID: your-generic-client-id
|
||||
# clientSecret: your-generic-client-secret # Store securely
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-generic"
|
||||
# scopes: # Must include "openid". "profile" and "email" are common.
|
||||
# - openid
|
||||
# - profile
|
||||
# - email
|
||||
# - custom_scope_for_claims # If your provider needs specific scopes for ID token claims
|
||||
# allowedRolesAndGroups:
|
||||
# - user_role_from_id_token
|
||||
# # Consult your provider's documentation on how to map attributes/roles/groups to ID Token claims.
|
||||
# # Verify ID Token contents (e.g. jwt.io) to see available claims.
|
||||
# # See README.md "Provider Configuration Recommendations" for Generic OIDC.
|
||||
|
||||
# Configuration documentation
|
||||
configuration:
|
||||
providerURL:
|
||||
|
||||
@@ -13,6 +13,8 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
|
||||
- Rate limiting
|
||||
- Excluded paths (public URLs)
|
||||
|
||||
**Important Note on Token Validation:** This middleware performs authentication and claim extraction based on the **ID Token** provided by the OIDC provider. It does not primarily use the Access Token for these purposes (though the Access Token is available for templated headers if needed). Therefore, ensure that all necessary claims (e.g., email, roles, custom attributes) are included in the ID Token by your OIDC provider's configuration.
|
||||
|
||||
The middleware has been tested with Auth0, Logto, Google and other standard OIDC providers. It includes special handling for Google's OAuth implementation.
|
||||
|
||||
## Traefik Version Compatibility
|
||||
@@ -702,6 +704,89 @@ The middleware also sets the following security headers:
|
||||
- `X-XSS-Protection: 1; mode=block`
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin`
|
||||
|
||||
## Provider Configuration Recommendations
|
||||
|
||||
**Important: ID Token Validation**
|
||||
|
||||
This Traefik OIDC plugin performs authentication and extracts user claims (like email, roles, groups) exclusively from the **ID Token** provided by your OIDC provider. It does not primarily use the Access Token for these critical functions. Therefore, it is crucial to ensure that all necessary claims are included in the ID Token itself. A common issue is that some OIDC providers might, by default, place certain claims only in the Access Token or UserInfo endpoint.
|
||||
|
||||
This section provides guidance on configuring popular OIDC providers to work optimally with this plugin.
|
||||
|
||||
### Keycloak
|
||||
|
||||
Keycloak is highly configurable, which means you need to ensure your client mappers are set up correctly to include necessary claims in the ID Token.
|
||||
|
||||
* **Ensure Claims in ID Token**:
|
||||
* **Email**: Navigate to your Keycloak realm -> Clients -> Your Client ID -> Mappers. Ensure there's a mapper for 'email' (e.g., a "User Property" mapper for the `email` property) and that "Add to ID token" is **ON**.
|
||||
* **Roles**: For client roles or realm roles, create or edit mappers (e.g., "User Client Role" or "User Realm Role"). Ensure "Add to ID token" is **ON**. You might want to customize the "Token Claim Name" (e.g., to `roles` or `groups`).
|
||||
* **Groups**: Similarly, for group membership, use a "Group Membership" mapper and ensure "Add to ID token" is **ON**. Customize the "Token Claim Name" as needed (e.g., `groups`).
|
||||
* **Scopes**: Ensure your client requests appropriate scopes that trigger the inclusion of these claims if your mappers are scope-dependent. The default `openid`, `profile`, `email` scopes are a good starting point.
|
||||
* **Troubleshooting**: If claims are missing, double-check the "Mappers" tab for your client in Keycloak. The "Token Claim Name" you define here is what you'll use in the `allowedRolesAndGroups` or `headers` configuration in this plugin. (See also the [Troubleshooting](#troubleshooting) section for Keycloak).
|
||||
|
||||
### Azure AD (Microsoft Entra ID)
|
||||
|
||||
Azure AD generally works well with standard OIDC configurations.
|
||||
|
||||
* **ID Token Claims**: Azure AD typically includes standard claims like `email`, `name`, `preferred_username`, and `oid` (Object ID) in the ID Token by default when `openid profile email` scopes are requested.
|
||||
* **Group Claims**: To include group claims in the ID Token, you need to configure this in the Azure AD application registration:
|
||||
* Go to your App Registration -> Token configuration -> Add groups claim.
|
||||
* You can choose which types of groups (Security groups, Directory roles, All groups) to include.
|
||||
* Be aware of the "overage" issue: If a user is a member of too many groups, Azure AD will send a link to fetch groups instead of embedding them. This plugin currently expects group claims to be directly in the ID token. For users with many groups, consider alternative role/permission management strategies.
|
||||
* The claim name for groups is typically `groups`.
|
||||
* **Optional Claims**: You can add other optional claims via the "Token configuration" section of your App Registration. Ensure these are configured for the ID token.
|
||||
* **Endpoints**: The `providerURL` should be `https://login.microsoftonline.com/{your-tenant-id}/v2.0`. The plugin will auto-discover the necessary endpoints.
|
||||
* **Optimization**: Ensure your application manifest in Azure AD is configured for the desired token version (v1.0 or v2.0). This plugin works with v2.0 endpoints.
|
||||
|
||||
### Google Workspace / Google Cloud Identity
|
||||
|
||||
Google's OIDC implementation is well-supported.
|
||||
|
||||
* **Optimal Configuration**: The plugin automatically handles Google-specific requirements, such as using `access_type=offline` and `prompt=consent` to ensure refresh tokens are issued for long-lived sessions. You do not need to add `offline_access` to scopes.
|
||||
* **ID Token Claims**: Google includes standard claims like `email`, `sub`, `name`, `given_name`, `family_name`, `picture` in the ID Token by default with `openid profile email` scopes.
|
||||
* **Hosted Domain (hd claim)**: If you are using Google Workspace and want to restrict access to users within your organization's domain, Google includes an `hd` (hosted domain) claim in the ID Token. You can use this with the `allowedUserDomains` setting or for custom header logic.
|
||||
* **Best Practices**:
|
||||
* Use the `providerURL`: `https://accounts.google.com`.
|
||||
* Ensure your OAuth consent screen in Google Cloud Console is configured correctly and published. For production, it should be "External" and in "Production" status. "Testing" status limits refresh token lifetime.
|
||||
* Refer to the [Google OAuth Compatibility Fix](#google-oauth-compatibility-fix) section for more details on how the plugin handles Google's specifics.
|
||||
|
||||
### Auth0
|
||||
|
||||
Auth0 is generally OIDC compliant and works well.
|
||||
|
||||
* **ID Token Claims**:
|
||||
* To add custom claims or standard claims not included by default (like roles or permissions) to the ID Token, you'll need to use Auth0 Rules or Actions.
|
||||
* **Using Actions (Recommended)**: Create a custom Action that runs after login to add claims to the ID Token. Example:
|
||||
```javascript
|
||||
// Auth0 Action to add email and roles to ID Token
|
||||
exports.onExecutePostLogin = async (event, api) => {
|
||||
const namespace = 'https://your-app.com/'; // Or your custom namespace
|
||||
if (event.authorization) {
|
||||
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
|
||||
api.idToken.setCustomClaim('email', event.user.email); // Standard claim, ensure it's there
|
||||
// Add other claims as needed
|
||||
}
|
||||
};
|
||||
```
|
||||
* Ensure the claims you add (e.g., `https://your-app.com/roles`) are then used in the plugin's `allowedRolesAndGroups` or `headers` configuration.
|
||||
* **Scopes**: Request appropriate scopes. You might need custom scopes if your Actions/Rules depend on them to add specific claims.
|
||||
* **Endpoints**: Your `providerURL` will be `https://your-auth0-domain.auth0.com`.
|
||||
* **Logout**: Ensure `postLogoutRedirectURI` is registered in your Auth0 application settings under "Allowed Logout URLs".
|
||||
|
||||
### Generic OIDC Providers
|
||||
|
||||
For other OIDC providers (e.g., Okta, Zitadel, self-hosted solutions):
|
||||
|
||||
* **ID Token is Key**: The primary requirement is that all claims needed for authentication decisions (email, roles, groups, custom attributes for headers) **must** be included in the ID Token.
|
||||
* **Check Provider Documentation**: Consult your OIDC provider's documentation on how to:
|
||||
* Configure client applications.
|
||||
* Map user attributes, roles, or group memberships to claims in the ID Token.
|
||||
* Define custom scopes if they are necessary to include certain claims.
|
||||
* **Standard Endpoints**: Ensure your provider exposes a standard OIDC discovery document (`.well-known/openid-configuration`) at the `providerURL`. The plugin uses this to find authorization, token, JWKS, and end_session endpoints.
|
||||
* **Scopes**: Always include `openid` in your scopes. `profile` and `email` are generally recommended. Add other scopes as required by your provider to release specific claims to the ID Token.
|
||||
* **Troubleshooting**: If the plugin isn't working as expected (e.g., access denied, claims missing), the first step is to decode the ID Token received from your provider (e.g., using jwt.io) to verify its contents. This will show you exactly what claims the plugin is seeing.
|
||||
|
||||
For common issues and general troubleshooting, please refer to the [Troubleshooting](#troubleshooting) section.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Logging
|
||||
@@ -726,6 +811,12 @@ logLevel: debug
|
||||
- Verify you're using a version of the middleware that includes the Google OAuth compatibility fix.
|
||||
- For more details, see the [Google OAuth Compatibility Fix](#google-oauth-compatibility-fix) section or the [detailed documentation](docs/google-oauth-fix.md).
|
||||
|
||||
7. **Keycloak: Claims Missing from ID Token (e.g., email, roles)**
|
||||
|
||||
If you are using Keycloak and claims like `email`, `roles`, or `groups` are missing from the ID Token, this plugin may not function as expected (e.g., for domain restrictions or RBAC).
|
||||
* **Solution**: This plugin validates the **ID Token**. You **must** configure Keycloak client mappers to add all necessary claims (email, roles, groups, etc.) to the ID Token.
|
||||
* For detailed instructions, please see the [Keycloak](#keycloak) section under [Provider Configuration Recommendations](#provider-configuration-recommendations).
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
### TODO / wishlist
|
||||
|
||||
- [] Improve test coverage
|
||||
- [x] Improve caching mechanism
|
||||
- [x] Add automatic release and semver generation
|
||||
+8
-2
@@ -37,7 +37,10 @@ func (bt *BackgroundTask) run() {
|
||||
ticker := time.NewTicker(bt.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
bt.logger.Debug("Starting background task: %s", bt.name)
|
||||
// Only log startup if debug level is enabled
|
||||
if bt.logger != nil {
|
||||
bt.logger.Info("Starting background task: %s", bt.name)
|
||||
}
|
||||
|
||||
// Run task immediately on startup
|
||||
bt.taskFunc()
|
||||
@@ -47,7 +50,10 @@ func (bt *BackgroundTask) run() {
|
||||
case <-ticker.C:
|
||||
bt.taskFunc()
|
||||
case <-bt.stopChan:
|
||||
bt.logger.Debug("Stopping background task: %s", bt.name)
|
||||
// Only log shutdown
|
||||
if bt.logger != nil {
|
||||
bt.logger.Info("Stopping background task: %s", bt.name)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
+26
-42
@@ -17,7 +17,7 @@ type mockTraefikOidc struct {
|
||||
// Override VerifyToken to avoid JWKS lookup in tests
|
||||
func (m *mockTraefikOidc) VerifyToken(token string) error {
|
||||
// Cache test claims to avoid "claims not found" errors
|
||||
testClaims := map[string]any{
|
||||
testClaims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
@@ -30,7 +30,7 @@ func (m *mockTraefikOidc) VerifyToken(token string) error {
|
||||
// Override VerifyJWTSignatureAndClaims to avoid JWKS lookup in tests
|
||||
func (m *mockTraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
// Cache test claims to avoid "claims not found" errors
|
||||
testClaims := map[string]any{
|
||||
testClaims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
@@ -80,7 +80,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
// For test tokens, always return success and cache claims
|
||||
if strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
|
||||
// Cache test claims for JWT tokens
|
||||
testClaims := map[string]any{
|
||||
testClaims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
@@ -94,7 +94,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
// For JWT tokens, cache basic claims to avoid cache lookup issues
|
||||
testClaims := map[string]any{
|
||||
testClaims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
@@ -109,7 +109,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
tOidc.jwtVerifier = &mockJWTVerifier{
|
||||
verifyFunc: func(jwt *JWT, token string) error {
|
||||
// Also cache claims here to ensure they're available
|
||||
testClaims := map[string]any{
|
||||
testClaims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
@@ -162,20 +162,12 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
|
||||
// Create a valid JWT access token for testing
|
||||
accessTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
accessToken, _ := createMockJWT(accessTokenClaims)
|
||||
// Use standardized test tokens with valid future expiration dates
|
||||
accessToken := ValidAccessToken // This token expires in 2065
|
||||
session.SetAccessToken(accessToken)
|
||||
|
||||
// Create an invalid/expired ID token
|
||||
idTokenClaims := map[string]any{
|
||||
// Create an expired ID token using a mock JWT with past expiration
|
||||
idTokenClaims := map[string]interface{}{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired
|
||||
@@ -192,7 +184,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
verifyFunc: func(token string) error {
|
||||
if token == accessToken {
|
||||
// Access token validation succeeds - cache claims
|
||||
testClaims := map[string]any{
|
||||
testClaims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
@@ -234,30 +226,21 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
session, _ := tOidc.sessionManager.GetSession(req)
|
||||
|
||||
// Set up session with opaque access token (non-JWT)
|
||||
// Set up session with JWT access token (not opaque for this test)
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ValidAccessToken)
|
||||
session.SetAccessToken(ValidAccessToken) // This is actually a JWT token
|
||||
|
||||
// Create a valid ID token for claims extraction
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
idToken, _ := createMockJWT(idTokenClaims)
|
||||
session.SetIDToken(idToken)
|
||||
// Use a valid ID token from test tokens
|
||||
session.SetIDToken(ValidIDToken) // This token expires in 2065
|
||||
|
||||
// Mock the token verification
|
||||
originalTokenVerifier := tOidc.tokenVerifier
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
if token == idToken {
|
||||
if token == ValidIDToken {
|
||||
// ID token is valid - cache claims
|
||||
testClaims := map[string]any{
|
||||
testClaims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
@@ -336,17 +319,18 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
}
|
||||
|
||||
// createMockJWT creates a basic JWT token for testing purposes
|
||||
func createMockJWT(claims map[string]any) (string, error) {
|
||||
// Simple mock JWT - in real tests you'd use a proper JWT library
|
||||
// For this test, we'll create a basic three-part token structure
|
||||
header := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0" // {"alg":"RS256","kid":"test-key-id","typ":"JWT"}
|
||||
func createMockJWT(claims map[string]interface{}) (string, error) {
|
||||
// For testing purposes, create a JWT with expired claims when needed
|
||||
// Use the test tokens infrastructure for most cases, but allow expired tokens for specific tests
|
||||
testTokens := NewTestTokens()
|
||||
|
||||
// Create a simple payload with test claims
|
||||
payload := "eyJpc3MiOiJ0ZXN0LWlzc3VlciIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjoxNjM4MzYwMDAwLCJpYXQiOjE2MzgzNTY0MDAsInN1YiI6InVzZXIxMjMiLCJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ" // Basic claims
|
||||
// Check if this is meant to be an expired token
|
||||
if exp, ok := claims["exp"].(int64); ok && exp < time.Now().Unix() {
|
||||
return testTokens.CreateExpiredJWT(), nil
|
||||
}
|
||||
|
||||
signature := "test-signature"
|
||||
|
||||
return header + "." + payload + "." + signature, nil
|
||||
// Otherwise return a valid token
|
||||
return ValidIDToken, nil
|
||||
}
|
||||
|
||||
// Mock error type for testing
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||
type CacheItem struct {
|
||||
// Value is the cached data of any type.
|
||||
Value any
|
||||
Value interface{}
|
||||
|
||||
// ExpiresAt is the timestamp when this item should be considered expired.
|
||||
ExpiresAt time.Time
|
||||
@@ -66,7 +66,7 @@ func NewCacheWithLogger(logger *Logger) *Cache {
|
||||
// If the key does not exist and the cache is full, the least recently used item is evicted
|
||||
// before adding the new item.
|
||||
// The expiration duration is relative to the time Set is called.
|
||||
func (c *Cache) Set(key string, value any, expiration time.Duration) {
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@@ -104,7 +104,7 @@ func (c *Cache) Set(key string, value any, expiration time.Duration) {
|
||||
// Accessing an item moves it to the most recently used position in the LRU list.
|
||||
// If the item does not exist or has expired, nil and false are returned, and the
|
||||
// expired item is removed from the cache.
|
||||
func (c *Cache) Get(key string) (any, bool) {
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MaxKeyLength defines the maximum allowed length for cache keys
|
||||
const MaxKeyLength = 256
|
||||
|
||||
// OptimizedCacheEntry represents a single cache entry with embedded LRU linked list
|
||||
// This eliminates the need for separate data structures and reduces memory overhead by ~66%
|
||||
type OptimizedCacheEntry struct {
|
||||
Value interface{}
|
||||
ExpiresAt time.Time
|
||||
Key string
|
||||
|
||||
// Embedded doubly-linked list pointers for LRU ordering
|
||||
prev, next *OptimizedCacheEntry
|
||||
}
|
||||
|
||||
// OptimizedCache provides a memory-efficient thread-safe cache with LRU eviction
|
||||
// Uses only a single map with embedded doubly-linked list to reduce memory overhead
|
||||
type OptimizedCache struct {
|
||||
items map[string]*OptimizedCacheEntry
|
||||
head, tail *OptimizedCacheEntry // LRU sentinel nodes
|
||||
cleanupTask *BackgroundTask
|
||||
logger *Logger
|
||||
maxSize int
|
||||
maxMemoryBytes int64 // Memory budget limit
|
||||
currentMemoryBytes int64 // Current estimated memory usage
|
||||
autoCleanupInterval time.Duration
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOptimizedCache creates a new memory-efficient cache with default settings
|
||||
func NewOptimizedCache() *OptimizedCache {
|
||||
return NewOptimizedCacheWithConfig(DefaultMaxSize, 0, nil)
|
||||
}
|
||||
|
||||
// NewOptimizedCacheWithConfig creates a cache with specified configuration
|
||||
func NewOptimizedCacheWithConfig(maxSize int, maxMemoryMB int, logger *Logger) *OptimizedCache {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
// Create sentinel nodes for the doubly-linked list
|
||||
head := &OptimizedCacheEntry{}
|
||||
tail := &OptimizedCacheEntry{}
|
||||
head.next = tail
|
||||
tail.prev = head
|
||||
|
||||
maxMemoryBytes := int64(maxMemoryMB) * 1024 * 1024 // Convert MB to bytes
|
||||
if maxMemoryBytes == 0 {
|
||||
maxMemoryBytes = 64 * 1024 * 1024 // Default 64MB
|
||||
}
|
||||
|
||||
c := &OptimizedCache{
|
||||
items: make(map[string]*OptimizedCacheEntry, maxSize),
|
||||
head: head,
|
||||
tail: tail,
|
||||
maxSize: maxSize,
|
||||
maxMemoryBytes: maxMemoryBytes,
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with memory and key validation
|
||||
func (c *OptimizedCache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
// Validate key length to prevent memory bloat
|
||||
if len(key) > MaxKeyLength {
|
||||
c.logger.Debugf("Cache key too long (%d > %d), ignoring", len(key), MaxKeyLength)
|
||||
return
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
expTime := now.Add(expiration)
|
||||
|
||||
// Update existing item
|
||||
if entry, exists := c.items[key]; exists {
|
||||
oldSize := c.estimateEntrySize(entry)
|
||||
entry.Value = value
|
||||
entry.ExpiresAt = expTime
|
||||
newSize := c.estimateEntrySize(entry)
|
||||
c.currentMemoryBytes += newSize - oldSize
|
||||
c.moveToTail(entry)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new entry
|
||||
entry := &OptimizedCacheEntry{
|
||||
Value: value,
|
||||
ExpiresAt: expTime,
|
||||
Key: key,
|
||||
}
|
||||
|
||||
entrySize := c.estimateEntrySize(entry)
|
||||
|
||||
// Check memory budget and evict if necessary
|
||||
for (c.currentMemoryBytes+entrySize > c.maxMemoryBytes || len(c.items) >= c.maxSize) && len(c.items) > 0 {
|
||||
if !c.evictOldest() {
|
||||
break // No more items to evict
|
||||
}
|
||||
}
|
||||
|
||||
// Add new entry
|
||||
c.items[key] = entry
|
||||
c.currentMemoryBytes += entrySize
|
||||
c.addToTail(entry)
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache with memory-efficient access tracking
|
||||
func (c *OptimizedCache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
entry, exists := c.items[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check for expiration
|
||||
if time.Now().After(entry.ExpiresAt) {
|
||||
c.removeEntry(entry)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Move to tail (most recently used)
|
||||
c.moveToTail(entry)
|
||||
return entry.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
func (c *OptimizedCache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if entry, exists := c.items[key]; exists {
|
||||
c.removeEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup removes expired items and performs memory optimization
|
||||
func (c *OptimizedCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
toRemove := make([]*OptimizedCacheEntry, 0, len(c.items)/10) // Pre-allocate for efficiency
|
||||
|
||||
// Collect expired entries (start from head - oldest items)
|
||||
for entry := c.head.next; entry != c.tail; entry = entry.next {
|
||||
if now.After(entry.ExpiresAt) {
|
||||
toRemove = append(toRemove, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove expired entries
|
||||
for _, entry := range toRemove {
|
||||
c.removeEntry(entry)
|
||||
}
|
||||
|
||||
// Perform memory pressure eviction if needed
|
||||
for c.currentMemoryBytes > c.maxMemoryBytes && len(c.items) > 0 {
|
||||
if !c.evictOldest() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used item
|
||||
// Returns false if no items to evict
|
||||
func (c *OptimizedCache) evictOldest() bool {
|
||||
if c.head.next == c.tail {
|
||||
return false // Empty cache
|
||||
}
|
||||
|
||||
oldest := c.head.next
|
||||
c.removeEntry(oldest)
|
||||
return true
|
||||
}
|
||||
|
||||
// removeEntry removes an entry from both the map and linked list
|
||||
func (c *OptimizedCache) removeEntry(entry *OptimizedCacheEntry) {
|
||||
// Remove from map
|
||||
delete(c.items, entry.Key)
|
||||
|
||||
// Update memory usage
|
||||
c.currentMemoryBytes -= c.estimateEntrySize(entry)
|
||||
|
||||
// Remove from linked list
|
||||
entry.prev.next = entry.next
|
||||
entry.next.prev = entry.prev
|
||||
|
||||
// Clear references to help GC
|
||||
entry.prev = nil
|
||||
entry.next = nil
|
||||
entry.Value = nil
|
||||
}
|
||||
|
||||
// addToTail adds an entry to the tail (most recently used position)
|
||||
func (c *OptimizedCache) addToTail(entry *OptimizedCacheEntry) {
|
||||
entry.prev = c.tail.prev
|
||||
entry.next = c.tail
|
||||
c.tail.prev.next = entry
|
||||
c.tail.prev = entry
|
||||
}
|
||||
|
||||
// moveToTail moves an existing entry to the tail (mark as most recently used)
|
||||
func (c *OptimizedCache) moveToTail(entry *OptimizedCacheEntry) {
|
||||
// Remove from current position
|
||||
entry.prev.next = entry.next
|
||||
entry.next.prev = entry.prev
|
||||
|
||||
// Add to tail
|
||||
c.addToTail(entry)
|
||||
}
|
||||
|
||||
// estimateEntrySize estimates the memory usage of a cache entry
|
||||
// Uses conservative estimates since unsafe.Sizeof is not allowed in Yaegi
|
||||
func (c *OptimizedCache) estimateEntrySize(entry *OptimizedCacheEntry) int64 {
|
||||
// Conservative estimate for OptimizedCacheEntry struct overhead
|
||||
// (3 pointers + time.Time + string) ≈ 80 bytes on 64-bit systems
|
||||
size := int64(80) + int64(len(entry.Key))
|
||||
|
||||
// Estimate value size based on type
|
||||
if entry.Value != nil {
|
||||
switch v := entry.Value.(type) {
|
||||
case string:
|
||||
size += int64(len(v))
|
||||
case []byte:
|
||||
size += int64(len(v))
|
||||
case map[string]interface{}:
|
||||
// Rough estimate for map overhead + keys + values
|
||||
size += int64(len(v)) * 64 // 64 bytes per entry estimate
|
||||
for key, val := range v {
|
||||
size += int64(len(key))
|
||||
// Estimate value size
|
||||
switch val := val.(type) {
|
||||
case string:
|
||||
size += int64(len(val))
|
||||
case []byte:
|
||||
size += int64(len(val))
|
||||
default:
|
||||
size += 32 // Default estimate for other types
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, s := range v {
|
||||
size += int64(len(s)) + 16 // 16 bytes slice overhead per string
|
||||
}
|
||||
default:
|
||||
// Generic estimate for unknown types
|
||||
size += 64
|
||||
}
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
// SetMaxSize changes the maximum number of items the cache can hold
|
||||
func (c *OptimizedCache) SetMaxSize(size int) {
|
||||
if size <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.maxSize = size
|
||||
|
||||
// Evict excess items if necessary
|
||||
for len(c.items) > c.maxSize && len(c.items) > 0 {
|
||||
if !c.evictOldest() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxMemory sets the maximum memory budget in MB
|
||||
func (c *OptimizedCache) SetMaxMemory(maxMemoryMB int) {
|
||||
if maxMemoryMB <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.maxMemoryBytes = int64(maxMemoryMB) * 1024 * 1024
|
||||
|
||||
// Evict items if over memory budget
|
||||
for c.currentMemoryBytes > c.maxMemoryBytes && len(c.items) > 0 {
|
||||
if !c.evictOldest() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background cleanup task
|
||||
func (c *OptimizedCache) startAutoCleanup() {
|
||||
c.cleanupTask = NewBackgroundTask("optimized-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
|
||||
c.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup task
|
||||
func (c *OptimizedCache) Close() {
|
||||
if c.cleanupTask != nil {
|
||||
c.cleanupTask.Stop()
|
||||
c.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
+1
-1
@@ -49,7 +49,7 @@ func TestCache_SetMaxSize(t *testing.T) {
|
||||
newMaxSize := 3
|
||||
|
||||
// Add more items than the new max size
|
||||
for i := range originalMaxSize {
|
||||
for i := 0; i < originalMaxSize; i++ {
|
||||
key := "key" + string(rune('A'+i))
|
||||
c.Set(key, i, 1*time.Hour)
|
||||
}
|
||||
|
||||
+31
-27
@@ -3,7 +3,6 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
@@ -17,7 +16,7 @@ type ErrorRecoveryMechanism interface {
|
||||
// ExecuteWithContext executes a function with error recovery
|
||||
ExecuteWithContext(ctx context.Context, fn func() error) error
|
||||
// GetMetrics returns metrics about the error recovery mechanism
|
||||
GetMetrics() map[string]any
|
||||
GetMetrics() map[string]interface{}
|
||||
// Reset resets the state of the error recovery mechanism
|
||||
Reset()
|
||||
// IsAvailable returns whether the mechanism is available for use
|
||||
@@ -74,11 +73,11 @@ func (b *BaseRecoveryMechanism) RecordFailure() {
|
||||
}
|
||||
|
||||
// GetBaseMetrics returns base metrics common to all recovery mechanisms
|
||||
func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]any {
|
||||
func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
|
||||
metrics := map[string]any{
|
||||
metrics := map[string]interface{}{
|
||||
"total_requests": atomic.LoadInt64(&b.totalRequests),
|
||||
"total_failures": atomic.LoadInt64(&b.totalFailures),
|
||||
"total_successes": atomic.LoadInt64(&b.totalSuccesses),
|
||||
@@ -108,23 +107,23 @@ func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]any {
|
||||
}
|
||||
|
||||
// LogInfo logs an informational message
|
||||
func (b *BaseRecoveryMechanism) LogInfo(format string, args ...any) {
|
||||
func (b *BaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
|
||||
if b.logger != nil {
|
||||
b.logger.Infof("%s: "+format, append([]any{b.name}, args...)...)
|
||||
b.logger.Infof("%s: "+format, append([]interface{}{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error message
|
||||
func (b *BaseRecoveryMechanism) LogError(format string, args ...any) {
|
||||
func (b *BaseRecoveryMechanism) LogError(format string, args ...interface{}) {
|
||||
if b.logger != nil {
|
||||
b.logger.Errorf("%s: "+format, append([]any{b.name}, args...)...)
|
||||
b.logger.Errorf("%s: "+format, append([]interface{}{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogDebug logs a debug message
|
||||
func (b *BaseRecoveryMechanism) LogDebug(format string, args ...any) {
|
||||
func (b *BaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
|
||||
if b.logger != nil {
|
||||
b.logger.Debugf("%s: "+format, append([]any{b.name}, args...)...)
|
||||
b.logger.Debugf("%s: "+format, append([]interface{}{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,7 +295,7 @@ func (cb *CircuitBreaker) IsAvailable() bool {
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the circuit breaker
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]any {
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
|
||||
cb.mutex.RLock()
|
||||
state := cb.state
|
||||
failures := cb.failures
|
||||
@@ -375,7 +374,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
|
||||
err := fn()
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
re.LogInfo("Operation succeeded on attempt %d", attempt)
|
||||
re.LogInfo("Operation succeeded after %d attempts", attempt)
|
||||
}
|
||||
re.RecordSuccess()
|
||||
return nil
|
||||
@@ -385,7 +384,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
|
||||
|
||||
// Check if error is retryable
|
||||
if !re.isRetryableError(err) {
|
||||
re.LogDebug("Non-retryable error on attempt %d: %v", attempt, err)
|
||||
// Only log non-retryable errors once
|
||||
re.RecordFailure()
|
||||
return err
|
||||
}
|
||||
@@ -398,8 +397,11 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
|
||||
|
||||
// Calculate delay with exponential backoff
|
||||
delay := re.calculateDelay(attempt)
|
||||
re.LogDebug("Retrying operation after %v (attempt %d/%d): %v",
|
||||
delay, attempt, re.config.MaxAttempts, err)
|
||||
// Only log on first retry and then every 3rd attempt to reduce spam
|
||||
if attempt == 1 || attempt%3 == 0 {
|
||||
re.LogDebug("Retrying operation after %v (attempt %d/%d): %v",
|
||||
delay, attempt, re.config.MaxAttempts, err)
|
||||
}
|
||||
|
||||
// Wait with context cancellation support
|
||||
select {
|
||||
@@ -498,7 +500,7 @@ func (re *RetryExecutor) IsAvailable() bool {
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the retry executor
|
||||
func (re *RetryExecutor) GetMetrics() map[string]any {
|
||||
func (re *RetryExecutor) GetMetrics() map[string]interface{} {
|
||||
metrics := re.GetBaseMetrics()
|
||||
|
||||
// Add retry executor specific metrics
|
||||
@@ -526,7 +528,7 @@ func (e *HTTPError) Error() string {
|
||||
// GracefulDegradation implements graceful degradation patterns
|
||||
type GracefulDegradation struct {
|
||||
*BaseRecoveryMechanism
|
||||
fallbacks map[string]func() (any, error)
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
healthChecks map[string]func() bool
|
||||
degradedServices map[string]time.Time
|
||||
config GracefulDegradationConfig
|
||||
@@ -553,7 +555,7 @@ func DefaultGracefulDegradationConfig() GracefulDegradationConfig {
|
||||
func NewGracefulDegradation(config GracefulDegradationConfig, logger *Logger) *GracefulDegradation {
|
||||
gd := &GracefulDegradation{
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("graceful-degradation", logger),
|
||||
fallbacks: make(map[string]func() (any, error)),
|
||||
fallbacks: make(map[string]func() (interface{}, error)),
|
||||
healthChecks: make(map[string]func() bool),
|
||||
degradedServices: make(map[string]time.Time),
|
||||
config: config,
|
||||
@@ -566,7 +568,7 @@ func NewGracefulDegradation(config GracefulDegradationConfig, logger *Logger) *G
|
||||
}
|
||||
|
||||
// RegisterFallback registers a fallback function for a service
|
||||
func (gd *GracefulDegradation) RegisterFallback(serviceName string, fallback func() (any, error)) {
|
||||
func (gd *GracefulDegradation) RegisterFallback(serviceName string, fallback func() (interface{}, error)) {
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
gd.fallbacks[serviceName] = fallback
|
||||
@@ -584,7 +586,7 @@ func (gd *GracefulDegradation) ExecuteWithContext(ctx context.Context, fn func()
|
||||
gd.RecordRequest()
|
||||
|
||||
// Execute with a simple wrapper
|
||||
_, err := gd.ExecuteWithFallback("default", func() (any, error) {
|
||||
_, err := gd.ExecuteWithFallback("default", func() (interface{}, error) {
|
||||
return nil, fn()
|
||||
})
|
||||
|
||||
@@ -598,7 +600,7 @@ func (gd *GracefulDegradation) ExecuteWithContext(ctx context.Context, fn func()
|
||||
}
|
||||
|
||||
// ExecuteWithFallback executes a function with fallback support
|
||||
func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary func() (any, error)) (any, error) {
|
||||
func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary func() (interface{}, error)) (interface{}, error) {
|
||||
// Check if service is degraded
|
||||
if gd.isServiceDegraded(serviceName) {
|
||||
gd.LogInfo("Service %s is degraded, using fallback", serviceName)
|
||||
@@ -656,7 +658,7 @@ func (gd *GracefulDegradation) markServiceDegraded(serviceName string) {
|
||||
}
|
||||
|
||||
// executeFallback executes the fallback function for a service
|
||||
func (gd *GracefulDegradation) executeFallback(serviceName string) (any, error) {
|
||||
func (gd *GracefulDegradation) executeFallback(serviceName string) (interface{}, error) {
|
||||
gd.mutex.RLock()
|
||||
fallback, exists := gd.fallbacks[serviceName]
|
||||
gd.mutex.RUnlock()
|
||||
@@ -684,7 +686,9 @@ func (gd *GracefulDegradation) startHealthCheckRoutine() {
|
||||
func (gd *GracefulDegradation) performHealthChecks() {
|
||||
gd.mutex.RLock()
|
||||
healthChecks := make(map[string]func() bool)
|
||||
maps.Copy(healthChecks, gd.healthChecks)
|
||||
for k, v := range gd.healthChecks {
|
||||
healthChecks[k] = v
|
||||
}
|
||||
gd.mutex.RUnlock()
|
||||
|
||||
for serviceName, healthCheck := range healthChecks {
|
||||
@@ -732,7 +736,7 @@ func (gd *GracefulDegradation) IsAvailable() bool {
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the graceful degradation mechanism
|
||||
func (gd *GracefulDegradation) GetMetrics() map[string]any {
|
||||
func (gd *GracefulDegradation) GetMetrics() map[string]interface{} {
|
||||
gd.mutex.RLock()
|
||||
degradedCount := len(gd.degradedServices)
|
||||
|
||||
@@ -805,14 +809,14 @@ func (erm *ErrorRecoveryManager) ExecuteWithRecovery(ctx context.Context, servic
|
||||
}
|
||||
|
||||
// GetRecoveryMetrics returns metrics for all recovery mechanisms
|
||||
func (erm *ErrorRecoveryManager) GetRecoveryMetrics() map[string]any {
|
||||
func (erm *ErrorRecoveryManager) GetRecoveryMetrics() map[string]interface{} {
|
||||
erm.mutex.RLock()
|
||||
defer erm.mutex.RUnlock()
|
||||
|
||||
metrics := make(map[string]any)
|
||||
metrics := make(map[string]interface{})
|
||||
|
||||
// Circuit breaker metrics
|
||||
cbMetrics := make(map[string]any)
|
||||
cbMetrics := make(map[string]interface{})
|
||||
for name, cb := range erm.circuitBreakers {
|
||||
cbMetrics[name] = cb.GetMetrics()
|
||||
}
|
||||
|
||||
+11
-6
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -208,7 +207,7 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
}()
|
||||
|
||||
t.Run("Register fallback and health check", func(t *testing.T) {
|
||||
gd.RegisterFallback("test-service", func() (any, error) {
|
||||
gd.RegisterFallback("test-service", func() (interface{}, error) {
|
||||
return "fallback-result", nil
|
||||
})
|
||||
|
||||
@@ -223,12 +222,12 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Execute with fallback on failure", func(t *testing.T) {
|
||||
gd.RegisterFallback("failing-service", func() (any, error) {
|
||||
gd.RegisterFallback("failing-service", func() (interface{}, error) {
|
||||
return "fallback-result", nil
|
||||
})
|
||||
|
||||
// First call should fail and mark service as degraded
|
||||
result, err := gd.ExecuteWithFallback("failing-service", func() (any, error) {
|
||||
result, err := gd.ExecuteWithFallback("failing-service", func() (interface{}, error) {
|
||||
return nil, errors.New("service failure")
|
||||
})
|
||||
if err != nil {
|
||||
@@ -245,7 +244,7 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("No fallback available", func(t *testing.T) {
|
||||
_, err := gd.ExecuteWithFallback("no-fallback-service", func() (any, error) {
|
||||
_, err := gd.ExecuteWithFallback("no-fallback-service", func() (interface{}, error) {
|
||||
return nil, errors.New("service failure")
|
||||
})
|
||||
|
||||
@@ -256,7 +255,13 @@ func TestGracefulDegradation(t *testing.T) {
|
||||
|
||||
t.Run("Get degraded services", func(t *testing.T) {
|
||||
degraded := gd.GetDegradedServices()
|
||||
found := slices.Contains(degraded, "failing-service")
|
||||
found := false
|
||||
for _, s := range degraded {
|
||||
if s == "failing-service" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected failing-service to be in degraded list")
|
||||
}
|
||||
|
||||
+11
-6
@@ -9,7 +9,6 @@ import (
|
||||
"math/big"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -132,9 +131,9 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tOidc.extractClaimsFunc = func(token string) (map[string]any, error) {
|
||||
tOidc.extractClaimsFunc = func(token string) (map[string]interface{}, error) {
|
||||
// Return mock claims
|
||||
return map[string]any{
|
||||
return map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
}, nil
|
||||
@@ -304,7 +303,13 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
scopeList := strings.Split(scope, " ")
|
||||
expectedScopes := []string{"openid", "profile", "email"}
|
||||
for _, expectedScope := range expectedScopes {
|
||||
found := slices.Contains(scopeList, expectedScope)
|
||||
found := false
|
||||
for _, s := range scopeList {
|
||||
if s == expectedScope {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected scope '%s' not found in scope parameter: %s", expectedScope, scope)
|
||||
}
|
||||
@@ -380,7 +385,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
nbf := now.Unix()
|
||||
|
||||
// Create initial ID token
|
||||
initialIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
initialIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -396,7 +401,7 @@ func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create refresh ID token
|
||||
refreshedIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
refreshedIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
|
||||
+22
-5
@@ -187,7 +187,7 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
// Returns:
|
||||
// - A map representing the JSON claims extracted from the token payload.
|
||||
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
|
||||
func extractClaims(tokenString string) (map[string]any, error) {
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
@@ -198,7 +198,7 @@ func extractClaims(tokenString string) (map[string]any, error) {
|
||||
return nil, fmt.Errorf("failed to decode token payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
||||
}
|
||||
@@ -236,7 +236,7 @@ func NewTokenCache() *TokenCache {
|
||||
// - token: The raw token string (used as the key).
|
||||
// - claims: The map of claims associated with the token.
|
||||
// - expiration: The duration for which the cache entry should be valid.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]any, expiration time.Duration) {
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
@@ -250,13 +250,13 @@ func (tc *TokenCache) Set(token string, claims map[string]any, expiration time.D
|
||||
// Returns:
|
||||
// - The cached claims map if found and valid.
|
||||
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
|
||||
func (tc *TokenCache) Get(token string) (map[string]any, bool) {
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
claims, ok := value.(map[string]any)
|
||||
claims, ok := value.(map[string]interface{})
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
@@ -404,3 +404,20 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// deduplicateScopes removes duplicate strings from a slice while preserving order.
|
||||
// The first occurrence of each scope is kept.
|
||||
func deduplicateScopes(scopes []string) []string {
|
||||
if len(scopes) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
result := []string{}
|
||||
for _, scope := range scopes {
|
||||
if _, ok := seen[scope]; !ok {
|
||||
seen[scope] = struct{}{}
|
||||
result = append(result, scope)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
+1
-1
@@ -614,7 +614,7 @@ func (iv *InputValidator) SanitizeInput(input string, maxLength int) string {
|
||||
}
|
||||
|
||||
// ValidateBoundaryValues validates numeric boundary values
|
||||
func (iv *InputValidator) ValidateBoundaryValues(value any, min, max int64) ValidationResult {
|
||||
func (iv *InputValidator) ValidateBoundaryValues(value interface{}, min, max int64) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
var numValue int64
|
||||
|
||||
@@ -246,7 +246,7 @@ func TestValidateBoundaryValues(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("Valid boundary values", func(t *testing.T) {
|
||||
validValues := []any{
|
||||
validValues := []interface{}{
|
||||
int(50),
|
||||
int64(100),
|
||||
float64(75.5),
|
||||
@@ -261,7 +261,7 @@ func TestValidateBoundaryValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Invalid boundary values", func(t *testing.T) {
|
||||
invalidValues := []any{
|
||||
invalidValues := []interface{}{
|
||||
int(-1),
|
||||
int64(2000),
|
||||
"not a number",
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Adapter facilitates communication between the legacy TraefikOIDC struct and the new provider system.
|
||||
type Adapter struct {
|
||||
provider OIDCProvider
|
||||
legacySettings LegacySettings
|
||||
tokenVerifier TokenVerifier
|
||||
tokenCache TokenCache
|
||||
}
|
||||
|
||||
// LegacySettings provides the adapter with access to the original configuration values.
|
||||
type LegacySettings interface {
|
||||
GetIssuerURL() string
|
||||
GetAuthURL() string
|
||||
GetScopes() []string
|
||||
IsPKCEEnabled() bool
|
||||
GetClientID() string
|
||||
GetRefreshGracePeriod() time.Duration
|
||||
IsOverrideScopes() bool
|
||||
}
|
||||
|
||||
// NewAdapter creates a new adapter for a given provider and legacy settings.
|
||||
func NewAdapter(provider OIDCProvider, settings LegacySettings, tokenVerifier TokenVerifier, tokenCache TokenCache) *Adapter {
|
||||
return &Adapter{
|
||||
provider: provider,
|
||||
legacySettings: settings,
|
||||
tokenVerifier: tokenVerifier,
|
||||
tokenCache: tokenCache,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildAuthURL constructs the authentication URL using the adapted provider.
|
||||
func (a *Adapter) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", a.legacySettings.GetClientID())
|
||||
params.Set("response_type", "code")
|
||||
params.Set("redirect_uri", redirectURL)
|
||||
params.Set("state", state)
|
||||
params.Set("nonce", nonce)
|
||||
|
||||
if a.legacySettings.IsPKCEEnabled() && codeChallenge != "" {
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
scopes := a.legacySettings.GetScopes()
|
||||
|
||||
// When overrideScopes is true, use exactly the scopes provided without modification
|
||||
if a.legacySettings.IsOverrideScopes() {
|
||||
// Use scopes as-is, don't let provider add anything
|
||||
finalParams := params
|
||||
finalParams.Set("scope", strings.Join(scopes, " "))
|
||||
|
||||
// For provider-specific parameters, we still need to check the provider type
|
||||
switch a.provider.GetType() {
|
||||
case ProviderTypeGoogle:
|
||||
// Google-specific parameters
|
||||
finalParams.Set("access_type", "offline")
|
||||
finalParams.Set("prompt", "consent")
|
||||
case ProviderTypeAzure:
|
||||
// Azure-specific parameters
|
||||
finalParams.Set("response_mode", "query")
|
||||
}
|
||||
|
||||
return a.buildURLWithParams(a.legacySettings.GetAuthURL(), finalParams)
|
||||
}
|
||||
|
||||
// When overrideScopes is false, let the provider add necessary scopes
|
||||
authParams, err := a.provider.BuildAuthParams(params, scopes)
|
||||
if err != nil {
|
||||
// Log the error appropriately
|
||||
return ""
|
||||
}
|
||||
|
||||
finalParams := authParams.URLValues
|
||||
finalParams.Set("scope", strings.Join(authParams.Scopes, " "))
|
||||
|
||||
// Build the full URL with params
|
||||
return a.buildURLWithParams(a.legacySettings.GetAuthURL(), finalParams)
|
||||
}
|
||||
|
||||
// buildURLWithParams takes a base URL and query parameters and constructs a full URL string.
|
||||
// If the baseURL is relative (doesn't start with http/https), it prepends the scheme and host
|
||||
// from the configured issuerURL.
|
||||
func (a *Adapter) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
// Relative URL - resolve against issuer URL
|
||||
issuerURLParsed, err := url.Parse(a.legacySettings.GetIssuerURL())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseURLParsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
|
||||
resolvedURL.RawQuery = params.Encode()
|
||||
return resolvedURL.String()
|
||||
}
|
||||
|
||||
// Absolute URL
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// ValidateTokens validates tokens using the adapted provider.
|
||||
func (a *Adapter) ValidateTokens(session Session) (*ValidationResult, error) {
|
||||
return a.provider.ValidateTokens(session, a.tokenVerifier, a.tokenCache, a.legacySettings.GetRefreshGracePeriod())
|
||||
}
|
||||
|
||||
// GetType returns the underlying provider's type.
|
||||
func (a *Adapter) GetType() ProviderType {
|
||||
return a.provider.GetType()
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AzureProvider encapsulates Azure AD-specific OIDC logic.
|
||||
type AzureProvider struct {
|
||||
*BaseProvider
|
||||
}
|
||||
|
||||
// NewAzureProvider creates a new instance of the AzureProvider.
|
||||
func NewAzureProvider() *AzureProvider {
|
||||
return &AzureProvider{
|
||||
BaseProvider: NewBaseProvider(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetType returns the provider's type.
|
||||
func (p *AzureProvider) GetType() ProviderType {
|
||||
return ProviderTypeAzure
|
||||
}
|
||||
|
||||
// GetCapabilities returns the specific capabilities of the Azure provider.
|
||||
func (p *AzureProvider) GetCapabilities() ProviderCapabilities {
|
||||
return ProviderCapabilities{
|
||||
SupportsRefreshTokens: true,
|
||||
RequiresOfflineAccessScope: true,
|
||||
PreferredTokenValidation: "access", // Azure AD prefers access token validation
|
||||
}
|
||||
}
|
||||
|
||||
// BuildAuthParams configures Azure-specific authentication parameters.
|
||||
func (p *AzureProvider) BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error) {
|
||||
baseParams.Set("response_mode", "query")
|
||||
|
||||
// Ensure "offline_access" scope is present for refresh tokens
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
URLValues: baseParams,
|
||||
Scopes: scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateTokens overrides the default token validation to implement Azure-specific logic.
|
||||
// Azure may use access tokens for validation, and this method ensures that behavior is preserved.
|
||||
func (p *AzureProvider) ValidateTokens(session Session, verifier TokenVerifier, tokenCache TokenCache, refreshGracePeriod time.Duration) (*ValidationResult, error) {
|
||||
if !session.GetAuthenticated() {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
idToken := session.GetIDToken()
|
||||
|
||||
if accessToken != "" {
|
||||
if strings.Count(accessToken, ".") == 2 {
|
||||
if err := verifier.VerifyToken(accessToken); err != nil {
|
||||
if idToken != "" {
|
||||
return p.ValidateTokenExpiry(session, idToken, tokenCache, refreshGracePeriod)
|
||||
}
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
return p.ValidateTokenExpiry(session, accessToken, tokenCache, refreshGracePeriod)
|
||||
}
|
||||
if idToken != "" {
|
||||
return p.ValidateTokenExpiry(session, idToken, tokenCache, refreshGracePeriod)
|
||||
}
|
||||
return &ValidationResult{Authenticated: true}, nil
|
||||
}
|
||||
|
||||
if idToken != "" {
|
||||
if err := verifier.VerifyToken(idToken); err != nil {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
return p.ValidateTokenExpiry(session, idToken, tokenCache, refreshGracePeriod)
|
||||
}
|
||||
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
|
||||
// ValidateConfig validates Azure-specific configuration requirements.
|
||||
// Azure requires specific tenant configuration and scope handling.
|
||||
func (p *AzureProvider) ValidateConfig() error {
|
||||
// Azure provider validation - ensure we have the necessary configuration
|
||||
// In a real implementation, this might check for tenant ID, proper issuer URL format, etc.
|
||||
return p.BaseProvider.ValidateConfig()
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BaseProvider provides a common foundation for OIDC provider implementations.
|
||||
// It can be embedded in specific provider structs to share common logic.
|
||||
type BaseProvider struct {
|
||||
// Common configuration or dependencies can be added here.
|
||||
}
|
||||
|
||||
// GetType returns the default provider type, which is Generic.
|
||||
// This should be overridden by specific provider implementations.
|
||||
func (p *BaseProvider) GetType() ProviderType {
|
||||
return ProviderTypeGeneric
|
||||
}
|
||||
|
||||
// GetCapabilities returns a default set of capabilities for a generic OIDC provider.
|
||||
// This can be overridden by specific providers to declare their unique features.
|
||||
func (p *BaseProvider) GetCapabilities() ProviderCapabilities {
|
||||
return ProviderCapabilities{
|
||||
SupportsRefreshTokens: true,
|
||||
RequiresOfflineAccessScope: true,
|
||||
PreferredTokenValidation: "id",
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateTokens provides a default token validation implementation.
|
||||
// This method can be extended or replaced by specific providers.
|
||||
func (p *BaseProvider) ValidateTokens(session Session, verifier TokenVerifier, tokenCache TokenCache, refreshGracePeriod time.Duration) (*ValidationResult, error) {
|
||||
if !session.GetAuthenticated() {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{}, nil
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
if accessToken == "" {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
|
||||
idToken := session.GetIDToken()
|
||||
if idToken == "" {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{Authenticated: true, NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{Authenticated: true}, nil
|
||||
}
|
||||
|
||||
if err := verifier.VerifyToken(idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
|
||||
return p.ValidateTokenExpiry(session, idToken, tokenCache, refreshGracePeriod)
|
||||
}
|
||||
|
||||
// ValidateTokenExpiry provides common token expiry validation logic that can be used by all providers.
|
||||
// This method is now exported so provider implementations can reuse this logic without duplication.
|
||||
func (p *BaseProvider) ValidateTokenExpiry(session Session, token string, tokenCache TokenCache, refreshGracePeriod time.Duration) (*ValidationResult, error) {
|
||||
cachedClaims, found := tokenCache.Get(token)
|
||||
if !found {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
|
||||
expClaim, ok := cachedClaims["exp"].(float64)
|
||||
if !ok {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{IsExpired: true}, nil
|
||||
}
|
||||
|
||||
expTime := time.Unix(int64(expClaim), 0)
|
||||
if expTime.Before(time.Now().Add(refreshGracePeriod)) {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return &ValidationResult{Authenticated: true, NeedsRefresh: true}, nil
|
||||
}
|
||||
return &ValidationResult{Authenticated: true}, nil
|
||||
}
|
||||
|
||||
return &ValidationResult{Authenticated: true}, nil
|
||||
}
|
||||
|
||||
// BuildAuthParams provides a default implementation for building authorization parameters.
|
||||
// It includes the "offline_access" scope by default.
|
||||
func (p *BaseProvider) BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error) {
|
||||
// Ensure offline_access is included if not already present
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
URLValues: baseParams,
|
||||
Scopes: scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleTokenRefresh provides a default implementation for token refresh handling.
|
||||
// By default, it does nothing and assumes the standard token response is sufficient.
|
||||
func (p *BaseProvider) HandleTokenRefresh(tokenData *TokenResult) error {
|
||||
// No provider-specific refresh handling by default.
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConfig provides a default implementation for configuration validation.
|
||||
// By default, it assumes the configuration is valid.
|
||||
func (p *BaseProvider) ValidateConfig() error {
|
||||
// No provider-specific config validation by default.
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewBaseProvider creates a new BaseProvider.
|
||||
func NewBaseProvider() *BaseProvider {
|
||||
return &BaseProvider{}
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProviderFactory encapsulates the logic for creating and configuring OIDC providers.
|
||||
type ProviderFactory struct {
|
||||
registry *ProviderRegistry
|
||||
}
|
||||
|
||||
// NewProviderFactory creates a new factory with a pre-configured registry.
|
||||
func NewProviderFactory() *ProviderFactory {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register all available providers
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
|
||||
return &ProviderFactory{
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProvider creates and returns the appropriate provider for the given issuer URL.
|
||||
// It automatically detects the provider type and returns a configured instance.
|
||||
func (f *ProviderFactory) CreateProvider(issuerURL string) (OIDCProvider, error) {
|
||||
if issuerURL == "" {
|
||||
return nil, fmt.Errorf("issuer URL cannot be empty")
|
||||
}
|
||||
|
||||
// Validate URL format
|
||||
if _, err := url.Parse(issuerURL); err != nil {
|
||||
return nil, fmt.Errorf("invalid issuer URL format: %w", err)
|
||||
}
|
||||
|
||||
provider := f.registry.DetectProvider(issuerURL)
|
||||
if provider == nil {
|
||||
return nil, fmt.Errorf("unable to detect provider for issuer URL: %s", issuerURL)
|
||||
}
|
||||
|
||||
// Validate the provider configuration if it implements config validation
|
||||
if err := provider.ValidateConfig(); err != nil {
|
||||
return nil, fmt.Errorf("provider configuration validation failed: %w", err)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// CreateProviderByType creates a provider instance for a specific provider type.
|
||||
// This is useful when you want to force a specific provider type regardless of URL.
|
||||
func (f *ProviderFactory) CreateProviderByType(providerType ProviderType) (OIDCProvider, error) {
|
||||
var provider OIDCProvider
|
||||
|
||||
switch providerType {
|
||||
case ProviderTypeGeneric:
|
||||
provider = NewGenericProvider()
|
||||
case ProviderTypeGoogle:
|
||||
provider = NewGoogleProvider()
|
||||
case ProviderTypeAzure:
|
||||
provider = NewAzureProvider()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider type: %d", providerType)
|
||||
}
|
||||
|
||||
if err := provider.ValidateConfig(); err != nil {
|
||||
return nil, fmt.Errorf("provider configuration validation failed: %w", err)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetSupportedProviders returns a list of all supported provider types and their detection patterns.
|
||||
func (f *ProviderFactory) GetSupportedProviders() map[ProviderType][]string {
|
||||
return map[ProviderType][]string{
|
||||
ProviderTypeGeneric: {"*"}, // Generic supports any issuer
|
||||
ProviderTypeGoogle: {"accounts.google.com"},
|
||||
ProviderTypeAzure: {"login.microsoftonline.com", "sts.windows.net"},
|
||||
}
|
||||
}
|
||||
|
||||
// DetectProviderType returns the provider type that would be used for a given issuer URL.
|
||||
// This is useful for diagnostic purposes or UI display.
|
||||
func (f *ProviderFactory) DetectProviderType(issuerURL string) (ProviderType, error) {
|
||||
provider, err := f.CreateProvider(issuerURL)
|
||||
if err != nil {
|
||||
return ProviderTypeGeneric, err
|
||||
}
|
||||
return provider.GetType(), nil
|
||||
}
|
||||
|
||||
// IsProviderSupported checks if a given issuer URL is supported by any registered provider.
|
||||
func (f *ProviderFactory) IsProviderSupported(issuerURL string) bool {
|
||||
if issuerURL == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
normalizedURL, err := url.Parse(issuerURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
host := strings.ToLower(normalizedURL.Host)
|
||||
supportedProviders := f.GetSupportedProviders()
|
||||
|
||||
for _, patterns := range supportedProviders {
|
||||
for _, pattern := range patterns {
|
||||
if pattern == "*" || strings.Contains(host, strings.ToLower(pattern)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package providers
|
||||
|
||||
// GenericProvider encapsulates standard OIDC logic for any compliant provider.
|
||||
type GenericProvider struct {
|
||||
*BaseProvider
|
||||
}
|
||||
|
||||
// NewGenericProvider creates a new instance of the GenericProvider.
|
||||
func NewGenericProvider() *GenericProvider {
|
||||
return &GenericProvider{
|
||||
BaseProvider: NewBaseProvider(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetType returns the provider's type.
|
||||
func (p *GenericProvider) GetType() ProviderType {
|
||||
return ProviderTypeGeneric
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// GoogleProvider encapsulates Google-specific OIDC logic.
|
||||
type GoogleProvider struct {
|
||||
*BaseProvider
|
||||
}
|
||||
|
||||
// NewGoogleProvider creates a new instance of the GoogleProvider.
|
||||
func NewGoogleProvider() *GoogleProvider {
|
||||
return &GoogleProvider{
|
||||
BaseProvider: NewBaseProvider(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetType returns the provider's type.
|
||||
func (p *GoogleProvider) GetType() ProviderType {
|
||||
return ProviderTypeGoogle
|
||||
}
|
||||
|
||||
// GetCapabilities returns the specific capabilities of the Google provider.
|
||||
func (p *GoogleProvider) GetCapabilities() ProviderCapabilities {
|
||||
return ProviderCapabilities{
|
||||
SupportsRefreshTokens: true,
|
||||
RequiresOfflineAccessScope: false, // Google uses access_type=offline instead
|
||||
RequiresPromptConsent: true,
|
||||
PreferredTokenValidation: "id",
|
||||
}
|
||||
}
|
||||
|
||||
// BuildAuthParams configures Google-specific authentication parameters.
|
||||
func (p *GoogleProvider) BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error) {
|
||||
baseParams.Set("access_type", "offline")
|
||||
baseParams.Set("prompt", "consent")
|
||||
|
||||
// Google does not use the "offline_access" scope, so we remove it if present.
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
URLValues: baseParams,
|
||||
Scopes: filteredScopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateConfig validates Google-specific configuration requirements.
|
||||
// Google requires specific scopes and client configuration for proper operation.
|
||||
func (p *GoogleProvider) ValidateConfig() error {
|
||||
// Google provider doesn't require additional validation beyond the base implementation
|
||||
// All Google-specific requirements are handled in BuildAuthParams
|
||||
return p.BaseProvider.ValidateConfig()
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
// Package providers implements a universal OIDC provider abstraction system.
|
||||
// It provides a clean interface for different OIDC providers (Google, Azure, Generic)
|
||||
// with provider-specific logic encapsulated in separate implementations.
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenVerifier defines the interface for token verification.
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) error
|
||||
}
|
||||
|
||||
// TokenCache defines the interface for a token cache.
|
||||
type TokenCache interface {
|
||||
Get(key string) (map[string]interface{}, bool)
|
||||
}
|
||||
|
||||
// ProviderType is an enumeration for identifying different OIDC providers.
|
||||
type ProviderType int
|
||||
|
||||
const (
|
||||
// ProviderTypeGeneric represents a standard, compliant OIDC provider.
|
||||
ProviderTypeGeneric ProviderType = iota
|
||||
// ProviderTypeGoogle represents Google as the OIDC provider.
|
||||
ProviderTypeGoogle
|
||||
// ProviderTypeAzure represents Microsoft Azure AD as the OIDC provider.
|
||||
ProviderTypeAzure
|
||||
)
|
||||
|
||||
// ProviderCapabilities defines the specific features and behaviors of an OIDC provider.
|
||||
type ProviderCapabilities struct {
|
||||
// SupportsRefreshTokens indicates if the provider issues refresh tokens.
|
||||
SupportsRefreshTokens bool
|
||||
// RequiresOfflineAccessScope indicates if the "offline_access" scope is needed for refresh tokens.
|
||||
RequiresOfflineAccessScope bool
|
||||
// RequiresPromptConsent indicates if "prompt=consent" is needed to ensure a refresh token is issued.
|
||||
RequiresPromptConsent bool
|
||||
// PreferredTokenValidation specifies the recommended token type to validate (e.g., "access" or "id").
|
||||
PreferredTokenValidation string
|
||||
}
|
||||
|
||||
// ValidationResult holds the outcome of a token validation check.
|
||||
type ValidationResult struct {
|
||||
// Authenticated is true if the token is valid and the user is authenticated.
|
||||
Authenticated bool
|
||||
// NeedsRefresh is true if the token is approaching its expiry and should be refreshed.
|
||||
NeedsRefresh bool
|
||||
// IsExpired is true if the token has expired or is invalid.
|
||||
IsExpired bool
|
||||
}
|
||||
|
||||
// AuthParams contains the provider-specific parameters for building the authorization URL.
|
||||
type AuthParams struct {
|
||||
// URLValues are the query parameters to be added to the authorization URL.
|
||||
URLValues url.Values
|
||||
// Scopes is the list of scopes to be requested.
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
// TokenResult holds the tokens returned by the provider.
|
||||
type TokenResult struct {
|
||||
// IDToken is the OIDC ID token.
|
||||
IDToken string
|
||||
// AccessToken is the OAuth2 access token.
|
||||
AccessToken string
|
||||
// RefreshToken is the OAuth2 refresh token.
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// OIDCProvider defines the interface for an OIDC provider implementation.
|
||||
// This abstraction allows for provider-specific logic to be encapsulated.
|
||||
type OIDCProvider interface {
|
||||
// GetType returns the type of the provider (e.g., Google, Azure, Generic).
|
||||
GetType() ProviderType
|
||||
|
||||
// GetCapabilities returns the feature set of the provider.
|
||||
GetCapabilities() ProviderCapabilities
|
||||
|
||||
// ValidateTokens performs token validation according to the provider's specific rules.
|
||||
// It should check the validity of the access and/or ID tokens from the session.
|
||||
ValidateTokens(session Session, verifier TokenVerifier, tokenCache TokenCache, refreshGracePeriod time.Duration) (*ValidationResult, error)
|
||||
|
||||
// BuildAuthParams modifies the authorization URL parameters for the provider.
|
||||
// This can be used to add provider-specific parameters like "access_type" for Google.
|
||||
BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error)
|
||||
|
||||
// HandleTokenRefresh manages the token refresh process for the provider.
|
||||
// It can modify the token request or handle the response as needed.
|
||||
HandleTokenRefresh(tokenData *TokenResult) error
|
||||
|
||||
// ValidateConfig checks if the user's configuration is valid for this provider.
|
||||
ValidateConfig() error
|
||||
}
|
||||
|
||||
// Session represents the session data required by providers for validation.
|
||||
// This interface decouples the providers from the main session management implementation.
|
||||
type Session interface {
|
||||
GetIDToken() string
|
||||
GetAccessToken() string
|
||||
GetRefreshToken() string
|
||||
GetAuthenticated() bool
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ProviderRegistry holds and manages the available OIDC provider implementations.
|
||||
// It provides thread-safe access to provider instances and caches detection results.
|
||||
type ProviderRegistry struct {
|
||||
mu sync.RWMutex
|
||||
providers []OIDCProvider
|
||||
cache map[string]OIDCProvider
|
||||
typeMap map[ProviderType]OIDCProvider // Maps provider type to instance
|
||||
}
|
||||
|
||||
// NewProviderRegistry creates and initializes a new ProviderRegistry.
|
||||
func NewProviderRegistry() *ProviderRegistry {
|
||||
return &ProviderRegistry{
|
||||
providers: make([]OIDCProvider, 0),
|
||||
cache: make(map[string]OIDCProvider),
|
||||
typeMap: make(map[ProviderType]OIDCProvider),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProvider adds a new provider to the registry.
|
||||
// It maintains both a list of providers and a type-to-provider mapping for efficient lookups.
|
||||
func (r *ProviderRegistry) RegisterProvider(provider OIDCProvider) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.providers = append(r.providers, provider)
|
||||
r.typeMap[provider.GetType()] = provider
|
||||
}
|
||||
|
||||
// GetProviderByType returns a provider instance for the specified type.
|
||||
// Returns nil if the provider type is not registered.
|
||||
func (r *ProviderRegistry) GetProviderByType(providerType ProviderType) OIDCProvider {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.typeMap[providerType]
|
||||
}
|
||||
|
||||
// GetRegisteredProviders returns a slice of all registered provider types.
|
||||
func (r *ProviderRegistry) GetRegisteredProviders() []ProviderType {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
types := make([]ProviderType, 0, len(r.typeMap))
|
||||
for providerType := range r.typeMap {
|
||||
types = append(types, providerType)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
// ClearCache removes all cached provider detection results.
|
||||
// This can be useful for testing or when provider configuration changes.
|
||||
func (r *ProviderRegistry) ClearCache() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.cache = make(map[string]OIDCProvider)
|
||||
}
|
||||
|
||||
// DetectProvider determines the most appropriate provider for a given issuer URL.
|
||||
// It iterates through the registered providers and returns the first one that matches.
|
||||
// Detection is based on URL patterns and other provider-specific criteria.
|
||||
func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Check cache first for performance
|
||||
if provider, found := r.cache[issuerURL]; found {
|
||||
return provider
|
||||
}
|
||||
|
||||
// Normalize issuer URL for consistent matching
|
||||
normalizedURL, err := url.Parse(issuerURL)
|
||||
if err != nil {
|
||||
// Log error or handle it appropriately
|
||||
return nil
|
||||
}
|
||||
host := normalizedURL.Host
|
||||
|
||||
// Iterate through registered providers to find a match
|
||||
for _, p := range r.providers {
|
||||
switch p.GetType() {
|
||||
case ProviderTypeGoogle:
|
||||
if strings.Contains(host, "accounts.google.com") {
|
||||
r.cache[issuerURL] = p
|
||||
return p
|
||||
}
|
||||
case ProviderTypeAzure:
|
||||
if strings.Contains(host, "login.microsoftonline.com") || strings.Contains(host, "sts.windows.net") {
|
||||
r.cache[issuerURL] = p
|
||||
return p
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the generic provider if no specific provider is detected
|
||||
for _, p := range r.providers {
|
||||
if p.GetType() == ProviderTypeGeneric {
|
||||
r.cache[issuerURL] = p
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ConfigValidator provides common configuration validation utilities for providers.
|
||||
type ConfigValidator struct{}
|
||||
|
||||
// NewConfigValidator creates a new configuration validator.
|
||||
func NewConfigValidator() *ConfigValidator {
|
||||
return &ConfigValidator{}
|
||||
}
|
||||
|
||||
// ValidateIssuerURL validates that an issuer URL is properly formatted and accessible.
|
||||
func (v *ConfigValidator) ValidateIssuerURL(issuerURL string) error {
|
||||
if issuerURL == "" {
|
||||
return fmt.Errorf("issuer URL cannot be empty")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(issuerURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid issuer URL format: %w", err)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme == "" {
|
||||
return fmt.Errorf("issuer URL must include scheme (http/https)")
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("issuer URL scheme must be http or https")
|
||||
}
|
||||
|
||||
if parsedURL.Host == "" {
|
||||
return fmt.Errorf("issuer URL must include host")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateClientID validates that a client ID is properly formatted.
|
||||
func (v *ConfigValidator) ValidateClientID(clientID string) error {
|
||||
if clientID == "" {
|
||||
return fmt.Errorf("client ID cannot be empty")
|
||||
}
|
||||
|
||||
if len(clientID) < 3 {
|
||||
return fmt.Errorf("client ID appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateScopes validates that the provided scopes are reasonable.
|
||||
func (v *ConfigValidator) ValidateScopes(scopes []string) error {
|
||||
if len(scopes) == 0 {
|
||||
return fmt.Errorf("at least one scope must be provided")
|
||||
}
|
||||
|
||||
// Check for required OIDC scope
|
||||
hasOpenIDScope := false
|
||||
for _, scope := range scopes {
|
||||
if strings.TrimSpace(scope) == "openid" {
|
||||
hasOpenIDScope = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasOpenIDScope {
|
||||
return fmt.Errorf("'openid' scope is required for OIDC authentication")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRedirectURL validates that a redirect URL is properly formatted.
|
||||
func (v *ConfigValidator) ValidateRedirectURL(redirectURL string) error {
|
||||
if redirectURL == "" {
|
||||
return fmt.Errorf("redirect URL cannot be empty")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid redirect URL format: %w", err)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme == "" {
|
||||
return fmt.Errorf("redirect URL must include scheme (http/https)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateProviderSpecificConfig performs provider-specific validation.
|
||||
func (v *ConfigValidator) ValidateProviderSpecificConfig(provider OIDCProvider, config map[string]interface{}) error {
|
||||
switch provider.GetType() {
|
||||
case ProviderTypeGoogle:
|
||||
return v.validateGoogleConfig(config)
|
||||
case ProviderTypeAzure:
|
||||
return v.validateAzureConfig(config)
|
||||
case ProviderTypeGeneric:
|
||||
return v.validateGenericConfig(config)
|
||||
default:
|
||||
return fmt.Errorf("unknown provider type: %d", provider.GetType())
|
||||
}
|
||||
}
|
||||
|
||||
// validateGoogleConfig validates Google-specific configuration.
|
||||
func (v *ConfigValidator) validateGoogleConfig(config map[string]interface{}) error {
|
||||
// Google-specific validation logic
|
||||
if issuerURL, ok := config["issuer_url"].(string); ok {
|
||||
if !strings.Contains(issuerURL, "accounts.google.com") {
|
||||
return fmt.Errorf("google provider requires issuer URL to contain accounts.google.com")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateAzureConfig validates Azure-specific configuration.
|
||||
func (v *ConfigValidator) validateAzureConfig(config map[string]interface{}) error {
|
||||
// Azure-specific validation logic
|
||||
if issuerURL, ok := config["issuer_url"].(string); ok {
|
||||
if !strings.Contains(issuerURL, "login.microsoftonline.com") && !strings.Contains(issuerURL, "sts.windows.net") {
|
||||
return fmt.Errorf("azure provider requires issuer URL to contain login.microsoftonline.com or sts.windows.net")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for tenant ID in the URL
|
||||
if issuerURL, ok := config["issuer_url"].(string); ok {
|
||||
parsedURL, err := url.Parse(issuerURL)
|
||||
if err == nil {
|
||||
pathParts := strings.Split(parsedURL.Path, "/")
|
||||
hasTenantID := false
|
||||
for _, part := range pathParts {
|
||||
// Simple check for GUID-like structure (tenant ID)
|
||||
if len(part) == 36 && strings.Count(part, "-") == 4 {
|
||||
hasTenantID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasTenantID {
|
||||
return fmt.Errorf("azure issuer URL should include tenant ID")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateGenericConfig validates generic OIDC provider configuration.
|
||||
func (v *ConfigValidator) validateGenericConfig(config map[string]interface{}) error {
|
||||
// Generic provider validation - basic checks only
|
||||
return nil
|
||||
}
|
||||
@@ -62,6 +62,9 @@ type JWKCacheInterface interface {
|
||||
// Returns:
|
||||
// - A pointer to the JWKSet containing the keys.
|
||||
// - An error if fetching fails or the response cannot be decoded.
|
||||
|
||||
// NewJWKCache creates a new JWK cache with default configuration.
|
||||
// It initializes a cache with a 1-hour lifetime and maximum size of 100 entries.
|
||||
func NewJWKCache() *JWKCache {
|
||||
cache := &JWKCache{
|
||||
CacheLifetime: 1 * time.Hour,
|
||||
|
||||
@@ -65,6 +65,7 @@ func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
|
||||
|
||||
replayCacheMu.RLock()
|
||||
if replayCache != nil {
|
||||
replayCache.Cleanup()
|
||||
}
|
||||
replayCacheMu.RUnlock()
|
||||
|
||||
@@ -87,8 +88,8 @@ var ClockSkewTolerance = ClockSkewToleranceFuture
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
type JWT struct {
|
||||
Header map[string]any
|
||||
Claims map[string]any
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
Token string
|
||||
Signature []byte
|
||||
}
|
||||
@@ -111,14 +112,29 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// ENHANCED: Use memory pool for JWT parsing buffers
|
||||
pools := GetGlobalMemoryPools()
|
||||
jwtBuf := pools.GetJWTParsingBuffer()
|
||||
defer pools.PutJWTParsingBuffer(jwtBuf)
|
||||
|
||||
jwt := &JWT{
|
||||
Token: tokenString,
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
// Decode header using pooled buffer
|
||||
headerLen := base64.RawURLEncoding.DecodedLen(len(parts[0]))
|
||||
if headerLen > cap(jwtBuf.HeaderBuf) {
|
||||
jwtBuf.HeaderBuf = make([]byte, headerLen)
|
||||
} else {
|
||||
jwtBuf.HeaderBuf = jwtBuf.HeaderBuf[:headerLen]
|
||||
}
|
||||
|
||||
n, err := base64.RawURLEncoding.Decode(jwtBuf.HeaderBuf, []byte(parts[0]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
}
|
||||
headerBytes := jwtBuf.HeaderBuf[:n]
|
||||
|
||||
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
@@ -127,10 +143,19 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: header is nil after unmarshaling")
|
||||
}
|
||||
|
||||
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
// Decode claims using pooled buffer
|
||||
claimsLen := base64.RawURLEncoding.DecodedLen(len(parts[1]))
|
||||
if claimsLen > cap(jwtBuf.PayloadBuf) {
|
||||
jwtBuf.PayloadBuf = make([]byte, claimsLen)
|
||||
} else {
|
||||
jwtBuf.PayloadBuf = jwtBuf.PayloadBuf[:claimsLen]
|
||||
}
|
||||
|
||||
n, err = base64.RawURLEncoding.Decode(jwtBuf.PayloadBuf, []byte(parts[1]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
}
|
||||
claimsBytes := jwtBuf.PayloadBuf[:n]
|
||||
|
||||
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
@@ -140,11 +165,22 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: claims is nil after unmarshaling")
|
||||
}
|
||||
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
// Decode signature using pooled buffer
|
||||
sigLen := base64.RawURLEncoding.DecodedLen(len(parts[2]))
|
||||
if sigLen > cap(jwtBuf.SignatureBuf) {
|
||||
jwtBuf.SignatureBuf = make([]byte, sigLen)
|
||||
} else {
|
||||
jwtBuf.SignatureBuf = jwtBuf.SignatureBuf[:sigLen]
|
||||
}
|
||||
|
||||
n, err = base64.RawURLEncoding.Decode(jwtBuf.SignatureBuf, []byte(parts[2]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
}
|
||||
jwt.Signature = signatureBytes
|
||||
|
||||
// Copy signature to JWT struct (create new slice to avoid pool retention)
|
||||
jwt.Signature = make([]byte, n)
|
||||
copy(jwt.Signature, jwtBuf.SignatureBuf[:n])
|
||||
|
||||
return jwt, nil
|
||||
}
|
||||
@@ -276,13 +312,13 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
|
||||
// Returns:
|
||||
// - nil if the expected audience is found.
|
||||
// - An error if the claim type is invalid or the expected audience is not present.
|
||||
func verifyAudience(tokenAudience any, expectedAudience string) error {
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
case []any:
|
||||
case []interface{}:
|
||||
found := false
|
||||
for _, v := range aud {
|
||||
if str, ok := v.(string); ok && str == expectedAudience {
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// It supports multiple OIDC providers including Google, Azure AD, and generic OIDC providers
|
||||
// with features like token refresh, session management, and provider-specific optimizations.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
@@ -6,15 +9,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
@@ -36,28 +36,32 @@ func createDefaultHTTPClient() *http.Client {
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
Timeout: 10 * time.Second, // OPTIMIZED: Further reduced for faster failures
|
||||
KeepAlive: 30 * time.Second, // OPTIMIZED: Increased for better connection reuse
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
TLSHandshakeTimeout: 3 * time.Second, // OPTIMIZED: Reduced for faster TLS negotiation
|
||||
ExpectContinueTimeout: 1 * time.Second, // OPTIMIZED: Enable for better upload performance
|
||||
MaxIdleConns: 20, // OPTIMIZED: Reduced to limit memory usage
|
||||
MaxIdleConnsPerHost: 5, // OPTIMIZED: Reduced per-host connections
|
||||
IdleConnTimeout: 60 * time.Second, // OPTIMIZED: Increased for better reuse
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
MaxConnsPerHost: 20, // OPTIMIZED: Reduced to limit memory usage
|
||||
ResponseHeaderTimeout: 5 * time.Second, // OPTIMIZED: Added response header timeout
|
||||
DisableCompression: false, // Enable compression for bandwidth efficiency
|
||||
WriteBufferSize: 4096, // OPTIMIZED: Set optimal buffer size
|
||||
ReadBufferSize: 4096, // OPTIMIZED: Set optimal buffer size
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Timeout: time.Second * 15, // Reduced timeout
|
||||
Timeout: time.Second * 10, // OPTIMIZED: Reduced timeout for faster failures
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
// OPTIMIZED: Reduced redirect limit to prevent abuse
|
||||
if len(via) >= 10 {
|
||||
return fmt.Errorf("stopped after 10 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -209,7 +213,7 @@ type TraefikOidc struct {
|
||||
sessionManager *SessionManager
|
||||
tokenCleanupStopChan chan struct{}
|
||||
excludedURLs map[string]struct{}
|
||||
extractClaimsFunc func(tokenString string) (map[string]any, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
|
||||
metadataCache *MetadataCache
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
@@ -221,6 +225,7 @@ type TraefikOidc struct {
|
||||
logger *Logger
|
||||
metadataRefreshStopChan chan struct{}
|
||||
cancelFunc context.CancelFunc
|
||||
errorRecoveryManager *ErrorRecoveryManager
|
||||
clientSecret string
|
||||
clientID string
|
||||
name string
|
||||
@@ -335,12 +340,8 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
|
||||
}
|
||||
|
||||
// DIAGNOSTIC: Determine token type for debugging
|
||||
// Determine token type for debugging
|
||||
tokenType := "UNKNOWN"
|
||||
tokenPrefix := token
|
||||
if len(token) > 20 {
|
||||
tokenPrefix = token[:20] + "..."
|
||||
}
|
||||
if aud, ok := parsedJWT.Claims["aud"]; ok {
|
||||
if audStr, ok := aud.(string); ok && audStr == t.clientID {
|
||||
tokenType = "ID_TOKEN"
|
||||
@@ -352,9 +353,7 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.logger.Debugf("DIAGNOSTIC: Verifying %s token (prefix: %s)", tokenType, tokenPrefix)
|
||||
}
|
||||
// Removed verbose diagnostic logging on every token verification
|
||||
|
||||
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
|
||||
if !strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
|
||||
@@ -366,9 +365,7 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
|
||||
// Check cache for efficiency AFTER blacklist checks
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.logger.Debugf("DIAGNOSTIC: %s token found in cache with valid claims; skipping signature verification", tokenType)
|
||||
}
|
||||
// Token found in cache, skip signature verification
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -377,17 +374,16 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.logger.Debugf("DIAGNOSTIC: %s token NOT in cache, performing full verification", tokenType)
|
||||
}
|
||||
// Token not in cache, perform full verification
|
||||
|
||||
// Use the already parsed JWT to avoid parsing twice
|
||||
jwt := parsedJWT
|
||||
|
||||
// Verify JWT signature and standard claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.logger.Errorf("DIAGNOSTIC: %s token verification failed: %v", tokenType, err)
|
||||
// Only log actual security-relevant verification failures
|
||||
if !strings.Contains(err.Error(), "token has expired") {
|
||||
t.logger.Errorf("%s token verification failed: %v", tokenType, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -441,7 +437,7 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
// Parameters:
|
||||
// - token: The raw token string (used as the cache key).
|
||||
// - claims: The map of claims extracted from the verified token.
|
||||
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]any) {
|
||||
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
|
||||
expClaim, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
t.logger.Errorf("Failed to cache token: invalid 'exp' claim type")
|
||||
@@ -616,6 +612,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
// Initialize logger
|
||||
logger := NewLogger(config.LogLevel)
|
||||
// Log the scopes received from Traefik to help diagnose duplication issues
|
||||
// Ensure key meets minimum length requirement
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
@@ -662,10 +659,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
enablePKCE: config.EnablePKCE,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
scopes: func() []string {
|
||||
// Deduplicate user-provided scopes from the configuration.
|
||||
userProvidedScopes := deduplicateScopes(config.Scopes)
|
||||
|
||||
if config.OverrideScopes {
|
||||
return append([]string(nil), config.Scopes...)
|
||||
// When overriding, only the explicitly user-provided scopes are used.
|
||||
// Default scopes like "openid", "profile", "email" are NOT added.
|
||||
return userProvidedScopes
|
||||
}
|
||||
return mergeScopes([]string{"openid", "profile", "email"}, config.Scopes)
|
||||
|
||||
// When not overriding (overrideScopes is false), merge user-provided scopes
|
||||
// with the system's default scopes.
|
||||
defaultSystemScopes := []string{"openid", "profile", "email"}
|
||||
return deduplicateScopes(mergeScopes(defaultSystemScopes, userProvidedScopes))
|
||||
}(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: cacheManager.GetSharedTokenCache(),
|
||||
@@ -691,6 +697,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
|
||||
t.extractClaimsFunc = extractClaims
|
||||
// t.exchangeCodeForTokenFunc = t.exchangeCodeForToken // Removed, using interface now
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
@@ -698,7 +705,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
|
||||
// Add default excluded URLs
|
||||
maps.Copy(t.excludedURLs, defaultExcludedURLs)
|
||||
for k, v := range defaultExcludedURLs {
|
||||
t.excludedURLs[k] = v
|
||||
}
|
||||
|
||||
t.tokenVerifier = t
|
||||
t.jwtVerifier = t
|
||||
@@ -723,7 +732,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
|
||||
startReplayCacheCleanup(pluginCtx, logger)
|
||||
|
||||
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
|
||||
go t.initializeMetadata(config.ProviderURL)
|
||||
|
||||
return t, nil
|
||||
@@ -741,8 +750,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Debug("Starting provider metadata discovery")
|
||||
|
||||
// Get metadata from cache or fetch it
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
// Get metadata from cache or fetch it with error recovery if available
|
||||
var metadata *ProviderMetadata
|
||||
var err error
|
||||
if t.errorRecoveryManager != nil {
|
||||
metadata, err = t.metadataCache.GetMetadataWithRecovery(providerURL, t.httpClient, t.logger, t.errorRecoveryManager)
|
||||
} else {
|
||||
// Fallback for test scenarios without error recovery manager
|
||||
metadata, err = t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
}
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to get provider metadata: %v", err)
|
||||
// Consider retrying or handling this more gracefully
|
||||
@@ -799,7 +815,14 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.logger.Debug("Refreshing OIDC metadata")
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
var metadata *ProviderMetadata
|
||||
var err error
|
||||
if t.errorRecoveryManager != nil {
|
||||
metadata, err = t.metadataCache.GetMetadataWithRecovery(providerURL, t.httpClient, t.logger, t.errorRecoveryManager)
|
||||
} else {
|
||||
// Fallback for test scenarios without error recovery manager
|
||||
metadata, err = t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
}
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to refresh metadata: %v", err)
|
||||
continue
|
||||
@@ -839,45 +862,45 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
|
||||
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
// Use shorter delays for tests to prevent timeouts
|
||||
maxRetries := 4 // Increased to 4 to allow for recovery after 3 failures
|
||||
baseDelay := 10 * time.Millisecond
|
||||
maxDelay := 100 * time.Millisecond
|
||||
totalTimeout := 5 * time.Second
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var lastErr error
|
||||
for attempt := range maxRetries {
|
||||
if time.Since(start) > totalTimeout {
|
||||
l.Errorf("Timeout exceeded while fetching provider metadata")
|
||||
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
|
||||
}
|
||||
|
||||
metadata, err := fetchMetadata(wellKnownURL, httpClient)
|
||||
if err == nil {
|
||||
l.Debug("Provider metadata fetched successfully")
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Don't sleep after the last attempt
|
||||
if attempt < maxRetries-1 {
|
||||
// Exponential backoff
|
||||
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
l.Debugf("Failed to fetch provider metadata (attempt %d/%d), retrying in %s. Error: %v", attempt+1, maxRetries, delay, err)
|
||||
time.Sleep(delay)
|
||||
} else {
|
||||
l.Debugf("Failed to fetch provider metadata (attempt %d/%d). Error: %v", attempt+1, maxRetries, err)
|
||||
}
|
||||
// Create retry executor with configuration optimized for test and production environments
|
||||
retryConfig := RetryConfig{
|
||||
MaxAttempts: 4,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true,
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"no route to host",
|
||||
"connection reset",
|
||||
"status code 500",
|
||||
"status code 502",
|
||||
"status code 503",
|
||||
"status code 504",
|
||||
},
|
||||
}
|
||||
|
||||
l.Errorf("Max retries exceeded while fetching provider metadata")
|
||||
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
|
||||
retryExecutor := NewRetryExecutor(retryConfig, l)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var metadata *ProviderMetadata
|
||||
err := retryExecutor.ExecuteWithContext(ctx, func() error {
|
||||
var fetchErr error
|
||||
metadata, fetchErr = fetchMetadata(wellKnownURL, httpClient)
|
||||
return fetchErr
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
l.Errorf("Failed to fetch provider metadata after retries: %v", err)
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||
}
|
||||
|
||||
l.Debug("Provider metadata fetched successfully")
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// fetchMetadata performs a single attempt to fetch and decode the OIDC provider metadata
|
||||
@@ -1048,7 +1071,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if idToken != "" {
|
||||
jwt, err := parseJWT(idToken)
|
||||
if err == nil {
|
||||
// jwt.Claims is already map[string]any, no type assertion needed
|
||||
// jwt.Claims is already map[string]interface{}, no type assertion needed
|
||||
claims := jwt.Claims
|
||||
// STABILITY FIX: Safe type assertion with proper error handling
|
||||
if expClaim, ok := claims["exp"].(float64); ok {
|
||||
@@ -1092,7 +1115,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
// Refresh failed
|
||||
t.logger.Infof("Token refresh failed (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
|
||||
t.logger.Debug("Token refresh failed, requiring re-authentication")
|
||||
// Handle refresh failure (401 for API, re-auth for browser)
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "application/json") {
|
||||
@@ -1124,7 +1147,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Error("CRITICAL: No email found in session during final processing, initiating re-auth")
|
||||
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
||||
// This case should ideally not happen if checks are done correctly before calling this,
|
||||
// but as a safeguard, initiate re-authentication.
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -1205,9 +1228,9 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
} else {
|
||||
// Create template data context with available tokens and claims
|
||||
// Fields must be exported (uppercase) to be accessible in templates
|
||||
templateData := map[string]any{
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": session.GetAccessToken(),
|
||||
"IdToken": session.GetIDToken(),
|
||||
"IDToken": session.GetIDToken(),
|
||||
"RefreshToken": session.GetRefreshToken(),
|
||||
"Claims": claims,
|
||||
}
|
||||
@@ -1267,6 +1290,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
|
||||
// Process the request
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -1616,6 +1640,8 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
// Build and redirect to authentication URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
|
||||
|
||||
// ENHANCED: Record authorization request metrics
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -1680,25 +1706,47 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
|
||||
hasOfflineAccess := false
|
||||
|
||||
if slices.Contains(scopes, "offline_access") {
|
||||
hasOfflineAccess = true
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
t.logger.Debug("Azure AD provider detected, added offline_access scope for refresh tokens")
|
||||
// For Azure AD, add offline_access scope if not overriding or if overriding with no user scopes
|
||||
if !t.overrideScopes || (t.overrideScopes && len(t.scopes) == 0) {
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
t.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", t.overrideScopes, len(t.scopes))
|
||||
}
|
||||
} else {
|
||||
t.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(t.scopes))
|
||||
}
|
||||
} else {
|
||||
// For other providers, use the standard offline_access scope
|
||||
hasOfflineAccess := slices.Contains(scopes, "offline_access")
|
||||
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
// Only add offline_access if overrideScopes is false,
|
||||
// or if overrideScopes is true AND no scopes were provided by the user (edge case, effectively defaults)
|
||||
if !t.overrideScopes || (t.overrideScopes && len(t.scopes) == 0) {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
t.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", t.overrideScopes, len(t.scopes))
|
||||
}
|
||||
} else {
|
||||
t.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(t.scopes))
|
||||
}
|
||||
}
|
||||
|
||||
if len(scopes) > 0 {
|
||||
params.Set("scope", strings.Join(scopes, " "))
|
||||
finalScopeString := strings.Join(scopes, " ")
|
||||
params.Set("scope", finalScopeString)
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
||||
}
|
||||
|
||||
// Use buildURLWithParams which handles potential relative authURL from metadata
|
||||
@@ -1850,8 +1898,15 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
|
||||
t.goroutineWG.Add(1) // Track this goroutine
|
||||
go func() {
|
||||
defer t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
defer ticker.Stop() // Ensure ticker is always stopped
|
||||
defer func() {
|
||||
t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
ticker.Stop() // Ensure ticker is always stopped
|
||||
|
||||
// ENHANCED: Recover from panics and log them
|
||||
if r := recover(); r != nil {
|
||||
t.logger.Errorf("Token cleanup goroutine panic recovered: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -1871,10 +1926,16 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
// Based on New(), t.jwkCache = &JWKCache{}, which has a Cleanup method.
|
||||
t.jwkCache.Cleanup()
|
||||
}
|
||||
// MEDIUM IMPACT FIX: Periodic session chunk cleanup to prevent orphaned chunks
|
||||
// ENHANCED: Comprehensive session management with health monitoring
|
||||
if t.sessionManager != nil {
|
||||
t.sessionManager.PeriodicChunkCleanup()
|
||||
|
||||
// Periodic session health monitoring
|
||||
t.logger.Debug("Running session health monitoring")
|
||||
// Note: Session health monitoring is performed on individual sessions
|
||||
// during GetSession() and Save() operations to avoid overhead here
|
||||
}
|
||||
|
||||
case <-t.tokenCleanupStopChan:
|
||||
t.logger.Debug("Token cleanup goroutine stopped.")
|
||||
return
|
||||
@@ -1951,8 +2012,19 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json") // Prefer JSON response if available
|
||||
|
||||
// Send the request
|
||||
resp, err := t.httpClient.Do(req)
|
||||
// Send the request with circuit breaker protection if available
|
||||
var resp *http.Response
|
||||
if t.errorRecoveryManager != nil {
|
||||
serviceName := fmt.Sprintf("token-revocation-%s", t.issuerURL)
|
||||
err = t.errorRecoveryManager.ExecuteWithRecovery(context.Background(), serviceName, func() error {
|
||||
var reqErr error
|
||||
resp, reqErr = t.httpClient.Do(req)
|
||||
return reqErr
|
||||
})
|
||||
} else {
|
||||
// Fallback for test scenarios without error recovery manager
|
||||
resp, err = t.httpClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send token revocation request: %w", err)
|
||||
}
|
||||
@@ -1998,7 +2070,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
|
||||
initialRefreshToken := session.GetRefreshToken()
|
||||
if initialRefreshToken == "" {
|
||||
t.logger.Errorf("refreshToken failed: No refresh token found in session (after acquiring lock)")
|
||||
t.logger.Debug("No refresh token found in session")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2019,13 +2091,10 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
// Attempt to refresh the token
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
if err != nil {
|
||||
// Log detailed error information
|
||||
t.logger.Errorf("refreshToken failed: Error from token refresh operation: %v", err)
|
||||
|
||||
// Check for specific error patterns
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
||||
t.logger.Errorf("Refresh token appears to be expired or revoked: %v", err)
|
||||
t.logger.Debug("Refresh token expired or revoked: %v", err)
|
||||
// Don't keep trying with an invalid refresh token
|
||||
session.SetRefreshToken("")
|
||||
if err = session.Save(req, rw); err != nil {
|
||||
@@ -2035,6 +2104,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
t.logger.Errorf("Client credentials rejected: %v - check client_id and client_secret configuration", err)
|
||||
} else if t.isGoogleProvider() && strings.Contains(errMsg, "invalid_request") {
|
||||
t.logger.Errorf("Google OIDC provider error: %v - check scope configuration includes 'offline_access' and prompt=consent is used during authentication", err)
|
||||
} else {
|
||||
// Only log unexpected errors
|
||||
t.logger.Errorf("Token refresh failed: %v", err)
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -2042,17 +2114,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
|
||||
// Handle potentially missing tokens in the response
|
||||
if newToken.IDToken == "" {
|
||||
t.logger.Errorf("refreshToken failed: Provider did not return a new ID token")
|
||||
t.logger.Info("Provider did not return a new ID token during refresh")
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the new ID token
|
||||
if err = t.verifyToken(newToken.IDToken); err != nil {
|
||||
truncatedToken := newToken.IDToken
|
||||
if len(newToken.IDToken) > 10 {
|
||||
truncatedToken = newToken.IDToken[:10]
|
||||
}
|
||||
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedToken, err)
|
||||
t.logger.Debug("Failed to verify newly obtained ID token: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2214,7 +2282,7 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
|
||||
// Extract groups with type checking
|
||||
if groupsClaim, exists := claims["groups"]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]any)
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
// Strictly expect an array
|
||||
return nil, nil, fmt.Errorf("groups claim is not an array")
|
||||
@@ -2232,7 +2300,7 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
|
||||
// Extract roles with type checking
|
||||
if rolesClaim, exists := claims["roles"]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]any)
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
// Strictly expect an array
|
||||
return nil, nil, fmt.Errorf("roles claim is not an array")
|
||||
@@ -2316,7 +2384,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(code)
|
||||
// Use a simple error structure - ensure this matches the expected response format in tests
|
||||
json.NewEncoder(rw).Encode(map[string]any{
|
||||
json.NewEncoder(rw).Encode(map[string]interface{}{
|
||||
"error": http.StatusText(code), // Use standard text for the code
|
||||
"error_description": message, // Provide specific detail here
|
||||
"status_code": code,
|
||||
@@ -2518,7 +2586,7 @@ func (t *TraefikOidc) validateTokenExpiry(session *SessionData, token string) (b
|
||||
// Get cached claims from verified token
|
||||
cachedClaims, found := t.tokenCache.Get(token)
|
||||
if !found {
|
||||
t.logger.Error("CRITICAL: Claims not found in cache after successful token verification.")
|
||||
t.logger.Debug("Claims not found in cache after successful token verification")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Claims missing post-verification, attempting refresh to recover.")
|
||||
return false, true, false
|
||||
|
||||
+259
-77
@@ -74,7 +74,7 @@ func (ts *TestSuite) Setup() {
|
||||
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
||||
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
||||
|
||||
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -205,8 +205,8 @@ func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) er
|
||||
}
|
||||
|
||||
// Helper function to create a JWT token
|
||||
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]any) (string, error) {
|
||||
header := map[string]any{
|
||||
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
||||
header := map[string]interface{}{
|
||||
"alg": alg,
|
||||
"kid": kid,
|
||||
"typ": "JWT",
|
||||
@@ -333,7 +333,7 @@ func TestVerifyToken(t *testing.T) {
|
||||
|
||||
if tc.cacheToken {
|
||||
// Use more realistic claims for cached token
|
||||
ts.tOidc.tokenCache.Set(tc.token, map[string]any{
|
||||
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"sub": "test-subject",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -376,7 +376,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago
|
||||
iat := time.Now().Add(-2 * time.Hour).Unix()
|
||||
nbf := time.Now().Add(-2 * time.Hour).Unix()
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -395,7 +395,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
exp := time.Now().Add(1 * time.Hour).Unix() // Valid for 1 hour
|
||||
iat := time.Now().Unix()
|
||||
nbf := time.Now().Unix()
|
||||
newToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
newToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -410,7 +410,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
sessionValues map[any]any
|
||||
sessionValues map[interface{}]interface{}
|
||||
setupSession func(*SessionData)
|
||||
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
|
||||
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager)
|
||||
@@ -481,7 +481,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
@@ -504,7 +504,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
@@ -561,7 +561,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
@@ -579,7 +579,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
@@ -607,7 +607,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
@@ -635,7 +635,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
exp := time.Now().Add(30 * time.Second).Unix()
|
||||
iat := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
@@ -666,7 +666,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
exp := time.Now().Add(10 * time.Minute).Unix()
|
||||
iat := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
||||
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
@@ -693,7 +693,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
@@ -715,7 +715,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
@@ -967,11 +967,11 @@ func TestJWTVerify_MissingClaims(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
jwt := &JWT{
|
||||
Header: map[string]any{
|
||||
Header: map[string]interface{}{
|
||||
"alg": "RS256",
|
||||
"kid": "test-key-id",
|
||||
},
|
||||
Claims: map[string]any{
|
||||
Claims: map[string]interface{}{
|
||||
// Missing 'iss', 'aud', 'exp', 'iat', 'sub'
|
||||
},
|
||||
}
|
||||
@@ -990,7 +990,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]any, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(*SessionData)
|
||||
name string
|
||||
queryParams string
|
||||
@@ -1005,8 +1005,8 @@ func TestHandleCallback(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
}, nil
|
||||
@@ -1060,7 +1060,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -1102,8 +1102,8 @@ func TestHandleCallback(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
}, nil
|
||||
@@ -1123,8 +1123,8 @@ func TestHandleCallback(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"nonce": "invalid-nonce",
|
||||
}, nil
|
||||
@@ -1144,8 +1144,8 @@ func TestHandleCallback(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
// Missing nonce
|
||||
}, nil
|
||||
@@ -1344,7 +1344,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]any, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(session *sessions.Session)
|
||||
name string
|
||||
queryParams string
|
||||
@@ -1368,9 +1368,9 @@ func TestOIDCHandler(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
// Simulate extraction of claims with invalid nonce
|
||||
return map[string]any{
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"nonce": "invalid-nonce",
|
||||
}, nil
|
||||
@@ -1392,9 +1392,9 @@ func TestOIDCHandler(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
// Simulate extraction of claims without nonce
|
||||
return map[string]any{
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
}, nil
|
||||
},
|
||||
@@ -1415,9 +1415,9 @@ func TestOIDCHandler(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
// Simulate extraction of claims
|
||||
return map[string]any{
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
}, nil
|
||||
@@ -1439,9 +1439,9 @@ func TestOIDCHandler(t *testing.T) {
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]any, error) {
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
// Simulate extraction of claims with mismatched nonce
|
||||
return map[string]any{
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"nonce": "invalid-nonce",
|
||||
}, nil
|
||||
@@ -1471,7 +1471,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
|
||||
if tc.cacheToken {
|
||||
// Cache the token with dummy claims
|
||||
ts.tOidc.tokenCache.Set(ts.token, map[string]any{
|
||||
ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{
|
||||
"empty": "claim",
|
||||
}, 60)
|
||||
}
|
||||
@@ -1732,7 +1732,7 @@ func TestRevokeToken(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
token := "test.token.with.claims"
|
||||
claims := map[string]any{
|
||||
claims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
}
|
||||
|
||||
@@ -1833,7 +1833,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
@@ -1848,7 +1848,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
@@ -1933,16 +1933,16 @@ func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims map[string]any
|
||||
claims map[string]interface{}
|
||||
expectGroups []string
|
||||
expectRoles []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid groups and roles",
|
||||
claims: map[string]any{
|
||||
"groups": []any{"group1", "group2"},
|
||||
"roles": []any{"role1", "role2"},
|
||||
claims: map[string]interface{}{
|
||||
"groups": []interface{}{"group1", "group2"},
|
||||
"roles": []interface{}{"role1", "role2"},
|
||||
},
|
||||
expectGroups: []string{"group1", "group2"},
|
||||
expectRoles: []string{"role1", "role2"},
|
||||
@@ -1950,9 +1950,9 @@ func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Empty groups and roles",
|
||||
claims: map[string]any{
|
||||
"groups": []any{},
|
||||
"roles": []any{},
|
||||
claims: map[string]interface{}{
|
||||
"groups": []interface{}{},
|
||||
"roles": []interface{}{},
|
||||
},
|
||||
expectGroups: []string{},
|
||||
expectRoles: []string{},
|
||||
@@ -1960,9 +1960,9 @@ func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Invalid groups format",
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"groups": "not-an-array",
|
||||
"roles": []any{"role1"},
|
||||
"roles": []interface{}{"role1"},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
@@ -2105,7 +2105,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
claims map[string]any
|
||||
claims map[string]interface{}
|
||||
setupSession func(*SessionData)
|
||||
expectedHeaders map[string]string
|
||||
name string
|
||||
@@ -2116,15 +2116,15 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []any{"admin", "user"},
|
||||
"groups": []any{"group1"},
|
||||
"roles": []interface{}{"admin", "user"},
|
||||
"groups": []interface{}{"group1"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
@@ -2142,15 +2142,15 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []any{"user"},
|
||||
"groups": []any{"allowed-group"},
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"allowed-group"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
@@ -2169,15 +2169,15 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
"admin": {},
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []any{"user"},
|
||||
"groups": []any{"regular-group"},
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
@@ -2189,15 +2189,15 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
{
|
||||
name: "No role/group restrictions",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []any{"user"},
|
||||
"groups": []any{"regular-group"},
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
@@ -2213,7 +2213,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
{
|
||||
name: "Claims without roles and groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -2861,7 +2861,7 @@ func TestJWTVerifyWithSkipReplayCheck(t *testing.T) {
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -2954,7 +2954,7 @@ func TestJWTVerifyBackwardCompatibility(t *testing.T) {
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3007,7 +3007,7 @@ func TestTokenReplayDetectionFalsePositiveFix(t *testing.T) {
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3080,7 +3080,7 @@ func TestAuthenticationFlowReplayDetection(t *testing.T) {
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3155,7 +3155,7 @@ func TestActualReplayAttackDetection(t *testing.T) {
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3241,7 +3241,7 @@ func TestConcurrentTokenValidation(t *testing.T) {
|
||||
jti := generateRandomString(16)
|
||||
jtis = append(jtis, jti)
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3322,7 +3322,7 @@ func TestJTIBlacklistBehavior(t *testing.T) {
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3422,7 +3422,7 @@ func TestSessionBasedTokenRevalidation(t *testing.T) {
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3495,7 +3495,7 @@ func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
|
||||
nbf := now.Unix()
|
||||
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
claims map[string]interface{}
|
||||
name string
|
||||
tokenType string
|
||||
expectError bool
|
||||
@@ -3503,7 +3503,7 @@ func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
|
||||
{
|
||||
name: "ID Token with JTI",
|
||||
tokenType: "id_token",
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3520,7 +3520,7 @@ func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
|
||||
{
|
||||
name: "Access Token with JTI",
|
||||
tokenType: "access_token",
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3536,7 +3536,7 @@ func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
|
||||
{
|
||||
name: "Token without JTI",
|
||||
tokenType: "no_jti",
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
@@ -3998,9 +3998,12 @@ func TestBuildAuthURLWithMergedScopes(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Configure the test instance with specific scopes
|
||||
tOidc := ts.tOidc
|
||||
tOidc.scopes = tc.scopes
|
||||
tOidc.scopes = tc.scopes // These scopes are already deduplicated by New()
|
||||
tOidc.authURL = "https://auth.example.com/oauth/authorize"
|
||||
tOidc.issuerURL = "https://auth.example.com"
|
||||
// Reset overrideScopes for each test case, as it's part of tOidc state
|
||||
// Default to false, specific tests will set it.
|
||||
tOidc.overrideScopes = false
|
||||
|
||||
// Build auth URL
|
||||
result := tOidc.buildAuthURL("https://app.example.com/callback", "test-state", "test-nonce", "")
|
||||
@@ -4019,3 +4022,182 @@ func TestBuildAuthURLWithMergedScopes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildAuthURL_OverrideScopes_And_OfflineAccess tests the offline_access logic in buildAuthURL
|
||||
// considering the overrideScopes flag.
|
||||
func TestBuildAuthURL_OverrideScopes_And_OfflineAccess(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup() // Sets up ts.tOidc
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialScopes []string // Scopes as they would be in tOidc.scopes (after New processing)
|
||||
overrideScopes bool
|
||||
isGoogle bool // To test Google-specific handling
|
||||
isAzure bool // To test Azure-specific handling
|
||||
expectedParams map[string]string
|
||||
expectedScope string // The final scope string expected in the URL
|
||||
}{
|
||||
{
|
||||
name: "Override false, no user scopes, non-Google/Azure",
|
||||
initialScopes: []string{"openid", "profile", "email"}, // Defaults from New() when config.Scopes is empty
|
||||
overrideScopes: false,
|
||||
expectedScope: "openid profile email offline_access",
|
||||
},
|
||||
{
|
||||
name: "Override false, user scopes without offline_access, non-Google/Azure",
|
||||
initialScopes: []string{"openid", "profile", "email", "custom1"}, // Merged and deduplicated by New()
|
||||
overrideScopes: false,
|
||||
expectedScope: "openid profile email custom1 offline_access",
|
||||
},
|
||||
{
|
||||
name: "Override false, user scopes with offline_access, non-Google/Azure",
|
||||
initialScopes: []string{"openid", "profile", "email", "offline_access", "custom1"},
|
||||
overrideScopes: false,
|
||||
expectedScope: "openid profile email offline_access custom1", // Order might vary based on merge, but offline_access present
|
||||
},
|
||||
{
|
||||
name: "Override true, user scopes without offline_access, non-Google/Azure",
|
||||
initialScopes: []string{"custom1", "custom2"}, // Directly from config.Scopes, deduplicated
|
||||
overrideScopes: true,
|
||||
expectedScope: "custom1 custom2", // offline_access NOT added
|
||||
},
|
||||
{
|
||||
name: "Override true, user scopes with offline_access, non-Google/Azure",
|
||||
initialScopes: []string{"custom1", "offline_access", "custom2"},
|
||||
overrideScopes: true,
|
||||
expectedScope: "custom1 offline_access custom2", // User explicitly included it
|
||||
},
|
||||
{
|
||||
name: "Override true, no user scopes (edge case), non-Google/Azure",
|
||||
initialScopes: []string{}, // config.Scopes was empty
|
||||
overrideScopes: true,
|
||||
// In this edge case, buildAuthURL's logic `(t.overrideScopes && len(t.scopes) == 0)`
|
||||
// will lead to offline_access being added, as it behaves like defaults.
|
||||
expectedScope: "offline_access",
|
||||
},
|
||||
// Google Provider Tests (access_type=offline, prompt=consent)
|
||||
{
|
||||
name: "Google, Override false, no user scopes",
|
||||
initialScopes: []string{"openid", "profile", "email"},
|
||||
overrideScopes: false,
|
||||
isGoogle: true,
|
||||
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
|
||||
expectedScope: "openid profile email", // No offline_access scope for Google
|
||||
},
|
||||
{
|
||||
name: "Google, Override true, user scopes",
|
||||
initialScopes: []string{"custom1", "custom2"},
|
||||
overrideScopes: true,
|
||||
isGoogle: true,
|
||||
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
|
||||
expectedScope: "custom1 custom2", // No offline_access scope for Google
|
||||
},
|
||||
// Azure Provider Tests (response_mode=query, offline_access scope added if not present by user)
|
||||
{
|
||||
name: "Azure, Override false, no user scopes",
|
||||
initialScopes: []string{"openid", "profile", "email"},
|
||||
overrideScopes: false,
|
||||
isAzure: true,
|
||||
expectedParams: map[string]string{"response_mode": "query"},
|
||||
expectedScope: "openid profile email offline_access",
|
||||
},
|
||||
{
|
||||
name: "Azure, Override true, user scopes without offline_access",
|
||||
initialScopes: []string{"custom1", "custom2"},
|
||||
overrideScopes: true,
|
||||
isAzure: true,
|
||||
expectedParams: map[string]string{"response_mode": "query"},
|
||||
expectedScope: "custom1 custom2", // offline_access NOT added by default when override is true
|
||||
},
|
||||
{
|
||||
name: "Azure, Override true, user scopes with offline_access",
|
||||
initialScopes: []string{"custom1", "offline_access"},
|
||||
overrideScopes: true,
|
||||
isAzure: true,
|
||||
expectedParams: map[string]string{"response_mode": "query"},
|
||||
expectedScope: "custom1 offline_access",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tOidc := ts.tOidc
|
||||
tOidc.scopes = tc.initialScopes // Set the scopes as if they came from New()
|
||||
tOidc.overrideScopes = tc.overrideScopes
|
||||
|
||||
// Adjust issuerURL for provider-specific tests
|
||||
originalIssuerURL := tOidc.issuerURL
|
||||
if tc.isGoogle {
|
||||
tOidc.issuerURL = "https://accounts.google.com"
|
||||
} else if tc.isAzure {
|
||||
tOidc.issuerURL = "https://login.microsoftonline.com/common"
|
||||
} else {
|
||||
tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure
|
||||
}
|
||||
|
||||
authURLString := tOidc.buildAuthURL("http://localhost/callback", "state123", "nonce123", "challenge123")
|
||||
parsedAuthURL, err := url.Parse(authURLString)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
query := parsedAuthURL.Query()
|
||||
|
||||
actualScope := query.Get("scope")
|
||||
if actualScope != tc.expectedScope {
|
||||
t.Errorf("Expected scope string %q, got %q", tc.expectedScope, actualScope)
|
||||
}
|
||||
|
||||
if tc.expectedParams != nil {
|
||||
for k, v := range tc.expectedParams {
|
||||
if query.Get(k) != v {
|
||||
t.Errorf("Expected param %s=%s, got %s", k, v, query.Get(k))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Restore original issuerURL for next test
|
||||
tOidc.issuerURL = originalIssuerURL
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildAuthURL_SpecificUserCase tests the buildAuthURL function with the specific user-reported scenario.
|
||||
func TestBuildAuthURL_SpecificUserCase(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup() // Basic setup for tOidc
|
||||
|
||||
// Configure the TraefikOidc instance for the specific scenario
|
||||
tOidc := ts.tOidc
|
||||
tOidc.scopes = []string{"email", "test3"} // This is what t.scopes should be after New()
|
||||
tOidc.overrideScopes = true
|
||||
tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure
|
||||
tOidc.authURL = "https://generic-provider.com/auth" // Dummy auth URL
|
||||
tOidc.clientID = "test-client-id"
|
||||
|
||||
// Expected scope string in the URL
|
||||
expectedScopeString := "email test3"
|
||||
|
||||
// Call buildAuthURL
|
||||
authURLString := tOidc.buildAuthURL("http://localhost/callback", "test-state", "test-nonce", "")
|
||||
|
||||
// Parse the resulting URL
|
||||
parsedAuthURL, err := url.Parse(authURLString)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse generated auth URL %q: %v", authURLString, err)
|
||||
}
|
||||
|
||||
// Get the 'scope' query parameter
|
||||
actualScopeString := parsedAuthURL.Query().Get("scope")
|
||||
|
||||
// Assert that the scope string is as expected
|
||||
if actualScopeString != expectedScopeString {
|
||||
t.Errorf("Expected scope parameter to be %q, but got %q. Full URL: %s",
|
||||
expectedScopeString, actualScopeString, authURLString)
|
||||
}
|
||||
|
||||
// Additionally, ensure 'offline_access' was not added
|
||||
if strings.Contains(actualScopeString, "offline_access") {
|
||||
t.Errorf("Scope parameter %q should not contain 'offline_access' when overrideScopes is true and it's not in tOidc.scopes", actualScopeString)
|
||||
}
|
||||
}
|
||||
|
||||
+217
@@ -0,0 +1,217 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryPoolManager manages various memory pools for high-frequency allocations
|
||||
type MemoryPoolManager struct {
|
||||
compressionBufferPool *sync.Pool
|
||||
jwtParsingPool *sync.Pool
|
||||
httpResponsePool *sync.Pool
|
||||
stringBuilderPool *sync.Pool
|
||||
}
|
||||
|
||||
// JWTParsingBuffer contains reusable buffers for JWT parsing operations
|
||||
type JWTParsingBuffer struct {
|
||||
HeaderBuf []byte
|
||||
PayloadBuf []byte
|
||||
SignatureBuf []byte
|
||||
}
|
||||
|
||||
// NewMemoryPoolManager creates and initializes all memory pools
|
||||
func NewMemoryPoolManager() *MemoryPoolManager {
|
||||
return &MemoryPoolManager{
|
||||
// Pool for compression/decompression buffers (4KB default)
|
||||
compressionBufferPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bytes.NewBuffer(make([]byte, 0, 4096))
|
||||
},
|
||||
},
|
||||
|
||||
// Pool for JWT parsing buffers
|
||||
jwtParsingPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &JWTParsingBuffer{
|
||||
HeaderBuf: make([]byte, 0, 512), // JWT headers are typically small
|
||||
PayloadBuf: make([]byte, 0, 2048), // Payloads can be larger
|
||||
SignatureBuf: make([]byte, 0, 512), // Signatures are fixed size
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
// Pool for HTTP response buffers (8KB default)
|
||||
httpResponsePool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 0, 8192)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
|
||||
// Pool for string builders
|
||||
stringBuilderPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
var sb strings.Builder
|
||||
sb.Grow(1024) // Pre-allocate 1KB
|
||||
return &sb
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetCompressionBuffer retrieves a buffer from the compression pool
|
||||
func (m *MemoryPoolManager) GetCompressionBuffer() *bytes.Buffer {
|
||||
return m.compressionBufferPool.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
// PutCompressionBuffer returns a buffer to the compression pool
|
||||
func (m *MemoryPoolManager) PutCompressionBuffer(buf *bytes.Buffer) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset buffer but keep capacity if reasonable size
|
||||
if buf.Cap() <= 16384 { // Don't pool buffers larger than 16KB
|
||||
buf.Reset()
|
||||
m.compressionBufferPool.Put(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// GetJWTParsingBuffer retrieves buffers for JWT parsing
|
||||
func (m *MemoryPoolManager) GetJWTParsingBuffer() *JWTParsingBuffer {
|
||||
return m.jwtParsingPool.Get().(*JWTParsingBuffer)
|
||||
}
|
||||
|
||||
// PutJWTParsingBuffer returns JWT parsing buffers to the pool
|
||||
func (m *MemoryPoolManager) PutJWTParsingBuffer(buf *JWTParsingBuffer) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset buffers but keep capacity if reasonable
|
||||
if cap(buf.HeaderBuf) <= 2048 && cap(buf.PayloadBuf) <= 8192 && cap(buf.SignatureBuf) <= 2048 {
|
||||
buf.HeaderBuf = buf.HeaderBuf[:0]
|
||||
buf.PayloadBuf = buf.PayloadBuf[:0]
|
||||
buf.SignatureBuf = buf.SignatureBuf[:0]
|
||||
m.jwtParsingPool.Put(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHTTPResponseBuffer retrieves a buffer for HTTP responses
|
||||
func (m *MemoryPoolManager) GetHTTPResponseBuffer() []byte {
|
||||
return *m.httpResponsePool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
// PutHTTPResponseBuffer returns an HTTP response buffer to the pool
|
||||
func (m *MemoryPoolManager) PutHTTPResponseBuffer(buf []byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Don't pool extremely large buffers
|
||||
if cap(buf) <= 32768 { // 32KB limit
|
||||
buf = buf[:0] // Reset length but keep capacity
|
||||
m.httpResponsePool.Put(&buf)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStringBuilder retrieves a string builder from the pool
|
||||
func (m *MemoryPoolManager) GetStringBuilder() *strings.Builder {
|
||||
return m.stringBuilderPool.Get().(*strings.Builder)
|
||||
}
|
||||
|
||||
// PutStringBuilder returns a string builder to the pool
|
||||
func (m *MemoryPoolManager) PutStringBuilder(sb *strings.Builder) {
|
||||
if sb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Don't pool extremely large builders
|
||||
if sb.Cap() <= 16384 { // 16KB limit
|
||||
sb.Reset()
|
||||
m.stringBuilderPool.Put(sb)
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCompressionPool manages memory pools for token compression operations
|
||||
type TokenCompressionPool struct {
|
||||
compressionBuffers sync.Pool
|
||||
decompressionBuffers sync.Pool
|
||||
stringBuilders sync.Pool
|
||||
}
|
||||
|
||||
// NewTokenCompressionPool creates a specialized pool for token operations
|
||||
func NewTokenCompressionPool() *TokenCompressionPool {
|
||||
return &TokenCompressionPool{
|
||||
compressionBuffers: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bytes.NewBuffer(make([]byte, 0, 4096))
|
||||
},
|
||||
},
|
||||
decompressionBuffers: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bytes.NewBuffer(make([]byte, 0, 8192))
|
||||
},
|
||||
},
|
||||
stringBuilders: sync.Pool{
|
||||
New: func() interface{} {
|
||||
var sb strings.Builder
|
||||
sb.Grow(2048) // Pre-allocate for token operations
|
||||
return &sb
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetCompressionBuffer gets a buffer for compression
|
||||
func (p *TokenCompressionPool) GetCompressionBuffer() *bytes.Buffer {
|
||||
return p.compressionBuffers.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
// PutCompressionBuffer returns a compression buffer
|
||||
func (p *TokenCompressionPool) PutCompressionBuffer(buf *bytes.Buffer) {
|
||||
if buf != nil && buf.Cap() <= 16384 {
|
||||
buf.Reset()
|
||||
p.compressionBuffers.Put(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// GetDecompressionBuffer gets a buffer for decompression
|
||||
func (p *TokenCompressionPool) GetDecompressionBuffer() *bytes.Buffer {
|
||||
return p.decompressionBuffers.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
// PutDecompressionBuffer returns a decompression buffer
|
||||
func (p *TokenCompressionPool) PutDecompressionBuffer(buf *bytes.Buffer) {
|
||||
if buf != nil && buf.Cap() <= 32768 {
|
||||
buf.Reset()
|
||||
p.decompressionBuffers.Put(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStringBuilder gets a string builder for token operations
|
||||
func (p *TokenCompressionPool) GetStringBuilder() *strings.Builder {
|
||||
return p.stringBuilders.Get().(*strings.Builder)
|
||||
}
|
||||
|
||||
// PutStringBuilder returns a string builder
|
||||
func (p *TokenCompressionPool) PutStringBuilder(sb *strings.Builder) {
|
||||
if sb != nil && sb.Cap() <= 16384 {
|
||||
sb.Reset()
|
||||
p.stringBuilders.Put(sb)
|
||||
}
|
||||
}
|
||||
|
||||
// Global memory pool manager instance
|
||||
var globalMemoryPools *MemoryPoolManager
|
||||
var memoryPoolOnce sync.Once
|
||||
|
||||
// GetGlobalMemoryPools returns the singleton memory pool manager
|
||||
func GetGlobalMemoryPools() *MemoryPoolManager {
|
||||
memoryPoolOnce.Do(func() {
|
||||
globalMemoryPools = NewMemoryPoolManager()
|
||||
})
|
||||
return globalMemoryPools
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -55,6 +56,90 @@ func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadataWithRecovery retrieves the OIDC provider metadata with comprehensive error recovery.
|
||||
// It uses circuit breaker protection and graceful degradation patterns.
|
||||
// Similar to GetMetadata but with enhanced error handling capabilities.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider.
|
||||
// - httpClient: The HTTP client to use for fetching metadata.
|
||||
// - logger: The logger instance for recording errors or warnings.
|
||||
// - errorRecoveryManager: The error recovery manager for circuit breaker and retry handling.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the ProviderMetadata struct.
|
||||
// - An error if metadata cannot be retrieved from cache or fetched from the provider.
|
||||
func (c *MetadataCache) GetMetadataWithRecovery(providerURL string, httpClient *http.Client, logger *Logger, errorRecoveryManager *ErrorRecoveryManager) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
if c.isCacheValid() {
|
||||
defer c.mutex.RUnlock()
|
||||
return c.metadata, nil
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if c.isCacheValid() {
|
||||
return c.metadata, nil
|
||||
}
|
||||
|
||||
// Use error recovery manager for fetching metadata with circuit breaker protection
|
||||
serviceName := fmt.Sprintf("metadata-provider-%s", providerURL)
|
||||
|
||||
// Register fallback function for graceful degradation
|
||||
errorRecoveryManager.gracefulDegradation.RegisterFallback(serviceName, func() (interface{}, error) {
|
||||
if c.metadata != nil {
|
||||
logger.Infof("Using cached metadata as fallback for service %s", serviceName)
|
||||
// Extend cache by 10 minutes when using fallback
|
||||
c.expiresAt = time.Now().Add(10 * time.Minute)
|
||||
return c.metadata, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no cached metadata available for fallback")
|
||||
})
|
||||
|
||||
// Register health check function
|
||||
errorRecoveryManager.gracefulDegradation.RegisterHealthCheck(serviceName, func() bool {
|
||||
// Simple health check by attempting a quick metadata fetch
|
||||
_, err := discoverProviderMetadata(providerURL, httpClient, logger)
|
||||
return err == nil
|
||||
})
|
||||
|
||||
// Execute metadata discovery with circuit breaker and retry protection
|
||||
ctx := context.Background()
|
||||
var metadata *ProviderMetadata
|
||||
err := errorRecoveryManager.ExecuteWithRecovery(ctx, serviceName, func() error {
|
||||
var fetchErr error
|
||||
metadata, fetchErr = discoverProviderMetadata(providerURL, httpClient, logger)
|
||||
return fetchErr
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Try graceful degradation fallback
|
||||
fallbackResult, fallbackErr := errorRecoveryManager.gracefulDegradation.ExecuteWithFallback(serviceName, func() (interface{}, error) {
|
||||
return discoverProviderMetadata(providerURL, httpClient, logger)
|
||||
})
|
||||
|
||||
if fallbackErr == nil {
|
||||
if fallbackMetadata, ok := fallbackResult.(*ProviderMetadata); ok {
|
||||
logger.Infof("Successfully used fallback metadata for service %s", serviceName)
|
||||
c.metadata = fallbackMetadata
|
||||
// Cache fallback result for 10 minutes
|
||||
c.expiresAt = time.Now().Add(10 * time.Minute)
|
||||
return fallbackMetadata, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata with error recovery and fallback: %w", err)
|
||||
}
|
||||
|
||||
c.metadata = metadata
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the OIDC provider metadata.
|
||||
// It first checks the cache for valid, non-expired metadata. If found, it's returned immediately.
|
||||
// If the cache is empty or expired, it attempts to fetch the metadata from the provider's
|
||||
|
||||
+21
-21
@@ -22,8 +22,8 @@ func TestConcurrentTokenVerification(t *testing.T) {
|
||||
|
||||
// Create multiple valid tokens to avoid replay detection
|
||||
tokens := make([]string, 10)
|
||||
for i := range 10 {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
for i := 0; i < 10; i++ {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -71,11 +71,11 @@ func TestConcurrentTokenVerification(t *testing.T) {
|
||||
var errorCount int64
|
||||
errors := make(chan error, numGoroutines*verificationsPerGoroutine)
|
||||
|
||||
for i := range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := range verificationsPerGoroutine {
|
||||
for j := 0; j < verificationsPerGoroutine; j++ {
|
||||
tokenIndex := (goroutineID*verificationsPerGoroutine + j) % len(tokens)
|
||||
err := tOidc.VerifyToken(tokens[tokenIndex])
|
||||
if err != nil {
|
||||
@@ -144,8 +144,8 @@ func TestCacheMemoryExhaustion(t *testing.T) {
|
||||
const numTokens = 500
|
||||
tokens := make([]string, numTokens)
|
||||
|
||||
for i := range numTokens {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
for i := 0; i < numTokens; i++ {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -161,7 +161,7 @@ func TestCacheMemoryExhaustion(t *testing.T) {
|
||||
tokens[i] = token
|
||||
|
||||
// Add to cache
|
||||
claims := map[string]any{
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -210,7 +210,7 @@ func TestSessionConcurrencyProtection(t *testing.T) {
|
||||
var successCount int64
|
||||
var errorCount int64
|
||||
|
||||
for i := range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
@@ -218,7 +218,7 @@ func TestSessionConcurrencyProtection(t *testing.T) {
|
||||
// Each goroutine gets its own request and session
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
for j := range operationsPerGoroutine {
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
// Get a fresh session for each operation
|
||||
s, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -276,11 +276,11 @@ func TestParallelCacheOperations(t *testing.T) {
|
||||
var deleteCount int64
|
||||
|
||||
// Start multiple goroutines performing cache operations
|
||||
for i := range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := range operationsPerGoroutine {
|
||||
for j := 0; j < operationsPerGoroutine; j++ {
|
||||
key := fmt.Sprintf("key-%d-%d", goroutineID, j)
|
||||
value := fmt.Sprintf("value-%d-%d", goroutineID, j)
|
||||
|
||||
@@ -377,7 +377,7 @@ func TestOversizedTokenHandling(t *testing.T) {
|
||||
|
||||
// Create an oversized token with large claims
|
||||
largeClaim := strings.Repeat("x", 10000) // 10KB claim
|
||||
oversizedClaims := map[string]any{
|
||||
oversizedClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -411,7 +411,7 @@ func TestOversizedTokenHandling(t *testing.T) {
|
||||
|
||||
// Test extremely long token (beyond reasonable limits)
|
||||
extremelyLongClaim := strings.Repeat("y", 100000) // 100KB claim
|
||||
extremeClaims := map[string]any{
|
||||
extremeClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -528,7 +528,7 @@ func TestMaliciousInputValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify the system is still functional after malicious input
|
||||
validToken, createErr := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validToken, createErr := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -609,7 +609,7 @@ func TestResourceLimits(t *testing.T) {
|
||||
defer cache.Close()
|
||||
|
||||
// Try to overwhelm the cache
|
||||
for i := range 1000 {
|
||||
for i := 0; i < 1000; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
value := fmt.Sprintf("value-%d", i)
|
||||
cache.Set(key, value, time.Minute)
|
||||
@@ -627,7 +627,7 @@ func TestResourceLimits(t *testing.T) {
|
||||
denied := 0
|
||||
|
||||
// Make many requests quickly
|
||||
for range 100 {
|
||||
for i := 0; i < 100; i++ {
|
||||
if limiter.Allow() {
|
||||
allowed++
|
||||
} else {
|
||||
@@ -655,7 +655,7 @@ func TestErrorRecoveryPatterns(t *testing.T) {
|
||||
ts.tOidc.tokenCache.cache.Set("corrupted", "invalid-data", time.Hour)
|
||||
|
||||
// System should handle corrupted cache gracefully
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -681,7 +681,7 @@ func TestErrorRecoveryPatterns(t *testing.T) {
|
||||
ts.tOidc.tokenBlacklist.Set("corrupted-entry", "invalid-data", time.Hour)
|
||||
|
||||
// System should still function
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -713,8 +713,8 @@ func TestPerformanceUnderLoad(t *testing.T) {
|
||||
// Create multiple valid tokens
|
||||
const numTokens = 100
|
||||
tokens := make([]string, numTokens)
|
||||
for i := range numTokens {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
for i := 0; i < numTokens; i++ {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -757,7 +757,7 @@ func TestPerformanceUnderLoad(t *testing.T) {
|
||||
const iterations = 1000
|
||||
start := time.Now()
|
||||
|
||||
for i := range iterations {
|
||||
for i := 0; i < iterations; i++ {
|
||||
tokenIndex := i % numTokens
|
||||
err := tOidc.VerifyToken(tokens[tokenIndex])
|
||||
if err != nil {
|
||||
|
||||
+300
-62
@@ -1,6 +1,7 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
@@ -54,71 +55,308 @@ func TestMergeScopes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeduplicateScopes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputScopes []string
|
||||
expectedScopes []string
|
||||
}{
|
||||
{
|
||||
name: "No duplicates",
|
||||
inputScopes: []string{"openid", "profile", "email"},
|
||||
expectedScopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Simple duplicates",
|
||||
inputScopes: []string{"openid", "profile", "openid", "email"},
|
||||
expectedScopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Multiple duplicates",
|
||||
inputScopes: []string{"scope1", "scope2", "scope1", "scope2", "scope1"},
|
||||
expectedScopes: []string{"scope1", "scope2"},
|
||||
},
|
||||
{
|
||||
name: "Empty input",
|
||||
inputScopes: []string{},
|
||||
expectedScopes: []string{},
|
||||
},
|
||||
{
|
||||
name: "Nil input",
|
||||
inputScopes: nil,
|
||||
expectedScopes: []string{},
|
||||
},
|
||||
{
|
||||
name: "Single element",
|
||||
inputScopes: []string{"openid"},
|
||||
expectedScopes: []string{"openid"},
|
||||
},
|
||||
{
|
||||
name: "All duplicates",
|
||||
inputScopes: []string{"test", "test", "test"},
|
||||
expectedScopes: []string{"test"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := deduplicateScopes(tc.inputScopes)
|
||||
if !reflect.DeepEqual(result, tc.expectedScopes) {
|
||||
t.Errorf("Expected %v, got %v", tc.expectedScopes, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopesConfiguration(t *testing.T) {
|
||||
defaultScopes := []string{"openid", "profile", "email"}
|
||||
userScopes := []string{"roles", "custom_scope"}
|
||||
|
||||
t.Run("Default Append Behavior", func(t *testing.T) {
|
||||
// Create config with user scopes but overrideScopes=false
|
||||
config := &Config{
|
||||
Scopes: userScopes,
|
||||
OverrideScopes: false,
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
configScopes []string // Scopes from Traefik config
|
||||
overrideScopes bool
|
||||
expectedResult []string
|
||||
}{
|
||||
{
|
||||
name: "Default Append Behavior - No user scopes",
|
||||
configScopes: []string{},
|
||||
overrideScopes: false,
|
||||
expectedResult: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Default Append Behavior - With user scopes",
|
||||
configScopes: []string{"roles", "custom_scope"},
|
||||
overrideScopes: false,
|
||||
expectedResult: []string{"openid", "profile", "email", "roles", "custom_scope"},
|
||||
},
|
||||
{
|
||||
name: "Default Append Behavior - With duplicate user scopes",
|
||||
configScopes: []string{"roles", "custom_scope", "roles"},
|
||||
overrideScopes: false,
|
||||
expectedResult: []string{"openid", "profile", "email", "roles", "custom_scope"},
|
||||
},
|
||||
{
|
||||
name: "Default Append Behavior - User scopes overlap with defaults",
|
||||
configScopes: []string{"openid", "roles", "profile"},
|
||||
overrideScopes: false,
|
||||
expectedResult: []string{"openid", "profile", "email", "roles"},
|
||||
},
|
||||
{
|
||||
name: "Override Behavior - With user scopes",
|
||||
configScopes: []string{"roles", "custom_scope"},
|
||||
overrideScopes: true,
|
||||
expectedResult: []string{"roles", "custom_scope"},
|
||||
},
|
||||
{
|
||||
name: "Override Behavior - With duplicate user scopes",
|
||||
configScopes: []string{"roles", "custom_scope", "roles"},
|
||||
overrideScopes: true,
|
||||
expectedResult: []string{"roles", "custom_scope"},
|
||||
},
|
||||
{
|
||||
name: "Override Behavior - Empty user scopes",
|
||||
configScopes: []string{},
|
||||
overrideScopes: true,
|
||||
expectedResult: []string{},
|
||||
},
|
||||
{
|
||||
name: "Override Behavior - Nil user scopes",
|
||||
configScopes: nil,
|
||||
overrideScopes: true,
|
||||
expectedResult: []string{}, // Deduplicate will handle nil as empty
|
||||
},
|
||||
{
|
||||
name: "Override Behavior - Single user scope",
|
||||
configScopes: []string{"email"},
|
||||
overrideScopes: true,
|
||||
expectedResult: []string{"email"},
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Simulate the logic within TraefikOidc.New for setting t.scopes
|
||||
var result []string
|
||||
uniqueConfigScopes := deduplicateScopes(tc.configScopes)
|
||||
if tc.overrideScopes {
|
||||
result = uniqueConfigScopes
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, uniqueConfigScopes)
|
||||
}
|
||||
|
||||
// Expect defaultScopes + userScopes with deduplication
|
||||
expectedScopes := []string{"openid", "profile", "email", "roles", "custom_scope"}
|
||||
if !reflect.DeepEqual(result, expectedScopes) {
|
||||
t.Errorf("Expected %v, got %v", expectedScopes, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Override Behavior", func(t *testing.T) {
|
||||
// Create config with user scopes and overrideScopes=true
|
||||
config := &Config{
|
||||
Scopes: userScopes,
|
||||
OverrideScopes: true,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect only userScopes
|
||||
if !reflect.DeepEqual(result, userScopes) {
|
||||
t.Errorf("Expected %v, got %v", userScopes, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty Scopes with Override", func(t *testing.T) {
|
||||
// Create config with empty scopes and overrideScopes=true
|
||||
config := &Config{
|
||||
Scopes: []string{},
|
||||
OverrideScopes: true,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect empty scopes - check length instead of DeepEqual
|
||||
if len(result) != 0 {
|
||||
t.Errorf("Expected empty slice, got %v with length %d", result, len(result))
|
||||
}
|
||||
})
|
||||
if !reflect.DeepEqual(result, tc.expectedResult) {
|
||||
t.Errorf("Expected scopes %v, got %v", tc.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthURLScopeHandling(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup() // Basic setup for TraefikOidc instance
|
||||
|
||||
// Default scopes expected if not overridden and no user scopes provided
|
||||
defaultInitialScopes := []string{"openid", "profile", "email"}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
configScopes []string // Scopes from Traefik config
|
||||
overrideScopes bool
|
||||
isGoogle bool
|
||||
isAzure bool
|
||||
expectedScopeString string // Expected final scope string in the auth URL
|
||||
expectedParams map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Deduplication: Default append, duplicate in user scopes",
|
||||
configScopes: []string{"openid", "custom", "profile", "custom"},
|
||||
overrideScopes: false,
|
||||
expectedScopeString: "openid profile email custom offline_access",
|
||||
},
|
||||
{
|
||||
name: "Deduplication: Override, duplicate in user scopes",
|
||||
configScopes: []string{"openid", "custom", "profile", "custom"},
|
||||
overrideScopes: true,
|
||||
expectedScopeString: "openid custom profile", // offline_access not added
|
||||
},
|
||||
{
|
||||
name: "Override True: No automatic offline_access",
|
||||
configScopes: []string{"scope1", "scope2"},
|
||||
overrideScopes: true,
|
||||
expectedScopeString: "scope1 scope2",
|
||||
},
|
||||
{
|
||||
name: "Override True: User includes offline_access",
|
||||
configScopes: []string{"scope1", "offline_access", "scope2"},
|
||||
overrideScopes: true,
|
||||
expectedScopeString: "scope1 offline_access scope2",
|
||||
},
|
||||
{
|
||||
name: "Override False: Automatic offline_access added",
|
||||
configScopes: []string{"scope1", "scope2"},
|
||||
overrideScopes: false,
|
||||
expectedScopeString: "openid profile email scope1 scope2 offline_access",
|
||||
},
|
||||
{
|
||||
name: "Override False: User includes offline_access (deduplicated)",
|
||||
configScopes: []string{"scope1", "offline_access", "scope2"},
|
||||
overrideScopes: false,
|
||||
expectedScopeString: "openid profile email scope1 offline_access scope2",
|
||||
},
|
||||
{
|
||||
name: "Integration: Duplicate scopes in config, override true",
|
||||
configScopes: []string{"scope1", "scope1", "scope2"},
|
||||
overrideScopes: true,
|
||||
expectedScopeString: "scope1 scope2",
|
||||
},
|
||||
{
|
||||
name: "Integration: No auto offline_access with override true",
|
||||
configScopes: []string{"scope1", "scope2"},
|
||||
overrideScopes: true,
|
||||
expectedScopeString: "scope1 scope2",
|
||||
},
|
||||
{
|
||||
name: "Integration: Duplicates and no auto offline_access with override true",
|
||||
configScopes: []string{"scope1", "scope1", "scope2"},
|
||||
overrideScopes: true,
|
||||
expectedScopeString: "scope1 scope2",
|
||||
},
|
||||
{
|
||||
name: "Integration: Google provider, override false, no user scopes",
|
||||
configScopes: []string{},
|
||||
overrideScopes: false,
|
||||
isGoogle: true,
|
||||
expectedScopeString: "openid profile email", // Google uses access_type=offline param
|
||||
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
|
||||
},
|
||||
{
|
||||
name: "Integration: Google provider, override true, user scopes",
|
||||
configScopes: []string{"custom1", "custom2"},
|
||||
overrideScopes: true,
|
||||
isGoogle: true,
|
||||
expectedScopeString: "custom1 custom2", // Google uses access_type=offline param
|
||||
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
|
||||
},
|
||||
{
|
||||
name: "Integration: Azure provider, override false, no user scopes",
|
||||
configScopes: []string{},
|
||||
overrideScopes: false,
|
||||
isAzure: true,
|
||||
expectedScopeString: "openid profile email offline_access", // Azure adds offline_access scope
|
||||
expectedParams: map[string]string{"response_mode": "query"},
|
||||
},
|
||||
{
|
||||
name: "Integration: Azure provider, override true, user scopes without offline_access",
|
||||
configScopes: []string{"custom1", "custom2"},
|
||||
overrideScopes: true,
|
||||
isAzure: true,
|
||||
expectedScopeString: "custom1 custom2", // Azure respects override
|
||||
expectedParams: map[string]string{"response_mode": "query"},
|
||||
},
|
||||
{
|
||||
name: "Integration: Azure provider, override true, user scopes with offline_access",
|
||||
configScopes: []string{"custom1", "offline_access"},
|
||||
overrideScopes: true,
|
||||
isAzure: true,
|
||||
expectedScopeString: "custom1 offline_access", // Azure respects override
|
||||
expectedParams: map[string]string{"response_mode": "query"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Simulate the TraefikOidc instance's scope initialization
|
||||
var initializedScopes []string
|
||||
uniqueConfigScopes := deduplicateScopes(tc.configScopes)
|
||||
if tc.overrideScopes {
|
||||
initializedScopes = uniqueConfigScopes
|
||||
} else {
|
||||
initializedScopes = mergeScopes(defaultInitialScopes, uniqueConfigScopes)
|
||||
}
|
||||
|
||||
// Create a new TraefikOidc instance for this test case
|
||||
// to ensure proper isolation of 'scopes' and 'overrideScopes' fields.
|
||||
// We use parts of the TestSuite's tOidc for common setup like logger, clientID etc.
|
||||
// but override the scope-related fields.
|
||||
testOidc := &TraefikOidc{
|
||||
clientID: ts.tOidc.clientID,
|
||||
logger: ts.tOidc.logger,
|
||||
scopes: initializedScopes, // Use scopes processed as New() would
|
||||
overrideScopes: tc.overrideScopes,
|
||||
// Set other necessary fields for buildAuthURL to function
|
||||
authURL: "https://provider.com/auth", // Dummy authURL
|
||||
issuerURL: "https://provider.com", // Dummy issuerURL
|
||||
httpClient: ts.tOidc.httpClient, // Reuse from TestSuite
|
||||
}
|
||||
|
||||
originalIssuerURL := testOidc.issuerURL
|
||||
if tc.isGoogle {
|
||||
testOidc.issuerURL = "https://accounts.google.com"
|
||||
} else if tc.isAzure {
|
||||
testOidc.issuerURL = "https://login.microsoftonline.com/common"
|
||||
}
|
||||
|
||||
authURLString := testOidc.buildAuthURL("http://localhost/callback", "state", "nonce", "challenge")
|
||||
parsedURL, err := url.Parse(authURLString)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
actualScopeString := query.Get("scope")
|
||||
|
||||
if actualScopeString != tc.expectedScopeString {
|
||||
t.Errorf("Expected scope string %q, got %q", tc.expectedScopeString, actualScopeString)
|
||||
}
|
||||
|
||||
if tc.expectedParams != nil {
|
||||
for k, v := range tc.expectedParams {
|
||||
if query.Get(k) != v {
|
||||
t.Errorf("Expected param %s=%s, got %s", k, v, query.Get(k))
|
||||
}
|
||||
}
|
||||
}
|
||||
testOidc.issuerURL = originalIssuerURL // Restore
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+33
-22
@@ -11,7 +11,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -26,7 +25,7 @@ func TestJWTAlgorithmConfusionAttack(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a standard JWT with RS256 algorithm
|
||||
validRS256JWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validRS256JWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -52,7 +51,7 @@ func TestJWTAlgorithmConfusionAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse header
|
||||
var header map[string]any
|
||||
var header map[string]interface{}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
t.Fatalf("Failed to unmarshal header: %v", err)
|
||||
}
|
||||
@@ -91,7 +90,7 @@ func TestJWTNoneAlgorithmAttack(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a standard JWT
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -117,7 +116,7 @@ func TestJWTNoneAlgorithmAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse header
|
||||
var header map[string]any
|
||||
var header map[string]interface{}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
t.Fatalf("Failed to unmarshal header: %v", err)
|
||||
}
|
||||
@@ -155,7 +154,7 @@ func TestJWTTokenTampering(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a standard JWT
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -181,7 +180,7 @@ func TestJWTTokenTampering(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse claims
|
||||
var claims map[string]any
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(claimsBytes, &claims); err != nil {
|
||||
t.Fatalf("Failed to unmarshal claims: %v", err)
|
||||
}
|
||||
@@ -220,7 +219,7 @@ func TestJWTExpiredToken(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT that is already expired
|
||||
expiredJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
expiredJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(-1 * time.Hour).Unix()), // Expired 1 hour ago
|
||||
@@ -253,7 +252,7 @@ func TestJWTFutureToken(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT with a future issuance time
|
||||
futureJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
futureJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(2 * time.Hour).Unix()),
|
||||
@@ -316,7 +315,7 @@ func TestJWTReplayAttack(t *testing.T) {
|
||||
fixedJTI := "fixed-test-jti-for-replay-" + generateRandomString(8)
|
||||
|
||||
// Create a JWT with the fixed JTI
|
||||
replayJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
replayJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -423,7 +422,7 @@ func TestMissingClaims(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create standard claims
|
||||
claims := map[string]any{
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -599,7 +598,13 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
// - The response is unauthorized (401), OR
|
||||
// - The token verification failed
|
||||
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized, http.StatusForbidden}
|
||||
codeFound := slices.Contains(expectedCodes, victimResp.Code)
|
||||
codeFound := false
|
||||
for _, code := range expectedCodes {
|
||||
if victimResp.Code == code {
|
||||
codeFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !codeFound {
|
||||
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, victimResp.Code)
|
||||
@@ -824,7 +829,7 @@ func TestTokenBlacklisting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a valid JWT
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -914,7 +919,7 @@ func TestDifferentSigningAlgorithms(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Define standard claims with unique JTI for each test
|
||||
standardClaims := map[string]any{
|
||||
standardClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1017,9 +1022,9 @@ func TestDifferentSigningAlgorithms(t *testing.T) {
|
||||
}
|
||||
|
||||
// createTestJWTWithECKey creates a JWT signed with an EC private key
|
||||
func createTestJWTWithECKey(privateKey *ecdsa.PrivateKey, alg, kid string, claims map[string]any) (string, error) {
|
||||
func createTestJWTWithECKey(privateKey *ecdsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
||||
// Create the header
|
||||
header := map[string]any{
|
||||
header := map[string]interface{}{
|
||||
"alg": alg,
|
||||
"typ": "JWT",
|
||||
"kid": kid,
|
||||
@@ -1276,7 +1281,7 @@ func TestRateLimiting(t *testing.T) {
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Create a valid JWT token
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1296,7 +1301,7 @@ func TestRateLimiting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Second request should succeed
|
||||
validJWT2, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validJWT2, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1315,7 +1320,7 @@ func TestRateLimiting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
validJWT3, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
validJWT3, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1399,7 +1404,13 @@ func TestAuthorizationHeaderBypass(t *testing.T) {
|
||||
|
||||
// Verify that the response is a redirect to authentication (302) or unauthorized (401)
|
||||
expectedCodes := []int{http.StatusFound, http.StatusUnauthorized}
|
||||
codeFound := slices.Contains(expectedCodes, resp.Code)
|
||||
codeFound := false
|
||||
for _, code := range expectedCodes {
|
||||
if resp.Code == code {
|
||||
codeFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !codeFound {
|
||||
t.Errorf("Expected status code to be one of %v, but got %d", expectedCodes, resp.Code)
|
||||
@@ -1412,7 +1423,7 @@ func TestEmptyAudience(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT with empty audience
|
||||
emptyAudJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
emptyAudJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "", // Empty audience
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
@@ -1445,7 +1456,7 @@ func TestEmptyIssuer(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create a JWT with empty issuer
|
||||
emptyIssJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
emptyIssJWT, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "", // Empty issuer
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
|
||||
+26
-21
@@ -55,14 +55,14 @@ func (t SecurityEventType) IPFailureType() string {
|
||||
|
||||
// SecurityEvent represents a security-related event that should be logged and monitored
|
||||
type SecurityEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SecurityMonitor tracks security events and suspicious activity patterns
|
||||
@@ -168,7 +168,7 @@ func (sm *SecurityMonitor) RecordSecurityEvent(
|
||||
eventType SecurityEventType,
|
||||
clientIP, userAgent, requestPath string,
|
||||
message string,
|
||||
details map[string]any,
|
||||
details map[string]interface{},
|
||||
trackIPFailure bool) {
|
||||
|
||||
// Create event with default values for the event type
|
||||
@@ -193,9 +193,9 @@ func (sm *SecurityMonitor) RecordSecurityEvent(
|
||||
}
|
||||
|
||||
// RecordAuthenticationFailure records an authentication failure event
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]any) {
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]any)
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["reason"] = reason
|
||||
|
||||
@@ -212,7 +212,7 @@ func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requ
|
||||
|
||||
// RecordTokenValidationFailure records a token validation failure
|
||||
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
|
||||
details := map[string]any{
|
||||
details := map[string]interface{}{
|
||||
"reason": reason,
|
||||
}
|
||||
if tokenPrefix != "" {
|
||||
@@ -232,7 +232,7 @@ func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, req
|
||||
|
||||
// RecordRateLimitHit records when rate limiting is triggered
|
||||
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
|
||||
details := map[string]any{
|
||||
details := map[string]interface{}{
|
||||
"limit_type": "token_verification",
|
||||
}
|
||||
|
||||
@@ -248,9 +248,9 @@ func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath s
|
||||
}
|
||||
|
||||
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]any) {
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]any)
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["activity_type"] = activityType
|
||||
|
||||
@@ -302,7 +302,7 @@ func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
|
||||
Details: map[string]any{
|
||||
Details: map[string]interface{}{
|
||||
"failure_count": tracker.FailureCount,
|
||||
"failure_types": tracker.FailureTypes,
|
||||
"blocked_until": tracker.BlockedUntil,
|
||||
@@ -347,15 +347,20 @@ func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
|
||||
|
||||
// Check for suspicious patterns
|
||||
if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 {
|
||||
for _, pattern := range patterns {
|
||||
sm.logger.Errorf("Suspicious pattern detected: %s", pattern)
|
||||
// Log once with all patterns instead of logging each pattern
|
||||
if len(patterns) == 1 {
|
||||
sm.logger.Errorf("Suspicious pattern detected: %s", patterns[0])
|
||||
} else {
|
||||
sm.logger.Errorf("Multiple suspicious patterns detected: %v", patterns)
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
patternEvent := SecurityEvent{
|
||||
Type: "suspicious_pattern",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
|
||||
Details: map[string]any{
|
||||
Details: map[string]interface{}{
|
||||
"pattern_type": pattern,
|
||||
"trigger_event": event,
|
||||
},
|
||||
@@ -389,8 +394,8 @@ func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
|
||||
|
||||
// GetSecurityMetrics returns minimal security metrics
|
||||
// This is kept for API compatibility but doesn't collect actual metrics
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]any {
|
||||
return map[string]any{
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"tracked_ips": 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -53,7 +52,7 @@ func TestSecurityMonitor(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Suspicious activity", func(t *testing.T) {
|
||||
details := map[string]any{"pattern": "unusual"}
|
||||
details := map[string]interface{}{"pattern": "unusual"}
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
|
||||
})
|
||||
@@ -64,7 +63,7 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
t.Run("Add events and detect patterns", func(t *testing.T) {
|
||||
// Add multiple events from same IP
|
||||
for range 10 {
|
||||
for i := 0; i < 10; i++ {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.100",
|
||||
@@ -75,7 +74,13 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := slices.Contains(patterns, "rapid_failures_from_ip_192.168.1.100")
|
||||
found := false
|
||||
for _, p := range patterns {
|
||||
if p == "rapid_failures_from_ip_192.168.1.100" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected to detect rapid failure pattern")
|
||||
}
|
||||
@@ -83,7 +88,7 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
t.Run("Detect distributed attack pattern", func(t *testing.T) {
|
||||
// Add failures from many different IPs
|
||||
for i := range 25 {
|
||||
for i := 0; i < 25; i++ {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1." + strconv.Itoa(100+i),
|
||||
@@ -94,7 +99,13 @@ func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := slices.Contains(patterns, "distributed_attack_pattern")
|
||||
found := false
|
||||
for _, p := range patterns {
|
||||
if p == "distributed_attack_pattern" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected to detect distributed attack pattern")
|
||||
}
|
||||
@@ -266,7 +277,7 @@ func TestSecurityEventTypes(t *testing.T) {
|
||||
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
|
||||
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
|
||||
|
||||
details := map[string]any{"pattern": "test"}
|
||||
details := map[string]interface{}{"pattern": "test"}
|
||||
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
|
||||
|
||||
// Just verify GetSecurityMetrics doesn't panic
|
||||
|
||||
+312
-519
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,844 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// TokenConfig holds validation rules for different token types
|
||||
type TokenConfig struct {
|
||||
Type string
|
||||
MinLength int
|
||||
MaxLength int
|
||||
MaxChunks int // Maximum number of chunks allowed
|
||||
MaxChunkSize int // Maximum size per chunk
|
||||
AllowOpaqueTokens bool
|
||||
RequireJWTFormat bool
|
||||
}
|
||||
|
||||
// Predefined configurations for each token type
|
||||
var (
|
||||
AccessTokenConfig = TokenConfig{
|
||||
Type: "access",
|
||||
MinLength: 5,
|
||||
MaxLength: 100 * 1024, // 100KB total limit
|
||||
MaxChunks: 25, // Maximum 25 chunks
|
||||
MaxChunkSize: maxCookieSize, // Use global chunk size limit
|
||||
AllowOpaqueTokens: true,
|
||||
RequireJWTFormat: false,
|
||||
}
|
||||
|
||||
RefreshTokenConfig = TokenConfig{
|
||||
Type: "refresh",
|
||||
MinLength: 5,
|
||||
MaxLength: 50 * 1024, // 50KB total limit (refresh tokens are typically smaller)
|
||||
MaxChunks: 15, // Maximum 15 chunks
|
||||
MaxChunkSize: maxCookieSize,
|
||||
AllowOpaqueTokens: true,
|
||||
RequireJWTFormat: false,
|
||||
}
|
||||
|
||||
IDTokenConfig = TokenConfig{
|
||||
Type: "id",
|
||||
MinLength: 5,
|
||||
MaxLength: 75 * 1024, // 75KB total limit
|
||||
MaxChunks: 20, // Maximum 20 chunks
|
||||
MaxChunkSize: maxCookieSize,
|
||||
AllowOpaqueTokens: false,
|
||||
RequireJWTFormat: true,
|
||||
}
|
||||
)
|
||||
|
||||
// TokenRetrievalResult encapsulates the result of token retrieval
|
||||
type TokenRetrievalResult struct {
|
||||
Token string
|
||||
Error error
|
||||
}
|
||||
|
||||
// ChunkManager handles token chunking operations
|
||||
type ChunkManager struct {
|
||||
logger *Logger
|
||||
mutex *sync.RWMutex
|
||||
}
|
||||
|
||||
// NewChunkManager creates a new ChunkManager instance
|
||||
func NewChunkManager(logger *Logger) *ChunkManager {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
return &ChunkManager{
|
||||
logger: logger,
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// GetToken retrieves and validates a token from either single storage or chunks
|
||||
func (cm *ChunkManager) GetToken(
|
||||
singleToken string,
|
||||
compressed bool,
|
||||
chunks map[int]*sessions.Session,
|
||||
config TokenConfig,
|
||||
) TokenRetrievalResult {
|
||||
cm.mutex.RLock()
|
||||
defer cm.mutex.RUnlock()
|
||||
|
||||
// Handle single-token storage
|
||||
if singleToken != "" {
|
||||
return cm.processSingleToken(singleToken, compressed, config)
|
||||
}
|
||||
|
||||
// Handle chunked storage
|
||||
if len(chunks) == 0 {
|
||||
return TokenRetrievalResult{Token: "", Error: nil}
|
||||
}
|
||||
|
||||
return cm.processChunkedToken(chunks, config)
|
||||
}
|
||||
|
||||
// processSingleToken handles tokens stored in a single cookie
|
||||
func (cm *ChunkManager) processSingleToken(token string, compressed bool, config TokenConfig) TokenRetrievalResult {
|
||||
// Detect corruption markers
|
||||
if isCorruptionMarker(token) {
|
||||
err := fmt.Errorf("%s token contains corruption marker", config.Type)
|
||||
// Only log if not a known test scenario
|
||||
if !strings.Contains(token, "TEST_CORRUPTION") {
|
||||
cm.logger.Debug("Token corruption detected for %s", config.Type)
|
||||
}
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
var finalToken string
|
||||
if compressed {
|
||||
decompressed := decompressToken(token)
|
||||
if isCorruptionMarker(decompressed) {
|
||||
err := fmt.Errorf("decompressed %s token contains corruption marker", config.Type)
|
||||
cm.logger.Debug("Decompressed token corruption detected for %s", config.Type)
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
finalToken = decompressed
|
||||
} else {
|
||||
finalToken = token
|
||||
}
|
||||
|
||||
return cm.validateToken(finalToken, config)
|
||||
}
|
||||
|
||||
// validateToken performs comprehensive token validation
|
||||
func (cm *ChunkManager) validateToken(token string, config TokenConfig) TokenRetrievalResult {
|
||||
// Enhanced size validation
|
||||
if sizeErr := cm.validateTokenSize(token, config); sizeErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: sizeErr}
|
||||
}
|
||||
|
||||
// Chunking efficiency validation (for pre-storage analysis)
|
||||
if chunkErr := cm.validateChunkingEfficiency(token, config); chunkErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: chunkErr}
|
||||
}
|
||||
|
||||
// Comprehensive content validation
|
||||
if contentErr := cm.validateTokenContent(token, config); contentErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: contentErr}
|
||||
}
|
||||
|
||||
// Token expiration validation
|
||||
if expErr := cm.validateTokenExpiration(token, config); expErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: expErr}
|
||||
}
|
||||
|
||||
// Token freshness validation
|
||||
if freshnessErr := cm.validateTokenFreshness(token, config); freshnessErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: freshnessErr}
|
||||
}
|
||||
|
||||
// Enhanced JWT format validation
|
||||
if config.RequireJWTFormat && !config.AllowOpaqueTokens {
|
||||
if validationErr := cm.validateJWTFormat(token, config.Type); validationErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: validationErr}
|
||||
}
|
||||
} else if config.RequireJWTFormat && config.AllowOpaqueTokens {
|
||||
// For tokens that can be either JWT or opaque, validate JWT format only if it has dots
|
||||
dotCount := strings.Count(token, ".")
|
||||
if dotCount > 0 {
|
||||
if validationErr := cm.validateJWTFormat(token, config.Type); validationErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: validationErr}
|
||||
}
|
||||
} else {
|
||||
// Validate as opaque token
|
||||
if validationErr := cm.validateOpaqueToken(token, config.Type); validationErr != nil {
|
||||
return TokenRetrievalResult{Token: "", Error: validationErr}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return TokenRetrievalResult{Token: token, Error: nil}
|
||||
}
|
||||
|
||||
// processChunkedToken handles tokens stored across multiple chunks
|
||||
func (cm *ChunkManager) processChunkedToken(chunks map[int]*sessions.Session, config TokenConfig) TokenRetrievalResult {
|
||||
// Enhanced chunk count validation using config limits
|
||||
if len(chunks) > config.MaxChunks {
|
||||
err := fmt.Errorf("too many %s token chunks (%d, max: %d)", config.Type, len(chunks), config.MaxChunks)
|
||||
cm.logger.Info("Token chunk count exceeded for %s: %d chunks", config.Type, len(chunks))
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
// Additional safety check for extremely large chunk counts
|
||||
if len(chunks) > 100 {
|
||||
err := fmt.Errorf("excessive %s token chunks (%d), potential security issue", config.Type, len(chunks))
|
||||
cm.logger.Error("Security: Excessive token chunks detected for %s: %d", config.Type, len(chunks))
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
// Sequential chunk validation and assembly
|
||||
var tokenParts []string
|
||||
totalSize := 0
|
||||
|
||||
for i := 0; i < len(chunks); i++ {
|
||||
session, ok := chunks[i]
|
||||
if !ok {
|
||||
err := fmt.Errorf("%s token chunk %d missing", config.Type, i)
|
||||
// Only log once for missing chunks, not for each missing chunk
|
||||
if i == 0 {
|
||||
cm.logger.Debug("Token chunks missing for %s starting at index %d", config.Type, i)
|
||||
}
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
chunk, chunkOk := session.Values["token_chunk"].(string)
|
||||
if !chunkOk || chunk == "" {
|
||||
err := fmt.Errorf("%s token chunk %d invalid", config.Type, i)
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
if isCorruptionMarker(chunk) {
|
||||
err := fmt.Errorf("%s token chunk %d corrupted", config.Type, i)
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
// Enhanced chunk size validation using config limits
|
||||
if len(chunk) > config.MaxChunkSize {
|
||||
err := fmt.Errorf("%s token chunk %d exceeds size limit (%d bytes, max: %d)",
|
||||
config.Type, i, len(chunk), config.MaxChunkSize)
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
// Additional safety check for extremely large chunks
|
||||
if len(chunk) > maxBrowserCookieSize {
|
||||
err := fmt.Errorf("%s token chunk %d exceeds browser limit (%d bytes)",
|
||||
config.Type, i, len(chunk))
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
totalSize += len(chunk)
|
||||
if totalSize > config.MaxLength {
|
||||
err := fmt.Errorf("%s token total size exceeds limit", config.Type)
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
tokenParts = append(tokenParts, chunk)
|
||||
}
|
||||
|
||||
// Reassemble token
|
||||
reassembledToken := strings.Join(tokenParts, "")
|
||||
|
||||
// Check compression flag from first chunk
|
||||
compressed, _ := chunks[0].Values["compressed"].(bool)
|
||||
|
||||
if compressed {
|
||||
decompressed := decompressToken(reassembledToken)
|
||||
if isCorruptionMarker(decompressed) {
|
||||
err := fmt.Errorf("decompressed chunked %s token corrupted", config.Type)
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
return cm.validateToken(decompressed, config)
|
||||
}
|
||||
|
||||
return cm.validateToken(reassembledToken, config)
|
||||
}
|
||||
|
||||
// validateJWTFormat performs enhanced JWT format validation
|
||||
func (cm *ChunkManager) validateJWTFormat(token string, tokenType string) error {
|
||||
// Check for exactly 2 dots
|
||||
dotCount := strings.Count(token, ".")
|
||||
if dotCount != 2 {
|
||||
err := fmt.Errorf("%s token invalid JWT format (dots: %d)", tokenType, dotCount)
|
||||
return err
|
||||
}
|
||||
|
||||
// Split into parts
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
err := fmt.Errorf("%s token invalid JWT structure", tokenType)
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate each part is non-empty and contains valid base64url characters
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
err := fmt.Errorf("%s token has empty JWT part %d", tokenType, i)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for valid base64url characters only (RFC 4648)
|
||||
// Valid characters: A-Z, a-z, 0-9, -, _, and = for padding
|
||||
for _, char := range part {
|
||||
if !((char >= 'A' && char <= 'Z') ||
|
||||
(char >= 'a' && char <= 'z') ||
|
||||
(char >= '0' && char <= '9') ||
|
||||
char == '-' || char == '_' || char == '=') {
|
||||
err := fmt.Errorf("%s token contains invalid base64url character in part %d", tokenType, i)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate base64url padding rules
|
||||
if strings.Contains(part, "=") {
|
||||
// Padding can only be at the end
|
||||
paddingIndex := strings.Index(part, "=")
|
||||
if paddingIndex != len(part)-1 && paddingIndex != len(part)-2 {
|
||||
err := fmt.Errorf("%s token has invalid base64url padding in part %d", tokenType, i)
|
||||
return err
|
||||
}
|
||||
// Check that after padding, no other characters exist
|
||||
for j := paddingIndex; j < len(part); j++ {
|
||||
if part[j] != '=' {
|
||||
err := fmt.Errorf("%s token has characters after padding in part %d", tokenType, i)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Additional length checks for JWT parts
|
||||
if len(parts[0]) < 10 { // Header too short
|
||||
err := fmt.Errorf("%s token header too short", tokenType)
|
||||
return err
|
||||
}
|
||||
if len(parts[1]) < 10 { // Payload too short
|
||||
err := fmt.Errorf("%s token payload too short", tokenType)
|
||||
return err
|
||||
}
|
||||
if len(parts[2]) < 10 { // Signature too short
|
||||
err := fmt.Errorf("%s token signature too short", tokenType)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpaqueToken performs validation for opaque (non-JWT) tokens
|
||||
func (cm *ChunkManager) validateOpaqueToken(token string, tokenType string) error {
|
||||
// Check for obviously invalid characters for opaque tokens
|
||||
if strings.Contains(token, " ") {
|
||||
err := fmt.Errorf("%s opaque token contains spaces", tokenType)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for control characters
|
||||
for _, char := range token {
|
||||
if char < 32 || char == 127 {
|
||||
err := fmt.Errorf("%s opaque token contains control characters", tokenType)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure minimum entropy for opaque tokens (basic check)
|
||||
if len(token) >= 20 {
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, char := range token {
|
||||
uniqueChars[char] = true
|
||||
}
|
||||
// Require at least 8 unique characters for reasonable entropy
|
||||
if len(uniqueChars) < 8 {
|
||||
err := fmt.Errorf("%s opaque token has insufficient entropy", tokenType)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTokenSize performs comprehensive token size validation
|
||||
func (cm *ChunkManager) validateTokenSize(token string, config TokenConfig) error {
|
||||
tokenLen := len(token)
|
||||
|
||||
// Basic length validation
|
||||
if tokenLen < config.MinLength {
|
||||
err := fmt.Errorf("%s token below minimum length (%d bytes, min: %d)",
|
||||
config.Type, tokenLen, config.MinLength)
|
||||
return err
|
||||
}
|
||||
|
||||
if tokenLen > config.MaxLength {
|
||||
err := fmt.Errorf("%s token exceeds maximum length (%d bytes, max: %d)",
|
||||
config.Type, tokenLen, config.MaxLength)
|
||||
return err
|
||||
}
|
||||
|
||||
// JWT-specific size validation
|
||||
if config.RequireJWTFormat || (config.AllowOpaqueTokens && strings.Contains(token, ".")) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) == 3 {
|
||||
// Validate individual JWT part sizes
|
||||
headerLen := len(parts[0])
|
||||
payloadLen := len(parts[1])
|
||||
signatureLen := len(parts[2])
|
||||
|
||||
// Check for unreasonably large JWT parts (potential security issue)
|
||||
if headerLen > 5*1024 { // 5KB header limit
|
||||
err := fmt.Errorf("%s token header too large (%d bytes)", config.Type, headerLen)
|
||||
return err
|
||||
}
|
||||
|
||||
if payloadLen > config.MaxLength-10*1024 { // Leave room for header and signature
|
||||
err := fmt.Errorf("%s token payload too large (%d bytes)", config.Type, payloadLen)
|
||||
return err
|
||||
}
|
||||
|
||||
if signatureLen > 2*1024 { // 2KB signature limit
|
||||
err := fmt.Errorf("%s token signature too large (%d bytes)", config.Type, signatureLen)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Opaque token size validation
|
||||
if config.AllowOpaqueTokens && !strings.Contains(token, ".") {
|
||||
// For opaque tokens, check for reasonable size limits
|
||||
if tokenLen > 8*1024 { // 8KB limit for opaque tokens
|
||||
err := fmt.Errorf("%s opaque token unusually large (%d bytes)", config.Type, tokenLen)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateChunkingEfficiency ensures that chunking is used appropriately
|
||||
func (cm *ChunkManager) validateChunkingEfficiency(token string, config TokenConfig) error {
|
||||
tokenLen := len(token)
|
||||
|
||||
// If token is small enough to fit in a single chunk, warn about unnecessary chunking
|
||||
if tokenLen <= config.MaxChunkSize && tokenLen <= maxCookieSize {
|
||||
// This is just informational - not an error, but helps with monitoring
|
||||
// Token could fit in single chunk - this is fine, just informational
|
||||
}
|
||||
|
||||
// Calculate expected number of chunks
|
||||
expectedChunks := (tokenLen + config.MaxChunkSize - 1) / config.MaxChunkSize
|
||||
if expectedChunks > config.MaxChunks {
|
||||
err := fmt.Errorf("%s token would require %d chunks (max: %d)",
|
||||
config.Type, expectedChunks, config.MaxChunks)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for potential storage efficiency issues
|
||||
if expectedChunks > 10 && tokenLen < 50*1024 {
|
||||
cm.logger.Info("%s token requires many chunks (%d) for size (%d bytes) - consider token optimization",
|
||||
config.Type, expectedChunks, tokenLen)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTokenContent performs comprehensive token content validation
|
||||
func (cm *ChunkManager) validateTokenContent(token string, config TokenConfig) error {
|
||||
// Basic content sanitization checks
|
||||
if err := cm.validateTokenSanitization(token, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// JWT-specific content validation
|
||||
if config.RequireJWTFormat || (config.AllowOpaqueTokens && strings.Contains(token, ".")) {
|
||||
if err := cm.validateJWTContent(token, config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Opaque token content validation
|
||||
if config.AllowOpaqueTokens && !strings.Contains(token, ".") {
|
||||
if err := cm.validateOpaqueTokenContent(token, config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTokenSanitization checks for basic security issues in token content
|
||||
func (cm *ChunkManager) validateTokenSanitization(token string, config TokenConfig) error {
|
||||
// Check for null bytes (potential injection attacks)
|
||||
if strings.Contains(token, "\x00") {
|
||||
err := fmt.Errorf("%s token contains null bytes", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for line feed/carriage return (header injection attacks)
|
||||
if strings.ContainsAny(token, "\r\n") {
|
||||
err := fmt.Errorf("%s token contains line breaks", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for suspicious escape sequences
|
||||
suspiciousPatterns := []string{
|
||||
"\\x", "\\u", "\\n", "\\r", "\\t", "\\0",
|
||||
"<script", "</script", "javascript:", "data:",
|
||||
"file://", "ftp://", "ldap://",
|
||||
}
|
||||
|
||||
tokenLower := strings.ToLower(token)
|
||||
for _, pattern := range suspiciousPatterns {
|
||||
if strings.Contains(tokenLower, pattern) {
|
||||
err := fmt.Errorf("%s token contains suspicious pattern: %s", config.Type, pattern)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check for excessive repeated characters (potential buffer overflow attempts)
|
||||
if err := cm.detectRepeatedCharacters(token, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTContent performs JWT-specific content validation
|
||||
func (cm *ChunkManager) validateJWTContent(token string, config TokenConfig) error {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
err := fmt.Errorf("%s JWT token malformed for content validation", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate header content
|
||||
if err := cm.validateJWTHeader(parts[0], config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate payload content
|
||||
if err := cm.validateJWTPayload(parts[1], config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate signature content
|
||||
if err := cm.validateJWTSignature(parts[2], config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTHeader validates JWT header content
|
||||
func (cm *ChunkManager) validateJWTHeader(header string, config TokenConfig) error {
|
||||
// Basic header structure validation
|
||||
if len(header) == 0 {
|
||||
err := fmt.Errorf("%s JWT header is empty", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate base64url encoding
|
||||
if _, err := base64.RawURLEncoding.DecodeString(header); err != nil {
|
||||
err := fmt.Errorf("%s JWT header not valid base64url", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTPayload validates JWT payload content
|
||||
func (cm *ChunkManager) validateJWTPayload(payload string, config TokenConfig) error {
|
||||
// Basic payload structure validation
|
||||
if len(payload) == 0 {
|
||||
err := fmt.Errorf("%s JWT payload is empty", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
// Payload should be decodable (basic structural check)
|
||||
if _, err := base64.RawURLEncoding.DecodeString(payload); err != nil {
|
||||
err := fmt.Errorf("%s JWT payload not valid base64url", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateJWTSignature validates JWT signature content
|
||||
func (cm *ChunkManager) validateJWTSignature(signature string, config TokenConfig) error {
|
||||
// Basic signature structure validation
|
||||
if len(signature) == 0 {
|
||||
err := fmt.Errorf("%s JWT signature is empty", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate base64url encoding
|
||||
if _, err := base64.RawURLEncoding.DecodeString(signature); err != nil {
|
||||
err := fmt.Errorf("%s JWT signature not valid base64url", config.Type)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpaqueTokenContent validates opaque token content
|
||||
func (cm *ChunkManager) validateOpaqueTokenContent(token string, config TokenConfig) error {
|
||||
// Check for reasonable character distribution in opaque tokens
|
||||
if len(token) >= 10 {
|
||||
alphabetic := 0
|
||||
numeric := 0
|
||||
special := 0
|
||||
|
||||
for _, char := range token {
|
||||
if (char >= 'A' && char <= 'Z') || (char >= 'a' && char <= 'z') {
|
||||
alphabetic++
|
||||
} else if char >= '0' && char <= '9' {
|
||||
numeric++
|
||||
} else {
|
||||
special++
|
||||
}
|
||||
}
|
||||
|
||||
total := alphabetic + numeric + special
|
||||
if total > 0 {
|
||||
// Require some distribution of character types for legitimate tokens
|
||||
alphaRatio := float64(alphabetic) / float64(total)
|
||||
numericRatio := float64(numeric) / float64(total)
|
||||
|
||||
// Opaque tokens should have reasonable character distribution
|
||||
if alphaRatio < 0.1 && numericRatio < 0.1 {
|
||||
err := fmt.Errorf("%s opaque token has suspicious character distribution", config.Type)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for common token prefixes/suffixes that might indicate legitimate tokens
|
||||
legitimatePrefixes := []string{
|
||||
"Bearer ", "bearer ", "eyJ", // JWT prefix
|
||||
"refresh_", "access_", "id_",
|
||||
"token_", "oauth_", "oidc_",
|
||||
}
|
||||
|
||||
hasLegitimatePrefix := false
|
||||
for _, prefix := range legitimatePrefixes {
|
||||
if strings.HasPrefix(token, prefix) {
|
||||
hasLegitimatePrefix = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// For longer tokens without legitimate prefixes, be more suspicious
|
||||
if len(token) > 50 && !hasLegitimatePrefix {
|
||||
// Opaque token without common prefixes - this is fine
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectRepeatedCharacters detects potential buffer overflow attempts
|
||||
func (cm *ChunkManager) detectRepeatedCharacters(token string, config TokenConfig) error {
|
||||
if len(token) < 10 {
|
||||
return nil // Too short to analyze meaningfully
|
||||
}
|
||||
|
||||
// Count consecutive repeated characters
|
||||
maxRepeated := 0
|
||||
currentRepeated := 1
|
||||
var lastChar rune
|
||||
|
||||
for i, char := range token {
|
||||
if i > 0 && char == lastChar {
|
||||
currentRepeated++
|
||||
if currentRepeated > maxRepeated {
|
||||
maxRepeated = currentRepeated
|
||||
}
|
||||
} else {
|
||||
currentRepeated = 1
|
||||
}
|
||||
lastChar = char
|
||||
}
|
||||
|
||||
// Flag tokens with excessive character repetition
|
||||
threshold := 20 // Allow up to 20 consecutive identical characters
|
||||
if maxRepeated > threshold {
|
||||
err := fmt.Errorf("%s token has excessive repeated characters (%d consecutive)",
|
||||
config.Type, maxRepeated)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for overall character frequency (detect padding attacks)
|
||||
charFreq := make(map[rune]int)
|
||||
for _, char := range token {
|
||||
charFreq[char]++
|
||||
}
|
||||
|
||||
tokenLen := len(token)
|
||||
for char, count := range charFreq {
|
||||
frequency := float64(count) / float64(tokenLen)
|
||||
|
||||
// Flag if any single character makes up more than 70% of the token
|
||||
if frequency > 0.7 && tokenLen > 20 {
|
||||
err := fmt.Errorf("%s token has suspicious character frequency (char '%c': %.1f%%)",
|
||||
config.Type, char, frequency*100)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTokenExpiration validates token expiration during storage/retrieval
|
||||
func (cm *ChunkManager) validateTokenExpiration(token string, config TokenConfig) error {
|
||||
// Only validate expiration for JWT tokens
|
||||
if !strings.Contains(token, ".") {
|
||||
return nil // Opaque tokens don't have embedded expiration
|
||||
}
|
||||
|
||||
// Parse JWT expiration claim
|
||||
expiration, err := cm.extractJWTExpiration(token)
|
||||
if err != nil {
|
||||
// If we can't parse expiration, log it but don't fail - the token might be valid but malformed
|
||||
cm.logger.Debugf("Could not extract expiration from %s token: %v", config.Type, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if expiration != nil && time.Now().After(*expiration) {
|
||||
err := fmt.Errorf("%s token is expired (expired at: %v)", config.Type, expiration.Format(time.RFC3339))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if token expires too far in the future (potential security issue)
|
||||
if expiration != nil {
|
||||
maxFutureTime := time.Now().Add(10 * 365 * 24 * time.Hour) // 10 years
|
||||
if expiration.After(maxFutureTime) {
|
||||
cm.logger.Info("%s token expires very far in future (%v) - potential security issue",
|
||||
config.Type, expiration.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractJWTExpiration extracts the expiration time from a JWT token
|
||||
func (cm *ChunkManager) extractJWTExpiration(token string) (*time.Time, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON payload
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
// Extract expiration claim
|
||||
exp, exists := claims["exp"]
|
||||
if !exists {
|
||||
return nil, nil // No expiration claim
|
||||
}
|
||||
|
||||
// Convert expiration to time.Time
|
||||
var expTime time.Time
|
||||
switch v := exp.(type) {
|
||||
case float64:
|
||||
expTime = time.Unix(int64(v), 0)
|
||||
case int64:
|
||||
expTime = time.Unix(v, 0)
|
||||
case int:
|
||||
expTime = time.Unix(int64(v), 0)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid expiration format: %T", exp)
|
||||
}
|
||||
|
||||
return &expTime, nil
|
||||
}
|
||||
|
||||
// validateTokenFreshness checks if token is fresh enough for storage
|
||||
func (cm *ChunkManager) validateTokenFreshness(token string, config TokenConfig) error {
|
||||
// Only validate freshness for JWT tokens
|
||||
if !strings.Contains(token, ".") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract issued at time
|
||||
issuedAt, err := cm.extractJWTIssuedAt(token)
|
||||
if err != nil {
|
||||
cm.logger.Debugf("Could not extract issued time from %s token: %v", config.Type, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if issuedAt != nil {
|
||||
now := time.Now()
|
||||
|
||||
// Check if token was issued in the future (clock skew tolerance: 5 minutes)
|
||||
if issuedAt.After(now.Add(5 * time.Minute)) {
|
||||
err := fmt.Errorf("%s token issued in future (issued at: %v)",
|
||||
config.Type, issuedAt.Format(time.RFC3339))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if token is too old (potential replay attack)
|
||||
maxAge := 24 * time.Hour // Tokens older than 24 hours are suspicious
|
||||
if now.Sub(*issuedAt) > maxAge {
|
||||
cm.logger.Info("%s token is quite old (issued: %v) - potential replay",
|
||||
config.Type, issuedAt.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractJWTIssuedAt extracts the issued at time from a JWT token
|
||||
func (cm *ChunkManager) extractJWTIssuedAt(token string) (*time.Time, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON payload
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
// Extract issued at claim
|
||||
iat, exists := claims["iat"]
|
||||
if !exists {
|
||||
return nil, nil // No issued at claim
|
||||
}
|
||||
|
||||
// Convert issued at to time.Time
|
||||
var iatTime time.Time
|
||||
switch v := iat.(type) {
|
||||
case float64:
|
||||
iatTime = time.Unix(int64(v), 0)
|
||||
case int64:
|
||||
iatTime = time.Unix(v, 0)
|
||||
case int:
|
||||
iatTime = time.Unix(int64(v), 0)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid issued at format: %T", iat)
|
||||
}
|
||||
|
||||
return &iatTime, nil
|
||||
}
|
||||
+6
-6
@@ -148,7 +148,7 @@ func getPooledObjects(sm *SessionManager) int {
|
||||
var objects []*SessionData
|
||||
maxAttempts := 100 // Safety limit to prevent infinite loops
|
||||
|
||||
for range maxAttempts {
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
obj := sm.sessionPool.Get()
|
||||
if obj == nil {
|
||||
break
|
||||
@@ -195,7 +195,7 @@ func TestSessionObjectTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create and discard 5 sessions
|
||||
for range 5 {
|
||||
for i := 0; i < 5; i++ {
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
@@ -549,7 +549,7 @@ func TestTokenSizeLimits(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Large but acceptable token",
|
||||
tokenSize: 30000, // FIXED: 30KB to ensure final size < 100KB limit
|
||||
tokenSize: 20000, // 20KB to ensure it fits within chunk limits (≤25 chunks)
|
||||
expectStored: true,
|
||||
},
|
||||
{
|
||||
@@ -609,11 +609,11 @@ func TestConcurrentTokenOperations(t *testing.T) {
|
||||
// Test concurrent access and refresh token operations
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := range numOperations {
|
||||
for j := 0; j < numOperations; j++ {
|
||||
// Create unique tokens for each goroutine/operation
|
||||
accessToken := ValidAccessToken
|
||||
refreshToken := fmt.Sprintf("refresh_token_%d_%d", id, j)
|
||||
@@ -637,7 +637,7 @@ func TestConcurrentTokenOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
+6
-6
@@ -414,7 +414,7 @@ func NewLogger(logLevel string) *Logger {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -424,7 +424,7 @@ func (l *Logger) Info(format string, args ...any) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -434,7 +434,7 @@ func (l *Logger) Debug(format string, args ...any) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -445,7 +445,7 @@ func (l *Logger) Error(format string, args ...any) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Infof(format string, args ...any) {
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -456,7 +456,7 @@ func (l *Logger) Infof(format string, args ...any) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debugf(format string, args ...any) {
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
@@ -467,7 +467,7 @@ func (l *Logger) Debugf(format string, args ...any) {
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Errorf(format string, args ...any) {
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,15 +11,15 @@ func TestTemplateExecution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]any
|
||||
data map[string]interface{}
|
||||
expectedValue string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "String Claim",
|
||||
templateText: "{{.Claims.email}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
@@ -29,8 +29,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Number Claim",
|
||||
templateText: "{{.Claims.age}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"age": 30,
|
||||
},
|
||||
},
|
||||
@@ -40,8 +40,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Boolean Claim",
|
||||
templateText: "{{.Claims.admin}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
@@ -51,8 +51,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Array Claim",
|
||||
templateText: "{{index .Claims.roles 0}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"roles": []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
@@ -62,9 +62,9 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Nested Object Claim",
|
||||
templateText: "{{.Claims.user.name}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"user": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"user": map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
},
|
||||
},
|
||||
@@ -75,7 +75,7 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Access Token",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
},
|
||||
expectedValue: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
@@ -83,9 +83,9 @@ func TestTemplateExecution(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "ID Token",
|
||||
templateText: "{{.IdToken}}",
|
||||
data: map[string]any{
|
||||
"IdToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
templateText: "{{.IDToken}}",
|
||||
data: map[string]interface{}{
|
||||
"IDToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
},
|
||||
expectedValue: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
expectError: false,
|
||||
@@ -93,7 +93,7 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Refresh Token",
|
||||
templateText: "{{.RefreshToken}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"RefreshToken": "refresh-token-value",
|
||||
},
|
||||
expectedValue: "refresh-token-value",
|
||||
@@ -102,8 +102,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Conditional Template",
|
||||
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
@@ -113,8 +113,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Multiple Claims",
|
||||
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"email": "john.doe@example.com",
|
||||
@@ -126,8 +126,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Missing Claim",
|
||||
templateText: "{{.Claims.missing}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{},
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expectedValue: "<no value>",
|
||||
expectError: false, // Go templates don't error on missing values
|
||||
@@ -135,8 +135,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Invalid Template Syntax",
|
||||
templateText: "{{.Claims.email",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
@@ -146,8 +146,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Custom Claims",
|
||||
templateText: "Role: {{.Claims.role}}, Department: {{.Claims.department}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"department": "engineering",
|
||||
@@ -159,10 +159,10 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Nested Custom Claims",
|
||||
templateText: "Org: {{.Claims.metadata.organization}}, Team: {{.Claims.metadata.team}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"metadata": map[string]any{
|
||||
"metadata": map[string]interface{}{
|
||||
"organization": "company-name",
|
||||
"team": "platform",
|
||||
},
|
||||
@@ -174,8 +174,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "Email Claims",
|
||||
templateText: "Email: {{.Claims.email}}, Verified: {{.Claims.email_verified}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
@@ -186,8 +186,8 @@ func TestTemplateExecution(t *testing.T) {
|
||||
{
|
||||
name: "User Identity Claims",
|
||||
templateText: "Name: {{.Claims.name}}, Subject: {{.Claims.sub}}, Username: {{.Claims.preferred_username}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
"sub": "user123",
|
||||
"preferred_username": "johndoe",
|
||||
@@ -233,16 +233,16 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
mapTests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]any
|
||||
data map[string]interface{}
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction with map",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
data: map[string]any{
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{},
|
||||
"IDToken": "id-token-value",
|
||||
"Claims": map[string]interface{}{},
|
||||
"RefreshToken": "refresh-token-value",
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
@@ -250,10 +250,10 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
{
|
||||
name: "Combining tokens and claims with map",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token",
|
||||
"IdToken": "id-token",
|
||||
"Claims": map[string]any{
|
||||
"IDToken": "id-token",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user123",
|
||||
},
|
||||
"RefreshToken": "refresh-token",
|
||||
@@ -263,17 +263,17 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
{
|
||||
name: "Authorization header with Bearer token",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "jwt-access-token",
|
||||
"IdToken": "id-token",
|
||||
"Claims": map[string]any{},
|
||||
"IDToken": "id-token",
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expectedValue: "Bearer jwt-access-token",
|
||||
},
|
||||
{
|
||||
name: "Boolean template data with AccessToken",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": true, // Test boolean values to ensure they render correctly
|
||||
},
|
||||
expectedValue: "Bearer true",
|
||||
@@ -281,10 +281,10 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
{
|
||||
name: "Custom non-standard claims in ID token",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"IDToken": "id-token-value",
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"permissions": "read:all,write:own",
|
||||
@@ -295,11 +295,11 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
{
|
||||
name: "Deeply nested custom claims",
|
||||
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"Claims": map[string]any{
|
||||
"app_metadata": map[string]any{
|
||||
"organization": map[string]any{
|
||||
"Claims": map[string]interface{}{
|
||||
"app_metadata": map[string]interface{}{
|
||||
"organization": map[string]interface{}{
|
||||
"name": "acme-corp",
|
||||
"id": "org-123",
|
||||
},
|
||||
@@ -312,10 +312,10 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
{
|
||||
name: "Email in claims",
|
||||
templateText: "X-User-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"IDToken": "id-token-value",
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
@@ -325,10 +325,10 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
{
|
||||
name: "User info from claims",
|
||||
templateText: "X-User-ID: {{.Claims.sub}}, X-User-Name: {{.Claims.name}}, X-Username: {{.Claims.preferred_username}}",
|
||||
data: map[string]any{
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"IDToken": "id-token-value",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user123456",
|
||||
"name": "Jane Doe",
|
||||
"preferred_username": "jane.doe",
|
||||
@@ -361,9 +361,9 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
|
||||
// For backward compatibility, also test the original struct-based implementation
|
||||
type templateData struct {
|
||||
Claims map[string]any
|
||||
Claims map[string]interface{}
|
||||
AccessToken string
|
||||
IdToken string
|
||||
IDToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
@@ -376,11 +376,11 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction with struct",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token-value",
|
||||
IdToken: "id-token-value", // Now these should be distinct values
|
||||
Claims: map[string]any{},
|
||||
IDToken: "id-token-value", // Now these should be distinct values
|
||||
Claims: map[string]interface{}{},
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
@@ -389,8 +389,8 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "access-token",
|
||||
Claims: map[string]any{
|
||||
IDToken: "access-token",
|
||||
Claims: map[string]interface{}{
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
@@ -401,8 +401,8 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
templateText: "X-Custom: {{.Claims.custom_field}}, X-Group: {{.Claims.group}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "id-token",
|
||||
Claims: map[string]any{
|
||||
IDToken: "id-token",
|
||||
Claims: map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"custom_field": "custom-value",
|
||||
"group": "admins",
|
||||
@@ -415,8 +415,8 @@ func TestTemplateExecutionContext(t *testing.T) {
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "id-token",
|
||||
Claims: map[string]any{
|
||||
IDToken: "id-token",
|
||||
Claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"name": "John Smith",
|
||||
},
|
||||
@@ -454,14 +454,14 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
dataContext any
|
||||
dataContext interface{}
|
||||
expectedValue string
|
||||
expectError bool // Added to skip the test that demonstrates the error
|
||||
}{
|
||||
{
|
||||
name: "Map with boolean as root",
|
||||
templateText: "{{.AccessToken}}",
|
||||
dataContext: map[string]any{"AccessToken": "token-value"},
|
||||
dataContext: map[string]interface{}{"AccessToken": "token-value"},
|
||||
expectedValue: "token-value",
|
||||
expectError: false,
|
||||
},
|
||||
@@ -475,17 +475,17 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
{
|
||||
name: "Bearer with map context",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
dataContext: map[string]any{"AccessToken": "token-value"},
|
||||
dataContext: map[string]interface{}{"AccessToken": "token-value"},
|
||||
expectedValue: "Bearer token-value",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Complex nesting with authorization",
|
||||
templateText: "Authorization: Bearer {{.AccessToken}}",
|
||||
dataContext: map[string]any{
|
||||
dataContext: map[string]interface{}{
|
||||
"AccessToken": "jwt-token-123",
|
||||
"something": true,
|
||||
"anotherField": map[string]any{
|
||||
"anotherField": map[string]interface{}{
|
||||
"nested": "value",
|
||||
},
|
||||
},
|
||||
@@ -495,13 +495,13 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
{
|
||||
name: "Custom claims access",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Groups: {{.Claims.groups}}",
|
||||
dataContext: map[string]any{
|
||||
dataContext: map[string]interface{}{
|
||||
"AccessToken": "jwt-token-xyz",
|
||||
"Claims": map[string]any{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"groups": "group1,group2,group3",
|
||||
"custom_data": map[string]any{
|
||||
"custom_data": map[string]interface{}{
|
||||
"organization": "company-name",
|
||||
"department": "engineering",
|
||||
},
|
||||
@@ -513,9 +513,9 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
{
|
||||
name: "Nested custom claims access",
|
||||
templateText: "X-Organization: {{.Claims.custom_data.organization}}, X-Department: {{.Claims.custom_data.department}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"custom_data": map[string]any{
|
||||
dataContext: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"custom_data": map[string]interface{}{
|
||||
"organization": "company-name",
|
||||
"department": "engineering",
|
||||
},
|
||||
@@ -527,8 +527,8 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
{
|
||||
name: "Azure AD specific claims",
|
||||
templateText: "X-TenantID: {{.Claims.tid}}, X-Roles: {{.Claims.roles}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
dataContext: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"tid": "tenant-id-12345",
|
||||
"roles": "User,Admin,Developer",
|
||||
},
|
||||
@@ -539,10 +539,10 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
{
|
||||
name: "Auth0 specific claims",
|
||||
templateText: "X-Permissions: {{.Claims.permissions}}, X-AppMetadata: {{.Claims.app_metadata.plan}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
dataContext: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"permissions": "read:products,write:orders",
|
||||
"app_metadata": map[string]any{
|
||||
"app_metadata": map[string]interface{}{
|
||||
"plan": "premium",
|
||||
"status": "active",
|
||||
"trial_ended": false,
|
||||
@@ -555,8 +555,8 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
{
|
||||
name: "Standard claims with email",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}, X-Subject: {{.Claims.sub}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
dataContext: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"name": "John Doe",
|
||||
"sub": "auth0|12345",
|
||||
@@ -568,8 +568,8 @@ func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
{
|
||||
name: "Verified email claim",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
dataContext: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
|
||||
@@ -2,7 +2,6 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -21,7 +20,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
sessionSetup func(*SessionData)
|
||||
claims map[string]any
|
||||
claims map[string]interface{}
|
||||
expectedHeaders map[string]string
|
||||
interceptedHeaders map[string]string
|
||||
name string
|
||||
@@ -32,7 +31,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -46,7 +45,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
{Name: "X-User-Name", Value: "{{.Claims.given_name}} {{.Claims.family_name}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"sub": "user123",
|
||||
"given_name": "John",
|
||||
@@ -71,7 +70,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
{
|
||||
name: "ID Token Header",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
{Name: "X-ID-Token", Value: "{{.IDToken}}"},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update this dynamically after generating the token
|
||||
@@ -82,7 +81,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
name: "Both Token Types",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
{Name: "X-ID-Token", Value: "{{.IDToken}}"},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update these dynamically after generating the tokens
|
||||
@@ -95,7 +94,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Role", Value: "{{.Claims.role}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
// role claim is missing
|
||||
},
|
||||
@@ -108,7 +107,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Admin", Value: "{{if .Claims.is_admin}}true{{else}}false{{end}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"email": "admin@example.com",
|
||||
"is_admin": true,
|
||||
},
|
||||
@@ -121,7 +120,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Auth-Info", Value: "User={{.Claims.email}}, Token={{.AccessToken}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -134,7 +133,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-AccessToken", Value: "{{.AccessToken}}"},
|
||||
},
|
||||
claims: map[string]any{ // For ID Token
|
||||
claims: map[string]interface{}{ // For ID Token
|
||||
"email": "opaque_user@example.com",
|
||||
"sub": "opaque_sub_for_id_token",
|
||||
},
|
||||
@@ -150,7 +149,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
token := ts.token
|
||||
if len(tc.claims) > 0 {
|
||||
var err error
|
||||
baseClaims := map[string]any{
|
||||
baseClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000), // Far future timestamp
|
||||
@@ -162,7 +161,9 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add the test-specific claims
|
||||
maps.Copy(baseClaims, tc.claims)
|
||||
for k, v := range tc.claims {
|
||||
baseClaims[k] = v
|
||||
}
|
||||
|
||||
token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", baseClaims)
|
||||
if err != nil {
|
||||
@@ -266,7 +267,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
if tc.name == "ID Token Header" || tc.name == "Both Token Types" {
|
||||
idTokenClaims := map[string]any{
|
||||
idTokenClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject",
|
||||
"nonce": "test-nonce", "jti": generateRandomString(16), "type": "id_token",
|
||||
@@ -284,7 +285,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", idErr)
|
||||
}
|
||||
|
||||
accessTokenClaims := map[string]any{
|
||||
accessTokenClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject",
|
||||
"jti": generateRandomString(16), "type": "access_token", "scope": "openid email profile",
|
||||
@@ -315,13 +316,15 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
tc.expectedHeaders["X-Access-Token"] = accessTokenForSession
|
||||
}
|
||||
} else if tc.name == "Opaque Access Token with AccessTokenField" {
|
||||
idTokenClaims := map[string]any{
|
||||
idTokenClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject", // Default sub
|
||||
"nonce": "test-nonce", "jti": generateRandomString(16), "type": "id_token",
|
||||
}
|
||||
// Populate ID token claims from tc.claims
|
||||
maps.Copy(idTokenClaims, tc.claims)
|
||||
for k, v := range tc.claims {
|
||||
idTokenClaims[k] = v
|
||||
}
|
||||
// Ensure email from tc.claims is used for the ID token
|
||||
session.SetEmail(tc.claims["email"].(string)) // Also set it directly for initial session state
|
||||
|
||||
@@ -386,6 +389,7 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
// The current test expects the literal string "<no value>".
|
||||
// Let's assume for now that if it's missing, it's an error unless specifically handled.
|
||||
// The test as written expects "<no value>" to be present.
|
||||
t.Logf("Header %s not set, but expected '<no value>' for missing claim", name)
|
||||
}
|
||||
t.Errorf("Expected header %s was not set", name)
|
||||
|
||||
@@ -423,7 +427,7 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
claims map[string]interface{}
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
shouldExecuteCheck bool
|
||||
@@ -444,8 +448,8 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Roles", Value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
"roles": []any{"admin", "user", "manager"},
|
||||
claims: map[string]interface{}{
|
||||
"roles": []interface{}{"admin", "user", "manager"},
|
||||
},
|
||||
shouldExecuteCheck: true,
|
||||
},
|
||||
@@ -454,7 +458,7 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with the test claims
|
||||
claims := map[string]any{
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000), // Far future timestamp
|
||||
@@ -466,7 +470,9 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add the test-specific claims
|
||||
maps.Copy(claims, tc.claims)
|
||||
for k, v := range tc.claims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
@@ -571,7 +577,8 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
// createLargeTemplate creates a template with many variable references
|
||||
func createLargeTemplate(size int) string {
|
||||
template := "{{with .Claims}}"
|
||||
for i := range size {
|
||||
for i := 0; i < size; i++ {
|
||||
|
||||
if i > 0 {
|
||||
template += ","
|
||||
}
|
||||
@@ -582,9 +589,9 @@ func createLargeTemplate(size int) string {
|
||||
}
|
||||
|
||||
// createLargeClaims creates a map with many claims for testing large templates
|
||||
func createLargeClaims(size int) map[string]any {
|
||||
claims := make(map[string]any)
|
||||
for i := range size {
|
||||
func createLargeClaims(size int) map[string]interface{} {
|
||||
claims := make(map[string]interface{})
|
||||
for i := 0; i < size; i++ {
|
||||
claims["email"] = "largeclaimsuser@example.com" // Add email claim
|
||||
key := "field" + string(rune('a'+i%26)) + string(rune('0'+i%10))
|
||||
claims[key] = "value" + string(rune('a'+i%26)) + string(rune('0'+i%10))
|
||||
|
||||
+47
-26
@@ -20,16 +20,16 @@ func NewTestTokens() *TestTokens {
|
||||
// Valid JWT tokens for testing
|
||||
const (
|
||||
// ValidAccessToken - A properly formatted JWT access token for testing
|
||||
ValidAccessToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiJlNDcxN2RhZDBmZjAyOTNkIiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.bmwp-vk0B7Ir9UiUkzib8L7yJbebJ00o3U9QrB6gP2H9-RfqyCbN8M9Rkx7Rb8Vdh3YzqkBBoLS_G0i414rs2I9uABnTC4E6-63qkGdUrLB7p-XbjcRW2RoIBwXHk7lfumi8eX0uWzBsJ9CY0__UECVsex5XORfBb4Bcqj0LK4y-glxkpI51I7BPySfciWC_PkdaQ1Qe5pCAlxeNs2E9NMGXp-Ox6vAufUzoC2cws1LswGPPP6icQ-Zlzd5WMCIWhdIkN4yTxk8FMqsTC52k2zskRHNSSd4DDVETonfzawZNqDcMpnTyN53sCJ9UHiQTl9mCm61ttYW-W9Gc-ze4Xw"
|
||||
ValidAccessToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MzAwMDAwMDAwMCwiaWF0IjoxMDAwMDAwMDAwLCJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImp0aSI6ImU0NzE3ZGFkMGZmMDI5M2QiLCJuYmYiOjEwMDAwMDAwMDAsIm5vbmNlIjoibm9uY2UxMjMiLCJzdWIiOiJ0ZXN0LXN1YmplY3QifQ.bmwp-vk0B7Ir9UiUkzib8L7yJbebJ00o3U9QrB6gP2H9-RfqyCbN8M9Rkx7Rb8Vdh3YzqkBBoLS_G0i414rs2I9uABnTC4E6-63qkGdUrLB7p-XbjcRW2RoIBwXHk7lfumi8eX0uWzBsJ9CY0__UECVsex5XORfBb4Bcqj0LK4y-glxkpI51I7BPySfciWC_PkdaQ1Qe5pCAlxeNs2E9NMGXp-Ox6vAufUzoC2cws1LswGPPP6icQ-Zlzd5WMCIWhdIkN4yTxk8FMqsTC52k2zskRHNSSd4DDVETonfzawZNqDcMpnTyN53sCJ9UHiQTl9mCm61ttYW-W9Gc-ze4Xw"
|
||||
|
||||
// ValidIDToken - A properly formatted JWT ID token for testing
|
||||
ValidIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiI2YzBjZTZmMTM4Y2EzMzc2IiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.RBQYejA9vP4lnh2EhFqWerePWaCyDTF0ZE1jlU2xm4g2wWVeaEHpv5SNg92_gwk633N9xx7ugS0UrlEu4qbT7wSb1HBDR00q_andyYnyFk4OoxPpD0AqHkVr-pjS-Z7UCGF3sLgQ4ECmU9695PIys3XvgUGMzEn_mK-PHcpY5AnbBGFsbj7epUld_sb6WfjjjwAa8kKfKObPvaIpuJ4TlxI1Uf0wYOoIA0zh5ipeAn-i8Ud-GErxis1Hp8UQK7IRolXpToiXnFcnf3vI3eCS7Yu3oPl7LRxTxKMCI9h0MCwu25ZNsOg2C9ohyebpU0jbURX9Q74GNOaphv-Lz9rCRA"
|
||||
ValidIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MzAwMDAwMDAwMCwiaWF0IjoxMDAwMDAwMDAwLCJpc3MiOiJodHRwczovL3Rlc3QtaXNzdWVyLmNvbSIsImp0aSI6IjZjMGNlNmYxMzhjYTMzNzYiLCJuYmYiOjEwMDAwMDAwMDAsIm5vbmNlIjoibm9uY2UxMjMiLCJzdWIiOiJ0ZXN0LXN1YmplY3QifQ.RBQYejA9vP4lnh2EhFqWerePWaCyDTF0ZE1jlU2xm4g2wWVeaEHpv5SNg92_gwk633N9xx7ugS0UrlEu4qbT7wSb1HBDR00q_andyYnyFk4OoxPpD0AqHkVr-pjS-Z7UCGF3sLgQ4ECmU9695PIys3XvgUGMzEn_mK-PHcpY5AnbBGFsbj7epUld_sb6WfjjjwAa8kKfKObPvaIpuJ4TlxI1Uf0wYOoIA0zh5ipeAn-i8Ud-GErxis1Hp8UQK7IRolXpToiXnFcnf3vI3eCS7Yu3oPl7LRxTxKMCI9h0MCwu25ZNsOg2C9ohyebpU0jbURX9Q74GNOaphv-Lz9rCRA"
|
||||
|
||||
// ValidRefreshToken - A properly formatted refresh token for testing
|
||||
ValidRefreshToken = "valid-refresh-token-12345"
|
||||
|
||||
// MinimalValidJWT - The shortest valid JWT for testing
|
||||
MinimalValidJWT = "h.p.s"
|
||||
// MinimalValidJWT - The shortest valid JWT for testing (actual base64url)
|
||||
MinimalValidJWT = "eyJ0eXAiOiJKV1QifQ.eyJzdWIiOiIxMjMifQ.abc123def456ghi789jkl012mno345pqr678stu901vwx234yz"
|
||||
|
||||
// ValidRefreshTokenGoogle - A Google-style refresh token for testing
|
||||
ValidRefreshTokenGoogle = "google_refresh_token_12345"
|
||||
@@ -57,14 +57,20 @@ const (
|
||||
// This replaces the ad-hoc createLargeValidJWT function in tests
|
||||
func (tt *TestTokens) CreateLargeValidJWT(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "signature_" + tt.generateRandomString(32)
|
||||
// Create a valid base64url signature
|
||||
signatureBytes := make([]byte, 32)
|
||||
rand.Read(signatureBytes)
|
||||
signature := base64.RawURLEncoding.EncodeToString(signatureBytes)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 50)
|
||||
payloadSize := targetSize - usedSize
|
||||
if payloadSize < 50 {
|
||||
payloadSize = 50
|
||||
}
|
||||
|
||||
// Create a payload with realistic JWT claims
|
||||
claims := map[string]any{
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
@@ -72,12 +78,10 @@ func (tt *TestTokens) CreateLargeValidJWT(targetSize int) string {
|
||||
"iat": 1000000000,
|
||||
}
|
||||
|
||||
// FIXED: Calculate data size safely
|
||||
dataSize := max(
|
||||
// Account for other claims and base64 encoding
|
||||
payloadSize-100,
|
||||
// Minimum data size
|
||||
10)
|
||||
dataSize := payloadSize - 100 // Account for other claims and base64 encoding
|
||||
if dataSize < 10 {
|
||||
dataSize = 10 // Minimum data size
|
||||
}
|
||||
|
||||
claims["data"] = tt.generateRandomString(dataSize)
|
||||
|
||||
@@ -99,7 +103,7 @@ func (tt *TestTokens) CreateExpiredJWT() string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
// Create claims with expired timestamp
|
||||
claims := map[string]any{
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
@@ -109,7 +113,10 @@ func (tt *TestTokens) CreateExpiredJWT() string {
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := "expired_signature"
|
||||
// Create a valid base64url signature
|
||||
signatureBytes := make([]byte, 16)
|
||||
rand.Read(signatureBytes)
|
||||
signature := base64.RawURLEncoding.EncodeToString(signatureBytes)
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
@@ -118,7 +125,7 @@ func (tt *TestTokens) CreateExpiredJWT() string {
|
||||
func (tt *TestTokens) CreateUniqueValidJWT(id string) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
claims := map[string]any{
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user_" + id,
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
@@ -129,7 +136,10 @@ func (tt *TestTokens) CreateUniqueValidJWT(id string) string {
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := "sig_" + id
|
||||
// Create a valid base64url signature
|
||||
signatureBytes := make([]byte, 16)
|
||||
rand.Read(signatureBytes)
|
||||
signature := base64.RawURLEncoding.EncodeToString(signatureBytes)
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
@@ -138,14 +148,17 @@ func (tt *TestTokens) CreateUniqueValidJWT(id string) string {
|
||||
// This is useful for testing chunking scenarios where compression doesn't help
|
||||
func (tt *TestTokens) CreateIncompressibleToken(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "incompressible_signature_" + tt.generateRandomString(32)
|
||||
// Create a valid base64url signature
|
||||
signatureBytes := make([]byte, 32)
|
||||
rand.Read(signatureBytes)
|
||||
signature := base64.RawURLEncoding.EncodeToString(signatureBytes)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 100)
|
||||
|
||||
// Generate multiple random fields to prevent compression
|
||||
randomFields := make(map[string]any)
|
||||
randomFields := make(map[string]interface{})
|
||||
randomFields["sub"] = "user123"
|
||||
randomFields["iss"] = "https://example.com"
|
||||
randomFields["aud"] = "client123"
|
||||
@@ -154,11 +167,12 @@ func (tt *TestTokens) CreateIncompressibleToken(targetSize int) string {
|
||||
|
||||
// Add many random fields with random data to prevent compression
|
||||
remainingSize := payloadSize - 200 // Account for base64 encoding and other fields
|
||||
fieldCount := max(
|
||||
// ~100 bytes per field
|
||||
remainingSize/100, 1)
|
||||
fieldCount := remainingSize / 100 // ~100 bytes per field
|
||||
if fieldCount < 1 {
|
||||
fieldCount = 1
|
||||
}
|
||||
|
||||
for i := range fieldCount {
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
// Generate truly random data for each field
|
||||
randomBytes := make([]byte, 50)
|
||||
rand.Read(randomBytes)
|
||||
@@ -232,7 +246,7 @@ func (tt *TestTokens) generateRandomString(length int) string {
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
for i := 0; i < length; i++ {
|
||||
randomByte := make([]byte, 1)
|
||||
rand.Read(randomByte)
|
||||
b[i] = charset[int(randomByte[0])%len(charset)]
|
||||
@@ -296,7 +310,7 @@ func (ts *TestScenarios) CorruptionTest() CorruptionTestSet {
|
||||
// ConcurrentTest returns unique tokens for concurrent testing
|
||||
func (ts *TestScenarios) ConcurrentTest(count int) []TokenSet {
|
||||
sets := make([]TokenSet, count)
|
||||
for i := range count {
|
||||
for i := 0; i < count; i++ {
|
||||
sets[i] = TokenSet{
|
||||
AccessToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("concurrent_%d", i)),
|
||||
IDToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("id_%d", i)),
|
||||
@@ -387,5 +401,12 @@ func AssertInvalidTokenRejection(t TestingInterface, session *SessionData, token
|
||||
|
||||
// TestingInterface provides the minimal interface needed for testing
|
||||
type TestingInterface interface {
|
||||
Errorf(format string, args ...any)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
+26
-51
@@ -3,11 +3,9 @@ package traefikoidc
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -24,8 +22,9 @@ func TestTokenCorruptionScenario(t *testing.T) {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a valid JWT token
|
||||
validJWT := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImV4cCI6OTk5OTk5OTk5OX0.signature"
|
||||
// Create a valid JWT token with proper base64url signature
|
||||
testTokens := NewTestTokens()
|
||||
validJWT := testTokens.CreateLargeValidJWT(100) // Create a small valid token
|
||||
|
||||
tests := []struct {
|
||||
corruptionScenario func(*SessionData)
|
||||
@@ -147,7 +146,7 @@ func TestCompressionIntegrityFailure(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig",
|
||||
token: NewTestTokens().CreateLargeValidJWT(100),
|
||||
expectSame: true,
|
||||
},
|
||||
{
|
||||
@@ -194,7 +193,8 @@ func TestChunkReassemblyEdgeCases(t *testing.T) {
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Create a large token that will definitely be chunked
|
||||
largeToken := createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig", 8000)
|
||||
testTokens := NewTestTokens()
|
||||
largeToken := testTokens.CreateLargeValidJWT(8000)
|
||||
|
||||
// Store the token to create chunks
|
||||
session.SetAccessToken(largeToken)
|
||||
@@ -251,8 +251,8 @@ func TestChunkReassemblyEdgeCases(t *testing.T) {
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
// This test simulates having too many chunks (>50 limit)
|
||||
// We'll create a scenario by adding many fake chunks
|
||||
for i := range 60 {
|
||||
fakeSession := &sessions.Session{Values: make(map[any]any)}
|
||||
for i := 0; i < 60; i++ {
|
||||
fakeSession := &sessions.Session{Values: make(map[interface{}]interface{})}
|
||||
fakeSession.Values["token_chunk"] = "fake_chunk_data"
|
||||
chunks[i] = fakeSession
|
||||
}
|
||||
@@ -312,21 +312,22 @@ func TestRaceConditionProtection(t *testing.T) {
|
||||
const numOperations = 50
|
||||
|
||||
// Create tokens of different sizes
|
||||
testTokens := NewTestTokens()
|
||||
tokens := []string{
|
||||
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig1",
|
||||
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig2", 3000),
|
||||
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig3", 6000),
|
||||
testTokens.CreateUniqueValidJWT("token1"),
|
||||
testTokens.CreateLargeValidJWT(3000),
|
||||
testTokens.CreateLargeValidJWT(6000),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
for i := range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := range numOperations {
|
||||
for j := 0; j < numOperations; j++ {
|
||||
tokenIndex := (goroutineID + j) % len(tokens)
|
||||
expectedToken := tokens[tokenIndex]
|
||||
|
||||
@@ -345,7 +346,13 @@ func TestRaceConditionProtection(t *testing.T) {
|
||||
|
||||
// The retrieved token should be one of the valid tokens we set
|
||||
// (due to concurrent access, it might not be the exact one we just set)
|
||||
isValidToken := slices.Contains(tokens, retrieved)
|
||||
isValidToken := false
|
||||
for _, validToken := range tokens {
|
||||
if retrieved == validToken {
|
||||
isValidToken = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if retrieved != "" && !isValidToken {
|
||||
errChan <- fmt.Errorf("goroutine %d, op %d: retrieved unknown token: %q",
|
||||
@@ -438,7 +445,8 @@ func TestBackwardCompatibility(t *testing.T) {
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Simulate old-style session data (without new validation fields)
|
||||
oldStyleToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.oldsig"
|
||||
testTokens := NewTestTokens()
|
||||
oldStyleToken := testTokens.CreateUniqueValidJWT("old")
|
||||
|
||||
// Manually set token without going through new SetAccessToken validation
|
||||
session.accessSession.Values["token"] = oldStyleToken
|
||||
@@ -462,41 +470,8 @@ func TestBackwardCompatibility(t *testing.T) {
|
||||
}
|
||||
|
||||
// createTokenOfSize creates a JWT token of approximately the specified size
|
||||
// This function is deprecated - use TestTokens.CreateLargeValidJWT instead
|
||||
func createTokenOfSize(baseToken string, targetSize int) string {
|
||||
parts := strings.Split(baseToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return baseToken
|
||||
}
|
||||
|
||||
header, payload, signature := parts[0], parts[1], parts[2]
|
||||
currentSize := len(baseToken)
|
||||
|
||||
if currentSize >= targetSize {
|
||||
return baseToken
|
||||
}
|
||||
|
||||
// Expand the payload to reach target size
|
||||
paddingNeeded := targetSize - len(header) - len(signature) - 2 // Account for dots
|
||||
if paddingNeeded > 0 {
|
||||
// Decode current payload, add padding, re-encode
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// If we can't decode, just pad with random base64-safe characters to resist compression
|
||||
randomBytes := make([]byte, paddingNeeded)
|
||||
rand.Read(randomBytes)
|
||||
// Encode as base64 to make it base64-safe
|
||||
padData := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||||
payload = payload + padData
|
||||
} else {
|
||||
// Add padding to the JSON - use random data to resist compression
|
||||
randomBytes := make([]byte, paddingNeeded/2)
|
||||
rand.Read(randomBytes)
|
||||
// Encode as base64 to make it JSON-safe
|
||||
padData := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
newPayload := fmt.Sprintf(`{"original":%s,"padding":"%s"}`, string(decoded), padData)
|
||||
payload = base64.RawURLEncoding.EncodeToString([]byte(newPayload))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
testTokens := NewTestTokens()
|
||||
return testTokens.CreateLargeValidJWT(targetSize)
|
||||
}
|
||||
|
||||
+46
-33
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -15,21 +16,21 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestTokenTypeDistinction tests that AccessToken and IdToken are correctly distinguished in templates
|
||||
// TestTokenTypeDistinction tests that AccessToken and IDToken are correctly distinguished in templates
|
||||
func TestTokenTypeDistinction(t *testing.T) {
|
||||
// Define test data where AccessToken and IdToken are deliberately different
|
||||
// Define test data where AccessToken and IDToken are deliberately different
|
||||
type templateData struct {
|
||||
Claims map[string]any
|
||||
Claims map[string]interface{}
|
||||
AccessToken string
|
||||
IdToken string
|
||||
IDToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
testData := templateData{
|
||||
AccessToken: "test-access-token-abc123",
|
||||
IdToken: "test-id-token-xyz789",
|
||||
IDToken: "test-id-token-xyz789",
|
||||
RefreshToken: "test-refresh-token",
|
||||
Claims: map[string]any{
|
||||
Claims: map[string]interface{}{
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
@@ -48,17 +49,17 @@ func TestTokenTypeDistinction(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "ID Token Only",
|
||||
templateText: "ID: {{.IdToken}}",
|
||||
templateText: "ID: {{.IDToken}}",
|
||||
expectedValue: "ID: test-id-token-xyz789",
|
||||
},
|
||||
{
|
||||
name: "Both Tokens",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
|
||||
expectedValue: "Access: test-access-token-abc123 ID: test-id-token-xyz789",
|
||||
},
|
||||
{
|
||||
name: "Both Tokens in Authorization Format",
|
||||
templateText: "Bearer {{.AccessToken}} and Bearer {{.IdToken}}",
|
||||
templateText: "Bearer {{.AccessToken}} and Bearer {{.IDToken}}",
|
||||
expectedValue: "Bearer test-access-token-abc123 and Bearer test-id-token-xyz789",
|
||||
},
|
||||
}
|
||||
@@ -91,7 +92,7 @@ func TestTokenTypeIntegration(t *testing.T) {
|
||||
ts.Setup()
|
||||
|
||||
// Create different tokens for ID and access tokens
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
@@ -107,7 +108,7 @@ func TestTokenTypeIntegration(t *testing.T) {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", err)
|
||||
}
|
||||
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
@@ -125,7 +126,7 @@ func TestTokenTypeIntegration(t *testing.T) {
|
||||
|
||||
// Define test headers that use both token types
|
||||
headers := []TemplatedHeader{
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
{Name: "X-ID-Token", Value: "{{.IDToken}}"},
|
||||
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-Email-From-Claims", Value: "{{.Claims.email}}"},
|
||||
@@ -332,9 +333,9 @@ func TestTokenCorruptionIntegrationFlows(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "Normal flow - small tokens",
|
||||
accessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.access_sig",
|
||||
accessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.access_signature_data_here",
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_token_signature_data_here",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
@@ -348,7 +349,7 @@ func TestTokenCorruptionIntegrationFlows(t *testing.T) {
|
||||
name: "Corrupted access token compression",
|
||||
accessToken: createLargeValidJWT(3000),
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_token_signature_data_here",
|
||||
expectSuccess: false,
|
||||
corruptAction: func(session *SessionData) {
|
||||
// Corrupt compressed access token
|
||||
@@ -360,15 +361,20 @@ func TestTokenCorruptionIntegrationFlows(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Corrupted chunk in large token",
|
||||
accessToken: createLargeValidJWT(8000), // Force chunking
|
||||
accessToken: createLargeValidJWT(15000), // Force chunking with larger size
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_token_signature_data_here",
|
||||
expectSuccess: false,
|
||||
corruptAction: func(session *SessionData) {
|
||||
// Corrupt first chunk
|
||||
// Corrupt first chunk if chunked, otherwise corrupt single token
|
||||
if len(session.accessTokenChunks) > 0 {
|
||||
if chunk, exists := session.accessTokenChunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = "corrupted_chunk_data"
|
||||
chunk.Values["token_chunk"] = "__CORRUPTED_CHUNK_DATA__"
|
||||
}
|
||||
} else {
|
||||
// Token is stored as single compressed token - corrupt it
|
||||
if session.accessSession != nil {
|
||||
session.accessSession.Values["token"] = "__CORRUPTED_CHUNK_DATA__"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -450,7 +456,8 @@ func TestSessionPersistenceWithCorruption(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
largeToken := createLargeValidJWT(6000)
|
||||
// Use a smaller token that's less likely to accidentally contain corruption markers
|
||||
largeToken := createLargeValidJWT(2000)
|
||||
session1.SetAccessToken(largeToken)
|
||||
session1.SetAuthenticated(true)
|
||||
|
||||
@@ -474,18 +481,18 @@ func TestSessionPersistenceWithCorruption(t *testing.T) {
|
||||
}
|
||||
defer session2.ReturnToPool()
|
||||
|
||||
// Verify token can be retrieved
|
||||
// Verify token can be retrieved initially
|
||||
retrieved := session2.GetAccessToken()
|
||||
if retrieved != largeToken {
|
||||
t.Errorf("Token persistence failed: expected %q, got %q", largeToken, retrieved)
|
||||
t.Errorf("Token persistence failed: expected valid token, got empty token")
|
||||
}
|
||||
|
||||
// Simulate corruption by modifying chunks
|
||||
if len(session2.accessTokenChunks) > 0 {
|
||||
// Corrupt a middle chunk
|
||||
// Corrupt a middle chunk with a unique corruption marker
|
||||
chunkIndex := len(session2.accessTokenChunks) / 2
|
||||
if chunk, exists := session2.accessTokenChunks[chunkIndex]; exists {
|
||||
chunk.Values["token_chunk"] = "corrupted"
|
||||
chunk.Values["token_chunk"] = "__CORRUPTION_MARKER_TEST__"
|
||||
}
|
||||
|
||||
// Try to retrieve again - should detect corruption and return empty
|
||||
@@ -518,11 +525,11 @@ func TestConcurrentTokenOperationsWithCorruption(t *testing.T) {
|
||||
errorChan := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Start concurrent operations
|
||||
for i := range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := range numOperations {
|
||||
for j := 0; j < numOperations; j++ {
|
||||
// Create a unique valid token for each operation
|
||||
token := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwib3AiOiIxMjMifQ.sig_%d_%d",
|
||||
goroutineID, j)
|
||||
@@ -553,7 +560,7 @@ func TestConcurrentTokenOperationsWithCorruption(t *testing.T) {
|
||||
// Intentionally corrupt a random chunk
|
||||
for chunkID, chunk := range session.accessTokenChunks {
|
||||
if chunkID%2 == 0 {
|
||||
chunk.Values["token_chunk"] = "intentionally_corrupted"
|
||||
chunk.Values["token_chunk"] = "__CORRUPTION_MARKER_TEST__"
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -563,7 +570,7 @@ func TestConcurrentTokenOperationsWithCorruption(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for range numGoroutines {
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
close(errorChan)
|
||||
@@ -641,20 +648,26 @@ func TestTokenValidationEdgeCases(t *testing.T) {
|
||||
// createLargeValidJWT creates a JWT of approximately the specified size
|
||||
func createLargeValidJWT(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "signature_" + generateRandomString(32)
|
||||
// Create a valid base64url signature
|
||||
signatureBytes := make([]byte, 32)
|
||||
rand.Read(signatureBytes)
|
||||
signature := base64.RawURLEncoding.EncodeToString(signatureBytes)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 50)
|
||||
payloadSize := targetSize - usedSize
|
||||
if payloadSize < 50 {
|
||||
payloadSize = 50
|
||||
}
|
||||
|
||||
// Create a payload with realistic JWT claims
|
||||
claims := map[string]any{
|
||||
// Create a payload with realistic JWT claims, using safe content
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": 9999999999,
|
||||
"iat": 1000000000,
|
||||
"data": generateRandomString(payloadSize - 100), // Account for other claims
|
||||
"data": strings.Repeat("abcdef0123456789", (payloadSize-100)/16), // Safe repeating pattern
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
|
||||
Reference in New Issue
Block a user