mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
90 Commits
v0.3.6
...
v0.6.2-beta6
| Author | SHA1 | Date | |
|---|---|---|---|
| 70443f0855 | |||
| 7a443c626c | |||
| 48de8265c5 | |||
| d8d1b74175 | |||
| c233aa92ef | |||
| c400251625 | |||
| 48faf7fadf | |||
| 84d7cd3d76 | |||
| 488264028b | |||
| e23135ded0 | |||
| cd307f88a1 | |||
| efa0cd708b | |||
| 99881f5837 | |||
| 82a640cc3b | |||
| 24d8dc38e8 | |||
| 248ca018e2 | |||
| 003a3686a0 | |||
| da70e69ad1 | |||
| 81000a824d | |||
| 83693d2893 | |||
| d88ef61c5d | |||
| 075476792f | |||
| 2583266738 | |||
| 996b25ebaf | |||
| 75b5904099 | |||
| a895333964 | |||
| 983585e96e | |||
| 8a6e37f7fc | |||
| bd7eaf6dff | |||
| 3df19e6d90 | |||
| 1910cd6000 | |||
| 46c2f98a15 | |||
| 9e8634bfc0 | |||
| 23e019092a | |||
| 4322407129 | |||
| 4ce2815123 | |||
| 7d204113ea | |||
| c721913cbe | |||
| 0f8b7f7ab1 | |||
| 2743b0e024 | |||
| e6fc36937b | |||
| df051e0cfb | |||
| 5d5ce8ae5e | |||
| d194cd778a | |||
| 803a1e5e21 | |||
| 3ad8fb4518 | |||
| 9402f1bca5 | |||
| e6205b3a48 | |||
| fdb8e3233e | |||
| 33c71fd6fe | |||
| 241cb1c209 | |||
| 09daa1025c | |||
| c09e7a9228 | |||
| e5da5d4fe9 | |||
| 31db701dda | |||
| 16481afd36 | |||
| 751933ffa0 | |||
| e74153b107 | |||
| 025107fe3e | |||
| dfb9c0771e | |||
| 1107df40e7 | |||
| bf294569eb | |||
| 482c346840 | |||
| a462e44896 | |||
| 5eff0dc866 | |||
| dfc534a400 | |||
| 061c12d0a3 | |||
| 4c4fff3613 | |||
| 0dcb44c187 | |||
| cbe773d96a | |||
| 40254888d7 | |||
| ef41870c81 | |||
| 081c32925a | |||
| 17dea67229 | |||
| 8512ad6d68 | |||
| 5aa838c669 | |||
| 6f359e5ef1 | |||
| bd18d6041c | |||
| 74c620ad51 | |||
| 7e3dc46b6e | |||
| 147aa0b169 | |||
| eecb7dfc92 | |||
| a8d65688c4 | |||
| bef4212c57 | |||
| 1fee2f9e9a | |||
| 11bc6f3e31 | |||
| 2b7af88ff9 | |||
| 01ee7c4dc8 | |||
| a6fa4d8789 | |||
| 8101fb2bf6 |
+297
-20
@@ -4,27 +4,304 @@ type: middleware
|
||||
import: github.com/lukaszraczylo/traefikoidc
|
||||
|
||||
summary: |
|
||||
Middleware adding OIDC authentication to traefik routes. Does what it says on the tin.
|
||||
Middleware has been tested with Auth0 and Logto. It should work with any OIDC provider.
|
||||
Middleware adding OpenID Connect (OIDC) authentication to Traefik routes.
|
||||
|
||||
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy.
|
||||
It provides a complete OIDC authentication solution with features like domain restrictions,
|
||||
role-based access control, token caching, and more.
|
||||
|
||||
The middleware has been tested with Auth0, Logto, Google, and other standard OIDC providers.
|
||||
It includes special handling for Google's OAuth implementation to ensure compatibility.
|
||||
It supports various authentication scenarios including:
|
||||
|
||||
- Basic authentication with customizable callback and logout URLs
|
||||
- Email domain restrictions to limit access to specific organizations
|
||||
- Role and group-based access control
|
||||
- Public URLs that bypass authentication
|
||||
- Rate limiting to prevent brute force attacks
|
||||
- Custom post-logout redirect behavior
|
||||
- Secure session management with encrypted cookies
|
||||
- Automatic token validation and refresh
|
||||
|
||||
testData:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
|
||||
- raczylo.com
|
||||
allowedRolesAndGroups:
|
||||
# Required parameters
|
||||
providerURL: https://accounts.google.com # Base URL of the OIDC provider
|
||||
clientID: 1234567890.apps.googleusercontent.com # OAuth 2.0 client identifier
|
||||
clientSecret: secret # OAuth 2.0 client secret
|
||||
callbackURL: /oauth2/callback # Path where the OIDC provider will redirect after authentication
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long # Key used to encrypt session data (must be at least 32 bytes)
|
||||
|
||||
# Optional parameters with defaults
|
||||
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
|
||||
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
|
||||
|
||||
scopes: # Additional scopes to append to defaults ["openid", "profile", "email"]
|
||||
- roles # Result: ["openid", "profile", "email", "roles"]
|
||||
|
||||
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
|
||||
allowedUsers: # Restricts access to specific email addresses regardless of domain
|
||||
- specific-user@company.com
|
||||
- another-user@gmail.com
|
||||
|
||||
allowedRolesAndGroups: # Restricts access to users with specific roles or groups (if not provided, no role/group restrictions)
|
||||
- guest-endpoints
|
||||
sessionEncryptionKey: potato-secret
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- admin
|
||||
- developer
|
||||
|
||||
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
|
||||
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
|
||||
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
|
||||
|
||||
excludedURLs: # Lists paths that bypass authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
- /public
|
||||
- /health
|
||||
- /metrics
|
||||
|
||||
headers: # Custom headers to set with templated values from claims and tokens
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{.Claims.sub}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{.AccessToken}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
|
||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
|
||||
|
||||
# Configuration documentation
|
||||
configuration:
|
||||
providerURL:
|
||||
type: string
|
||||
description: |
|
||||
The base URL of the OIDC provider. This is the issuer URL that will be used to discover
|
||||
OIDC endpoints like authorization, token, and JWKS URIs.
|
||||
|
||||
Examples:
|
||||
- https://accounts.google.com
|
||||
- https://login.microsoftonline.com/tenant-id/v2.0
|
||||
- https://your-auth0-domain.auth0.com
|
||||
- https://your-logto-instance.com/oidc
|
||||
required: true
|
||||
|
||||
clientID:
|
||||
type: string
|
||||
description: |
|
||||
The OAuth 2.0 client identifier obtained from your OIDC provider.
|
||||
This is the public identifier for your application.
|
||||
required: true
|
||||
|
||||
clientSecret:
|
||||
type: string
|
||||
description: |
|
||||
The OAuth 2.0 client secret obtained from your OIDC provider.
|
||||
This should be kept confidential and not exposed in client-side code.
|
||||
|
||||
For Kubernetes deployments, you can use the secret reference format:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
required: true
|
||||
|
||||
callbackURL:
|
||||
type: string
|
||||
description: |
|
||||
The path where the OIDC provider will redirect after authentication.
|
||||
This must match one of the redirect URIs configured in your OIDC provider.
|
||||
|
||||
The full redirect URI will be constructed as:
|
||||
[scheme]://[host][callbackURL]
|
||||
|
||||
Example: /oauth2/callback
|
||||
required: true
|
||||
|
||||
sessionEncryptionKey:
|
||||
type: string
|
||||
description: |
|
||||
Key used to encrypt session data stored in cookies.
|
||||
Must be at least 32 bytes long for security.
|
||||
|
||||
Example: potato-secret-is-at-least-32-bytes-long
|
||||
required: true
|
||||
|
||||
logoutURL:
|
||||
type: string
|
||||
description: |
|
||||
The path for handling logout requests.
|
||||
If not provided, it will be set to callbackURL + "/logout".
|
||||
|
||||
Example: /oauth2/logout
|
||||
required: false
|
||||
|
||||
postLogoutRedirectURI:
|
||||
type: string
|
||||
description: |
|
||||
The URL to redirect to after logout.
|
||||
Default: "/"
|
||||
|
||||
Example: /logged-out-page
|
||||
required: false
|
||||
|
||||
scopes:
|
||||
type: array
|
||||
description: |
|
||||
Additional OAuth 2.0 scopes to append to the default scopes.
|
||||
Default scopes are always included: ["openid", "profile", "email"]
|
||||
|
||||
User-provided scopes are appended to defaults with automatic deduplication.
|
||||
For example, specifying ["roles", "custom_scope"] results in:
|
||||
["openid", "profile", "email", "roles", "custom_scope"]
|
||||
|
||||
Include "roles" or similar scope if you need role/group information.
|
||||
Note: For Google OAuth, the middleware automatically handles the
|
||||
proper authentication parameters and does NOT require the "offline_access"
|
||||
scope (which Google rejects as invalid). See documentation for details.
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
logLevel:
|
||||
type: string
|
||||
description: |
|
||||
Sets the logging verbosity.
|
||||
Valid values: "debug", "info", "error"
|
||||
Default: "info"
|
||||
required: false
|
||||
enum:
|
||||
- debug
|
||||
- info
|
||||
- error
|
||||
|
||||
forceHTTPS:
|
||||
type: boolean
|
||||
description: |
|
||||
Forces the use of HTTPS for all URLs.
|
||||
This is recommended for security in production environments.
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
rateLimit:
|
||||
type: integer
|
||||
description: |
|
||||
Sets the maximum number of requests per second.
|
||||
This helps prevent brute force attacks.
|
||||
Default: 100
|
||||
Minimum: 10
|
||||
required: false
|
||||
|
||||
excludedURLs:
|
||||
type: array
|
||||
description: |
|
||||
Lists paths that bypass authentication.
|
||||
These paths will be accessible without OIDC authentication.
|
||||
|
||||
The middleware uses prefix matching, so "/public" will match
|
||||
"/public", "/public/page", "/public-data", etc.
|
||||
|
||||
Examples: ["/health", "/metrics", "/public"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
allowedUserDomains:
|
||||
type: array
|
||||
description: |
|
||||
Restricts access to users with email addresses from specific domains.
|
||||
If not provided, the middleware relies entirely on the OIDC provider
|
||||
for authentication decisions.
|
||||
|
||||
Examples: ["company.com", "subsidiary.com"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
allowedUsers:
|
||||
type: array
|
||||
description: |
|
||||
Restricts access to specific email addresses.
|
||||
If provided, only users with these exact email addresses will be allowed access,
|
||||
in addition to any domain-level restrictions set by allowedUserDomains.
|
||||
|
||||
This provides fine-grained control over individual access and can be used
|
||||
together with allowedUserDomains for flexible access control strategies.
|
||||
|
||||
Examples: ["user1@example.com", "admin@company.com"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
allowedRolesAndGroups:
|
||||
type: array
|
||||
description: |
|
||||
Restricts access to users with specific roles or groups.
|
||||
If not provided, no role/group restrictions are applied.
|
||||
|
||||
The middleware checks both the "roles" and "groups" claims in the ID token.
|
||||
|
||||
Examples: ["admin", "developer"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
The endpoint for revoking tokens.
|
||||
If not provided, it will be discovered from provider metadata.
|
||||
|
||||
Example: https://accounts.google.com/revoke
|
||||
required: false
|
||||
|
||||
oidcEndSessionURL:
|
||||
type: string
|
||||
description: |
|
||||
The provider's end session endpoint.
|
||||
If not provided, it will be discovered from provider metadata.
|
||||
|
||||
Example: https://accounts.google.com/logout
|
||||
required: false
|
||||
|
||||
enablePKCE:
|
||||
type: boolean
|
||||
description: |
|
||||
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
|
||||
PKCE adds an extra layer of security to protect against authorization code interception attacks.
|
||||
|
||||
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
|
||||
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
Custom HTTP headers to set with templated values derived from OIDC claims and tokens.
|
||||
Each header has a name and a value template that can access:
|
||||
- {{.Claims.field}} - Access ID token claims (e.g., email, sub, name)
|
||||
- {{.AccessToken}} - The raw access token string
|
||||
- {{.IdToken}} - The raw ID token string
|
||||
- {{.RefreshToken}} - The raw refresh token string
|
||||
|
||||
Templates support Go template syntax including conditionals and iteration.
|
||||
Variable names are case-sensitive - use .Claims not .claims.
|
||||
|
||||
Examples:
|
||||
- name: "X-User-Email", value: "{{.Claims.email}}"
|
||||
- name: "Authorization", value: "Bearer {{.AccessToken}}"
|
||||
- name: "X-User-Roles", value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
required: false
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: The HTTP header name to set
|
||||
value:
|
||||
type: string
|
||||
description: Template string for the header value
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Lukasz Raczylo
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,150 +1,486 @@
|
||||
## Traefik OIDC middleware
|
||||
# Traefik OIDC Middleware
|
||||
|
||||
This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy to support the OIDC authentication.
|
||||
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy to support OpenID Connect (OIDC) authentication.
|
||||
|
||||
Middleware has been tested with Auth0 and Logto.
|
||||
## Overview
|
||||
|
||||
### Traefik version compatibility
|
||||
The Traefik OIDC middleware provides a complete OIDC authentication solution with features like:
|
||||
- Token validation and verification
|
||||
- Session management
|
||||
- Domain restrictions
|
||||
- Role-based access control
|
||||
- Token caching and blacklisting
|
||||
- Rate limiting
|
||||
- Excluded paths (public URLs)
|
||||
|
||||
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
|
||||
The middleware has been tested with Auth0, Logto, Google and other standard OIDC providers. It includes special handling for Google's OAuth implementation.
|
||||
|
||||
### Configuration options
|
||||
## Traefik Version Compatibility
|
||||
|
||||
Middleware currently supports following scenarios:
|
||||
This middleware follows closely the current Traefik helm chart versions. If the plugin fails to load, it's time to update to the latest version of the Traefik helm chart.
|
||||
|
||||
* Setting custom callback and logout URLs via `callbackURL` and `logoutURL`
|
||||
* Allowing for access only from the listed domains if `allowedUserDomains` is set, otherwise it relies entirely on the OIDC provider
|
||||
* Using excluded URLs which do **NOT** require the OIDC authentication
|
||||
* Rate limiting requests to prevent the bruteforce attacks
|
||||
## Installation
|
||||
|
||||
#### How to configure...
|
||||
### As a Traefik Plugin
|
||||
|
||||
##### Keeping secrets secret
|
||||
|
||||
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
|
||||
1. Enable the plugin in your Traefik static configuration:
|
||||
|
||||
```yaml
|
||||
# traefik.yml
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.2.1 # Use the latest version
|
||||
```
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
|
||||
### Local Development with Docker Compose
|
||||
|
||||
For local development or testing, you can use the provided Docker Compose setup:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
This will start Traefik with the OIDC middleware and two test services.
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The middleware supports the following configuration options:
|
||||
|
||||
### Required Parameters
|
||||
|
||||
| Parameter | Description | Example |
|
||||
|-----------|-------------|---------|
|
||||
| `providerURL` | The base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | The OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | The OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | Key used to encrypt session data (must be at least 32 bytes long) | `potato-secret-is-at-least-32-bytes-long` |
|
||||
| `callbackURL` | The path where the OIDC provider will redirect after authentication | `/oauth2/callback` |
|
||||
|
||||
### Optional Parameters
|
||||
|
||||
| Parameter | Description | Default | Example |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `logoutURL` | The path for handling logout requests | `callbackURL + "/logout"` | `/oauth2/logout` |
|
||||
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
||||
| `scopes` | OAuth 2.0 scopes to use for authentication | `["openid", "profile", "email"]` (always included by default) | `["roles", "custom_scope"]` (appended to defaults) |
|
||||
| `overrideScopes` | When true, replaces default scopes with provided scopes instead of appending | `false` | `true` (use only the scopes explicitly provided) |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedUsers` | A list of specific email addresses that are allowed access | none | `["user1@example.com", "user2@another.org"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
### Scope Behavior
|
||||
|
||||
The middleware supports two modes for handling OAuth 2.0 scopes, controlled by the `overrideScopes` parameter:
|
||||
|
||||
#### Default Append Mode (`overrideScopes: false`)
|
||||
|
||||
By default, the middleware uses an **append** behavior for OAuth 2.0 scopes:
|
||||
|
||||
- **Default scopes** are always included: `["openid", "profile", "email"]`
|
||||
- **User-provided scopes** are appended to the defaults with automatic deduplication
|
||||
- The final scope list maintains the order: defaults first, then user scopes
|
||||
|
||||
#### Override Mode (`overrideScopes: true`)
|
||||
|
||||
When `overrideScopes` is set to `true`, the middleware uses **replacement** behavior:
|
||||
|
||||
- Default scopes are **not** automatically included
|
||||
- Only the scopes explicitly provided in the `scopes` field are used
|
||||
- You must include all required scopes explicitly, including `openid` if needed
|
||||
|
||||
### Examples:
|
||||
|
||||
**Default behavior (no custom scopes):**
|
||||
```yaml
|
||||
# No scopes field specified
|
||||
# Result: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
**Default append behavior:**
|
||||
```yaml
|
||||
scopes:
|
||||
- roles
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
|
||||
```
|
||||
|
||||
**Overlapping scopes with append (automatic deduplication):**
|
||||
```yaml
|
||||
scopes:
|
||||
- openid # Duplicate - will be deduplicated
|
||||
- roles
|
||||
- profile # Duplicate - will be deduplicated
|
||||
- permissions
|
||||
# Result: ["openid", "profile", "email", "roles", "permissions"]
|
||||
```
|
||||
|
||||
**Using override mode:**
|
||||
```yaml
|
||||
overrideScopes: true
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "custom_scope"]
|
||||
```
|
||||
|
||||
**Empty scopes list with default behavior:**
|
||||
```yaml
|
||||
scopes: []
|
||||
# Result: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
**Empty scopes list with override mode:**
|
||||
```yaml
|
||||
overrideScopes: true
|
||||
scopes: []
|
||||
# Result: [] (Warning: empty scopes may cause authentication to fail)
|
||||
```
|
||||
|
||||
The default append behavior ensures essential OIDC scopes are always present, while the override mode gives you complete control over the exact scopes requested from the provider.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-basic
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### With Excluded URLs (Public Access Paths)
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
excludedURLs:
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /public-data
|
||||
- /health
|
||||
- /metrics
|
||||
```
|
||||
|
||||
### With Email Domain Restrictions
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-domain-restricted
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
```
|
||||
|
||||
### With Specific User Access
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-specific-users
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUsers:
|
||||
- user1@example.com
|
||||
- user2@another.org
|
||||
```
|
||||
|
||||
### With Both Domain and Specific User Access
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-domain-and-users
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
allowedUsers:
|
||||
- special-user@gmail.com
|
||||
- contractor@external.org
|
||||
```
|
||||
|
||||
When configuring access control:
|
||||
- If only `allowedUsers` is set, only the specified email addresses will be granted access
|
||||
- If only `allowedUserDomains` is set, only users with email addresses from those domains will be granted access
|
||||
- If both are set, access is granted if the user's email is in `allowedUsers` OR their email's domain is in `allowedUserDomains`
|
||||
- If neither is set, any authenticated user will be granted access
|
||||
- Email matching is case-insensitive
|
||||
|
||||
### With Role-Based Access Control
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-rbac
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
```
|
||||
|
||||
### With Custom Logging and Rate Limiting
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-custom-settings
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
logLevel: debug # Options: debug, info, error (default: info)
|
||||
rateLimit: 500 # Requests per second (default: 100)
|
||||
forceHTTPS: false # Default is true for security
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### With Custom Post-Logout Redirect
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-custom-logout
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /logged-out-page # Where to redirect after logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### With Templated Headers
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-headers
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{.Claims.sub}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{.AccessToken}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
- name: "X-Is-Admin"
|
||||
value: "{{if eq .Claims.role \"admin\"}}true{{else}}false{{end}}"
|
||||
```
|
||||
|
||||
### With PKCE Enabled
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-pkce
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
enablePKCE: true # Enables PKCE for added security
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### Google OIDC Configuration Example
|
||||
|
||||
This example shows a configuration specifically tailored for Google OIDC:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-google-client-id.apps.googleusercontent.com # Replace with your Client ID
|
||||
clientSecret: your-google-client-secret # Replace with your Client Secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars # Replace with your key
|
||||
callbackURL: /oauth2/callback # Adjust if needed
|
||||
logoutURL: /oauth2/logout # Optional: Adjust if needed
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
# Note: DO NOT manually add offline_access scope for Google
|
||||
# The middleware automatically handles Google-specific requirements
|
||||
refreshGracePeriodSeconds: 300 # Optional: Start refresh 5 min before expiry (default 60)
|
||||
# Other optional parameters like allowedUserDomains, etc. can be added here
|
||||
```
|
||||
|
||||
The middleware automatically detects Google as the provider and applies the necessary adjustments to ensure proper authentication and token refresh. See the [Google OAuth Fix](#google-oauth-compatibility-fix) section for details.
|
||||
|
||||
### Keeping Secrets Secret in Kubernetes
|
||||
|
||||
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-secrets
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
|
||||
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
##### Excluded URLs with open access
|
||||
Don't forget to create the secret:
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```bash
|
||||
kubectl create secret generic traefik-middleware-oidc \
|
||||
--from-literal=ISSUER=https://accounts.google.com \
|
||||
--from-literal=CLIENT_ID=1234567890.apps.googleusercontent.com \
|
||||
--from-literal=SECRET=your-client-secret \
|
||||
-n traefik
|
||||
```
|
||||
|
||||
## Complete Docker Compose Example
|
||||
|
||||
##### Allowed email domains
|
||||
|
||||
Assuming that your OIDC provider allows anyone to log in, you may want to limit the access to people using emains in specific domain.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-only-my-users
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /new-oidc/callback
|
||||
logoutURL: /new-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
##### Allowed groups and roles
|
||||
|
||||
In case of multiple roles / groups and access separation for various endpoints you will need to create multiple traefik middlewares.
|
||||
Following example allows access for users who have additional role `guest-endpoints` assigned.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-guest-endpoints
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /my-oidc/callback
|
||||
logoutURL: /my-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # This line queries the OIDC provider for roles
|
||||
forceHTTPS: true
|
||||
allowedRolesAndGroups:
|
||||
- guest-endpoints # This line specifies the roles or groups allowed to access content
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
#### Docker compose example
|
||||
|
||||
`docker-compose.yaml`
|
||||
Here's a complete example of using the middleware with Docker Compose:
|
||||
|
||||
```yaml
|
||||
version: "3.7"
|
||||
|
||||
services:
|
||||
traefik:
|
||||
image: traefik:v3.0.1
|
||||
image: traefik:v3.2.1
|
||||
command:
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.2.1"
|
||||
@@ -155,7 +491,6 @@ services:
|
||||
labels:
|
||||
- "traefik.http.routers.dash.rule=Host(`dash.localhost`)"
|
||||
- "traefik.http.routers.dash.service=api@internal"
|
||||
|
||||
ports:
|
||||
- "80:80"
|
||||
|
||||
@@ -178,8 +513,7 @@ services:
|
||||
- traefik.http.routers.whoami.middlewares=my-plugin@file
|
||||
```
|
||||
|
||||
`traefik-config/traefik.yaml`
|
||||
|
||||
`traefik-config/traefik.yml`:
|
||||
```yaml
|
||||
log:
|
||||
level: INFO
|
||||
@@ -208,7 +542,7 @@ providers:
|
||||
filename: /etc/traefik/dynamic-configuration.yml
|
||||
```
|
||||
|
||||
`traefik-config/dynamic-configuration.yaml`
|
||||
`traefik-config/dynamic-configuration.yml`:
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
@@ -217,20 +551,181 @@ http:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: secret
|
||||
clientSecret: your-client-secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
|
||||
- raczylo.com
|
||||
sessionEncryptionKey: potato-secret
|
||||
postLogoutRedirectURI: /logged-out-page
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
scopes:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
allowedUsers:
|
||||
- special-user@gmail.com
|
||||
- contractor@external.org
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
logLevel: debug
|
||||
rateLimit: 100
|
||||
excludedURLs:
|
||||
- /login
|
||||
- /public
|
||||
- /health
|
||||
- /metrics
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{.Claims.sub}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{.AccessToken}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Session Management
|
||||
|
||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||
|
||||
### PKCE Support
|
||||
|
||||
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
|
||||
|
||||
PKCE is recommended when:
|
||||
- Your OIDC provider supports it (most modern providers do)
|
||||
- You need an additional layer of security for the authorization code flow
|
||||
- You're concerned about potential authorization code interception attacks
|
||||
|
||||
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
|
||||
|
||||
### Session Duration and Token Refresh
|
||||
|
||||
This middleware aims to provide long-lived user sessions, typically up to 24 hours, by utilizing OIDC refresh tokens.
|
||||
|
||||
**How it works:**
|
||||
- When a user authenticates, the middleware requests an access token and, if available, a refresh token from the OIDC provider.
|
||||
- The access token usually has a short lifespan (e.g., 1 hour).
|
||||
- Before the access token expires (controlled by `refreshGracePeriodSeconds`), the middleware uses the refresh token to obtain a new access token from the provider without requiring the user to log in again.
|
||||
- This process repeats, allowing the session to remain valid for as long as the refresh token is valid (often 24 hours or more, depending on the provider).
|
||||
|
||||
**Provider-Specific Considerations (e.g., Google):**
|
||||
- Some providers, like Google, issue short-lived access tokens (e.g., 1 hour) and require specific configurations for long-term sessions.
|
||||
- To enable session extension beyond the initial token expiry with Google and similar providers, the middleware automatically includes the `offline_access` scope in the authentication request. This scope is necessary to obtain a refresh token.
|
||||
- For Google specifically, the middleware also adds the `prompt=consent` parameter to the initial authorization request. This ensures Google issues a refresh token, which is crucial for extending the session.
|
||||
- If a refresh attempt fails (e.g., the refresh token is revoked or expired), the user will be required to re-authenticate. The middleware includes enhanced error handling and logging for these scenarios.
|
||||
- Ensure your OIDC provider is configured to issue refresh tokens and allows their use for extending sessions. Check your provider's documentation for details on refresh token validity periods.
|
||||
|
||||
### Google OAuth Compatibility Fix
|
||||
|
||||
The middleware includes a specific fix for Google's OAuth implementation, which differs from the standard OIDC specification in how it handles refresh tokens:
|
||||
|
||||
- **Issue**: Google does not support the standard `offline_access` scope for requesting refresh tokens and instead requires special parameters.
|
||||
|
||||
- **Automatic Solution**: The middleware detects Google as the provider based on the issuer URL and:
|
||||
- Uses `access_type=offline` query parameter instead of the `offline_access` scope
|
||||
- Adds `prompt=consent` to ensure refresh tokens are consistently issued
|
||||
- Properly handles token refresh with Google's implementation
|
||||
|
||||
You do not need any special configuration to use Google OAuth - just set `providerURL` to `https://accounts.google.com` and the middleware will automatically apply the proper parameters.
|
||||
|
||||
For detailed information on the Google OAuth fix, see the [dedicated documentation](docs/google-oauth-fix.md).
|
||||
|
||||
### Token Caching and Blacklisting
|
||||
|
||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||
### Templated Headers
|
||||
|
||||
The middleware supports setting custom HTTP headers with values templated from OIDC claims and tokens. This allows you to pass authentication information to downstream services in a flexible, customized format.
|
||||
|
||||
Templates can access the following variables:
|
||||
- `{{.Claims.field}}` - Access individual claims from the ID token (e.g., `{{.Claims.email}}`, `{{.Claims.sub}}`)
|
||||
- `{{.AccessToken}}` - The raw access token string
|
||||
- `{{.IdToken}}` - The raw ID token string (same as AccessToken in most configurations)
|
||||
- `{{.RefreshToken}}` - The raw refresh token string
|
||||
|
||||
**Example configuration:**
|
||||
```yaml
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{.Claims.sub}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{.AccessToken}}"
|
||||
- name: "X-User-Name"
|
||||
value: "{{.Claims.given_name}} {{.Claims.family_name}}"
|
||||
```
|
||||
|
||||
**Advanced template examples:**
|
||||
|
||||
Conditional logic:
|
||||
```yaml
|
||||
headers:
|
||||
- name: "X-Is-Admin"
|
||||
value: "{{if eq .Claims.role \"admin\"}}true{{else}}false{{end}}"
|
||||
```
|
||||
|
||||
Array handling:
|
||||
```yaml
|
||||
headers:
|
||||
- name: "X-User-Roles"
|
||||
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Variable names are case-sensitive (use `.Claims`, not `.claims`)
|
||||
- Missing claims will result in `<no value>` in the header value
|
||||
- The middleware validates templates during startup and logs errors for invalid templates
|
||||
|
||||
### Default Headers Set for Downstream Services
|
||||
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
|
||||
### Security Headers
|
||||
|
||||
The middleware also sets the following security headers:
|
||||
|
||||
- `X-Frame-Options: DENY`
|
||||
- `X-Content-Type-Options: nosniff`
|
||||
- `X-XSS-Protection: 1; mode=block`
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Logging
|
||||
|
||||
Set the `logLevel` to `debug` to get more detailed logs:
|
||||
|
||||
```yaml
|
||||
logLevel: debug
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Token verification failed**: Check that your `providerURL` is correct and accessible.
|
||||
2. **Session encryption key too short**: Ensure your `sessionEncryptionKey` is at least 32 bytes long.
|
||||
3. **No matching public key found**: The JWKS endpoint might be unavailable or the token's key ID (kid) doesn't match any key in the JWKS.
|
||||
4. **Access denied: Your email domain is not allowed**: The user's email domain is not in the `allowedUserDomains` list.
|
||||
5. **Access denied: You do not have any of the allowed roles or groups**: The user doesn't have any of the roles or groups specified in `allowedRolesAndGroups`.
|
||||
6. **Google sessions expire after ~1 hour**: If using Google as the OIDC provider and sessions expire prematurely (around 1 hour instead of longer), ensure:
|
||||
- Do NOT manually add the `offline_access` scope. Google rejects this scope as invalid.
|
||||
- The middleware automatically applies the required Google parameters (`access_type=offline` and `prompt=consent`).
|
||||
- Your Google Cloud OAuth consent screen is set to "External" and "Production" mode. "Testing" mode often limits refresh token validity.
|
||||
- Verify you're using a version of the middleware that includes the Google OAuth compatibility fix.
|
||||
- For more details, see the [Google OAuth Compatibility Fix](#google-oauth-compatibility-fix) section or the [detailed documentation](docs/google-oauth-fix.md).
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
### TODO / wishlist
|
||||
|
||||
- [] Improve test coverage
|
||||
- [x] Improve caching mechanism
|
||||
- [x] Add automatic release and semver generation
|
||||
@@ -0,0 +1,79 @@
|
||||
package traefikoidc
|
||||
|
||||
import "time"
|
||||
|
||||
// BackgroundTask represents a recurring task that runs in the background
|
||||
type BackgroundTask struct {
|
||||
stopChan chan struct{}
|
||||
taskFunc func()
|
||||
logger *Logger
|
||||
name string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewBackgroundTask creates a new background task
|
||||
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger) *BackgroundTask {
|
||||
return &BackgroundTask{
|
||||
name: name,
|
||||
interval: interval,
|
||||
stopChan: make(chan struct{}),
|
||||
taskFunc: taskFunc,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the background task execution
|
||||
func (bt *BackgroundTask) Start() {
|
||||
go bt.run()
|
||||
}
|
||||
|
||||
// Stop terminates the background task
|
||||
func (bt *BackgroundTask) Stop() {
|
||||
close(bt.stopChan)
|
||||
}
|
||||
|
||||
// run is the main execution loop for the background task
|
||||
func (bt *BackgroundTask) run() {
|
||||
ticker := time.NewTicker(bt.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
bt.logger.Debug("Starting background task: %s", bt.name)
|
||||
|
||||
// Run task immediately on startup
|
||||
bt.taskFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
bt.taskFunc()
|
||||
case <-bt.stopChan:
|
||||
bt.logger.Debug("Stopping background task: %s", bt.name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// autoCleanupRoutine periodically calls the provided cleanup function.
|
||||
// It starts a ticker with the given interval and executes the cleanup function
|
||||
// on each tick. The routine stops gracefully when a signal is received on the
|
||||
// stop channel. This is typically used for background cleanup tasks like
|
||||
// expiring cache entries.
|
||||
//
|
||||
// Parameters:
|
||||
// - interval: The time duration between cleanup calls.
|
||||
// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop.
|
||||
// - cleanup: The function to call periodically for cleanup tasks.
|
||||
//
|
||||
// Deprecated: Use BackgroundTask instead.
|
||||
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanup()
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAutoCleanupRoutine(t *testing.T) {
|
||||
var counter int32
|
||||
cleanupFunc := func() {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
}
|
||||
stop := make(chan struct{})
|
||||
go autoCleanupRoutine(50*time.Millisecond, stop, cleanupFunc)
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
close(stop)
|
||||
|
||||
if atomic.LoadInt32(&counter) < 3 {
|
||||
t.Errorf("Expected cleanup to be called at least 3 times, got %d", counter)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,387 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// mockTraefikOidc extends TraefikOidc to override JWT verification for testing
|
||||
type mockTraefikOidc struct {
|
||||
*TraefikOidc
|
||||
}
|
||||
|
||||
// Override VerifyToken to avoid JWKS lookup in tests
|
||||
func (m *mockTraefikOidc) VerifyToken(token string) error {
|
||||
// Cache test claims to avoid "claims not found" errors
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
m.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed for testing
|
||||
}
|
||||
|
||||
// Override VerifyJWTSignatureAndClaims to avoid JWKS lookup in tests
|
||||
func (m *mockTraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
// Cache test claims to avoid "claims not found" errors
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
m.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed for testing
|
||||
}
|
||||
|
||||
func TestAzureOIDCRegression(t *testing.T) {
|
||||
// Create a mocked TraefikOidc instance configured for Azure AD
|
||||
mockLogger := NewLogger("debug")
|
||||
|
||||
// Configure for Azure AD provider
|
||||
baseOidc := &TraefikOidc{
|
||||
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
authURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/authorize",
|
||||
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
refreshGracePeriod: 60 * time.Second,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Add rate limiter
|
||||
logger: mockLogger,
|
||||
httpClient: createDefaultHTTPClient(), // Add HTTP client
|
||||
jwkCache: &JWKCache{}, // Add JWK cache
|
||||
tokenCache: NewTokenCache(),
|
||||
tokenBlacklist: NewCache(),
|
||||
allowedUserDomains: make(map[string]struct{}),
|
||||
allowedUsers: make(map[string]struct{}),
|
||||
allowedRolesAndGroups: make(map[string]struct{}),
|
||||
excludedURLs: make(map[string]struct{}),
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
|
||||
// Create the mock wrapper
|
||||
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
|
||||
|
||||
// Initialize session manager
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, mockLogger)
|
||||
tOidc.sessionManager = sessionManager
|
||||
|
||||
// Mock the JWT verification to avoid JWKS lookup issues
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
// For test tokens, always return success and cache claims
|
||||
if strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
|
||||
// Cache test claims for JWT tokens
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil
|
||||
}
|
||||
// For opaque tokens (non-JWT format), return success
|
||||
if !strings.Contains(token, ".") || strings.Count(token, ".") != 2 {
|
||||
return nil
|
||||
}
|
||||
// For JWT tokens, cache basic claims to avoid cache lookup issues
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed for test purposes
|
||||
},
|
||||
}
|
||||
|
||||
// Mock JWT verifier to avoid JWKS lookup
|
||||
tOidc.jwtVerifier = &mockJWTVerifier{
|
||||
verifyFunc: func(jwt *JWT, token string) error {
|
||||
// Also cache claims here to ensure they're available
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil // Always succeed
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Azure provider detection works correctly", func(t *testing.T) {
|
||||
if !tOidc.isAzureProvider() {
|
||||
t.Error("Azure provider should be detected for Azure AD issuer URL")
|
||||
}
|
||||
|
||||
if tOidc.isGoogleProvider() {
|
||||
t.Error("Google provider should not be detected for Azure AD issuer URL")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure auth URL includes correct parameters", func(t *testing.T) {
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that response_mode=query was added for Azure
|
||||
if !strings.Contains(authURL, "response_mode=query") {
|
||||
t.Errorf("response_mode=query not added to Azure auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify offline_access scope is included for Azure providers
|
||||
if !strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("offline_access scope not included in Azure auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify Azure doesn't get Google-specific parameters
|
||||
if strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("access_type=offline incorrectly added to Azure auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
if strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent incorrectly added to Azure auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure access token validation takes priority", func(t *testing.T) {
|
||||
// Create a request and session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
session, _ := tOidc.sessionManager.GetSession(req)
|
||||
|
||||
// Set up session with Azure-style tokens
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
|
||||
// Create a valid JWT access token for testing
|
||||
accessTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
accessToken, _ := createMockJWT(accessTokenClaims)
|
||||
session.SetAccessToken(accessToken)
|
||||
|
||||
// Create an invalid/expired ID token
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
idToken, _ := createMockJWT(idTokenClaims)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
// Mock the token verification to simulate Azure behavior
|
||||
originalTokenVerifier := tOidc.tokenVerifier
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
if token == accessToken {
|
||||
// Access token validation succeeds - cache claims
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil
|
||||
}
|
||||
if token == idToken {
|
||||
// ID token validation fails (expired) - don't cache
|
||||
return newMockError("token has expired")
|
||||
}
|
||||
return newMockError("token validation failed")
|
||||
},
|
||||
}
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test Azure-specific validation
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
|
||||
// Azure should prioritize access token, so even with expired ID token,
|
||||
// user should still be authenticated since access token is valid
|
||||
if !authenticated {
|
||||
t.Error("Azure user should be authenticated when access token is valid, even if ID token is expired")
|
||||
}
|
||||
|
||||
if expired {
|
||||
t.Error("Azure session should not be marked as expired when access token is valid")
|
||||
}
|
||||
|
||||
// May need refresh if we want to get a fresh ID token
|
||||
if !needsRefresh {
|
||||
t.Log("Azure session may not need immediate refresh if access token is still valid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure handles opaque access tokens gracefully", func(t *testing.T) {
|
||||
// Create a request and session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
session, _ := tOidc.sessionManager.GetSession(req)
|
||||
|
||||
// Set up session with opaque access token (non-JWT)
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ValidAccessToken)
|
||||
|
||||
// Create a valid ID token for claims extraction
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
idToken, _ := createMockJWT(idTokenClaims)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
// Mock the token verification
|
||||
originalTokenVerifier := tOidc.tokenVerifier
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
if token == idToken {
|
||||
// ID token is valid - cache claims
|
||||
testClaims := map[string]any{
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
tOidc.tokenCache.Set(token, testClaims, time.Hour)
|
||||
return nil
|
||||
}
|
||||
return newMockError("token validation failed")
|
||||
},
|
||||
}
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test Azure-specific validation with opaque token
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
|
||||
// Azure should handle opaque access tokens gracefully
|
||||
if !authenticated {
|
||||
t.Error("Azure user should be authenticated with opaque access token")
|
||||
}
|
||||
|
||||
if expired {
|
||||
t.Error("Azure session should not be expired with valid tokens")
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
t.Log("Azure session with opaque token may signal refresh to get JWT tokens")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure CSRF handling during token validation failures", func(t *testing.T) {
|
||||
// Create a request and session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
session, _ := tOidc.sessionManager.GetSession(req)
|
||||
|
||||
// Set up session with CSRF token (simulating ongoing auth flow)
|
||||
session.SetCSRF("test-csrf-token-123")
|
||||
session.SetNonce("test-nonce-456")
|
||||
session.SetAuthenticated(false) // Not yet authenticated
|
||||
|
||||
// Save session to simulate real scenario
|
||||
session.Save(req, rw)
|
||||
|
||||
// Mock token verification to always fail (simulating Azure token issues)
|
||||
originalTokenVerifier := tOidc.tokenVerifier
|
||||
tOidc.tokenVerifier = &mockTokenVerifier{
|
||||
verifyFunc: func(token string) error {
|
||||
return newMockError("azure token validation failed")
|
||||
},
|
||||
}
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test that CSRF is preserved during Azure validation failures
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
|
||||
// Should not be authenticated due to validation failure
|
||||
if authenticated {
|
||||
t.Error("Should not be authenticated when token validation fails")
|
||||
}
|
||||
|
||||
// Should be marked as expired since no tokens work
|
||||
if !expired && !needsRefresh {
|
||||
t.Error("Should be marked as needing refresh or expired when validation fails")
|
||||
}
|
||||
|
||||
// Verify CSRF token is still preserved in session
|
||||
if session.GetCSRF() != "test-csrf-token-123" {
|
||||
t.Error("CSRF token should be preserved during Azure token validation failures")
|
||||
}
|
||||
|
||||
if session.GetNonce() != "test-nonce-456" {
|
||||
t.Error("Nonce should be preserved during Azure token validation failures")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// createMockJWT creates a basic JWT token for testing purposes
|
||||
func createMockJWT(claims map[string]any) (string, error) {
|
||||
// Simple mock JWT - in real tests you'd use a proper JWT library
|
||||
// For this test, we'll create a basic three-part token structure
|
||||
header := "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0" // {"alg":"RS256","kid":"test-key-id","typ":"JWT"}
|
||||
|
||||
// Create a simple payload with test claims
|
||||
payload := "eyJpc3MiOiJ0ZXN0LWlzc3VlciIsImF1ZCI6InRlc3QtY2xpZW50LWlkIiwiZXhwIjoxNjM4MzYwMDAwLCJpYXQiOjE2MzgzNTY0MDAsInN1YiI6InVzZXIxMjMiLCJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20ifQ" // Basic claims
|
||||
|
||||
signature := "test-signature"
|
||||
|
||||
return header + "." + payload + "." + signature, nil
|
||||
}
|
||||
|
||||
// Mock error type for testing
|
||||
type mockError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *mockError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func newMockError(message string) error {
|
||||
return &mockError{message: message}
|
||||
}
|
||||
|
||||
// Mock token verifier for testing
|
||||
type mockTokenVerifier struct {
|
||||
verifyFunc func(token string) error
|
||||
}
|
||||
|
||||
func (m *mockTokenVerifier) VerifyToken(token string) error {
|
||||
if m.verifyFunc != nil {
|
||||
return m.verifyFunc(token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mock JWT verifier for testing
|
||||
type mockJWTVerifier struct {
|
||||
verifyFunc func(jwt *JWT, token string) error
|
||||
}
|
||||
|
||||
func (m *mockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
if m.verifyFunc != nil {
|
||||
return m.verifyFunc(jwt, token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,69 +1,229 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheItem represents an item in the cache
|
||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
// Value is the cached data of any type.
|
||||
Value any
|
||||
|
||||
// ExpiresAt is the timestamp when this item should be considered expired.
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Cache is a simple in-memory cache
|
||||
// lruEntry represents an entry in the LRU list.
|
||||
type lruEntry struct {
|
||||
key string
|
||||
}
|
||||
|
||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
|
||||
type Cache struct {
|
||||
items map[string]CacheItem
|
||||
mutex sync.RWMutex
|
||||
items map[string]CacheItem
|
||||
order *list.List
|
||||
elems map[string]*list.Element
|
||||
cleanupTask *BackgroundTask
|
||||
logger *Logger
|
||||
maxSize int
|
||||
autoCleanupInterval time.Duration
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCache creates a new Cache
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 500
|
||||
|
||||
// NewCache creates a new empty cache instance with default settings.
|
||||
// It initializes the internal maps and list and sets the default maximum size.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
items: make(map[string]CacheItem),
|
||||
}
|
||||
return NewCacheWithLogger(nil)
|
||||
}
|
||||
|
||||
// Set adds an item to the cache
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
// NewCacheWithLogger creates a new cache with a specified logger
|
||||
func NewCacheWithLogger(logger *Logger) *Cache {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
logger: logger,
|
||||
}
|
||||
c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with the specified key, value, and expiration duration.
|
||||
// If the key already exists, its value and expiration time are updated, and it's moved
|
||||
// to the most recently used position in the LRU list.
|
||||
// If the key does not exist and the cache is full, the least recently used item is evicted
|
||||
// before adding the new item.
|
||||
// The expiration duration is relative to the time Set is called.
|
||||
func (c *Cache) Set(key string, value any, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
expTime := now.Add(expiration)
|
||||
|
||||
// Update existing item.
|
||||
if _, exists := c.items[key]; exists {
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: expTime,
|
||||
}
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.MoveToBack(elem)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest item if cache is full.
|
||||
if len(c.items) >= c.maxSize {
|
||||
c.evictOldest()
|
||||
}
|
||||
|
||||
// Add new item.
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
ExpiresAt: expTime,
|
||||
}
|
||||
elem := c.order.PushBack(lruEntry{key: key})
|
||||
c.elems[key] = elem
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
item, found := c.items[key]
|
||||
if !found {
|
||||
// Get retrieves an item from the cache by its key.
|
||||
// If the item exists and has not expired, its value and true are returned.
|
||||
// Accessing an item moves it to the most recently used position in the LRU list.
|
||||
// If the item does not exist or has expired, nil and false are returned, and the
|
||||
// expired item is removed from the cache.
|
||||
func (c *Cache) Get(key string) (any, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check for expiration.
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
c.removeItem(key)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Move item to the back (most recently used).
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.MoveToBack(elem)
|
||||
}
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
// Delete removes an item from the cache by its key.
|
||||
// If the key exists, the corresponding item is removed from the cache storage
|
||||
// and the LRU list.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.items, key)
|
||||
|
||||
c.removeItem(key)
|
||||
}
|
||||
|
||||
// Cleanup removes expired items from the cache
|
||||
// Cleanup iterates through the cache and removes all items that have expired.
|
||||
// An item is considered expired if the current time is after its ExpiresAt timestamp.
|
||||
// This method is called automatically by the auto-cleanup goroutine, but can also
|
||||
// be called manually.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, item := range c.items {
|
||||
// Remove items that are expired
|
||||
if now.After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
c.removeItem(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used (oldest) item from the cache.
|
||||
// It first attempts to find and remove an expired item from the front of the LRU list.
|
||||
// If no expired items are found at the front, it removes the absolute oldest item (front of the list).
|
||||
// This method is called internally by Set when the cache reaches its maximum size.
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) evictOldest() {
|
||||
now := time.Now()
|
||||
elem := c.order.Front()
|
||||
|
||||
// First try to find an expired item from the front
|
||||
for elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
if item, exists := c.items[entry.key]; exists {
|
||||
if now.After(item.ExpiresAt) {
|
||||
c.removeItem(entry.key)
|
||||
return
|
||||
}
|
||||
}
|
||||
elem = elem.Next()
|
||||
}
|
||||
|
||||
// If no expired items found, remove the oldest item
|
||||
if elem = c.order.Front(); elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
c.removeItem(entry.key)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxSize changes the maximum number of items the cache can hold.
|
||||
// If the new size is smaller than the current number of items in the cache,
|
||||
// oldest items will be evicted until the cache size is within the new limit.
|
||||
func (c *Cache) SetMaxSize(size int) {
|
||||
if size <= 0 {
|
||||
return // Invalid size, ignore
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.maxSize = size
|
||||
|
||||
// If cache exceeds the new max size, evict oldest items
|
||||
for len(c.items) > c.maxSize {
|
||||
c.evictOldest()
|
||||
}
|
||||
}
|
||||
|
||||
// removeItem removes an item specified by the key from the cache's internal storage (items map)
|
||||
// and its corresponding entry from the LRU list (order list and elems map).
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) removeItem(key string) {
|
||||
delete(c.items, key)
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.Remove(elem)
|
||||
delete(c.elems, key)
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background task that automatically calls the Cleanup method
|
||||
// at the interval specified by c.autoCleanupInterval.
|
||||
func (c *Cache) startAutoCleanup() {
|
||||
c.cleanupTask = NewBackgroundTask("cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
|
||||
c.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup task associated with this cache instance.
|
||||
// It should be called when the cache is no longer needed to prevent resource leaks.
|
||||
func (c *Cache) Close() {
|
||||
if c.cleanupTask != nil {
|
||||
c.cleanupTask.Stop()
|
||||
c.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCache_Cleanup(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
// Add some items with different expiration times
|
||||
now := time.Now()
|
||||
pastTime := now.Add(-1 * time.Hour) // Already expired
|
||||
futureTime := now.Add(1 * time.Hour) // Not expired
|
||||
|
||||
// Create test items
|
||||
c.items["expired"] = CacheItem{
|
||||
Value: "expired-value",
|
||||
ExpiresAt: pastTime,
|
||||
}
|
||||
|
||||
c.items["valid"] = CacheItem{
|
||||
Value: "valid-value",
|
||||
ExpiresAt: futureTime,
|
||||
}
|
||||
|
||||
// Store original elements in the order list to match items
|
||||
c.elems["expired"] = c.order.PushBack(lruEntry{key: "expired"})
|
||||
c.elems["valid"] = c.order.PushBack(lruEntry{key: "valid"})
|
||||
|
||||
// Call cleanup, which should only remove expired items
|
||||
c.Cleanup()
|
||||
|
||||
// Check that only the expired item was removed
|
||||
if _, exists := c.items["expired"]; exists {
|
||||
t.Error("Expired item was not removed by Cleanup()")
|
||||
}
|
||||
|
||||
if _, exists := c.items["valid"]; !exists {
|
||||
t.Error("Valid item was incorrectly removed by Cleanup()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_SetMaxSize(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
// Set a lower max size
|
||||
originalMaxSize := c.maxSize
|
||||
newMaxSize := 3
|
||||
|
||||
// Add more items than the new max size
|
||||
for i := range originalMaxSize {
|
||||
key := "key" + string(rune('A'+i))
|
||||
c.Set(key, i, 1*time.Hour)
|
||||
}
|
||||
|
||||
// Verify items were added
|
||||
if len(c.items) != originalMaxSize {
|
||||
t.Errorf("Expected %d items before SetMaxSize, got %d", originalMaxSize, len(c.items))
|
||||
}
|
||||
|
||||
// Change the max size to a smaller value
|
||||
c.SetMaxSize(newMaxSize)
|
||||
|
||||
// Check that the cache was reduced to the new max size
|
||||
if len(c.items) > newMaxSize {
|
||||
t.Errorf("Cache size %d exceeds new max size %d after SetMaxSize", len(c.items), newMaxSize)
|
||||
}
|
||||
|
||||
if c.maxSize != newMaxSize {
|
||||
t.Errorf("Cache maxSize not updated, expected %d, got %d", newMaxSize, c.maxSize)
|
||||
}
|
||||
|
||||
// Check that the oldest items were evicted (should keep "keyC", "keyD", "keyE", etc.)
|
||||
if _, exists := c.items["keyA"]; exists {
|
||||
t.Error("Expected oldest item 'keyA' to be evicted, but it still exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWKCache_WithInternalCache(t *testing.T) {
|
||||
cache := NewJWKCache()
|
||||
|
||||
// Check that the internal cache is properly initialized
|
||||
if cache.internalCache == nil {
|
||||
t.Error("internalCache field was not initialized")
|
||||
}
|
||||
|
||||
// Test max size configuration
|
||||
testSize := 50
|
||||
cache.SetMaxSize(testSize)
|
||||
|
||||
if cache.maxSize != testSize {
|
||||
t.Errorf("JWKCache maxSize not updated, expected %d, got %d", testSize, cache.maxSize)
|
||||
}
|
||||
|
||||
if cache.internalCache.maxSize != testSize {
|
||||
t.Errorf("internalCache maxSize not updated, expected %d, got %d", testSize, cache.internalCache.maxSize)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
# Google OAuth Integration Fix
|
||||
|
||||
## Problem Overview
|
||||
|
||||
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
|
||||
|
||||
```
|
||||
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
|
||||
```
|
||||
|
||||
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
|
||||
|
||||
## Technical Details of the Issue
|
||||
|
||||
### Standard OIDC Provider Behavior
|
||||
|
||||
Most OpenID Connect (OIDC) providers follow the standard specification, where:
|
||||
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
|
||||
- This allows authenticated sessions to persist beyond the initial access token expiration
|
||||
|
||||
### Google's Non-Standard Approach
|
||||
|
||||
Google's OAuth implementation deviates from the standard by:
|
||||
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
|
||||
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
|
||||
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
|
||||
|
||||
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
|
||||
|
||||
## Solution Implementation
|
||||
|
||||
The fix involved modifying the authentication flow to specifically handle Google providers:
|
||||
|
||||
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
|
||||
|
||||
```go
|
||||
// Check if we're dealing with a Google OIDC provider
|
||||
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
|
||||
strings.Contains(t.issuerURL, "accounts.google.com")
|
||||
```
|
||||
|
||||
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
|
||||
|
||||
```go
|
||||
// Handle offline access differently for Google vs other providers
|
||||
if isGoogleProvider {
|
||||
// For Google, use access_type=offline parameter instead of offline_access scope
|
||||
params.Set("access_type", "offline")
|
||||
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
|
||||
// Add prompt=consent for Google to ensure refresh token is issued
|
||||
params.Set("prompt", "consent")
|
||||
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
} else {
|
||||
// For non-Google providers, use the offline_access scope
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
|
||||
|
||||
## Why This Approach Works
|
||||
|
||||
This solution aligns with Google's OAuth 2.0 documentation which specifies:
|
||||
|
||||
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
|
||||
|
||||
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
|
||||
|
||||
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
|
||||
|
||||
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
|
||||
|
||||
## Testing and Verification
|
||||
|
||||
Comprehensive tests were implemented to verify the solution:
|
||||
|
||||
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
|
||||
|
||||
2. **Auth URL Parameter Tests**: Verifies that:
|
||||
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
|
||||
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
|
||||
|
||||
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
|
||||
|
||||
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
|
||||
|
||||
Sample test case (simplified):
|
||||
|
||||
```go
|
||||
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
|
||||
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that access_type=offline was added (not offline_access scope for Google)
|
||||
if !strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify offline_access scope is NOT included for Google providers
|
||||
if strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Check that prompt=consent was added
|
||||
if !strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Usage Guidance for Developers
|
||||
|
||||
When configuring the Traefik OIDC middleware for Google:
|
||||
|
||||
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
|
||||
|
||||
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
|
||||
- Configure the authorized redirect URI to match your `callbackURL` setting
|
||||
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
|
||||
|
||||
3. **Configuration Example**:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-google-client-id.apps.googleusercontent.com
|
||||
clientSecret: your-google-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
# Note: DO NOT manually add offline_access scope for Google
|
||||
# The middleware handles this automatically and correctly
|
||||
```
|
||||
|
||||
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
|
||||
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
|
||||
- Review your application logs with `logLevel: debug` to check for refresh token errors
|
||||
- Verify you're using a version of the middleware that includes this fix
|
||||
|
||||
## Conclusion
|
||||
|
||||
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
|
||||
@@ -0,0 +1,844 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorRecoveryMechanism defines the common interface for all error recovery mechanisms
|
||||
type ErrorRecoveryMechanism interface {
|
||||
// ExecuteWithContext executes a function with error recovery
|
||||
ExecuteWithContext(ctx context.Context, fn func() error) error
|
||||
// GetMetrics returns metrics about the error recovery mechanism
|
||||
GetMetrics() map[string]any
|
||||
// Reset resets the state of the error recovery mechanism
|
||||
Reset()
|
||||
// IsAvailable returns whether the mechanism is available for use
|
||||
IsAvailable() bool
|
||||
}
|
||||
|
||||
// BaseRecoveryMechanism provides common functionality for error recovery mechanisms
|
||||
type BaseRecoveryMechanism struct {
|
||||
startTime time.Time
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
logger *Logger
|
||||
name string
|
||||
totalRequests int64
|
||||
totalFailures int64
|
||||
totalSuccesses int64
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBaseRecoveryMechanism creates a new base recovery mechanism
|
||||
func NewBaseRecoveryMechanism(name string, logger *Logger) *BaseRecoveryMechanism {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
return &BaseRecoveryMechanism{
|
||||
name: name,
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordRequest records a request to the error recovery mechanism
|
||||
func (b *BaseRecoveryMechanism) RecordRequest() {
|
||||
atomic.AddInt64(&b.totalRequests, 1)
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
func (b *BaseRecoveryMechanism) RecordSuccess() {
|
||||
atomic.AddInt64(&b.totalSuccesses, 1)
|
||||
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
b.lastSuccessTime = time.Now()
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
func (b *BaseRecoveryMechanism) RecordFailure() {
|
||||
atomic.AddInt64(&b.totalFailures, 1)
|
||||
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
b.lastFailureTime = time.Now()
|
||||
}
|
||||
|
||||
// GetBaseMetrics returns base metrics common to all recovery mechanisms
|
||||
func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]any {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
|
||||
metrics := map[string]any{
|
||||
"total_requests": atomic.LoadInt64(&b.totalRequests),
|
||||
"total_failures": atomic.LoadInt64(&b.totalFailures),
|
||||
"total_successes": atomic.LoadInt64(&b.totalSuccesses),
|
||||
"uptime_seconds": time.Since(b.startTime).Seconds(),
|
||||
"name": b.name,
|
||||
}
|
||||
|
||||
if !b.lastFailureTime.IsZero() {
|
||||
metrics["last_failure_time"] = b.lastFailureTime.Format(time.RFC3339)
|
||||
metrics["seconds_since_last_failure"] = time.Since(b.lastFailureTime).Seconds()
|
||||
}
|
||||
|
||||
if !b.lastSuccessTime.IsZero() {
|
||||
metrics["last_success_time"] = b.lastSuccessTime.Format(time.RFC3339)
|
||||
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
|
||||
}
|
||||
|
||||
// Calculate success rate
|
||||
if metrics["total_requests"].(int64) > 0 {
|
||||
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
|
||||
metrics["success_rate"] = successRate
|
||||
} else {
|
||||
metrics["success_rate"] = 1.0 // Default to 100% if no requests
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// LogInfo logs an informational message
|
||||
func (b *BaseRecoveryMechanism) LogInfo(format string, args ...any) {
|
||||
if b.logger != nil {
|
||||
b.logger.Infof("%s: "+format, append([]any{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error message
|
||||
func (b *BaseRecoveryMechanism) LogError(format string, args ...any) {
|
||||
if b.logger != nil {
|
||||
b.logger.Errorf("%s: "+format, append([]any{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogDebug logs a debug message
|
||||
func (b *BaseRecoveryMechanism) LogDebug(format string, args ...any) {
|
||||
if b.logger != nil {
|
||||
b.logger.Debugf("%s: "+format, append([]any{b.name}, args...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerState represents the current state of a circuit breaker
|
||||
type CircuitBreakerState int
|
||||
|
||||
const (
|
||||
// CircuitBreakerClosed - normal operation, requests are allowed
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen - circuit is open, requests are rejected
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen - testing if service has recovered
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for external service calls
|
||||
type CircuitBreaker struct {
|
||||
*BaseRecoveryMechanism
|
||||
maxFailures int
|
||||
timeout time.Duration
|
||||
resetTimeout time.Duration
|
||||
state CircuitBreakerState
|
||||
failures int64
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration for circuit breakers
|
||||
type CircuitBreakerConfig struct {
|
||||
MaxFailures int `json:"max_failures"`
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
ResetTimeout time.Duration `json:"reset_timeout"`
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns default circuit breaker configuration
|
||||
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
return CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker with the given configuration
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig, logger *Logger) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("circuit-breaker", logger),
|
||||
maxFailures: config.MaxFailures,
|
||||
timeout: config.Timeout,
|
||||
resetTimeout: config.ResetTimeout,
|
||||
state: CircuitBreakerClosed,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
|
||||
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
cb.RecordRequest()
|
||||
|
||||
// Check if circuit breaker allows the request
|
||||
if !cb.allowRequest() {
|
||||
return fmt.Errorf("circuit breaker is open")
|
||||
}
|
||||
|
||||
// Execute the function
|
||||
err := fn()
|
||||
// Record the result
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
cb.RecordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
cb.RecordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute is the original method for backward compatibility
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
return cb.ExecuteWithContext(context.Background(), fn)
|
||||
}
|
||||
|
||||
// allowRequest checks if the circuit breaker allows the request
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
return true
|
||||
|
||||
case CircuitBreakerOpen:
|
||||
// Check if timeout has passed
|
||||
if now.Sub(cb.lastFailureTime) > cb.timeout {
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
cb.logger.Infof("Circuit breaker transitioning to half-open state")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.failures++
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
if cb.failures >= int64(cb.maxFailures) {
|
||||
cb.state = CircuitBreakerOpen
|
||||
cb.LogError("Circuit breaker opened after %d failures", cb.failures)
|
||||
}
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Go back to open state on any failure in half-open
|
||||
cb.state = CircuitBreakerOpen
|
||||
cb.LogError("Circuit breaker returned to open state after failure in half-open")
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a success and potentially closes the circuit
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Reset failures and close circuit on success in half-open
|
||||
cb.failures = 0
|
||||
cb.state = CircuitBreakerClosed
|
||||
cb.LogInfo("Circuit breaker closed after successful request in half-open state")
|
||||
|
||||
case CircuitBreakerClosed:
|
||||
// Reset failure count on success
|
||||
cb.failures = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to its initial state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.state = CircuitBreakerClosed
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
cb.LogInfo("Circuit breaker has been reset")
|
||||
}
|
||||
|
||||
// IsAvailable returns whether the circuit breaker is allowing requests
|
||||
func (cb *CircuitBreaker) IsAvailable() bool {
|
||||
return cb.allowRequest()
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the circuit breaker
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]any {
|
||||
cb.mutex.RLock()
|
||||
state := cb.state
|
||||
failures := cb.failures
|
||||
cb.mutex.RUnlock()
|
||||
|
||||
metrics := cb.GetBaseMetrics()
|
||||
|
||||
// Add circuit breaker specific metrics
|
||||
stateStr := "unknown"
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
stateStr = "closed"
|
||||
case CircuitBreakerOpen:
|
||||
stateStr = "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
stateStr = "half-open"
|
||||
}
|
||||
|
||||
metrics["state"] = stateStr
|
||||
metrics["max_failures"] = cb.maxFailures
|
||||
metrics["current_failures"] = failures
|
||||
metrics["timeout_ms"] = cb.timeout.Milliseconds()
|
||||
metrics["reset_timeout_ms"] = cb.resetTimeout.Milliseconds()
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// RetryConfig holds configuration for retry mechanisms
|
||||
type RetryConfig struct {
|
||||
RetryableErrors []string `json:"retryable_errors"`
|
||||
MaxAttempts int `json:"max_attempts"`
|
||||
InitialDelay time.Duration `json:"initial_delay"`
|
||||
MaxDelay time.Duration `json:"max_delay"`
|
||||
BackoffFactor float64 `json:"backoff_factor"`
|
||||
EnableJitter bool `json:"enable_jitter"`
|
||||
}
|
||||
|
||||
// DefaultRetryConfig returns default retry configuration
|
||||
func DefaultRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true,
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff
|
||||
type RetryExecutor struct {
|
||||
*BaseRecoveryMechanism
|
||||
config RetryConfig
|
||||
}
|
||||
|
||||
// NewRetryExecutor creates a new retry executor
|
||||
func NewRetryExecutor(config RetryConfig, logger *Logger) *RetryExecutor {
|
||||
return &RetryExecutor{
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("retry-executor", logger),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
|
||||
func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
re.RecordRequest()
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= re.config.MaxAttempts; attempt++ {
|
||||
// Execute the function
|
||||
err := fn()
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
re.LogInfo("Operation succeeded on attempt %d", attempt)
|
||||
}
|
||||
re.RecordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Check if error is retryable
|
||||
if !re.isRetryableError(err) {
|
||||
re.LogDebug("Non-retryable error on attempt %d: %v", attempt, err)
|
||||
re.RecordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't wait after the last attempt
|
||||
if attempt == re.config.MaxAttempts {
|
||||
re.RecordFailure()
|
||||
break
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff
|
||||
delay := re.calculateDelay(attempt)
|
||||
re.LogDebug("Retrying operation after %v (attempt %d/%d): %v",
|
||||
delay, attempt, re.config.MaxAttempts, err)
|
||||
|
||||
// Wait with context cancellation support
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
re.RecordFailure()
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
// Continue to next attempt
|
||||
}
|
||||
}
|
||||
|
||||
finalErr := fmt.Errorf("operation failed after %d attempts: %w", re.config.MaxAttempts, lastErr)
|
||||
return finalErr
|
||||
}
|
||||
|
||||
// Execute runs the given function with retry logic (for backward compatibility)
|
||||
func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
return re.ExecuteWithContext(ctx, fn)
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
// Check against configured retryable errors
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
if contains(errStr, retryableErr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for common network errors using modern Go error handling
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
// Use Timeout() method which is still valid
|
||||
if netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
// Check for specific temporary error patterns instead of deprecated Temporary()
|
||||
errStr := netErr.Error()
|
||||
temporaryPatterns := []string{
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"network is unreachable",
|
||||
"no route to host",
|
||||
"temporary failure",
|
||||
"try again",
|
||||
"resource temporarily unavailable",
|
||||
}
|
||||
for _, pattern := range temporaryPatterns {
|
||||
if contains(errStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for HTTP status codes that are retryable
|
||||
if httpErr, ok := err.(*HTTPError); ok {
|
||||
return httpErr.StatusCode >= 500 || httpErr.StatusCode == 429
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// calculateDelay calculates the delay for the next retry attempt
|
||||
func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
// Calculate exponential backoff
|
||||
delay := float64(re.config.InitialDelay) * math.Pow(re.config.BackoffFactor, float64(attempt-1))
|
||||
|
||||
// Apply maximum delay limit
|
||||
if delay > float64(re.config.MaxDelay) {
|
||||
delay = float64(re.config.MaxDelay)
|
||||
}
|
||||
|
||||
// Add jitter to prevent thundering herd
|
||||
if re.config.EnableJitter {
|
||||
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0) // ±10% jitter
|
||||
delay += jitter
|
||||
}
|
||||
|
||||
return time.Duration(delay)
|
||||
}
|
||||
|
||||
// Reset resets the retry executor state
|
||||
func (re *RetryExecutor) Reset() {
|
||||
// Nothing to reset for RetryExecutor
|
||||
re.LogDebug("Retry executor reset")
|
||||
}
|
||||
|
||||
// IsAvailable always returns true for RetryExecutor
|
||||
func (re *RetryExecutor) IsAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the retry executor
|
||||
func (re *RetryExecutor) GetMetrics() map[string]any {
|
||||
metrics := re.GetBaseMetrics()
|
||||
|
||||
// Add retry executor specific metrics
|
||||
metrics["max_attempts"] = re.config.MaxAttempts
|
||||
metrics["initial_delay_ms"] = re.config.InitialDelay.Milliseconds()
|
||||
metrics["max_delay_ms"] = re.config.MaxDelay.Milliseconds()
|
||||
metrics["backoff_factor"] = re.config.BackoffFactor
|
||||
metrics["enable_jitter"] = re.config.EnableJitter
|
||||
metrics["retryable_errors"] = re.config.RetryableErrors
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// HTTPError represents an HTTP error with status code
|
||||
type HTTPError struct {
|
||||
Message string
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *HTTPError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Message)
|
||||
}
|
||||
|
||||
// GracefulDegradation implements graceful degradation patterns
|
||||
type GracefulDegradation struct {
|
||||
*BaseRecoveryMechanism
|
||||
fallbacks map[string]func() (any, error)
|
||||
healthChecks map[string]func() bool
|
||||
degradedServices map[string]time.Time
|
||||
config GracefulDegradationConfig
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// GracefulDegradationConfig holds configuration for graceful degradation
|
||||
type GracefulDegradationConfig struct {
|
||||
HealthCheckInterval time.Duration `json:"health_check_interval"`
|
||||
RecoveryTimeout time.Duration `json:"recovery_timeout"`
|
||||
EnableFallbacks bool `json:"enable_fallbacks"`
|
||||
}
|
||||
|
||||
// DefaultGracefulDegradationConfig returns default configuration
|
||||
func DefaultGracefulDegradationConfig() GracefulDegradationConfig {
|
||||
return GracefulDegradationConfig{
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
RecoveryTimeout: 5 * time.Minute,
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewGracefulDegradation creates a new graceful degradation manager
|
||||
func NewGracefulDegradation(config GracefulDegradationConfig, logger *Logger) *GracefulDegradation {
|
||||
gd := &GracefulDegradation{
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("graceful-degradation", logger),
|
||||
fallbacks: make(map[string]func() (any, error)),
|
||||
healthChecks: make(map[string]func() bool),
|
||||
degradedServices: make(map[string]time.Time),
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Start health check routine
|
||||
go gd.startHealthCheckRoutine()
|
||||
|
||||
return gd
|
||||
}
|
||||
|
||||
// RegisterFallback registers a fallback function for a service
|
||||
func (gd *GracefulDegradation) RegisterFallback(serviceName string, fallback func() (any, error)) {
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
gd.fallbacks[serviceName] = fallback
|
||||
}
|
||||
|
||||
// RegisterHealthCheck registers a health check function for a service
|
||||
func (gd *GracefulDegradation) RegisterHealthCheck(serviceName string, healthCheck func() bool) {
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
gd.healthChecks[serviceName] = healthCheck
|
||||
}
|
||||
|
||||
// ExecuteWithContext implements the ErrorRecoveryMechanism interface
|
||||
func (gd *GracefulDegradation) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
gd.RecordRequest()
|
||||
|
||||
// Execute with a simple wrapper
|
||||
_, err := gd.ExecuteWithFallback("default", func() (any, error) {
|
||||
return nil, fn()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
gd.RecordFailure()
|
||||
} else {
|
||||
gd.RecordSuccess()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecuteWithFallback executes a function with fallback support
|
||||
func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary func() (any, error)) (any, error) {
|
||||
// Check if service is degraded
|
||||
if gd.isServiceDegraded(serviceName) {
|
||||
gd.LogInfo("Service %s is degraded, using fallback", serviceName)
|
||||
return gd.executeFallback(serviceName)
|
||||
}
|
||||
|
||||
// Try primary function
|
||||
result, err := primary()
|
||||
if err != nil {
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded(serviceName)
|
||||
gd.LogError("Service %s failed: %v", serviceName, err)
|
||||
|
||||
// Try fallback if available
|
||||
if gd.config.EnableFallbacks {
|
||||
gd.LogInfo("Using fallback for service %s", serviceName)
|
||||
return gd.executeFallback(serviceName)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// isServiceDegraded checks if a service is currently degraded
|
||||
func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool {
|
||||
gd.mutex.RLock()
|
||||
defer gd.mutex.RUnlock()
|
||||
|
||||
degradedTime, exists := gd.degradedServices[serviceName]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if recovery timeout has passed
|
||||
if time.Since(degradedTime) > gd.config.RecoveryTimeout {
|
||||
delete(gd.degradedServices, serviceName)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// markServiceDegraded marks a service as degraded
|
||||
func (gd *GracefulDegradation) markServiceDegraded(serviceName string) {
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
|
||||
if _, exists := gd.degradedServices[serviceName]; !exists {
|
||||
gd.LogError("Service %s marked as degraded", serviceName)
|
||||
}
|
||||
|
||||
gd.degradedServices[serviceName] = time.Now()
|
||||
}
|
||||
|
||||
// executeFallback executes the fallback function for a service
|
||||
func (gd *GracefulDegradation) executeFallback(serviceName string) (any, error) {
|
||||
gd.mutex.RLock()
|
||||
fallback, exists := gd.fallbacks[serviceName]
|
||||
gd.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no fallback available for service %s", serviceName)
|
||||
}
|
||||
|
||||
gd.LogInfo("Executing fallback for degraded service %s", serviceName)
|
||||
return fallback()
|
||||
}
|
||||
|
||||
// startHealthCheckRoutine starts the background health check routine
|
||||
func (gd *GracefulDegradation) startHealthCheckRoutine() {
|
||||
healthCheckTask := NewBackgroundTask(
|
||||
"graceful-degradation-health-check",
|
||||
gd.config.HealthCheckInterval,
|
||||
gd.performHealthChecks,
|
||||
gd.BaseRecoveryMechanism.logger,
|
||||
)
|
||||
healthCheckTask.Start()
|
||||
}
|
||||
|
||||
// performHealthChecks runs health checks for all registered services
|
||||
func (gd *GracefulDegradation) performHealthChecks() {
|
||||
gd.mutex.RLock()
|
||||
healthChecks := make(map[string]func() bool)
|
||||
maps.Copy(healthChecks, gd.healthChecks)
|
||||
gd.mutex.RUnlock()
|
||||
|
||||
for serviceName, healthCheck := range healthChecks {
|
||||
if healthCheck() {
|
||||
// Service is healthy, remove from degraded list
|
||||
gd.mutex.Lock()
|
||||
if _, wasDegraded := gd.degradedServices[serviceName]; wasDegraded {
|
||||
delete(gd.degradedServices, serviceName)
|
||||
gd.logger.Infof("Service %s recovered from degraded state", serviceName)
|
||||
}
|
||||
gd.mutex.Unlock()
|
||||
} else {
|
||||
// Service is unhealthy, mark as degraded
|
||||
gd.markServiceDegraded(serviceName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetDegradedServices returns a list of currently degraded services
|
||||
func (gd *GracefulDegradation) GetDegradedServices() []string {
|
||||
gd.mutex.RLock()
|
||||
defer gd.mutex.RUnlock()
|
||||
|
||||
var degraded []string
|
||||
for serviceName := range gd.degradedServices {
|
||||
degraded = append(degraded, serviceName)
|
||||
}
|
||||
|
||||
return degraded
|
||||
}
|
||||
|
||||
// Reset resets the state of all degraded services
|
||||
func (gd *GracefulDegradation) Reset() {
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
|
||||
// Clear degraded services
|
||||
gd.degradedServices = make(map[string]time.Time)
|
||||
gd.LogInfo("Graceful degradation state has been reset")
|
||||
}
|
||||
|
||||
// IsAvailable returns whether the mechanism is available for use
|
||||
func (gd *GracefulDegradation) IsAvailable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the graceful degradation mechanism
|
||||
func (gd *GracefulDegradation) GetMetrics() map[string]any {
|
||||
gd.mutex.RLock()
|
||||
degradedCount := len(gd.degradedServices)
|
||||
|
||||
// Get the names of degraded services
|
||||
degradedServices := make([]string, 0, degradedCount)
|
||||
for service := range gd.degradedServices {
|
||||
degradedServices = append(degradedServices, service)
|
||||
}
|
||||
|
||||
// Get total count of registered fallbacks and health checks
|
||||
fallbackCount := len(gd.fallbacks)
|
||||
healthCheckCount := len(gd.healthChecks)
|
||||
gd.mutex.RUnlock()
|
||||
|
||||
// Get base metrics
|
||||
metrics := gd.GetBaseMetrics()
|
||||
|
||||
// Add graceful degradation specific metrics
|
||||
metrics["degraded_services_count"] = degradedCount
|
||||
metrics["degraded_services"] = degradedServices
|
||||
metrics["registered_fallbacks_count"] = fallbackCount
|
||||
metrics["registered_health_checks_count"] = healthCheckCount
|
||||
metrics["health_check_interval_seconds"] = gd.config.HealthCheckInterval.Seconds()
|
||||
metrics["recovery_timeout_seconds"] = gd.config.RecoveryTimeout.Seconds()
|
||||
metrics["fallbacks_enabled"] = gd.config.EnableFallbacks
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// ErrorRecoveryManager coordinates all error recovery mechanisms
|
||||
type ErrorRecoveryManager struct {
|
||||
circuitBreakers map[string]*CircuitBreaker
|
||||
retryExecutor *RetryExecutor
|
||||
gracefulDegradation *GracefulDegradation
|
||||
logger *Logger
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorRecoveryManager creates a new error recovery manager
|
||||
func NewErrorRecoveryManager(logger *Logger) *ErrorRecoveryManager {
|
||||
return &ErrorRecoveryManager{
|
||||
circuitBreakers: make(map[string]*CircuitBreaker),
|
||||
retryExecutor: NewRetryExecutor(DefaultRetryConfig(), logger),
|
||||
gracefulDegradation: NewGracefulDegradation(DefaultGracefulDegradationConfig(), logger),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCircuitBreaker gets or creates a circuit breaker for a service
|
||||
func (erm *ErrorRecoveryManager) GetCircuitBreaker(serviceName string) *CircuitBreaker {
|
||||
erm.mutex.Lock()
|
||||
defer erm.mutex.Unlock()
|
||||
|
||||
if cb, exists := erm.circuitBreakers[serviceName]; exists {
|
||||
return cb
|
||||
}
|
||||
|
||||
cb := NewCircuitBreaker(DefaultCircuitBreakerConfig(), erm.logger)
|
||||
erm.circuitBreakers[serviceName] = cb
|
||||
return cb
|
||||
}
|
||||
|
||||
// ExecuteWithRecovery executes a function with full error recovery support
|
||||
func (erm *ErrorRecoveryManager) ExecuteWithRecovery(ctx context.Context, serviceName string, fn func() error) error {
|
||||
cb := erm.GetCircuitBreaker(serviceName)
|
||||
|
||||
return erm.retryExecutor.Execute(ctx, func() error {
|
||||
return cb.Execute(fn)
|
||||
})
|
||||
}
|
||||
|
||||
// GetRecoveryMetrics returns metrics for all recovery mechanisms
|
||||
func (erm *ErrorRecoveryManager) GetRecoveryMetrics() map[string]any {
|
||||
erm.mutex.RLock()
|
||||
defer erm.mutex.RUnlock()
|
||||
|
||||
metrics := make(map[string]any)
|
||||
|
||||
// Circuit breaker metrics
|
||||
cbMetrics := make(map[string]any)
|
||||
for name, cb := range erm.circuitBreakers {
|
||||
cbMetrics[name] = cb.GetMetrics()
|
||||
}
|
||||
metrics["circuit_breakers"] = cbMetrics
|
||||
|
||||
// Degraded services
|
||||
metrics["degraded_services"] = erm.gracefulDegradation.GetDegradedServices()
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) &&
|
||||
(s == substr ||
|
||||
(len(s) > len(substr) &&
|
||||
(s[:len(substr)] == substr ||
|
||||
s[len(s)-len(substr):] == substr ||
|
||||
containsSubstring(s, substr))))
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,428 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
config.MaxFailures = 2
|
||||
config.Timeout = 100 * time.Millisecond
|
||||
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
t.Run("Initial state is closed", func(t *testing.T) {
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be closed, got %v", cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Successful execution", func(t *testing.T) {
|
||||
err := cb.Execute(func() error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Circuit opens after max failures", func(t *testing.T) {
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < config.MaxFailures; i++ {
|
||||
cb.Execute(func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected circuit to be open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Should reject requests when open
|
||||
err := cb.Execute(func() error {
|
||||
return nil
|
||||
})
|
||||
if err == nil || err.Error() != "circuit breaker is open" {
|
||||
t.Errorf("Expected circuit breaker open error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Circuit transitions to half-open after timeout", func(t *testing.T) {
|
||||
// Wait for timeout
|
||||
time.Sleep(config.Timeout + 10*time.Millisecond)
|
||||
|
||||
// Next request should transition to half-open
|
||||
cb.Execute(func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected circuit to be closed after successful request, got %v", cb.GetState())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get metrics", func(t *testing.T) {
|
||||
metrics := cb.GetMetrics()
|
||||
if metrics["state"] == nil {
|
||||
t.Error("Expected metrics to contain state")
|
||||
}
|
||||
if metrics["total_requests"] == nil {
|
||||
t.Error("Expected metrics to contain total_requests")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryExecutor(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
config := DefaultRetryConfig()
|
||||
config.MaxAttempts = 3
|
||||
config.InitialDelay = 10 * time.Millisecond
|
||||
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
t.Run("Successful execution on first attempt", func(t *testing.T) {
|
||||
attempts := 0
|
||||
err := re.Execute(context.Background(), func() error {
|
||||
attempts++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if attempts != 1 {
|
||||
t.Errorf("Expected 1 attempt, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Retry on retryable error", func(t *testing.T) {
|
||||
attempts := 0
|
||||
err := re.Execute(context.Background(), func() error {
|
||||
attempts++
|
||||
if attempts < 2 {
|
||||
return errors.New("connection refused")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error after retry, got %v", err)
|
||||
}
|
||||
if attempts != 2 {
|
||||
t.Errorf("Expected 2 attempts, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("No retry on non-retryable error", func(t *testing.T) {
|
||||
attempts := 0
|
||||
err := re.Execute(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("non-retryable error")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error to be returned")
|
||||
}
|
||||
if attempts != 1 {
|
||||
t.Errorf("Expected 1 attempt, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Max attempts reached", func(t *testing.T) {
|
||||
attempts := 0
|
||||
err := re.Execute(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("timeout")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error after max attempts")
|
||||
}
|
||||
if attempts != config.MaxAttempts {
|
||||
t.Errorf("Expected %d attempts, got %d", config.MaxAttempts, attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Context cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
err := re.Execute(ctx, func() error {
|
||||
return errors.New("timeout")
|
||||
})
|
||||
|
||||
if err != context.Canceled {
|
||||
t.Errorf("Expected context canceled error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network error handling", func(t *testing.T) {
|
||||
// Test timeout error
|
||||
timeoutErr := &net.OpError{Op: "dial", Err: errors.New("timeout")}
|
||||
if !re.isRetryableError(timeoutErr) {
|
||||
t.Error("Expected timeout error to be retryable")
|
||||
}
|
||||
|
||||
// Test connection refused
|
||||
connErr := errors.New("connection refused")
|
||||
if !re.isRetryableError(connErr) {
|
||||
t.Error("Expected connection refused to be retryable")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTP error handling", func(t *testing.T) {
|
||||
// Test 500 error (retryable)
|
||||
httpErr500 := &HTTPError{StatusCode: 500, Message: "Internal Server Error"}
|
||||
if !re.isRetryableError(httpErr500) {
|
||||
t.Error("Expected 500 error to be retryable")
|
||||
}
|
||||
|
||||
// Test 429 error (retryable)
|
||||
httpErr429 := &HTTPError{StatusCode: 429, Message: "Too Many Requests"}
|
||||
if !re.isRetryableError(httpErr429) {
|
||||
t.Error("Expected 429 error to be retryable")
|
||||
}
|
||||
|
||||
// Test 400 error (not retryable)
|
||||
httpErr400 := &HTTPError{StatusCode: 400, Message: "Bad Request"}
|
||||
if re.isRetryableError(httpErr400) {
|
||||
t.Error("Expected 400 error to not be retryable")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGracefulDegradation(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
config.RecoveryTimeout = 100 * time.Millisecond
|
||||
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer func() {
|
||||
// Clean up goroutine
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}()
|
||||
|
||||
t.Run("Register fallback and health check", func(t *testing.T) {
|
||||
gd.RegisterFallback("test-service", func() (any, error) {
|
||||
return "fallback-result", nil
|
||||
})
|
||||
|
||||
gd.RegisterHealthCheck("test-service", func() bool {
|
||||
return true
|
||||
})
|
||||
|
||||
// Should not be degraded initially
|
||||
if gd.isServiceDegraded("test-service") {
|
||||
t.Error("Service should not be degraded initially")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Execute with fallback on failure", func(t *testing.T) {
|
||||
gd.RegisterFallback("failing-service", func() (any, error) {
|
||||
return "fallback-result", nil
|
||||
})
|
||||
|
||||
// First call should fail and mark service as degraded
|
||||
result, err := gd.ExecuteWithFallback("failing-service", func() (any, error) {
|
||||
return nil, errors.New("service failure")
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected fallback to succeed, got error: %v", err)
|
||||
}
|
||||
if result != "fallback-result" {
|
||||
t.Errorf("Expected fallback result, got %v", result)
|
||||
}
|
||||
|
||||
// Service should now be degraded
|
||||
if !gd.isServiceDegraded("failing-service") {
|
||||
t.Error("Service should be marked as degraded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("No fallback available", func(t *testing.T) {
|
||||
_, err := gd.ExecuteWithFallback("no-fallback-service", func() (any, error) {
|
||||
return nil, errors.New("service failure")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when no fallback available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get degraded services", func(t *testing.T) {
|
||||
degraded := gd.GetDegradedServices()
|
||||
found := slices.Contains(degraded, "failing-service")
|
||||
if !found {
|
||||
t.Error("Expected failing-service to be in degraded list")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Service recovery after timeout", func(t *testing.T) {
|
||||
// Wait for recovery timeout
|
||||
time.Sleep(config.RecoveryTimeout + 20*time.Millisecond)
|
||||
|
||||
// Service should no longer be degraded
|
||||
if gd.isServiceDegraded("failing-service") {
|
||||
t.Error("Service should have recovered after timeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorRecoveryManager(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
erm := NewErrorRecoveryManager(logger)
|
||||
|
||||
t.Run("Get circuit breaker", func(t *testing.T) {
|
||||
cb1 := erm.GetCircuitBreaker("service1")
|
||||
cb2 := erm.GetCircuitBreaker("service1")
|
||||
|
||||
// Should return the same instance
|
||||
if cb1 != cb2 {
|
||||
t.Error("Expected same circuit breaker instance for same service")
|
||||
}
|
||||
|
||||
cb3 := erm.GetCircuitBreaker("service2")
|
||||
if cb1 == cb3 {
|
||||
t.Error("Expected different circuit breaker instances for different services")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Execute with recovery", func(t *testing.T) {
|
||||
attempts := 0
|
||||
err := erm.ExecuteWithRecovery(context.Background(), "test-service", func() error {
|
||||
attempts++
|
||||
if attempts < 2 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected recovery to succeed, got %v", err)
|
||||
}
|
||||
if attempts < 2 {
|
||||
t.Errorf("Expected at least 2 attempts, got %d", attempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get recovery metrics", func(t *testing.T) {
|
||||
metrics := erm.GetRecoveryMetrics()
|
||||
|
||||
if metrics["circuit_breakers"] == nil {
|
||||
t.Error("Expected circuit_breakers in metrics")
|
||||
}
|
||||
if metrics["degraded_services"] == nil {
|
||||
t.Error("Expected degraded_services in metrics")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPError(t *testing.T) {
|
||||
err := &HTTPError{StatusCode: 500, Message: "Internal Server Error"}
|
||||
expected := "HTTP 500: Internal Server Error"
|
||||
if err.Error() != expected {
|
||||
t.Errorf("Expected %q, got %q", expected, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperFunctions(t *testing.T) {
|
||||
t.Run("contains function", func(t *testing.T) {
|
||||
if !contains("hello world", "hello") {
|
||||
t.Error("Expected contains to find substring at start")
|
||||
}
|
||||
if !contains("hello world", "world") {
|
||||
t.Error("Expected contains to find substring at end")
|
||||
}
|
||||
if !contains("hello world", "lo wo") {
|
||||
t.Error("Expected contains to find substring in middle")
|
||||
}
|
||||
if contains("hello world", "xyz") {
|
||||
t.Error("Expected contains to not find non-existent substring")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("containsSubstring function", func(t *testing.T) {
|
||||
if !containsSubstring("hello world", "lo wo") {
|
||||
t.Error("Expected containsSubstring to find substring")
|
||||
}
|
||||
if containsSubstring("hello", "hello world") {
|
||||
t.Error("Expected containsSubstring to not find longer substring")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultConfigs(t *testing.T) {
|
||||
t.Run("DefaultCircuitBreakerConfig", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
if config.MaxFailures <= 0 {
|
||||
t.Error("Expected positive MaxFailures")
|
||||
}
|
||||
if config.Timeout <= 0 {
|
||||
t.Error("Expected positive Timeout")
|
||||
}
|
||||
if config.ResetTimeout <= 0 {
|
||||
t.Error("Expected positive ResetTimeout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DefaultRetryConfig", func(t *testing.T) {
|
||||
config := DefaultRetryConfig()
|
||||
if config.MaxAttempts <= 0 {
|
||||
t.Error("Expected positive MaxAttempts")
|
||||
}
|
||||
if config.InitialDelay <= 0 {
|
||||
t.Error("Expected positive InitialDelay")
|
||||
}
|
||||
if config.BackoffFactor <= 1 {
|
||||
t.Error("Expected BackoffFactor > 1")
|
||||
}
|
||||
if len(config.RetryableErrors) == 0 {
|
||||
t.Error("Expected some retryable errors")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DefaultGracefulDegradationConfig", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
if config.HealthCheckInterval <= 0 {
|
||||
t.Error("Expected positive HealthCheckInterval")
|
||||
}
|
||||
if config.RecoveryTimeout <= 0 {
|
||||
t.Error("Expected positive RecoveryTimeout")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Mock network error for testing
|
||||
type mockNetError struct {
|
||||
timeout bool
|
||||
temp bool
|
||||
}
|
||||
|
||||
func (e *mockNetError) Error() string { return "mock network error" }
|
||||
func (e *mockNetError) Timeout() bool { return e.timeout }
|
||||
func (e *mockNetError) Temporary() bool { return e.temp }
|
||||
|
||||
func TestNetworkErrorHandling(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
config := DefaultRetryConfig()
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
t.Run("Timeout error is retryable", func(t *testing.T) {
|
||||
err := &mockNetError{timeout: true}
|
||||
if !re.isRetryableError(err) {
|
||||
t.Error("Expected timeout error to be retryable")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-timeout network error with retryable pattern", func(t *testing.T) {
|
||||
err := &mockNetError{timeout: false}
|
||||
// This should not be retryable since it doesn't match patterns and isn't timeout
|
||||
if re.isRetryableError(err) {
|
||||
t.Error("Expected non-timeout network error without pattern to not be retryable")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -6,7 +6,7 @@ toolchain go1.23.1
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.4.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
golang.org/x/time v0.7.0
|
||||
)
|
||||
|
||||
|
||||
@@ -6,9 +6,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
||||
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
|
||||
@@ -0,0 +1,595 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// MockJWTVerifier implements the JWTVerifier interface for testing
|
||||
type MockJWTVerifier struct {
|
||||
VerifyJWTFunc func(jwt *JWT, token string) error
|
||||
}
|
||||
|
||||
func (m *MockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
if m.VerifyJWTFunc != nil {
|
||||
return m.VerifyJWTFunc(jwt, token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
|
||||
// Create a mocked TraefikOidc instance that simulates Google provider behavior
|
||||
mockLogger := NewLogger("debug")
|
||||
|
||||
// Create a test instance with a Google-like issuer URL
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://accounts.google.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
logger: mockLogger,
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
refreshGracePeriod: 60,
|
||||
}
|
||||
|
||||
// Create a session manager
|
||||
sessionManager, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, mockLogger)
|
||||
tOidc.sessionManager = sessionManager
|
||||
|
||||
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
|
||||
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that access_type=offline was added (not offline_access scope for Google)
|
||||
if !strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify offline_access scope is NOT included for Google providers
|
||||
if strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Check that prompt=consent was added
|
||||
if !strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-Google provider doesn't add Google-specific params", func(t *testing.T) {
|
||||
// Create a test instance with a non-Google issuer URL
|
||||
nonGoogleOidc := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
logger: mockLogger,
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
}
|
||||
|
||||
// Test buildAuthURL without Google-specific parameters
|
||||
authURL := nonGoogleOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that prompt=consent is not automatically added
|
||||
if strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent added to non-Google auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Session refresh with Google provider", func(t *testing.T) {
|
||||
// Create a request and response recorder
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Create a session and set a refresh token
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetAccessToken(ValidAccessToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
|
||||
// Create a mock token exchanger that simulates Google's behavior
|
||||
mockTokenExchanger := &MockTokenExchanger{
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
// Check that the refresh token is passed correctly
|
||||
if refreshToken != "valid-refresh-token" {
|
||||
t.Errorf("Incorrect refresh token passed: %s", refreshToken)
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// Use standardized test tokens instead of ad-hoc strings
|
||||
testTokens := NewTestTokens()
|
||||
googleTokens := testTokens.GetGoogleTokenSet()
|
||||
|
||||
// Return a simulated Google token response with a new access token
|
||||
// but without a new refresh token (Google doesn't always return a new refresh token)
|
||||
return &TokenResponse{
|
||||
IDToken: googleTokens.IDToken,
|
||||
AccessToken: googleTokens.AccessToken,
|
||||
RefreshToken: "", // Google often doesn't return a new refresh token
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Set the mock token exchanger
|
||||
tOidc.tokenExchanger = mockTokenExchanger
|
||||
|
||||
// Create a struct that implements the TokenVerifier interface
|
||||
tOidc.tokenVerifier = &MockTokenVerifier{
|
||||
VerifyFunc: func(token string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
tOidc.extractClaimsFunc = func(token string) (map[string]any, error) {
|
||||
// Return mock claims
|
||||
return map[string]any{
|
||||
"email": "test@example.com",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Attempt to refresh the token
|
||||
refreshed := tOidc.refreshToken(rw, req, session)
|
||||
|
||||
// Verify the refresh was successful
|
||||
if !refreshed {
|
||||
t.Error("Token refresh failed for Google provider")
|
||||
}
|
||||
|
||||
// Check that we kept the original refresh token since Google didn't provide a new one
|
||||
if session.GetRefreshToken() != "valid-refresh-token" {
|
||||
t.Errorf("Original refresh token not preserved: got %s, expected 'valid-refresh-token'",
|
||||
session.GetRefreshToken())
|
||||
}
|
||||
|
||||
// Use the same test tokens for validation
|
||||
testTokens := NewTestTokens()
|
||||
expectedTokens := testTokens.GetGoogleTokenSet()
|
||||
|
||||
// Check that the tokens were updated correctly
|
||||
if session.GetIDToken() != expectedTokens.IDToken {
|
||||
t.Errorf("ID token not updated: got %s, expected %s",
|
||||
session.GetIDToken(), expectedTokens.IDToken)
|
||||
}
|
||||
|
||||
if session.GetAccessToken() != expectedTokens.AccessToken {
|
||||
t.Errorf("Access token not updated: got %s, expected %s",
|
||||
session.GetAccessToken(), expectedTokens.AccessToken)
|
||||
}
|
||||
})
|
||||
// Test that our fix specifically addresses the reported Google error
|
||||
t.Run("Google provider handles offline access correctly", func(t *testing.T) {
|
||||
// Build the auth URL with Google provider detection
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Parse the URL to examine its parameters
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
params := parsedURL.Query()
|
||||
|
||||
// Verify that access_type=offline is set (Google's way of requesting refresh tokens)
|
||||
if params.Get("access_type") != "offline" {
|
||||
t.Errorf("access_type=offline not set in Google auth URL")
|
||||
}
|
||||
|
||||
// Verify that the scope parameter doesn't contain offline_access
|
||||
// (which Google reports as invalid: {invalid=[offline_access]})
|
||||
scope := params.Get("scope")
|
||||
if strings.Contains(scope, "offline_access") {
|
||||
t.Errorf("offline_access incorrectly included in scope for Google provider: %s", scope)
|
||||
}
|
||||
|
||||
// Verify that the necessary scopes are still included
|
||||
for _, requiredScope := range []string{"openid", "profile", "email"} {
|
||||
if !strings.Contains(scope, requiredScope) {
|
||||
t.Errorf("Required scope '%s' missing from auth URL", requiredScope)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Enhanced test for verifying non-Google provider includes offline_access scope
|
||||
t.Run("Non-Google provider includes offline_access scope", func(t *testing.T) {
|
||||
// Create a test instance with a non-Google issuer URL
|
||||
nonGoogleOidc := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
logger: mockLogger,
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
}
|
||||
|
||||
// Test buildAuthURL for a non-Google provider
|
||||
authURL := nonGoogleOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Parse the URL to examine its parameters
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
params := parsedURL.Query()
|
||||
|
||||
// Verify that access_type=offline is NOT set for non-Google providers
|
||||
if params.Get("access_type") == "offline" {
|
||||
t.Errorf("access_type=offline incorrectly added to non-Google auth URL")
|
||||
}
|
||||
|
||||
// Verify that offline_access scope IS included for non-Google providers
|
||||
scope := params.Get("scope")
|
||||
if !strings.Contains(scope, "offline_access") {
|
||||
t.Errorf("offline_access scope missing from non-Google auth URL scope: %s", scope)
|
||||
}
|
||||
|
||||
// Verify that the necessary scopes are still included
|
||||
for _, requiredScope := range []string{"openid", "profile", "email"} {
|
||||
if !strings.Contains(scope, requiredScope) {
|
||||
t.Errorf("Required scope '%s' missing from non-Google auth URL", requiredScope)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Additional test for complete URL construction for Google provider
|
||||
t.Run("Complete Google auth URL construction", func(t *testing.T) {
|
||||
// Build the auth URL with additional parameters
|
||||
redirectURL := "https://example.com/callback"
|
||||
state := "state123"
|
||||
nonce := "nonce123"
|
||||
codeChallenge := "code_challenge_value" // For PKCE
|
||||
|
||||
// Enable PKCE for this test
|
||||
tOidc.enablePKCE = true
|
||||
|
||||
// Build auth URL
|
||||
authURL := tOidc.buildAuthURL(redirectURL, state, nonce, codeChallenge)
|
||||
|
||||
// Parse the URL to examine its structure and parameters
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify the base URL
|
||||
expectedBaseURL := "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
if !strings.HasPrefix(authURL, expectedBaseURL) && !strings.Contains(authURL, "accounts.google.com") {
|
||||
t.Errorf("Auth URL doesn't start with expected Google OAuth endpoint: %s", authURL)
|
||||
}
|
||||
|
||||
// Check all required parameters
|
||||
params := parsedURL.Query()
|
||||
expectedParams := map[string]string{
|
||||
"client_id": "test-client-id",
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirectURL,
|
||||
"state": state,
|
||||
"nonce": nonce,
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
|
||||
// Also check PKCE parameters if enabled
|
||||
if tOidc.enablePKCE {
|
||||
expectedParams["code_challenge"] = codeChallenge
|
||||
expectedParams["code_challenge_method"] = "S256"
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedParams {
|
||||
if value := params.Get(key); value != expectedValue {
|
||||
t.Errorf("Parameter %s has incorrect value. Expected: %s, Got: %s",
|
||||
key, expectedValue, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify scope parameter separately due to it being space-separated values
|
||||
scope := params.Get("scope")
|
||||
if scope == "" {
|
||||
t.Error("Scope parameter missing from Google auth URL")
|
||||
}
|
||||
|
||||
// Check that all required scopes are present
|
||||
scopeList := strings.Split(scope, " ")
|
||||
expectedScopes := []string{"openid", "profile", "email"}
|
||||
for _, expectedScope := range expectedScopes {
|
||||
found := slices.Contains(scopeList, expectedScope)
|
||||
if !found {
|
||||
t.Errorf("Expected scope '%s' not found in scope parameter: %s", expectedScope, scope)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify offline_access is NOT in the scope list
|
||||
for _, actualScope := range scopeList {
|
||||
if actualScope == "offline_access" {
|
||||
t.Errorf("offline_access scope incorrectly included in Google auth URL: %s", scope)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Integration test with mocked Google provider
|
||||
t.Run("Integration test with mocked Google provider", func(t *testing.T) {
|
||||
// Generate an RSA key for signing the test JWTs
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
|
||||
// Create JWK for the RSA public key
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPrivateKey.PublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(rsaPrivateKey.PublicKey.E)))),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
// Create a mock JWK cache
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
// Create a complete test instance with all required fields
|
||||
mockLogger := NewLogger("debug")
|
||||
googleTOidc := &TraefikOidc{
|
||||
issuerURL: "https://accounts.google.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
logger: mockLogger,
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
refreshGracePeriod: 60,
|
||||
tokenCache: NewTokenCache(), // Initialize tokenCache
|
||||
tokenBlacklist: NewCache(), // Initialize tokenBlacklist
|
||||
enablePKCE: false,
|
||||
limiter: rate.NewLimiter(rate.Inf, 0), // No rate limiting for tests
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://accounts.google.com/jwks",
|
||||
}
|
||||
|
||||
// Create a session manager
|
||||
sessionManager, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, mockLogger)
|
||||
googleTOidc.sessionManager = sessionManager
|
||||
|
||||
// Create a mock token verifier
|
||||
mockTokenVerifier := &MockTokenVerifier{
|
||||
VerifyFunc: func(token string) error {
|
||||
return nil // Always verify successfully for this test
|
||||
},
|
||||
}
|
||||
googleTOidc.tokenVerifier = mockTokenVerifier
|
||||
|
||||
// Create JWT tokens for the test
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
|
||||
// Create initial ID token
|
||||
initialIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "nonce123", // For initial authentication verification
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create refresh ID token
|
||||
refreshedIDToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://accounts.google.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create refreshed test ID token: %v", err)
|
||||
}
|
||||
|
||||
// Set up token verifier with mock
|
||||
googleTOidc.tokenVerifier = &MockTokenVerifier{
|
||||
VerifyFunc: func(token string) error {
|
||||
return nil // Always verify successfully for this test
|
||||
},
|
||||
}
|
||||
|
||||
// Set up JWT verifier with mock
|
||||
googleTOidc.jwtVerifier = &MockJWTVerifier{
|
||||
VerifyJWTFunc: func(jwt *JWT, token string) error {
|
||||
return nil // Always verify successfully for this test
|
||||
},
|
||||
}
|
||||
|
||||
// Create a mock token exchanger that simulates Google's OAuth behavior
|
||||
mockTokenExchanger := &MockTokenExchanger{
|
||||
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
// Verify the correct parameters are passed
|
||||
if grantType != "authorization_code" {
|
||||
t.Errorf("Expected grant_type=authorization_code, got %s", grantType)
|
||||
}
|
||||
if codeOrToken != "test_auth_code" {
|
||||
t.Errorf("Expected code=test_auth_code, got %s", codeOrToken)
|
||||
}
|
||||
if redirectURL != "https://example.com/callback" {
|
||||
t.Errorf("Expected redirect_uri=https://example.com/callback, got %s", redirectURL)
|
||||
}
|
||||
|
||||
// Return a successful token response with a proper JWT
|
||||
return &TokenResponse{
|
||||
IDToken: initialIDToken,
|
||||
AccessToken: initialIDToken, // Use a valid JWT as the access token too
|
||||
RefreshToken: "google_refresh_token",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
// Verify the correct refresh token is passed
|
||||
if refreshToken != "google_refresh_token" {
|
||||
t.Errorf("Expected refresh_token=google_refresh_token, got %s", refreshToken)
|
||||
}
|
||||
|
||||
// Return a successful refresh response with a proper JWT
|
||||
return &TokenResponse{
|
||||
IDToken: refreshedIDToken,
|
||||
AccessToken: refreshedIDToken, // Use a valid JWT as the access token
|
||||
RefreshToken: "", // Google doesn't always return a new refresh token
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
googleTOidc.tokenExchanger = mockTokenExchanger
|
||||
|
||||
// Use the real extractClaimsFunc to parse the proper JWT tokens
|
||||
googleTOidc.extractClaimsFunc = extractClaims
|
||||
|
||||
// 1. Test building the authorization URL
|
||||
authURL := googleTOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Verify Google-specific parameters
|
||||
if !strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("Google auth URL missing access_type=offline: %s", authURL)
|
||||
}
|
||||
if !strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("Google auth URL missing prompt=consent: %s", authURL)
|
||||
}
|
||||
if strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("Google auth URL incorrectly includes offline_access scope: %s", authURL)
|
||||
}
|
||||
|
||||
// 2. Test handling the callback and token exchange
|
||||
// Create a request and response recorder for the callback
|
||||
req := httptest.NewRequest("GET", "/callback?code=test_auth_code&state=state123", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Create a session and set the necessary values
|
||||
session, _ := googleTOidc.sessionManager.GetSession(req)
|
||||
session.SetCSRF("state123") // Must match the state parameter
|
||||
session.SetNonce("nonce123")
|
||||
|
||||
// Save the session to the request
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies from the response and add them to a new request
|
||||
cookies := rw.Result().Cookies()
|
||||
callbackReq := httptest.NewRequest("GET", "/callback?code=test_auth_code&state=state123", nil)
|
||||
for _, cookie := range cookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
callbackRw := httptest.NewRecorder()
|
||||
|
||||
// Handle the callback
|
||||
googleTOidc.handleCallback(callbackRw, callbackReq, "https://example.com/callback")
|
||||
|
||||
// Verify the response is a redirect (302 Found)
|
||||
if callbackRw.Code != 302 {
|
||||
t.Errorf("Expected 302 redirect, got %d", callbackRw.Code)
|
||||
}
|
||||
|
||||
// Create a new request to get the updated session
|
||||
newReq := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range callbackRw.Result().Cookies() {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the updated session
|
||||
newSession, err := googleTOidc.sessionManager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session after callback: %v", err)
|
||||
}
|
||||
|
||||
// Verify the session contains the expected values
|
||||
if !newSession.GetAuthenticated() {
|
||||
t.Error("Session not marked as authenticated after callback")
|
||||
}
|
||||
if newSession.GetEmail() != "user@example.com" {
|
||||
t.Errorf("Session email incorrect: got %s, expected user@example.com",
|
||||
newSession.GetEmail())
|
||||
}
|
||||
|
||||
// Check for non-empty access token that can be parsed as JWT
|
||||
accessToken := newSession.GetAccessToken()
|
||||
if accessToken == "" {
|
||||
t.Error("Session access token is empty")
|
||||
} else {
|
||||
claims, err := extractClaims(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse access token as JWT: %v", err)
|
||||
} else if email, ok := claims["email"].(string); !ok || email != "user@example.com" {
|
||||
t.Errorf("Access token JWT doesn't contain expected email claim")
|
||||
}
|
||||
}
|
||||
|
||||
// Check refresh token
|
||||
if newSession.GetRefreshToken() != "google_refresh_token" {
|
||||
t.Errorf("Session refresh token incorrect: got %s, expected google_refresh_token",
|
||||
newSession.GetRefreshToken())
|
||||
}
|
||||
|
||||
// 3. Test token refresh
|
||||
refreshReq := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range callbackRw.Result().Cookies() {
|
||||
refreshReq.AddCookie(cookie)
|
||||
}
|
||||
refreshRw := httptest.NewRecorder()
|
||||
|
||||
// Get the session for refresh
|
||||
refreshSession, _ := googleTOidc.sessionManager.GetSession(refreshReq)
|
||||
|
||||
// Refresh the token
|
||||
refreshed := googleTOidc.refreshToken(refreshRw, refreshReq, refreshSession)
|
||||
|
||||
// Verify refresh was successful
|
||||
if !refreshed {
|
||||
t.Error("Token refresh failed")
|
||||
}
|
||||
|
||||
// Verify the session data after refresh
|
||||
// Check for non-empty refreshed access token that can be parsed as JWT
|
||||
refreshedAccessToken := refreshSession.GetAccessToken()
|
||||
if refreshedAccessToken == "" {
|
||||
t.Error("Session access token is empty after refresh")
|
||||
} else {
|
||||
claims, err := extractClaims(refreshedAccessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse refreshed access token as JWT: %v", err)
|
||||
} else if email, ok := claims["email"].(string); !ok || email != "user@example.com" {
|
||||
t.Errorf("Refreshed access token JWT doesn't contain expected email claim")
|
||||
}
|
||||
}
|
||||
|
||||
// Since Google didn't return a new refresh token, the original should be preserved
|
||||
if refreshSession.GetRefreshToken() != "google_refresh_token" {
|
||||
t.Errorf("Original refresh token not preserved: got %s, expected google_refresh_token",
|
||||
refreshSession.GetRefreshToken())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// No need to redefine MockTokenExchanger - it's already defined in main_test.go
|
||||
+271
-308
@@ -3,21 +3,26 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateNonce generates a random nonce
|
||||
// generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce.
|
||||
// The nonce is used during the authentication flow to mitigate replay attacks by associating
|
||||
// the ID token with the specific authentication request.
|
||||
// It generates 32 random bytes and encodes them using base64 URL encoding.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (nonce).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -27,16 +32,70 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host, and path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
// generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier.
|
||||
// According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long.
|
||||
// This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (code verifier).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateCodeVerifier() (string, error) {
|
||||
// Using 32 bytes (256 bits) will produce a 43 character base64url string
|
||||
verifierBytes := make([]byte, 32)
|
||||
_, err := rand.Read(verifierBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not generate code verifier: %w", err)
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
|
||||
}
|
||||
|
||||
// exchangeTokens exchanges a code or refresh token for tokens
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
// deriveCodeChallenge computes the PKCE code challenge from a given code verifier.
|
||||
// It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding)
|
||||
// as defined in RFC 7636.
|
||||
//
|
||||
// Parameters:
|
||||
// - codeVerifier: The high-entropy string generated by generateCodeVerifier.
|
||||
//
|
||||
// Returns:
|
||||
// - The base64 URL encoded SHA-256 hash of the code verifier (code challenge).
|
||||
func deriveCodeChallenge(codeVerifier string) string {
|
||||
// Calculate SHA-256 hash of the code verifier
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(codeVerifier))
|
||||
hash := hasher.Sum(nil)
|
||||
|
||||
// Base64url encode the hash to get the code challenge
|
||||
return base64.RawURLEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from the OIDC token endpoint.
|
||||
// It contains the various tokens and metadata returned after successful
|
||||
// code exchange or token refresh operations.
|
||||
type TokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
|
||||
// It handles both the "authorization_code" grant type (exchanging an authorization code for tokens)
|
||||
// and the "refresh_token" grant type (using a refresh token to obtain new tokens).
|
||||
// It includes necessary parameters like client credentials and handles PKCE verification if applicable.
|
||||
// The function follows redirects and handles potential errors during the exchange.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the outgoing HTTP request.
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token").
|
||||
// - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant).
|
||||
// - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant).
|
||||
// - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens (ID, access, refresh).
|
||||
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant).
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
@@ -46,17 +105,38 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
if grantType == "authorization_code" {
|
||||
data.Set("code", codeOrToken)
|
||||
data.Set("redirect_uri", redirectURL)
|
||||
|
||||
// Add code_verifier if PKCE is being used
|
||||
if codeVerifier != "" {
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
} else if grantType == "refresh_token" {
|
||||
data.Set("refresh_token", codeOrToken)
|
||||
}
|
||||
|
||||
client := t.tokenHTTPClient
|
||||
if client == nil {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
client = &http.Client{
|
||||
Transport: t.httpClient.Transport,
|
||||
Timeout: t.httpClient.Timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
|
||||
}
|
||||
@@ -75,256 +155,39 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from the token endpoint
|
||||
type TokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken refreshes the token using the refresh token
|
||||
// getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh)
|
||||
// from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the
|
||||
// "refresh_token" grant type.
|
||||
//
|
||||
// Parameters:
|
||||
// - refreshToken: The refresh token previously obtained during authentication or a prior refresh.
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the newly obtained tokens.
|
||||
// - An error if the refresh operation fails.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token response: %+v", tokenResponse)
|
||||
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// handleLogout handles the logout process
|
||||
func (t *TraefikOidc) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := t.store.Get(r, cookieName)
|
||||
if err != nil {
|
||||
handleError(w, fmt.Sprintf("Error getting session: %v", err), http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Get tokens from session
|
||||
idToken, _ := session.Values["id_token"].(string)
|
||||
refreshToken, _ := session.Values["refresh_token"].(string)
|
||||
accessToken, _ := session.Values["access_token"].(string)
|
||||
|
||||
// Revoke tokens if they exist
|
||||
if refreshToken != "" {
|
||||
t.RevokeTokenWithProvider(refreshToken, "refresh_token")
|
||||
t.RevokeToken(refreshToken)
|
||||
}
|
||||
if accessToken != "" {
|
||||
t.RevokeTokenWithProvider(accessToken, "access_token")
|
||||
t.RevokeToken(accessToken)
|
||||
}
|
||||
|
||||
// Clear session
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if err := session.Save(r, w); err != nil {
|
||||
handleError(w, fmt.Sprintf("Error saving session: %v", err), http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Determine redirect URL
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = r.Host
|
||||
}
|
||||
scheme := "http"
|
||||
if r.Header.Get("X-Forwarded-Proto") == "https" || t.forceHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
baseURL := fmt.Sprintf("%s://%s/", scheme, host)
|
||||
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, baseURL)
|
||||
if err != nil {
|
||||
handleError(w, fmt.Sprintf("Invalid end session URL: %v", err), http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, baseURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the logout URL with proper encoding
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid end session URL: %v", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// handleExpiredToken handles the case when a token has expired
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
|
||||
// Clear the existing session
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
|
||||
// Set new values
|
||||
session.Values["csrf"] = uuid.New().String()
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
session.Values["nonce"], _ = generateNonce()
|
||||
session.Options = defaultSessionOptions
|
||||
|
||||
// Save the session before initiating authentication
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initiate a new authentication flow
|
||||
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback handles the callback from the OIDC provider
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the query parameters
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the state parameter matches the session's CSRF token
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
csrfToken, ok := session.Values["csrf"].(string)
|
||||
if !ok || csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Proceed to exchange the code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
http.Error(rw, "No code in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract id_token
|
||||
idToken := tokenResponse.IDToken
|
||||
if idToken == "" {
|
||||
t.logger.Error("No id_token in token response")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the id_token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims from id_token
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the nonce claim matches the one stored in session
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sessionNonce, ok := session.Values["nonce"].(string)
|
||||
if !ok || sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the email from claims
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Store tokens and authentication status in session
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["email"] = email
|
||||
session.Values["id_token"] = idToken
|
||||
session.Values["refresh_token"] = tokenResponse.RefreshToken
|
||||
session.Options = defaultSessionOptions
|
||||
|
||||
// Remove CSRF and nonce from session
|
||||
delete(session.Values, "csrf")
|
||||
delete(session.Values, "nonce")
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Authentication successful. User email: %s", email)
|
||||
|
||||
// Redirect to the original requested path or default to root
|
||||
redirectPath := "/"
|
||||
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
|
||||
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
|
||||
redirectPath = path
|
||||
}
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// extractClaims extracts claims from a JWT token
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
// extractClaims decodes the payload (claims set) part of a JWT string.
|
||||
// It splits the JWT into its three parts, base64 URL decodes the second part (payload),
|
||||
// and unmarshals the resulting JSON into a map.
|
||||
// Note: This function does *not* validate the token's signature or claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A map representing the JSON claims extracted from the token payload.
|
||||
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
|
||||
func extractClaims(tokenString string) (map[string]any, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
@@ -335,7 +198,7 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to decode token payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
||||
}
|
||||
@@ -343,97 +206,117 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// TokenBlacklist maintains a blacklist of tokens
|
||||
type TokenBlacklist struct {
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new TokenBlacklist
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is blacklisted
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
expiration, exists := tb.blacklist[tokenID]
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for tokenID, expiration := range tb.blacklist {
|
||||
if now.After(expiration) {
|
||||
delete(tb.blacklist, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCache caches tokens
|
||||
// TokenCache provides a caching mechanism for validated tokens.
|
||||
// It stores token claims to avoid repeated validation of the
|
||||
// same token, improving performance for frequently used tokens.
|
||||
type TokenCache struct {
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache
|
||||
const (
|
||||
defaultTokenCacheMaxSize = 1000
|
||||
defaultTokenCacheCleanupInterval = 2 * time.Minute
|
||||
)
|
||||
|
||||
// NewTokenCache creates and initializes a new TokenCache.
|
||||
func NewTokenCache() *TokenCache {
|
||||
cache := NewCache()
|
||||
cache.SetMaxSize(defaultTokenCacheMaxSize)
|
||||
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a token in the cache
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
// Set stores the claims associated with a specific token string in the cache.
|
||||
// It prefixes the token string to avoid potential collisions with other cache types
|
||||
// and sets the provided expiration duration.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string (used as the key).
|
||||
// - claims: The map of claims associated with the token.
|
||||
// - expiration: The duration for which the cache entry should be valid.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]any, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
|
||||
// Get retrieves a token from the cache
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
// Get retrieves the cached claims for a given token string.
|
||||
// It prefixes the token string before querying the underlying cache.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to look up.
|
||||
//
|
||||
// Returns:
|
||||
// - The cached claims map if found and valid.
|
||||
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
|
||||
func (tc *TokenCache) Get(token string) (map[string]any, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
claims, ok := value.(map[string]interface{})
|
||||
claims, ok := value.(map[string]any)
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache
|
||||
// Delete removes the cached entry for a specific token string.
|
||||
// It prefixes the token string before calling the underlying cache's Delete method.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to remove from the cache.
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
}
|
||||
|
||||
// Cleanup cleans up expired tokens from the cache
|
||||
// Cleanup triggers the cleanup process for the underlying generic cache,
|
||||
// removing expired token entries.
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
|
||||
// Close stops the cleanup goroutine in the underlying cache.
|
||||
func (tc *TokenCache) Close() {
|
||||
tc.cache.Close()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically
|
||||
// for the "authorization_code" grant type. It handles the conditional inclusion of the
|
||||
// PKCE code verifier based on the middleware's configuration (t.enablePKCE).
|
||||
//
|
||||
// Parameters:
|
||||
// - code: The authorization code received from the OIDC provider.
|
||||
// - redirectURL: The redirect URI used in the initial authorization request.
|
||||
// - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens.
|
||||
// - An error if the code exchange fails.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
|
||||
|
||||
effectiveCodeVerifier := ""
|
||||
if t.enablePKCE && codeVerifier != "" {
|
||||
effectiveCodeVerifier = codeVerifier
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings
|
||||
// createStringMap converts a slice of strings into a map[string]struct{} (a set).
|
||||
// This is useful for creating efficient lookups (O(1) average time complexity)
|
||||
// for checking the presence of items like allowed domains, roles, or groups.
|
||||
//
|
||||
// Parameters:
|
||||
// - keys: A slice of strings to be added to the set.
|
||||
//
|
||||
// Returns:
|
||||
// - A map where the keys are the strings from the input slice and the values are empty structs.
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
@@ -441,3 +324,83 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout processes requests to the configured logout path.
|
||||
// It performs the following steps:
|
||||
// 1. Retrieves the current user session.
|
||||
// 2. Gets the access token (ID token hint) from the session.
|
||||
// 3. Clears all authentication-related data from the session cookies.
|
||||
// 4. Determines the final post-logout redirect URI.
|
||||
// 5. If an OIDC end_session_endpoint is configured and an ID token hint is available,
|
||||
// it builds the OIDC logout URL and redirects the user agent to the provider for logout.
|
||||
// 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI.
|
||||
//
|
||||
// It handles potential errors during session retrieval or clearing.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
idToken := session.GetIDToken()
|
||||
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's
|
||||
// end_session_endpoint, including the required id_token_hint and optional
|
||||
// post_logout_redirect_uri parameters as query arguments.
|
||||
//
|
||||
// Parameters:
|
||||
// - endSessionURL: The URL of the OIDC provider's end session endpoint.
|
||||
// - idToken: The ID token previously issued to the user (used as id_token_hint).
|
||||
// - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout.
|
||||
//
|
||||
// Returns:
|
||||
// - The fully constructed logout URL string.
|
||||
// - An error if the provided endSessionURL is invalid.
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse end session URL: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
if postLogoutRedirectURI != "" {
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// generateRandomString generates a random string of the specified length
|
||||
// This is used in tests to create unique identifiers
|
||||
func generateRandomString(length int) string {
|
||||
bytes := make([]byte, length/2)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
// In tests, fallback to a predictable string if random fails
|
||||
return "random-string-fallback"
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
@@ -0,0 +1,651 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// InputValidator provides comprehensive input validation and sanitization
|
||||
type InputValidator struct {
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of input validation
|
||||
type ValidationResult struct {
|
||||
SanitizedValue string `json:"sanitized_value,omitempty"`
|
||||
SecurityRisk string `json:"security_risk,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
IsValid bool `json:"is_valid"`
|
||||
}
|
||||
|
||||
// InputValidationConfig holds configuration for input validation
|
||||
type InputValidationConfig struct {
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
}
|
||||
|
||||
// DefaultInputValidationConfig returns default validation configuration
|
||||
func DefaultInputValidationConfig() InputValidationConfig {
|
||||
return InputValidationConfig{
|
||||
MaxTokenLength: 50000, // 50KB for tokens
|
||||
MaxURLLength: 2048, // Standard URL length limit
|
||||
MaxHeaderLength: 8192, // 8KB for headers
|
||||
MaxClaimLength: 1024, // 1KB for individual claims
|
||||
MaxEmailLength: 254, // RFC 5321 limit
|
||||
MaxUsernameLength: 64, // Reasonable username limit
|
||||
StrictMode: true, // Enable strict validation by default
|
||||
}
|
||||
}
|
||||
|
||||
// NewInputValidator creates a new input validator with the given configuration
|
||||
func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputValidator, error) {
|
||||
// Compile regex patterns
|
||||
emailRegex, err := regexp.Compile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile email regex: %w", err)
|
||||
}
|
||||
|
||||
urlRegex, err := regexp.Compile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile URL regex: %w", err)
|
||||
}
|
||||
|
||||
tokenRegex, err := regexp.Compile(`^[A-Za-z0-9._-]+$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile token regex: %w", err)
|
||||
}
|
||||
|
||||
usernameRegex, err := regexp.Compile(`^[a-zA-Z0-9._-]+$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile username regex: %w", err)
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
"create", "alter", "exec", "execute", "script",
|
||||
},
|
||||
xssPatterns: []string{
|
||||
"<script", "</script>", "javascript:", "vbscript:",
|
||||
"onload=", "onerror=", "onclick=", "onmouseover=",
|
||||
"<iframe", "<object", "<embed", "<link", "<meta",
|
||||
},
|
||||
pathTraversalPatterns: []string{
|
||||
"../", "..\\", "%2e%2e%2f", "%2e%2e%5c",
|
||||
"..%2f", "..%5c", "%252e%252e%252f",
|
||||
},
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates JWT tokens and similar token strings
|
||||
func (iv *InputValidator) ValidateToken(token string) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
// Check for empty token
|
||||
if token == "" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "token cannot be empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check length limits
|
||||
if len(token) > iv.maxTokenLength {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("token length %d exceeds maximum %d", len(token), iv.maxTokenLength))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for minimum reasonable length
|
||||
if len(token) < 10 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "token is too short to be valid")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for valid JWT structure (3 parts separated by dots)
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "token does not have valid JWT structure (expected 3 parts)")
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate each part is base64url encoded
|
||||
for i, part := range parts {
|
||||
if !iv.isValidBase64URL(part) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("token part %d is not valid base64url", i+1))
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(token); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
}
|
||||
|
||||
// Check for null bytes and control characters
|
||||
if iv.containsNullBytes(token) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "token contains null bytes")
|
||||
return result
|
||||
}
|
||||
|
||||
if iv.containsControlCharacters(token) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "token contains control characters")
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate UTF-8 encoding
|
||||
if !utf8.ValidString(token) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "token contains invalid UTF-8 sequences")
|
||||
return result
|
||||
}
|
||||
|
||||
result.SanitizedValue = token
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateEmail validates email addresses
|
||||
func (iv *InputValidator) ValidateEmail(email string) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
// Check for empty email
|
||||
if email == "" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "email cannot be empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check length limits
|
||||
if len(email) > iv.maxEmailLength {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("email length %d exceeds maximum %d", len(email), iv.maxEmailLength))
|
||||
return result
|
||||
}
|
||||
|
||||
// Sanitize email (trim whitespace, convert to lowercase)
|
||||
sanitized := strings.TrimSpace(strings.ToLower(email))
|
||||
|
||||
// Check regex pattern
|
||||
if !iv.emailRegex.MatchString(sanitized) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "email format is invalid")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(sanitized); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
}
|
||||
|
||||
// Additional email-specific validations
|
||||
parts := strings.Split(sanitized, "@")
|
||||
if len(parts) != 2 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "email must contain exactly one @ symbol")
|
||||
return result
|
||||
}
|
||||
|
||||
localPart, domain := parts[0], parts[1]
|
||||
|
||||
// Validate local part
|
||||
if len(localPart) == 0 || len(localPart) > 64 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "email local part length is invalid")
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate domain
|
||||
if len(domain) == 0 || len(domain) > 253 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "email domain length is invalid")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for consecutive dots
|
||||
if strings.Contains(sanitized, "..") {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "email contains consecutive dots")
|
||||
return result
|
||||
}
|
||||
|
||||
result.SanitizedValue = sanitized
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateURL validates URLs
|
||||
func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
// Check for empty URL
|
||||
if urlStr == "" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "URL cannot be empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check length limits
|
||||
if len(urlStr) > iv.maxURLLength {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("URL length %d exceeds maximum %d", len(urlStr), iv.maxURLLength))
|
||||
return result
|
||||
}
|
||||
|
||||
// Sanitize URL (trim whitespace)
|
||||
sanitized := strings.TrimSpace(urlStr)
|
||||
|
||||
// Parse URL
|
||||
parsedURL, err := url.Parse(sanitized)
|
||||
if err != nil {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("URL parsing failed: %v", err))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check scheme
|
||||
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "URL scheme must be http or https")
|
||||
return result
|
||||
}
|
||||
|
||||
// Prefer HTTPS
|
||||
if parsedURL.Scheme == "http" {
|
||||
result.Warnings = append(result.Warnings, "HTTP URLs are less secure than HTTPS")
|
||||
}
|
||||
|
||||
// Check host
|
||||
if parsedURL.Host == "" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "URL must have a valid host")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(sanitized); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
}
|
||||
|
||||
// Check for path traversal attempts
|
||||
if iv.containsPathTraversal(sanitized) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "URL contains path traversal patterns")
|
||||
return result
|
||||
}
|
||||
|
||||
result.SanitizedValue = sanitized
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateUsername validates usernames
|
||||
func (iv *InputValidator) ValidateUsername(username string) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
// Check for empty username
|
||||
if username == "" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "username cannot be empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check length limits
|
||||
if len(username) > iv.maxUsernameLength {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("username length %d exceeds maximum %d", len(username), iv.maxUsernameLength))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(username) < 2 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "username must be at least 2 characters long")
|
||||
return result
|
||||
}
|
||||
|
||||
// Sanitize username (trim whitespace)
|
||||
sanitized := strings.TrimSpace(username)
|
||||
|
||||
// Check regex pattern
|
||||
if !iv.usernameRegex.MatchString(sanitized) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "username contains invalid characters (only letters, numbers, dots, underscores, and hyphens allowed)")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(sanitized); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
}
|
||||
|
||||
result.SanitizedValue = sanitized
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateClaim validates individual JWT claims
|
||||
func (iv *InputValidator) ValidateClaim(claimName, claimValue string) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
// Check claim name
|
||||
if claimName == "" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "claim name cannot be empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check claim value length
|
||||
if len(claimValue) > iv.maxClaimLength {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("claim value length %d exceeds maximum %d", len(claimValue), iv.maxClaimLength))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for null bytes and control characters
|
||||
if iv.containsNullBytes(claimValue) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "claim value contains null bytes")
|
||||
return result
|
||||
}
|
||||
|
||||
if iv.containsControlCharacters(claimValue) {
|
||||
result.Warnings = append(result.Warnings, "claim value contains control characters")
|
||||
}
|
||||
|
||||
// Validate UTF-8 encoding
|
||||
if !utf8.ValidString(claimValue) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "claim value contains invalid UTF-8 sequences")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(claimValue); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
}
|
||||
|
||||
// Specific validations based on claim name
|
||||
switch claimName {
|
||||
case "email":
|
||||
emailResult := iv.ValidateEmail(claimValue)
|
||||
if !emailResult.IsValid {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, emailResult.Errors...)
|
||||
}
|
||||
result.Warnings = append(result.Warnings, emailResult.Warnings...)
|
||||
result.SanitizedValue = emailResult.SanitizedValue
|
||||
|
||||
case "iss", "aud":
|
||||
urlResult := iv.ValidateURL(claimValue)
|
||||
if !urlResult.IsValid {
|
||||
// For issuer/audience, we're more lenient - just warn
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("%s claim is not a valid URL: %v", claimName, urlResult.Errors))
|
||||
}
|
||||
result.SanitizedValue = claimValue
|
||||
|
||||
case "preferred_username", "username":
|
||||
usernameResult := iv.ValidateUsername(claimValue)
|
||||
if !usernameResult.IsValid {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, usernameResult.Errors...)
|
||||
}
|
||||
result.Warnings = append(result.Warnings, usernameResult.Warnings...)
|
||||
result.SanitizedValue = usernameResult.SanitizedValue
|
||||
|
||||
default:
|
||||
// Generic string validation
|
||||
result.SanitizedValue = strings.TrimSpace(claimValue)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateHeader validates HTTP header values
|
||||
func (iv *InputValidator) ValidateHeader(headerName, headerValue string) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
// Check header name
|
||||
if headerName == "" {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "header name cannot be empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for control characters in header name (including CRLF)
|
||||
if iv.containsControlCharacters(headerName) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "header name contains control characters")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for CRLF injection in header name
|
||||
if strings.Contains(headerName, "\r") || strings.Contains(headerName, "\n") {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "header name contains CRLF characters (potential header injection)")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check header value length
|
||||
if len(headerValue) > iv.maxHeaderLength {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("header value length %d exceeds maximum %d", len(headerValue), iv.maxHeaderLength))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for null bytes and control characters (except allowed ones)
|
||||
if iv.containsNullBytes(headerValue) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "header value contains null bytes")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for CRLF injection
|
||||
if strings.Contains(headerValue, "\r") || strings.Contains(headerValue, "\n") {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "header value contains CRLF characters (potential header injection)")
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate UTF-8 encoding
|
||||
if !utf8.ValidString(headerValue) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "header value contains invalid UTF-8 sequences")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(headerValue); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
}
|
||||
|
||||
result.SanitizedValue = strings.TrimSpace(headerValue)
|
||||
return result
|
||||
}
|
||||
|
||||
// isValidBase64URL checks if a string is valid base64url encoding
|
||||
func (iv *InputValidator) isValidBase64URL(s string) bool {
|
||||
// Base64url uses A-Z, a-z, 0-9, -, _ and no padding
|
||||
for _, r := range s {
|
||||
if !((r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') || r == '-' || r == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// containsNullBytes checks if a string contains null bytes
|
||||
func (iv *InputValidator) containsNullBytes(s string) bool {
|
||||
return strings.Contains(s, "\x00")
|
||||
}
|
||||
|
||||
// containsControlCharacters checks if a string contains control characters
|
||||
func (iv *InputValidator) containsControlCharacters(s string) bool {
|
||||
for _, r := range s {
|
||||
if unicode.IsControl(r) && r != '\t' && r != '\n' && r != '\r' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// containsPathTraversal checks for path traversal patterns
|
||||
func (iv *InputValidator) containsPathTraversal(s string) bool {
|
||||
lowerS := strings.ToLower(s)
|
||||
for _, pattern := range iv.pathTraversalPatterns {
|
||||
if strings.Contains(lowerS, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectSecurityRisk detects potential security risks in input
|
||||
func (iv *InputValidator) detectSecurityRisk(input string) string {
|
||||
lowerInput := strings.ToLower(input)
|
||||
|
||||
// Check for SQL injection patterns
|
||||
for _, pattern := range iv.sqlInjectionPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return "sql_injection"
|
||||
}
|
||||
}
|
||||
|
||||
// Check for XSS patterns
|
||||
for _, pattern := range iv.xssPatterns {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return "xss"
|
||||
}
|
||||
}
|
||||
|
||||
// Check for path traversal
|
||||
if iv.containsPathTraversal(input) {
|
||||
return "path_traversal"
|
||||
}
|
||||
|
||||
// Check for excessive length (potential DoS)
|
||||
if len(input) > 10000 {
|
||||
return "excessive_length"
|
||||
}
|
||||
|
||||
// Check for suspicious character patterns
|
||||
if iv.containsNullBytes(input) {
|
||||
return "null_bytes"
|
||||
}
|
||||
|
||||
// Check for binary data patterns
|
||||
nonPrintableCount := 0
|
||||
for _, r := range input {
|
||||
if !unicode.IsPrint(r) && !unicode.IsSpace(r) {
|
||||
nonPrintableCount++
|
||||
}
|
||||
}
|
||||
if nonPrintableCount > len(input)/10 { // More than 10% non-printable
|
||||
return "binary_data"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// SanitizeInput provides general input sanitization
|
||||
func (iv *InputValidator) SanitizeInput(input string, maxLength int) string {
|
||||
// Trim whitespace
|
||||
sanitized := strings.TrimSpace(input)
|
||||
|
||||
// Truncate if too long
|
||||
if len(sanitized) > maxLength {
|
||||
sanitized = sanitized[:maxLength]
|
||||
}
|
||||
|
||||
// Remove null bytes
|
||||
sanitized = strings.ReplaceAll(sanitized, "\x00", "")
|
||||
|
||||
// Remove other control characters except tab, newline, carriage return
|
||||
var result strings.Builder
|
||||
for _, r := range sanitized {
|
||||
if !unicode.IsControl(r) || r == '\t' || r == '\n' || r == '\r' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// ValidateBoundaryValues validates numeric boundary values
|
||||
func (iv *InputValidator) ValidateBoundaryValues(value any, min, max int64) ValidationResult {
|
||||
result := ValidationResult{IsValid: true, Errors: []string{}, Warnings: []string{}}
|
||||
|
||||
var numValue int64
|
||||
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
numValue = int64(v)
|
||||
case int32:
|
||||
numValue = int64(v)
|
||||
case int64:
|
||||
numValue = v
|
||||
case float64:
|
||||
numValue = int64(v)
|
||||
if float64(numValue) != v {
|
||||
result.Warnings = append(result.Warnings, "floating point value truncated to integer")
|
||||
}
|
||||
default:
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "value is not a numeric type")
|
||||
return result
|
||||
}
|
||||
|
||||
if numValue < min {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("value %d is below minimum %d", numValue, min))
|
||||
}
|
||||
|
||||
if numValue > max {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("value %d exceeds maximum %d", numValue, max))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,421 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInputValidator(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
logger := NewLogger("debug")
|
||||
validator, err := NewInputValidator(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create validator: %v", err)
|
||||
}
|
||||
|
||||
t.Run("Valid token validation", func(t *testing.T) {
|
||||
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc"
|
||||
|
||||
result := validator.ValidateToken(validToken)
|
||||
if !result.IsValid {
|
||||
t.Errorf("Expected valid token to pass validation, got errors: %v", result.Errors)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid token validation", func(t *testing.T) {
|
||||
invalidTokens := []string{
|
||||
"", // Empty token
|
||||
"invalid.token", // Invalid format
|
||||
"a.b", // Too few parts
|
||||
"a.b.c.d", // Too many parts
|
||||
}
|
||||
|
||||
for _, token := range invalidTokens {
|
||||
result := validator.ValidateToken(token)
|
||||
if result.IsValid {
|
||||
t.Errorf("Expected invalid token '%s' to fail validation", token)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid email validation", func(t *testing.T) {
|
||||
validEmails := []string{
|
||||
"user@example.com",
|
||||
"test.email@domain.co.uk",
|
||||
"user123@test-domain.org",
|
||||
}
|
||||
|
||||
for _, email := range validEmails {
|
||||
result := validator.ValidateEmail(email)
|
||||
if !result.IsValid {
|
||||
t.Errorf("Expected valid email '%s' to pass validation, got errors: %v", email, result.Errors)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid email validation", func(t *testing.T) {
|
||||
invalidEmails := []string{
|
||||
"", // Empty
|
||||
"invalid", // No @ symbol
|
||||
"@domain.com", // No local part
|
||||
"user@", // No domain
|
||||
"user@domain", // No TLD
|
||||
"user..double@domain.com", // Double dots
|
||||
}
|
||||
|
||||
for _, email := range invalidEmails {
|
||||
result := validator.ValidateEmail(email)
|
||||
if result.IsValid {
|
||||
t.Errorf("Expected invalid email '%s' to fail validation", email)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid URL validation", func(t *testing.T) {
|
||||
validURLs := []string{
|
||||
"https://example.com",
|
||||
"https://sub.domain.com/path",
|
||||
"https://localhost:8080/callback",
|
||||
}
|
||||
|
||||
for _, url := range validURLs {
|
||||
result := validator.ValidateURL(url)
|
||||
if !result.IsValid {
|
||||
t.Errorf("Expected valid URL '%s' to pass validation, got errors: %v", url, result.Errors)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid URL validation", func(t *testing.T) {
|
||||
invalidURLs := []string{
|
||||
"", // Empty
|
||||
"not-a-url", // Invalid format
|
||||
"ftp://example.com", // Wrong scheme
|
||||
"https://", // No host
|
||||
}
|
||||
|
||||
for _, url := range invalidURLs {
|
||||
result := validator.ValidateURL(url)
|
||||
if result.IsValid {
|
||||
t.Errorf("Expected invalid URL '%s' to fail validation", url)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid username validation", func(t *testing.T) {
|
||||
validUsernames := []string{
|
||||
"user123",
|
||||
"test_user",
|
||||
"user-name",
|
||||
}
|
||||
|
||||
for _, username := range validUsernames {
|
||||
result := validator.ValidateUsername(username)
|
||||
if !result.IsValid {
|
||||
t.Errorf("Expected valid username '%s' to pass validation, got errors: %v", username, result.Errors)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid username validation", func(t *testing.T) {
|
||||
invalidUsernames := []string{
|
||||
"", // Empty
|
||||
"a", // Too short
|
||||
strings.Repeat("a", 100), // Too long
|
||||
"user name", // Spaces
|
||||
}
|
||||
|
||||
for _, username := range invalidUsernames {
|
||||
result := validator.ValidateUsername(username)
|
||||
if result.IsValid {
|
||||
t.Errorf("Expected invalid username '%s' to fail validation", username)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid claim validation", func(t *testing.T) {
|
||||
validClaims := map[string]string{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
"name": "John Doe",
|
||||
}
|
||||
|
||||
for key, value := range validClaims {
|
||||
result := validator.ValidateClaim(key, value)
|
||||
if !result.IsValid {
|
||||
t.Errorf("Expected valid claim '%s'='%s' to pass validation, got errors: %v", key, value, result.Errors)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid claim validation", func(t *testing.T) {
|
||||
invalidClaims := map[string]string{
|
||||
"": "value", // Empty key
|
||||
"long_key": strings.Repeat("a", 10000), // Too long value
|
||||
}
|
||||
|
||||
for key, value := range invalidClaims {
|
||||
result := validator.ValidateClaim(key, value)
|
||||
if result.IsValid {
|
||||
t.Errorf("Expected invalid claim '%s'='%s' to fail validation", key, value)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Valid header validation", func(t *testing.T) {
|
||||
validHeaders := map[string]string{
|
||||
"Authorization": "Bearer token123",
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom": "custom-value",
|
||||
}
|
||||
|
||||
for key, value := range validHeaders {
|
||||
result := validator.ValidateHeader(key, value)
|
||||
if !result.IsValid {
|
||||
t.Errorf("Expected valid header '%s'='%s' to pass validation, got errors: %v", key, value, result.Errors)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid header validation", func(t *testing.T) {
|
||||
invalidHeaders := map[string]string{
|
||||
"": "value", // Empty key
|
||||
"Invalid\nKey": "value", // Control characters in key
|
||||
"key": "value\r\n", // Control characters in value
|
||||
}
|
||||
|
||||
for key, value := range invalidHeaders {
|
||||
result := validator.ValidateHeader(key, value)
|
||||
if result.IsValid {
|
||||
t.Errorf("Expected invalid header '%s'='%s' to fail validation", key, value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeInput(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
logger := NewLogger("debug")
|
||||
validator, err := NewInputValidator(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create validator: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
name: "Normal text",
|
||||
input: "Hello World",
|
||||
maxLen: 100,
|
||||
expected: "Hello World",
|
||||
},
|
||||
{
|
||||
name: "Control characters",
|
||||
input: "text\x00with\x01control\x02chars",
|
||||
maxLen: 100,
|
||||
expected: "textwithcontrolchars",
|
||||
},
|
||||
{
|
||||
name: "Truncation",
|
||||
input: "very long text",
|
||||
maxLen: 5,
|
||||
expected: "very ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.SanitizeInput(tt.input, tt.maxLen)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected sanitized input '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateBoundaryValues(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
logger := NewLogger("debug")
|
||||
validator, err := NewInputValidator(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create validator: %v", err)
|
||||
}
|
||||
|
||||
t.Run("Valid boundary values", func(t *testing.T) {
|
||||
validValues := []any{
|
||||
int(50),
|
||||
int64(100),
|
||||
float64(75.5),
|
||||
}
|
||||
|
||||
for _, value := range validValues {
|
||||
result := validator.ValidateBoundaryValues(value, 1, 1000)
|
||||
if !result.IsValid {
|
||||
t.Errorf("Expected valid boundary value %v to pass validation, got errors: %v", value, result.Errors)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid boundary values", func(t *testing.T) {
|
||||
invalidValues := []any{
|
||||
int(-1),
|
||||
int64(2000),
|
||||
"not a number",
|
||||
}
|
||||
|
||||
for _, value := range invalidValues {
|
||||
result := validator.ValidateBoundaryValues(value, 1, 1000)
|
||||
if result.IsValid {
|
||||
t.Errorf("Expected invalid boundary value %v to fail validation", value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultInputValidationConfig(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
|
||||
if config.MaxTokenLength <= 0 {
|
||||
t.Error("Expected positive MaxTokenLength")
|
||||
}
|
||||
if config.MaxEmailLength <= 0 {
|
||||
t.Error("Expected positive MaxEmailLength")
|
||||
}
|
||||
if config.MaxUsernameLength <= 0 {
|
||||
t.Error("Expected positive MaxUsernameLength")
|
||||
}
|
||||
if config.MaxClaimLength <= 0 {
|
||||
t.Error("Expected positive MaxClaimLength")
|
||||
}
|
||||
if config.MaxHeaderLength <= 0 {
|
||||
t.Error("Expected positive MaxHeaderLength")
|
||||
}
|
||||
if !config.StrictMode {
|
||||
t.Error("Expected StrictMode to be true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputValidationHelpers(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
logger := NewLogger("debug")
|
||||
validator, err := NewInputValidator(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create validator: %v", err)
|
||||
}
|
||||
|
||||
t.Run("isValidBase64URL", func(t *testing.T) {
|
||||
validBase64URL := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
if !validator.isValidBase64URL(validBase64URL) {
|
||||
t.Error("Expected valid base64url to be recognized")
|
||||
}
|
||||
|
||||
invalidBase64URL := "invalid+base64/with+padding="
|
||||
if validator.isValidBase64URL(invalidBase64URL) {
|
||||
t.Error("Expected invalid base64url to be rejected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("containsNullBytes", func(t *testing.T) {
|
||||
withNull := "text\x00with\x00null"
|
||||
if !validator.containsNullBytes(withNull) {
|
||||
t.Error("Expected string with null bytes to be detected")
|
||||
}
|
||||
|
||||
withoutNull := "normal text"
|
||||
if validator.containsNullBytes(withoutNull) {
|
||||
t.Error("Expected string without null bytes to pass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("containsControlCharacters", func(t *testing.T) {
|
||||
withControl := "text\x01with\x02control"
|
||||
if !validator.containsControlCharacters(withControl) {
|
||||
t.Error("Expected string with control characters to be detected")
|
||||
}
|
||||
|
||||
withoutControl := "normal text"
|
||||
if validator.containsControlCharacters(withoutControl) {
|
||||
t.Error("Expected string without control characters to pass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("containsPathTraversal", func(t *testing.T) {
|
||||
withTraversal := "../../../etc/passwd"
|
||||
if !validator.containsPathTraversal(withTraversal) {
|
||||
t.Error("Expected path traversal to be detected")
|
||||
}
|
||||
|
||||
normalPath := "/normal/path"
|
||||
if validator.containsPathTraversal(normalPath) {
|
||||
t.Error("Expected normal path to pass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("detectSecurityRisk", func(t *testing.T) {
|
||||
riskyInputs := []string{
|
||||
"<script>alert('xss')</script>",
|
||||
"'; DROP TABLE users; --",
|
||||
"javascript:alert('xss')",
|
||||
}
|
||||
|
||||
for _, input := range riskyInputs {
|
||||
if validator.detectSecurityRisk(input) == "" {
|
||||
t.Errorf("Expected security risk to be detected in: %s", input)
|
||||
}
|
||||
}
|
||||
|
||||
safeInput := "normal safe text"
|
||||
if validator.detectSecurityRisk(safeInput) != "" {
|
||||
t.Error("Expected safe input to pass security check")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInputValidationEdgeCases(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
logger := NewLogger("debug")
|
||||
validator, err := NewInputValidator(config, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create validator: %v", err)
|
||||
}
|
||||
|
||||
t.Run("Empty inputs", func(t *testing.T) {
|
||||
// Most validations should reject empty inputs
|
||||
if result := validator.ValidateToken(""); result.IsValid {
|
||||
t.Error("Expected empty token to be rejected")
|
||||
}
|
||||
if result := validator.ValidateEmail(""); result.IsValid {
|
||||
t.Error("Expected empty email to be rejected")
|
||||
}
|
||||
if result := validator.ValidateURL(""); result.IsValid {
|
||||
t.Error("Expected empty URL to be rejected")
|
||||
}
|
||||
if result := validator.ValidateUsername(""); result.IsValid {
|
||||
t.Error("Expected empty username to be rejected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Very long inputs", func(t *testing.T) {
|
||||
longString := strings.Repeat("a", 10000)
|
||||
|
||||
if result := validator.ValidateEmail(longString + "@domain.com"); result.IsValid {
|
||||
t.Error("Expected very long email to be rejected")
|
||||
}
|
||||
if result := validator.ValidateUsername(longString); result.IsValid {
|
||||
t.Error("Expected very long username to be rejected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unicode handling", func(t *testing.T) {
|
||||
unicodeEmail := "用户@example.com"
|
||||
// Should handle unicode gracefully
|
||||
validator.ValidateEmail(unicodeEmail) // Don't fail on unicode
|
||||
|
||||
unicodeUsername := "用户名"
|
||||
validator.ValidateUsername(unicodeUsername) // Don't fail on unicode
|
||||
})
|
||||
}
|
||||
@@ -1,22 +1,21 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
@@ -29,53 +28,150 @@ type JWK struct {
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JWKs
|
||||
type JWKSet struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache caches the JWKs
|
||||
type JWKCache struct {
|
||||
jwks *JWKSet
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
expiresAt time.Time
|
||||
jwks *JWKSet
|
||||
internalCache *Cache
|
||||
CacheLifetime time.Duration
|
||||
maxSize int
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the interface for the JWK cache
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
Cleanup()
|
||||
Close()
|
||||
}
|
||||
|
||||
// GetJWKS gets the JWKS, either from cache or by fetching it
|
||||
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// GetJWKS retrieves the JSON Web Key Set (JWKS) from the cache or fetches it from the provider.
|
||||
// It first checks if a valid, non-expired JWKS is present in the cache. If so, it returns the cached version.
|
||||
// Otherwise, it attempts to fetch the JWKS from the specified jwksURL using the provided httpClient.
|
||||
// If the fetch is successful, the JWKS is stored in the cache with an expiration time based on CacheLifetime
|
||||
// (defaulting to 1 hour if not set) and returned.
|
||||
// This method uses double-checked locking to minimize contention when the cache needs refreshing.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request if fetching is required.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for fetching the JWKS.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the JWKSet containing the keys.
|
||||
// - An error if fetching fails or the response cannot be decoded.
|
||||
func NewJWKCache() *JWKCache {
|
||||
cache := &JWKCache{
|
||||
CacheLifetime: 1 * time.Hour,
|
||||
maxSize: 100, // Default maximum size
|
||||
internalCache: NewCache(),
|
||||
}
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// First check if we already have cached JWKS for this URL
|
||||
if c.internalCache != nil {
|
||||
if cachedJwks, found := c.internalCache.Get(jwksURL); found {
|
||||
return cachedJwks.(*JWKSet), nil
|
||||
}
|
||||
}
|
||||
|
||||
// STABILITY FIX: Fix race condition in double-checked locking
|
||||
// First read check with read lock
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
defer c.mutex.RUnlock()
|
||||
return c.jwks, nil
|
||||
jwks := c.jwks // Copy reference while holding read lock
|
||||
c.mutex.RUnlock()
|
||||
return jwks, nil
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
// Acquire write lock for potential update
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Second check after acquiring write lock (double-checked locking)
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
return c.jwks, nil
|
||||
}
|
||||
|
||||
jwks, err := fetchJWKS(jwksURL, httpClient)
|
||||
// Fetch new JWKS
|
||||
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// STABILITY FIX: Validate JWKS contains keys before caching
|
||||
if len(jwks.Keys) == 0 {
|
||||
return nil, fmt.Errorf("JWKS response contains no keys")
|
||||
}
|
||||
|
||||
// Update cache atomically
|
||||
c.jwks = jwks
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
lifetime := c.CacheLifetime
|
||||
if lifetime == 0 {
|
||||
lifetime = 1 * time.Hour
|
||||
}
|
||||
c.expiresAt = time.Now().Add(lifetime)
|
||||
|
||||
// Also store in the internalCache
|
||||
if c.internalCache != nil {
|
||||
c.internalCache.Set(jwksURL, jwks, lifetime)
|
||||
}
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// fetchJWKS fetches the JWKS from the provider
|
||||
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
resp, err := httpClient.Get(jwksURL)
|
||||
// Cleanup removes the cached JWKS if it has expired.
|
||||
// This is intended to be called periodically to ensure stale JWKS data is cleared.
|
||||
func (c *JWKCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if c.jwks != nil && now.After(c.expiresAt) {
|
||||
c.jwks = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the cache's auto-cleanup routine.
|
||||
func (c *JWKCache) Close() {
|
||||
// Close shuts down the internal cache's auto-cleanup routine, if the cache exists.
|
||||
if c.internalCache != nil {
|
||||
c.internalCache.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxSize sets the maximum number of items in the cache
|
||||
func (c *JWKCache) SetMaxSize(size int) {
|
||||
c.maxSize = size
|
||||
if c.internalCache != nil {
|
||||
c.internalCache.maxSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// fetchJWKS retrieves the JSON Web Key Set (JWKS) from the specified URL.
|
||||
// It uses the provided context and HTTP client to make the request.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for the request.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the fetched JWKSet.
|
||||
// - An error if the request fails, the status code is not OK, or the response body cannot be decoded.
|
||||
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Create a request with context to enforce timeout
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||
}
|
||||
@@ -93,7 +189,16 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JWK to PEM format
|
||||
// jwkToPEM converts a JWK (JSON Web Key) object into PEM (Privacy-Enhanced Mail) format.
|
||||
// It selects the appropriate conversion function based on the JWK's key type ("kty").
|
||||
// Currently supports "RSA" and "EC" key types.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the JWK object to convert.
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the public key in PEM format.
|
||||
// - An error if the key type is unsupported or conversion fails.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -109,7 +214,17 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JWK to PEM
|
||||
// rsaJWKToPEM converts an RSA JWK into PEM format.
|
||||
// It decodes the modulus (n) and exponent (e) from base64 URL encoding,
|
||||
// constructs an rsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the RSA JWK object (must have "kty": "RSA").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the RSA public key in PEM format.
|
||||
// - An error if decoding parameters fails or key marshaling fails.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -141,7 +256,18 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC JWK to PEM
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JWK into PEM format.
|
||||
// It decodes the X and Y coordinates from base64 URL encoding, determines the
|
||||
// elliptic curve based on the "crv" parameter (P-256, P-384, P-521),
|
||||
// constructs an ecdsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the EC JWK object (must have "kty": "EC").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the EC public key in PEM format.
|
||||
// - An error if decoding parameters fails, the curve is unsupported, or key marshaling fails.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,30 +1,110 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
||||
"math/big"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token
|
||||
type JWT struct {
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
Signature []byte
|
||||
Token string
|
||||
var (
|
||||
replayCacheMu sync.RWMutex // Use RWMutex for better read performance
|
||||
replayCache *Cache // Replace unbounded map with bounded Cache
|
||||
replayCacheOnce sync.Once
|
||||
)
|
||||
|
||||
func initReplayCache() {
|
||||
replayCacheOnce.Do(func() {
|
||||
replayCache = NewCache()
|
||||
replayCache.SetMaxSize(10000)
|
||||
})
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct
|
||||
func cleanupReplayCache() {
|
||||
replayCacheMu.Lock()
|
||||
defer replayCacheMu.Unlock()
|
||||
|
||||
if replayCache != nil {
|
||||
replayCache.Close()
|
||||
replayCache = nil
|
||||
}
|
||||
}
|
||||
|
||||
func getReplayCacheStats() (size int, maxSize int) {
|
||||
replayCacheMu.RLock()
|
||||
defer replayCacheMu.RUnlock()
|
||||
|
||||
if replayCache == nil {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
return 0, 10000
|
||||
}
|
||||
|
||||
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
size, maxSize := getReplayCacheStats()
|
||||
if logger != nil {
|
||||
logger.Debugf("Replay cache stats: size=%d, maxSize=%d", size, maxSize)
|
||||
}
|
||||
|
||||
replayCacheMu.RLock()
|
||||
if replayCache != nil {
|
||||
}
|
||||
replayCacheMu.RUnlock()
|
||||
|
||||
case <-ctx.Done():
|
||||
cleanupReplayCache()
|
||||
if logger != nil {
|
||||
logger.Debug("Replay cache cleanup goroutine stopped due to context cancellation")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var ClockSkewToleranceFuture = 2 * time.Minute
|
||||
|
||||
var ClockSkewTolerancePast = 10 * time.Second
|
||||
|
||||
var ClockSkewTolerance = ClockSkewToleranceFuture
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
type JWT struct {
|
||||
Header map[string]any
|
||||
Claims map[string]any
|
||||
Token string
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature.
|
||||
// It splits the token string by '.', decodes each part using base64 URL decoding,
|
||||
// and unmarshals the header and claims JSON into maps. The raw signature bytes are stored.
|
||||
// It performs basic format validation (expecting 3 parts).
|
||||
// Note: This function does *not* validate the signature or the claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a JWT struct containing the decoded parts.
|
||||
// - An error if the token format is invalid or decoding/unmarshaling fails.
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -35,7 +115,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
Token: tokenString,
|
||||
}
|
||||
|
||||
// Decode and unmarshal the header
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
@@ -44,16 +123,23 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
|
||||
// Decode and unmarshal the claims
|
||||
if jwt.Header == nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: header is nil after unmarshaling")
|
||||
}
|
||||
|
||||
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
// Decode the signature
|
||||
if jwt.Claims == nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: claims is nil after unmarshaling")
|
||||
}
|
||||
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
@@ -63,8 +149,40 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify verifies the standard claims in the JWT
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Verify performs standard claim validation on the JWT according to RFC 7519.
|
||||
// It checks the following:
|
||||
// - Algorithm ('alg') is supported.
|
||||
// - Issuer ('iss') matches the expected issuerURL.
|
||||
// - Audience ('aud') contains the expected clientID.
|
||||
// - Expiration time ('exp') is in the future (within tolerance).
|
||||
// - Issued at time ('iat') is in the past (within tolerance).
|
||||
// - Not before time ('nbf'), if present, is in the past (within tolerance).
|
||||
// - Subject ('sub') claim exists and is not empty.
|
||||
// - JWT ID ('jti'), if present, is checked against a replay cache to prevent token reuse.
|
||||
//
|
||||
// Parameters:
|
||||
// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com").
|
||||
// - clientID: The expected audience value (the client ID of this application).
|
||||
// - skipReplayCheck: If true, skips JTI replay detection (used for revalidation of cached tokens).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if all standard claims are valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error {
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'alg' header")
|
||||
}
|
||||
supportedAlgs := map[string]bool{
|
||||
"RS256": true, "RS384": true, "RS512": true,
|
||||
"PS256": true, "PS384": true, "PS512": true,
|
||||
"ES256": true, "ES384": true, "ES512": true,
|
||||
}
|
||||
if !supportedAlgs[alg] {
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
claims := j.Claims
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
@@ -99,6 +217,47 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if err := verifyNotBefore(nbf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0]
|
||||
|
||||
if jti, ok := claims["jti"].(string); ok && !shouldSkipReplay {
|
||||
if j.Token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
initReplayCache()
|
||||
|
||||
replayCacheMu.RLock()
|
||||
_, exists := replayCache.Get(jti)
|
||||
replayCacheMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return fmt.Errorf("token replay detected (jti: %s)", jti)
|
||||
}
|
||||
|
||||
expFloat, ok := claims["exp"].(float64)
|
||||
var expTime time.Time
|
||||
if ok {
|
||||
expTime = time.Unix(int64(expFloat), 0)
|
||||
} else {
|
||||
expTime = time.Now().Add(10 * time.Minute)
|
||||
}
|
||||
|
||||
duration := time.Until(expTime)
|
||||
if duration > 0 {
|
||||
replayCacheMu.Lock()
|
||||
if replayCache != nil {
|
||||
replayCache.Set(jti, true, duration)
|
||||
}
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
sub, ok := claims["sub"].(string)
|
||||
if !ok || sub == "" {
|
||||
return fmt.Errorf("missing or empty 'sub' claim")
|
||||
@@ -107,14 +266,23 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience verifies the audience claim
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
// verifyAudience checks if the expected audience is present in the token's 'aud' claim.
|
||||
// The 'aud' claim can be a single string or an array of strings.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenAudience: The 'aud' claim value extracted from the token (can be string or []interface{}).
|
||||
// - expectedAudience: The audience value expected for this application (client ID).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the expected audience is found.
|
||||
// - An error if the claim type is invalid or the expected audience is not present.
|
||||
func verifyAudience(tokenAudience any, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
case []interface{}:
|
||||
case []any:
|
||||
found := false
|
||||
for _, v := range aud {
|
||||
if str, ok := v.(string); ok && str == expectedAudience {
|
||||
@@ -131,62 +299,109 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer verifies the issuer claim
|
||||
// verifyIssuer checks if the token's 'iss' claim matches the expected issuer URL.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenIssuer: The 'iss' claim value from the token.
|
||||
// - expectedIssuer: The expected issuer URL configured for the OIDC provider.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the issuers match.
|
||||
// - An error if the issuers do not match.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyExpiration checks if the token has expired
|
||||
// verifyTimeConstraint checks time-based claims ('exp', 'iat', 'nbf') against the current time,
|
||||
// allowing for configurable clock skew. It uses different tolerances for past and future checks.
|
||||
//
|
||||
// Parameters:
|
||||
// - unixTime: The timestamp value from the claim (as a float64 Unix time).
|
||||
// - claimName: The name of the claim being verified ("exp", "iat", "nbf").
|
||||
// - future: A boolean indicating the direction of the check (true for 'exp', false for 'iat'/'nbf').
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the time constraint is met within the allowed tolerance.
|
||||
// - An error describing the failure (e.g., "token has expired", "token used before issued").
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now()
|
||||
|
||||
var err error
|
||||
if future {
|
||||
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
|
||||
if now.After(allowedExpiry) {
|
||||
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
|
||||
}
|
||||
} else {
|
||||
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
|
||||
if now.Before(allowedStart) {
|
||||
reason := "not yet valid"
|
||||
if claimName == "iat" {
|
||||
reason = "used before issued"
|
||||
}
|
||||
err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// verifyExpiration checks the 'exp' (Expiration Time) claim.
|
||||
// It calls verifyTimeConstraint with future=true.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
if time.Now().After(expirationTime) {
|
||||
return fmt.Errorf("token has expired")
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(expiration, "exp", true)
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks if the token was issued in the future
|
||||
// verifyIssuedAt checks the 'iat' (Issued At) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
if time.Now().Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued")
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(issuedAt, "iat", false)
|
||||
}
|
||||
|
||||
// verifySignature verifies the token signature using the provided public key and algorithm
|
||||
// verifyNotBefore checks the 'nbf' (Not Before) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
return verifyTimeConstraint(notBefore, "nbf", false)
|
||||
}
|
||||
|
||||
// verifySignature validates the JWT's signature using the provided public key.
|
||||
// It parses the public key from PEM format, selects the appropriate hashing algorithm
|
||||
// based on the 'alg' parameter (SHA256/384/512), hashes the token's signing input
|
||||
// (header + "." + payload), and then verifies the signature against the hash using
|
||||
// the corresponding RSA (PKCS1v15 or PSS) or ECDSA verification method.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw, complete JWT string.
|
||||
// - publicKeyPEM: The public key corresponding to the private key used for signing, in PEM format.
|
||||
// - alg: The algorithm specified in the JWT header (e.g., "RS256", "ES384").
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the signature is valid.
|
||||
// - An error if the token format is invalid, decoding fails, key parsing fails,
|
||||
// the algorithm is unsupported, or the signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
|
||||
// Decode the signature from the token
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
// Decode the PEM-encoded public key
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
// Parse the public key
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
// Determine the hash function to use based on the algorithm
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
hashFunc = crypto.SHA256
|
||||
@@ -197,27 +412,20 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// Hash the signed content
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
|
||||
// Verify the signature based on the key type and algorithm
|
||||
switch pubKey := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
// RSA PKCS#1 v1.5 signature
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
// RSA PSS signature
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature
|
||||
var r, s big.Int
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
|
||||
+2959
-263
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,125 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MetadataCache struct {
|
||||
expiresAt time.Time
|
||||
metadata *ProviderMetadata
|
||||
cleanupTask *BackgroundTask
|
||||
logger *Logger
|
||||
autoCleanupInterval time.Duration
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMetadataCache creates a new MetadataCache instance.
|
||||
// It initializes the cache structure and starts the background cleanup task.
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
return NewMetadataCacheWithLogger(nil)
|
||||
}
|
||||
|
||||
// NewMetadataCacheWithLogger creates a new MetadataCache with a specified logger.
|
||||
func NewMetadataCacheWithLogger(logger *Logger) *MetadataCache {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
logger: logger,
|
||||
}
|
||||
c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Cleanup removes the cached provider metadata if it has expired.
|
||||
// This is called periodically by the auto-cleanup goroutine.
|
||||
func (c *MetadataCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if c.metadata != nil && now.After(c.expiresAt) {
|
||||
c.metadata = nil
|
||||
}
|
||||
}
|
||||
|
||||
// isCacheValid checks if the cached metadata is present and has not expired.
|
||||
// Note: This function assumes the read lock is held or it's called from a context
|
||||
// where the lock is already held (like within GetMetadata after locking).
|
||||
func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the OIDC provider metadata.
|
||||
// It first checks the cache for valid, non-expired metadata. If found, it's returned immediately.
|
||||
// If the cache is empty or expired, it attempts to fetch the metadata from the provider's
|
||||
// well-known endpoint using discoverProviderMetadata.
|
||||
// If fetching is successful, the new metadata is cached for 1 hour.
|
||||
// If fetching fails but valid metadata exists in the cache (even if expired), the cache expiry
|
||||
// is extended by 5 minutes, and the cached data is returned to prevent thundering herd issues.
|
||||
// If fetching fails and there's no cached data, an error is returned.
|
||||
// It employs double-checked locking for thread safety and performance.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider.
|
||||
// - httpClient: The HTTP client to use for fetching metadata.
|
||||
// - logger: The logger instance for recording errors or warnings.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the ProviderMetadata struct.
|
||||
// - An error if metadata cannot be retrieved from cache or fetched from the provider.
|
||||
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
if c.isCacheValid() {
|
||||
defer c.mutex.RUnlock()
|
||||
return c.metadata, nil
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if c.isCacheValid() {
|
||||
return c.metadata, nil
|
||||
}
|
||||
|
||||
metadata, err := discoverProviderMetadata(providerURL, httpClient, logger)
|
||||
if err != nil {
|
||||
if c.metadata != nil {
|
||||
// On error, extend current cache by 5 minutes to prevent thundering herd
|
||||
c.expiresAt = time.Now().Add(5 * time.Minute)
|
||||
logger.Errorf("Failed to refresh metadata, using cached version for 5 more minutes: %v", err)
|
||||
return c.metadata, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||
}
|
||||
|
||||
c.metadata = metadata
|
||||
// Set a fixed cache lifetime (e.g., 1 hour)
|
||||
// Consider making this configurable or respecting HTTP cache headers
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
|
||||
// End of GetMetadata
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background task that periodically calls Cleanup
|
||||
// to remove expired metadata from the cache.
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
c.cleanupTask = NewBackgroundTask("metadata-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
|
||||
c.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup task associated with this metadata cache.
|
||||
func (c *MetadataCache) Close() {
|
||||
if c.cleanupTask != nil {
|
||||
c.cleanupTask.Stop()
|
||||
c.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsCacheValid(t *testing.T) {
|
||||
// Setup with a dummy ProviderMetadata.
|
||||
pm := &ProviderMetadata{}
|
||||
mc := &MetadataCache{
|
||||
metadata: pm,
|
||||
expiresAt: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
if !mc.isCacheValid() {
|
||||
t.Errorf("Expected cache to be valid")
|
||||
}
|
||||
mc.expiresAt = time.Now().Add(-1 * time.Hour)
|
||||
if mc.isCacheValid() {
|
||||
t.Errorf("Expected cache to be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup(t *testing.T) {
|
||||
pm := &ProviderMetadata{}
|
||||
mc := &MetadataCache{
|
||||
metadata: pm,
|
||||
expiresAt: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
mc.Cleanup()
|
||||
if mc.metadata != nil {
|
||||
t.Errorf("Expected metadata to be nil after cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetadata_Cached(t *testing.T) {
|
||||
dummyData := &ProviderMetadata{}
|
||||
// Construct MetadataCache manually to avoid interference from auto cleanup.
|
||||
mc := &MetadataCache{
|
||||
metadata: dummyData,
|
||||
expiresAt: time.Now().Add(1 * time.Hour),
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
logger: newNoOpLogger(),
|
||||
}
|
||||
// Use NewLogger to create a logger that writes errors only.
|
||||
logger := NewLogger("error")
|
||||
result, err := mc.GetMetadata("http://example.com", http.DefaultClient, logger)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if result != dummyData {
|
||||
t.Errorf("Expected cached metadata to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataCacheAutoCleanup(t *testing.T) {
|
||||
mc := &MetadataCache{
|
||||
autoCleanupInterval: 50 * time.Millisecond,
|
||||
logger: newNoOpLogger(),
|
||||
}
|
||||
// Start auto cleanup.
|
||||
mc.startAutoCleanup()
|
||||
mc.mutex.Lock()
|
||||
mc.metadata = &ProviderMetadata{}
|
||||
mc.expiresAt = time.Now().Add(-50 * time.Millisecond)
|
||||
mc.mutex.Unlock()
|
||||
|
||||
// Wait enough time for the auto cleanup to run.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
mc.Close()
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
if mc.metadata != nil {
|
||||
t.Errorf("Expected metadata to be cleared by auto cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
type errorRoundTripper struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return nil, e.err
|
||||
}
|
||||
|
||||
func TestGetMetadata_FetchError(t *testing.T) {
|
||||
// Create an HTTP client that always returns an error.
|
||||
errorClient := &http.Client{
|
||||
Transport: errorRoundTripper{err: fmt.Errorf("fake fetch error")},
|
||||
}
|
||||
|
||||
// Case 1: Cache is empty.
|
||||
mc := &MetadataCache{
|
||||
logger: newNoOpLogger(),
|
||||
}
|
||||
logger := NewLogger("error")
|
||||
metadata, err := mc.GetMetadata("http://example.com", errorClient, logger)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
if metadata != nil {
|
||||
t.Errorf("Expected nil metadata, got %v", metadata)
|
||||
}
|
||||
|
||||
// Case 2: Cache has old metadata.
|
||||
dummy := &ProviderMetadata{}
|
||||
mc.metadata = dummy
|
||||
mc.expiresAt = time.Now().Add(-1 * time.Minute)
|
||||
logger2 := NewLogger("error")
|
||||
metadata, err = mc.GetMetadata("http://example.com", errorClient, logger2)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error when cached metadata exists, got %v", err)
|
||||
}
|
||||
if metadata != dummy {
|
||||
t.Errorf("Expected cached metadata to be returned")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,778 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestConcurrentTokenVerification tests race conditions in token verification
|
||||
func TestConcurrentTokenVerification(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create multiple valid tokens to avoid replay detection
|
||||
tokens := make([]string, 10)
|
||||
for i := range 10 {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token %d: %v", i, err)
|
||||
}
|
||||
tokens[i] = token
|
||||
}
|
||||
|
||||
// Create a fresh instance for this test
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Microsecond), 10000), // Very high rate limit
|
||||
logger: NewLogger("debug"),
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc
|
||||
tOidc.jwtVerifier = tOidc
|
||||
|
||||
// Ensure cleanup when test finishes
|
||||
defer func() {
|
||||
if err := tOidc.Close(); err != nil {
|
||||
t.Logf("Error closing TraefikOidc instance: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test concurrent verification
|
||||
const numGoroutines = 50
|
||||
const verificationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int64
|
||||
var errorCount int64
|
||||
errors := make(chan error, numGoroutines*verificationsPerGoroutine)
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := range verificationsPerGoroutine {
|
||||
tokenIndex := (goroutineID*verificationsPerGoroutine + j) % len(tokens)
|
||||
err := tOidc.VerifyToken(tokens[tokenIndex])
|
||||
if err != nil {
|
||||
atomic.AddInt64(&errorCount, 1)
|
||||
select {
|
||||
case errors <- fmt.Errorf("goroutine %d, verification %d: %w", goroutineID, j, err):
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check results
|
||||
totalOperations := int64(numGoroutines * verificationsPerGoroutine)
|
||||
t.Logf("Concurrent verification results: %d successes, %d errors out of %d total operations",
|
||||
successCount, errorCount, totalOperations)
|
||||
|
||||
// Collect and log errors
|
||||
var errorList []error
|
||||
for err := range errors {
|
||||
errorList = append(errorList, err)
|
||||
}
|
||||
|
||||
if len(errorList) > 0 {
|
||||
t.Logf("Errors encountered during concurrent verification:")
|
||||
for i, err := range errorList {
|
||||
if i < 10 { // Log first 10 errors
|
||||
t.Logf(" %d: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
if len(errorList) > 10 {
|
||||
t.Logf(" ... and %d more errors", len(errorList)-10)
|
||||
}
|
||||
}
|
||||
|
||||
// We expect most operations to succeed
|
||||
if successCount < totalOperations/2 {
|
||||
t.Errorf("Too many failures in concurrent verification: %d successes out of %d operations", successCount, totalOperations)
|
||||
}
|
||||
|
||||
// Check for data races by verifying cache consistency
|
||||
cacheSize := len(tOidc.tokenCache.cache.items)
|
||||
blacklistSize := len(tOidc.tokenBlacklist.items)
|
||||
t.Logf("Final cache sizes: token cache=%d, blacklist=%d", cacheSize, blacklistSize)
|
||||
}
|
||||
|
||||
// TestCacheMemoryExhaustion tests cache behavior under memory pressure
|
||||
func TestCacheMemoryExhaustion(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create a cache with limited size
|
||||
cache := NewTokenCache()
|
||||
cache.cache.SetMaxSize(100) // Small cache size
|
||||
|
||||
// Ensure cleanup when test finishes
|
||||
defer cache.Close()
|
||||
|
||||
// Create many tokens to exceed cache capacity
|
||||
const numTokens = 500
|
||||
tokens := make([]string, numTokens)
|
||||
|
||||
for i := range numTokens {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": fmt.Sprintf("jti-%d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create token %d: %v", i, err)
|
||||
}
|
||||
tokens[i] = token
|
||||
|
||||
// Add to cache
|
||||
claims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": fmt.Sprintf("jti-%d", i),
|
||||
}
|
||||
cache.Set(token, claims, time.Hour)
|
||||
}
|
||||
|
||||
// Verify cache size is within limits
|
||||
cacheSize := len(cache.cache.items)
|
||||
if cacheSize > 100 {
|
||||
t.Errorf("Cache size exceeded limit: got %d, expected <= 100", cacheSize)
|
||||
}
|
||||
|
||||
// Verify LRU eviction works
|
||||
// The first tokens should have been evicted
|
||||
firstToken := tokens[0]
|
||||
if _, exists := cache.Get(firstToken); exists {
|
||||
t.Errorf("First token should have been evicted from cache")
|
||||
}
|
||||
|
||||
// The last tokens should still be in cache
|
||||
lastToken := tokens[numTokens-1]
|
||||
if _, exists := cache.Get(lastToken); !exists {
|
||||
t.Errorf("Last token should still be in cache")
|
||||
}
|
||||
|
||||
t.Logf("Cache memory exhaustion test passed: cache size=%d", cacheSize)
|
||||
}
|
||||
|
||||
// TestSessionConcurrencyProtection tests session safety under concurrent access
|
||||
func TestSessionConcurrencyProtection(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sessionManager, err := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Test concurrent session access with separate requests
|
||||
const numGoroutines = 20
|
||||
const operationsPerGoroutine = 10 // Reduced to avoid overwhelming
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int64
|
||||
var errorCount int64
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine gets its own request and session
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
for j := range operationsPerGoroutine {
|
||||
// Get a fresh session for each operation
|
||||
s, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
atomic.AddInt64(&errorCount, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform operations on session
|
||||
s.SetEmail(fmt.Sprintf("user%d-%d@example.com", goroutineID, j))
|
||||
s.SetAuthenticated(true)
|
||||
s.SetAccessToken(ValidAccessToken)
|
||||
|
||||
// Save session
|
||||
testRR := httptest.NewRecorder()
|
||||
if err := s.Save(req, testRR); err != nil {
|
||||
atomic.AddInt64(&errorCount, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
}
|
||||
|
||||
// Copy cookies back to request for next iteration
|
||||
for _, cookie := range testRR.Result().Cookies() {
|
||||
req.Header.Set("Cookie", cookie.String())
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
totalOperations := int64(numGoroutines * operationsPerGoroutine)
|
||||
t.Logf("Session concurrency test results: %d successes, %d errors out of %d operations",
|
||||
successCount, errorCount, totalOperations)
|
||||
|
||||
// Most operations should succeed
|
||||
if successCount < totalOperations/2 {
|
||||
t.Errorf("Too many session operation failures: %d successes out of %d operations", successCount, totalOperations)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParallelCacheOperations tests cache thread safety
|
||||
func TestParallelCacheOperations(t *testing.T) {
|
||||
cache := NewCache()
|
||||
cache.SetMaxSize(1000)
|
||||
|
||||
// Ensure cleanup when test finishes
|
||||
defer cache.Close()
|
||||
|
||||
const numGoroutines = 10
|
||||
const operationsPerGoroutine = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var setCount int64
|
||||
var getCount int64
|
||||
var deleteCount int64
|
||||
|
||||
// Start multiple goroutines performing cache operations
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
for j := range operationsPerGoroutine {
|
||||
key := fmt.Sprintf("key-%d-%d", goroutineID, j)
|
||||
value := fmt.Sprintf("value-%d-%d", goroutineID, j)
|
||||
|
||||
// Set operation
|
||||
cache.Set(key, value, time.Minute)
|
||||
atomic.AddInt64(&setCount, 1)
|
||||
|
||||
// Get operation
|
||||
if _, exists := cache.Get(key); exists {
|
||||
atomic.AddInt64(&getCount, 1)
|
||||
}
|
||||
|
||||
// Delete some items
|
||||
if j%10 == 0 {
|
||||
cache.Delete(key)
|
||||
atomic.AddInt64(&deleteCount, 1)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
t.Logf("Parallel cache operations completed: %d sets, %d gets, %d deletes",
|
||||
setCount, getCount, deleteCount)
|
||||
|
||||
// Verify cache is still functional
|
||||
cache.Set("test-key", "test-value", time.Minute)
|
||||
if value, exists := cache.Get("test-key"); !exists || value != "test-value" {
|
||||
t.Errorf("Cache corrupted after parallel operations")
|
||||
}
|
||||
|
||||
// Check cache size is reasonable
|
||||
cacheSize := len(cache.items)
|
||||
expectedSize := int(setCount - deleteCount)
|
||||
if cacheSize > expectedSize {
|
||||
t.Logf("Cache size after operations: %d (expected around %d)", cacheSize, expectedSize)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderFailureRecovery tests network failure scenarios
|
||||
func TestProviderFailureRecovery(t *testing.T) {
|
||||
// Create a server that fails initially then recovers
|
||||
var requestCount int64
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
count := atomic.AddInt64(&requestCount, 1)
|
||||
if count <= 3 {
|
||||
// Fail first 3 requests
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Succeed after 3 failures
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
RevokeURL: "https://test-issuer.com/revoke",
|
||||
EndSessionURL: "https://test-issuer.com/end-session",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Test metadata discovery with retries
|
||||
logger := NewLogger("debug")
|
||||
httpClient := createDefaultHTTPClient()
|
||||
|
||||
start := time.Now()
|
||||
metadata, err := discoverProviderMetadata(server.URL, httpClient, logger)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Provider metadata discovery failed after retries: %v", err)
|
||||
}
|
||||
|
||||
if metadata == nil {
|
||||
t.Errorf("Expected metadata to be returned after recovery")
|
||||
}
|
||||
|
||||
// Should have taken some time due to retries (at least the sum of delays: 10ms + 20ms + 40ms = 70ms)
|
||||
expectedMinDuration := 70 * time.Millisecond
|
||||
if duration < expectedMinDuration {
|
||||
t.Errorf("Expected discovery to take at least %v due to retries, but took %v", expectedMinDuration, duration)
|
||||
}
|
||||
|
||||
t.Logf("Provider failure recovery test passed: %d requests, duration: %v", requestCount, duration)
|
||||
}
|
||||
|
||||
// TestOversizedTokenHandling tests boundary value handling
|
||||
func TestOversizedTokenHandling(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create an oversized token with large claims
|
||||
largeClaim := strings.Repeat("x", 10000) // 10KB claim
|
||||
oversizedClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
"large_data": largeClaim,
|
||||
}
|
||||
|
||||
oversizedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", oversizedClaims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create oversized token: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Created oversized token of length: %d bytes", len(oversizedToken))
|
||||
|
||||
// Test verification of oversized token
|
||||
err = ts.tOidc.VerifyToken(oversizedToken)
|
||||
if err != nil {
|
||||
t.Logf("Oversized token verification failed as expected: %v", err)
|
||||
// This is acceptable - oversized tokens should be rejected
|
||||
} else {
|
||||
t.Logf("Oversized token verification succeeded")
|
||||
// Verify it was cached properly
|
||||
if _, exists := ts.tOidc.tokenCache.Get(oversizedToken); !exists {
|
||||
t.Errorf("Oversized token was not cached after successful verification")
|
||||
}
|
||||
}
|
||||
|
||||
// Test extremely long token (beyond reasonable limits)
|
||||
extremelyLongClaim := strings.Repeat("y", 100000) // 100KB claim
|
||||
extremeClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
"extreme_data": extremelyLongClaim,
|
||||
}
|
||||
|
||||
extremeToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", extremeClaims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create extreme token: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Created extreme token of length: %d bytes", len(extremeToken))
|
||||
|
||||
// This should likely fail due to size limits
|
||||
err = ts.tOidc.VerifyToken(extremeToken)
|
||||
if err != nil {
|
||||
t.Logf("Extreme token verification failed as expected: %v", err)
|
||||
} else {
|
||||
t.Logf("Warning: Extreme token verification succeeded - consider adding size limits")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaliciousInputValidation tests security input validation
|
||||
func TestMaliciousInputValidation(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
maliciousInputs := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "Empty token",
|
||||
token: "",
|
||||
},
|
||||
{
|
||||
name: "Single dot",
|
||||
token: ".",
|
||||
},
|
||||
{
|
||||
name: "Two dots only",
|
||||
token: "..",
|
||||
},
|
||||
{
|
||||
name: "SQL injection attempt",
|
||||
token: "'; DROP TABLE users; --",
|
||||
},
|
||||
{
|
||||
name: "Script injection attempt",
|
||||
token: "<script>alert('xss')</script>",
|
||||
},
|
||||
{
|
||||
name: "Path traversal attempt",
|
||||
token: "../../../etc/passwd",
|
||||
},
|
||||
{
|
||||
name: "Null bytes",
|
||||
token: "token\x00with\x00nulls",
|
||||
},
|
||||
{
|
||||
name: "Unicode control characters",
|
||||
token: "token\u0000\u0001\u0002",
|
||||
},
|
||||
{
|
||||
name: "Extremely long string",
|
||||
token: strings.Repeat("a", 1000000), // 1MB string
|
||||
},
|
||||
{
|
||||
name: "Invalid base64 characters",
|
||||
token: "header.payload!@#$%^&*().signature",
|
||||
},
|
||||
{
|
||||
name: "Binary data",
|
||||
token: string([]byte{0x00, 0x01, 0x02, 0x03, 0xFF, 0xFE, 0xFD}),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range maliciousInputs {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Create a fresh instance for each test to avoid rate limiting issues
|
||||
freshOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Microsecond), 10000), // Very high rate limit
|
||||
logger: NewLogger("debug"),
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
freshOidc.tokenVerifier = freshOidc
|
||||
freshOidc.jwtVerifier = freshOidc
|
||||
|
||||
// Ensure cleanup when test finishes
|
||||
defer func() {
|
||||
if err := freshOidc.Close(); err != nil {
|
||||
t.Logf("Error closing TraefikOidc instance: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// All malicious inputs should be safely rejected
|
||||
err := freshOidc.VerifyToken(test.token)
|
||||
if err == nil {
|
||||
t.Errorf("Malicious input '%s' was not rejected", test.name)
|
||||
} else {
|
||||
t.Logf("Malicious input '%s' correctly rejected: %v", test.name, err)
|
||||
}
|
||||
|
||||
// Verify the system is still functional after malicious input
|
||||
validToken, createErr := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if createErr != nil {
|
||||
t.Fatalf("Failed to create valid token for recovery test: %v", createErr)
|
||||
}
|
||||
|
||||
// System should still work with valid tokens
|
||||
if verifyErr := freshOidc.VerifyToken(validToken); verifyErr != nil {
|
||||
t.Errorf("System failed to process valid token after malicious input: %v", verifyErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNetworkErrorCleanup tests resource cleanup on network errors
|
||||
func TestNetworkErrorCleanup(t *testing.T) {
|
||||
// Create a server that times out
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate network timeout by sleeping
|
||||
time.Sleep(2 * time.Second)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create HTTP client with short timeout
|
||||
httpClient := &http.Client{
|
||||
Timeout: 100 * time.Millisecond, // Very short timeout
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
|
||||
// Track goroutines before test
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Attempt metadata discovery that should timeout
|
||||
start := time.Now()
|
||||
_, err := discoverProviderMetadata(server.URL, httpClient, logger)
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should fail due to timeout
|
||||
if err == nil {
|
||||
t.Errorf("Expected timeout error, but request succeeded")
|
||||
}
|
||||
|
||||
// Should fail quickly due to timeout
|
||||
if duration > time.Second {
|
||||
t.Errorf("Request took too long despite timeout: %v", duration)
|
||||
}
|
||||
|
||||
// Give time for cleanup
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check for goroutine leaks
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 { // Allow some tolerance
|
||||
t.Errorf("Potential goroutine leak: started with %d, ended with %d goroutines",
|
||||
initialGoroutines, finalGoroutines)
|
||||
}
|
||||
|
||||
t.Logf("Network error cleanup test passed: duration=%v, goroutines=%d->%d",
|
||||
duration, initialGoroutines, finalGoroutines)
|
||||
}
|
||||
|
||||
// TestResourceLimits tests system behavior under resource constraints
|
||||
func TestResourceLimits(t *testing.T) {
|
||||
// Test memory allocation limits
|
||||
cache := NewCache()
|
||||
cache.SetMaxSize(10) // Very small cache
|
||||
|
||||
// Ensure cleanup when test finishes
|
||||
defer cache.Close()
|
||||
|
||||
// Try to overwhelm the cache
|
||||
for i := range 1000 {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
value := fmt.Sprintf("value-%d", i)
|
||||
cache.Set(key, value, time.Minute)
|
||||
}
|
||||
|
||||
// Cache should not exceed its limit
|
||||
if len(cache.items) > 10 {
|
||||
t.Errorf("Cache exceeded size limit: got %d items, expected <= 10", len(cache.items))
|
||||
}
|
||||
|
||||
// Test rate limiting under load
|
||||
limiter := rate.NewLimiter(rate.Every(time.Second), 5) // 5 requests per second
|
||||
|
||||
allowed := 0
|
||||
denied := 0
|
||||
|
||||
// Make many requests quickly
|
||||
for range 100 {
|
||||
if limiter.Allow() {
|
||||
allowed++
|
||||
} else {
|
||||
denied++
|
||||
}
|
||||
}
|
||||
|
||||
// Most should be denied due to rate limiting
|
||||
if denied < 90 {
|
||||
t.Errorf("Rate limiting not effective: allowed=%d, denied=%d", allowed, denied)
|
||||
}
|
||||
|
||||
t.Logf("Resource limits test passed: cache size=%d, rate limiting: allowed=%d, denied=%d",
|
||||
len(cache.items), allowed, denied)
|
||||
}
|
||||
|
||||
// TestErrorRecoveryPatterns tests various error recovery scenarios
|
||||
func TestErrorRecoveryPatterns(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Test recovery from cache corruption
|
||||
t.Run("CacheCorruption", func(t *testing.T) {
|
||||
// Corrupt the cache by using the Set method to avoid data race
|
||||
ts.tOidc.tokenCache.cache.Set("corrupted", "invalid-data", time.Hour)
|
||||
|
||||
// System should handle corrupted cache gracefully
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid token: %v", err)
|
||||
}
|
||||
|
||||
// Should still work despite cache corruption
|
||||
if err := ts.tOidc.VerifyToken(validToken); err != nil {
|
||||
t.Errorf("Token verification failed despite cache corruption: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test recovery from blacklist corruption
|
||||
t.Run("BlacklistCorruption", func(t *testing.T) {
|
||||
// Add invalid data to blacklist
|
||||
ts.tOidc.tokenBlacklist.Set("corrupted-entry", "invalid-data", time.Hour)
|
||||
|
||||
// System should still function
|
||||
validToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid token: %v", err)
|
||||
}
|
||||
|
||||
if err := ts.tOidc.VerifyToken(validToken); err != nil {
|
||||
t.Errorf("Token verification failed despite blacklist corruption: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPerformanceUnderLoad tests system performance under high load
|
||||
func TestPerformanceUnderLoad(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping performance test in short mode")
|
||||
}
|
||||
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create multiple valid tokens
|
||||
const numTokens = 100
|
||||
tokens := make([]string, numTokens)
|
||||
for i := range numTokens {
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"nbf": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": fmt.Sprintf("jti-%d", i),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create token %d: %v", i, err)
|
||||
}
|
||||
tokens[i] = token
|
||||
}
|
||||
|
||||
// Create fresh instance with high rate limit
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Microsecond), 10000), // Very high limit
|
||||
logger: NewLogger("info"), // Reduce logging for performance
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc
|
||||
tOidc.jwtVerifier = tOidc
|
||||
|
||||
// Ensure cleanup when test finishes
|
||||
defer func() {
|
||||
if err := tOidc.Close(); err != nil {
|
||||
t.Logf("Error closing TraefikOidc instance: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Performance test
|
||||
const iterations = 1000
|
||||
start := time.Now()
|
||||
|
||||
for i := range iterations {
|
||||
tokenIndex := i % numTokens
|
||||
err := tOidc.VerifyToken(tokens[tokenIndex])
|
||||
if err != nil {
|
||||
t.Errorf("Token verification failed at iteration %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
opsPerSecond := float64(iterations) / duration.Seconds()
|
||||
|
||||
t.Logf("Performance test completed: %d operations in %v (%.2f ops/sec)",
|
||||
iterations, duration, opsPerSecond)
|
||||
|
||||
// Should achieve reasonable performance
|
||||
if opsPerSecond < 100 {
|
||||
t.Errorf("Performance too low: %.2f ops/sec (expected > 100)", opsPerSecond)
|
||||
}
|
||||
}
|
||||
+124
@@ -0,0 +1,124 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMergeScopes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
defaultScopes []string
|
||||
userScopes []string
|
||||
expectedScopes []string
|
||||
}{
|
||||
{
|
||||
name: "Empty user scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: []string{},
|
||||
expectedScopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Non-overlapping scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: []string{"roles", "custom_scope"},
|
||||
expectedScopes: []string{"openid", "profile", "email", "roles", "custom_scope"},
|
||||
},
|
||||
{
|
||||
name: "Overlapping scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: []string{"openid", "roles", "profile", "permissions"},
|
||||
expectedScopes: []string{"openid", "profile", "email", "roles", "permissions"},
|
||||
},
|
||||
{
|
||||
name: "Nil user scopes",
|
||||
defaultScopes: []string{"openid", "profile", "email"},
|
||||
userScopes: nil,
|
||||
expectedScopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Nil default scopes",
|
||||
defaultScopes: nil,
|
||||
userScopes: []string{"roles", "custom_scope"},
|
||||
expectedScopes: []string{"roles", "custom_scope"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := mergeScopes(tc.defaultScopes, tc.userScopes)
|
||||
if !reflect.DeepEqual(result, tc.expectedScopes) {
|
||||
t.Errorf("Expected %v, got %v", tc.expectedScopes, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopesConfiguration(t *testing.T) {
|
||||
defaultScopes := []string{"openid", "profile", "email"}
|
||||
userScopes := []string{"roles", "custom_scope"}
|
||||
|
||||
t.Run("Default Append Behavior", func(t *testing.T) {
|
||||
// Create config with user scopes but overrideScopes=false
|
||||
config := &Config{
|
||||
Scopes: userScopes,
|
||||
OverrideScopes: false,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect defaultScopes + userScopes with deduplication
|
||||
expectedScopes := []string{"openid", "profile", "email", "roles", "custom_scope"}
|
||||
if !reflect.DeepEqual(result, expectedScopes) {
|
||||
t.Errorf("Expected %v, got %v", expectedScopes, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Override Behavior", func(t *testing.T) {
|
||||
// Create config with user scopes and overrideScopes=true
|
||||
config := &Config{
|
||||
Scopes: userScopes,
|
||||
OverrideScopes: true,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect only userScopes
|
||||
if !reflect.DeepEqual(result, userScopes) {
|
||||
t.Errorf("Expected %v, got %v", userScopes, result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty Scopes with Override", func(t *testing.T) {
|
||||
// Create config with empty scopes and overrideScopes=true
|
||||
config := &Config{
|
||||
Scopes: []string{},
|
||||
OverrideScopes: true,
|
||||
}
|
||||
|
||||
// Simulate middleware initialization
|
||||
var result []string
|
||||
if config.OverrideScopes {
|
||||
result = append([]string(nil), config.Scopes...)
|
||||
} else {
|
||||
result = mergeScopes(defaultScopes, config.Scopes)
|
||||
}
|
||||
|
||||
// Expect empty scopes - check length instead of DeepEqual
|
||||
if len(result) != 0 {
|
||||
t.Errorf("Expected empty slice, got %v with length %d", result, len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,566 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityEventType represents different types of security events
|
||||
type SecurityEventType string
|
||||
|
||||
const (
|
||||
// AuthFailure represents an authentication failure event
|
||||
AuthFailure SecurityEventType = "authentication_failure"
|
||||
// TokenValidFailure represents a token validation failure event
|
||||
TokenValidFailure SecurityEventType = "token_validation_failure"
|
||||
// RateLimitHit represents a rate limit hit event
|
||||
RateLimitHit SecurityEventType = "rate_limit_hit"
|
||||
// SuspiciousActivity represents a suspicious activity event
|
||||
SuspiciousActivity SecurityEventType = "suspicious_activity"
|
||||
)
|
||||
|
||||
// DefaultSeverity returns the default severity level for a security event type
|
||||
func (t SecurityEventType) DefaultSeverity() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "medium"
|
||||
case TokenValidFailure:
|
||||
return "medium"
|
||||
case RateLimitHit:
|
||||
return "low"
|
||||
case SuspiciousActivity:
|
||||
return "high"
|
||||
default:
|
||||
return "medium"
|
||||
}
|
||||
}
|
||||
|
||||
// IPFailureType returns the IP failure tracking type for a security event type
|
||||
func (t SecurityEventType) IPFailureType() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "auth_failure"
|
||||
case TokenValidFailure:
|
||||
return "token_failure"
|
||||
case SuspiciousActivity:
|
||||
return "suspicious"
|
||||
default:
|
||||
return "general"
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityEvent represents a security-related event that should be logged and monitored
|
||||
type SecurityEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SecurityMonitor tracks security events and suspicious activity patterns
|
||||
type SecurityMonitor struct {
|
||||
ipFailures map[string]*IPFailureTracker
|
||||
patternDetector *SuspiciousPatternDetector
|
||||
logger *Logger
|
||||
eventHandlers []SecurityEventHandler
|
||||
config SecurityMonitorConfig
|
||||
ipMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// IPFailureTracker tracks failures for a specific IP address
|
||||
type IPFailureTracker struct {
|
||||
LastFailure time.Time
|
||||
FirstFailure time.Time
|
||||
BlockedUntil time.Time
|
||||
FailureTypes map[string]int64
|
||||
FailureCount int64
|
||||
mutex sync.RWMutex
|
||||
IsBlocked bool
|
||||
}
|
||||
|
||||
// SuspiciousPatternDetector identifies patterns that may indicate attacks
|
||||
type SuspiciousPatternDetector struct {
|
||||
recentEvents []SecurityEvent
|
||||
shortWindow time.Duration
|
||||
mediumWindow time.Duration
|
||||
longWindow time.Duration
|
||||
rapidFailureThreshold int
|
||||
distributedAttackThreshold int
|
||||
persistentAttackThreshold int
|
||||
eventsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// SecurityEventHandler defines the interface for handling security events
|
||||
type SecurityEventHandler interface {
|
||||
HandleSecurityEvent(event SecurityEvent)
|
||||
}
|
||||
|
||||
// SecurityMonitorConfig contains configuration for the security monitor
|
||||
type SecurityMonitorConfig struct {
|
||||
MaxFailuresPerIP int `json:"max_failures_per_ip"`
|
||||
FailureWindowMinutes int `json:"failure_window_minutes"`
|
||||
BlockDurationMinutes int `json:"block_duration_minutes"`
|
||||
RapidFailureThreshold int `json:"rapid_failure_threshold"`
|
||||
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
|
||||
RetentionHours int `json:"retention_hours"`
|
||||
EnablePatternDetection bool `json:"enable_pattern_detection"`
|
||||
EnableDetailedLogging bool `json:"enable_detailed_logging"`
|
||||
LogSuspiciousOnly bool `json:"log_suspicious_only"`
|
||||
}
|
||||
|
||||
// DefaultSecurityMonitorConfig returns a default configuration
|
||||
func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
|
||||
return SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 10,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 60,
|
||||
EnablePatternDetection: true,
|
||||
RapidFailureThreshold: 5,
|
||||
EnableDetailedLogging: true,
|
||||
LogSuspiciousOnly: false,
|
||||
CleanupIntervalMinutes: 30,
|
||||
RetentionHours: 24,
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupTask holds the BackgroundTask for security cleanup
|
||||
var cleanupTask *BackgroundTask
|
||||
|
||||
// NewSecurityMonitor creates a new security monitor instance
|
||||
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
|
||||
sm := &SecurityMonitor{
|
||||
ipFailures: make(map[string]*IPFailureTracker),
|
||||
eventHandlers: make([]SecurityEventHandler, 0),
|
||||
config: config,
|
||||
logger: logger,
|
||||
patternDetector: NewSuspiciousPatternDetector(),
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// NewSuspiciousPatternDetector creates a new pattern detector
|
||||
func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
|
||||
return &SuspiciousPatternDetector{
|
||||
shortWindow: 1 * time.Minute,
|
||||
mediumWindow: 5 * time.Minute,
|
||||
longWindow: 15 * time.Minute,
|
||||
rapidFailureThreshold: 5,
|
||||
distributedAttackThreshold: 20,
|
||||
persistentAttackThreshold: 50,
|
||||
recentEvents: make([]SecurityEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSecurityEvent is a generic method to record any type of security event
|
||||
func (sm *SecurityMonitor) RecordSecurityEvent(
|
||||
eventType SecurityEventType,
|
||||
clientIP, userAgent, requestPath string,
|
||||
message string,
|
||||
details map[string]any,
|
||||
trackIPFailure bool) {
|
||||
|
||||
// Create event with default values for the event type
|
||||
event := SecurityEvent{
|
||||
Type: string(eventType),
|
||||
Severity: eventType.DefaultSeverity(),
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
// Track IP failures if requested
|
||||
if trackIPFailure {
|
||||
sm.recordIPFailure(clientIP, eventType.IPFailureType())
|
||||
}
|
||||
|
||||
// Process the event
|
||||
sm.processSecurityEvent(event)
|
||||
}
|
||||
|
||||
// RecordAuthenticationFailure records an authentication failure event
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]any) {
|
||||
if details == nil {
|
||||
details = make(map[string]any)
|
||||
}
|
||||
details["reason"] = reason
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
AuthFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Authentication failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordTokenValidationFailure records a token validation failure
|
||||
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
|
||||
details := map[string]any{
|
||||
"reason": reason,
|
||||
}
|
||||
if tokenPrefix != "" {
|
||||
details["token_prefix"] = tokenPrefix
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
TokenValidFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Token validation failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordRateLimitHit records when rate limiting is triggered
|
||||
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
|
||||
details := map[string]any{
|
||||
"limit_type": "token_verification",
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
RateLimitHit,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
"Rate limit exceeded",
|
||||
details,
|
||||
true, // Track IP failure for rate limiting
|
||||
)
|
||||
}
|
||||
|
||||
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]any) {
|
||||
if details == nil {
|
||||
details = make(map[string]any)
|
||||
}
|
||||
details["activity_type"] = activityType
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
SuspiciousActivity,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// recordIPFailure tracks failures for a specific IP address
|
||||
func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
tracker = &IPFailureTracker{
|
||||
FailureTypes: make(map[string]int64),
|
||||
FirstFailure: time.Now(),
|
||||
}
|
||||
sm.ipFailures[clientIP] = tracker
|
||||
}
|
||||
|
||||
tracker.mutex.Lock()
|
||||
defer tracker.mutex.Unlock()
|
||||
|
||||
tracker.FailureCount++
|
||||
tracker.LastFailure = time.Now()
|
||||
tracker.FailureTypes[failureType]++
|
||||
|
||||
// Check if IP should be blocked
|
||||
windowStart := time.Now().Add(-time.Duration(sm.config.FailureWindowMinutes) * time.Minute)
|
||||
if tracker.FirstFailure.After(windowStart) && tracker.FailureCount >= int64(sm.config.MaxFailuresPerIP) {
|
||||
if !tracker.IsBlocked {
|
||||
tracker.IsBlocked = true
|
||||
tracker.BlockedUntil = time.Now().Add(time.Duration(sm.config.BlockDurationMinutes) * time.Minute)
|
||||
|
||||
sm.logger.Errorf("IP %s blocked due to %d failures (types: %v)", clientIP, tracker.FailureCount, tracker.FailureTypes)
|
||||
|
||||
// Record blocking event
|
||||
blockEvent := SecurityEvent{
|
||||
Type: "ip_blocked",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
|
||||
Details: map[string]any{
|
||||
"failure_count": tracker.FailureCount,
|
||||
"failure_types": tracker.FailureTypes,
|
||||
"blocked_until": tracker.BlockedUntil,
|
||||
},
|
||||
}
|
||||
sm.processSecurityEvent(blockEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsIPBlocked checks if an IP address is currently blocked
|
||||
func (sm *SecurityMonitor) IsIPBlocked(clientIP string) bool {
|
||||
sm.ipMutex.RLock()
|
||||
defer sm.ipMutex.RUnlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
defer tracker.mutex.RUnlock()
|
||||
|
||||
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Unblock if time has passed
|
||||
if tracker.IsBlocked && time.Now().After(tracker.BlockedUntil) {
|
||||
tracker.IsBlocked = false
|
||||
sm.logger.Infof("IP %s automatically unblocked", clientIP)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// processSecurityEvent processes a security event through all handlers and pattern detection
|
||||
func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
|
||||
// Add to pattern detector
|
||||
if sm.config.EnablePatternDetection {
|
||||
sm.patternDetector.AddEvent(event)
|
||||
|
||||
// Check for suspicious patterns
|
||||
if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 {
|
||||
for _, pattern := range patterns {
|
||||
sm.logger.Errorf("Suspicious pattern detected: %s", pattern)
|
||||
|
||||
patternEvent := SecurityEvent{
|
||||
Type: "suspicious_pattern",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
|
||||
Details: map[string]any{
|
||||
"pattern_type": pattern,
|
||||
"trigger_event": event,
|
||||
},
|
||||
}
|
||||
sm.handleSecurityEvent(patternEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sm.handleSecurityEvent(event)
|
||||
}
|
||||
|
||||
// handleSecurityEvent sends the event to all registered handlers
|
||||
func (sm *SecurityMonitor) handleSecurityEvent(event SecurityEvent) {
|
||||
// Log the event
|
||||
if sm.config.EnableDetailedLogging && (!sm.config.LogSuspiciousOnly || event.Severity == "high") {
|
||||
sm.logger.Infof("Security Event [%s/%s]: %s (IP: %s, Path: %s)",
|
||||
event.Type, event.Severity, event.Message, event.ClientIP, event.RequestPath)
|
||||
}
|
||||
|
||||
// Send to all handlers
|
||||
for _, handler := range sm.eventHandlers {
|
||||
go handler.HandleSecurityEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// AddEventHandler adds a security event handler
|
||||
func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
|
||||
sm.eventHandlers = append(sm.eventHandlers, handler)
|
||||
}
|
||||
|
||||
// GetSecurityMetrics returns minimal security metrics
|
||||
// This is kept for API compatibility but doesn't collect actual metrics
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]any {
|
||||
return map[string]any{
|
||||
"tracked_ips": 0,
|
||||
}
|
||||
}
|
||||
|
||||
// AddEvent adds an event to the pattern detector
|
||||
func (spd *SuspiciousPatternDetector) AddEvent(event SecurityEvent) {
|
||||
spd.eventsMutex.Lock()
|
||||
defer spd.eventsMutex.Unlock()
|
||||
|
||||
spd.recentEvents = append(spd.recentEvents, event)
|
||||
|
||||
// Clean old events
|
||||
cutoff := time.Now().Add(-spd.longWindow)
|
||||
var filteredEvents []SecurityEvent
|
||||
for _, e := range spd.recentEvents {
|
||||
if e.Timestamp.After(cutoff) {
|
||||
filteredEvents = append(filteredEvents, e)
|
||||
}
|
||||
}
|
||||
spd.recentEvents = filteredEvents
|
||||
}
|
||||
|
||||
// DetectSuspiciousPatterns analyzes recent events for suspicious patterns
|
||||
func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
|
||||
spd.eventsMutex.RLock()
|
||||
defer spd.eventsMutex.RUnlock()
|
||||
|
||||
var patterns []string
|
||||
now := time.Now()
|
||||
|
||||
// Check for rapid failures from single IP
|
||||
ipCounts := make(map[string]int)
|
||||
shortWindowStart := now.Add(-spd.shortWindow)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(shortWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
ipCounts[event.ClientIP]++
|
||||
}
|
||||
}
|
||||
|
||||
for ip, count := range ipCounts {
|
||||
if count >= spd.rapidFailureThreshold {
|
||||
patterns = append(patterns, fmt.Sprintf("rapid_failures_from_ip_%s", ip))
|
||||
}
|
||||
}
|
||||
|
||||
// Check for distributed attack (many IPs failing)
|
||||
mediumWindowStart := now.Add(-spd.mediumWindow)
|
||||
uniqueFailingIPs := make(map[string]bool)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(mediumWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
uniqueFailingIPs[event.ClientIP] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(uniqueFailingIPs) >= spd.distributedAttackThreshold {
|
||||
patterns = append(patterns, "distributed_attack_pattern")
|
||||
}
|
||||
|
||||
// Check for persistent attack
|
||||
longWindowStart := now.Add(-spd.longWindow)
|
||||
persistentFailures := 0
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(longWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
persistentFailures++
|
||||
}
|
||||
}
|
||||
|
||||
if persistentFailures >= spd.persistentAttackThreshold {
|
||||
patterns = append(patterns, "persistent_attack_pattern")
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// startCleanupRoutine starts the background cleanup routine
|
||||
func (sm *SecurityMonitor) startCleanupRoutine() {
|
||||
// Use BackgroundTask abstraction for consistent management
|
||||
cleanupTask = NewBackgroundTask(
|
||||
"security-monitor-cleanup",
|
||||
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
|
||||
sm.cleanup,
|
||||
sm.logger)
|
||||
cleanupTask.Start()
|
||||
}
|
||||
|
||||
// StopCleanupRoutine stops the background cleanup routine
|
||||
func (sm *SecurityMonitor) StopCleanupRoutine() {
|
||||
if cleanupTask != nil {
|
||||
cleanupTask.Stop()
|
||||
cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes old tracking data
|
||||
func (sm *SecurityMonitor) cleanup() {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-time.Duration(sm.config.RetentionHours) * time.Hour)
|
||||
|
||||
for ip, tracker := range sm.ipFailures {
|
||||
tracker.mutex.RLock()
|
||||
shouldRemove := tracker.LastFailure.Before(cutoff) && !tracker.IsBlocked
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
if shouldRemove {
|
||||
delete(sm.ipFailures, ip)
|
||||
}
|
||||
}
|
||||
|
||||
sm.logger.Debugf("Security monitor cleanup completed, tracking %d IPs", len(sm.ipFailures))
|
||||
}
|
||||
|
||||
// ExtractClientIP extracts the client IP from the request, considering proxy headers
|
||||
func ExtractClientIP(r *http.Request) string {
|
||||
// Check X-Real-IP header first (highest priority)
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
if net.ParseIP(xri) != nil {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Forwarded-For header second
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the chain
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
ip := strings.TrimSpace(ips[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// LoggingSecurityEventHandler logs security events to the standard logger
|
||||
type LoggingSecurityEventHandler struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewLoggingSecurityEventHandler creates a new logging event handler
|
||||
func NewLoggingSecurityEventHandler(logger *Logger) *LoggingSecurityEventHandler {
|
||||
return &LoggingSecurityEventHandler{logger: logger}
|
||||
}
|
||||
|
||||
// HandleSecurityEvent implements SecurityEventHandler
|
||||
func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
switch event.Severity {
|
||||
case "high":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "medium":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "low":
|
||||
h.logger.Infof("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
default:
|
||||
h.logger.Debugf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: MetricsSecurityEventHandler has been removed as part of metrics cleanup
|
||||
@@ -0,0 +1,274 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSecurityMonitor(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
config.MaxFailuresPerIP = 3
|
||||
config.BlockDurationMinutes = 1 // 1 minute for testing
|
||||
config.CleanupIntervalMinutes = 1
|
||||
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
defer func() {
|
||||
// Allow cleanup goroutine to finish
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}()
|
||||
|
||||
t.Run("Record authentication failure", func(t *testing.T) {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "invalid credentials", nil)
|
||||
|
||||
// Should not be blocked after first failure
|
||||
if monitor.IsIPBlocked("192.168.1.1") {
|
||||
t.Error("IP should not be blocked after first failure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IP blocked after max failures", func(t *testing.T) {
|
||||
// Record multiple failures
|
||||
for i := 0; i < config.MaxFailuresPerIP; i++ {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.2", "test-agent", "/login", "invalid credentials", nil)
|
||||
}
|
||||
|
||||
// Should be blocked now
|
||||
if !monitor.IsIPBlocked("192.168.1.2") {
|
||||
t.Error("IP should be blocked after max failures")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Token validation failure", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123")
|
||||
})
|
||||
|
||||
t.Run("Rate limit hit", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api")
|
||||
})
|
||||
|
||||
t.Run("Suspicious activity", func(t *testing.T) {
|
||||
details := map[string]any{"pattern": "unusual"}
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
detector := NewSuspiciousPatternDetector()
|
||||
|
||||
t.Run("Add events and detect patterns", func(t *testing.T) {
|
||||
// Add multiple events from same IP
|
||||
for range 10 {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.100",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
detector.AddEvent(event)
|
||||
}
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := slices.Contains(patterns, "rapid_failures_from_ip_192.168.1.100")
|
||||
if !found {
|
||||
t.Error("Expected to detect rapid failure pattern")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Detect distributed attack pattern", func(t *testing.T) {
|
||||
// Add failures from many different IPs
|
||||
for i := range 25 {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1." + strconv.Itoa(100+i),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
detector.AddEvent(event)
|
||||
}
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := slices.Contains(patterns, "distributed_attack_pattern")
|
||||
if !found {
|
||||
t.Error("Expected to detect distributed attack pattern")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
headers map[string]string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "Direct connection",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For header",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 10.0.0.1"},
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP header",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{"X-Real-IP": "203.0.113.2"},
|
||||
expectedIP: "203.0.113.2",
|
||||
},
|
||||
{
|
||||
name: "Multiple headers - X-Real-IP takes precedence",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "203.0.113.1",
|
||||
"X-Real-IP": "203.0.113.2",
|
||||
},
|
||||
expectedIP: "203.0.113.2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
ip := ExtractClientIP(req)
|
||||
if ip != tt.expectedIP {
|
||||
t.Errorf("Expected IP %s, got %s", tt.expectedIP, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityEventHandlers(t *testing.T) {
|
||||
t.Run("Logging security event handler", func(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
handler := NewLoggingSecurityEventHandler(logger)
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.1",
|
||||
Timestamp: time.Now(),
|
||||
Message: "Test failure",
|
||||
Severity: "medium",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
handler.HandleSecurityEvent(event)
|
||||
})
|
||||
|
||||
// Metrics security event handler test removed as part of metrics cleanup
|
||||
}
|
||||
|
||||
func TestSecurityMonitorEventHandlers(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Add event handler with proper synchronization
|
||||
handlerCalled := make(chan bool, 1)
|
||||
handler := &testSecurityEventHandler{
|
||||
callback: func(event SecurityEvent) {
|
||||
select {
|
||||
case handlerCalled <- true:
|
||||
default:
|
||||
// Channel already has a value, don't block
|
||||
}
|
||||
},
|
||||
}
|
||||
monitor.AddEventHandler(handler)
|
||||
|
||||
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "test failure", nil)
|
||||
|
||||
// Wait for event handler to be called with timeout
|
||||
select {
|
||||
case <-handlerCalled:
|
||||
// Success - handler was called
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Expected event handler to be called within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper for security event handler
|
||||
type testSecurityEventHandler struct {
|
||||
callback func(SecurityEvent)
|
||||
}
|
||||
|
||||
func (h *testSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
h.callback(event)
|
||||
}
|
||||
|
||||
func TestDefaultSecurityMonitorConfig(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
|
||||
if config.MaxFailuresPerIP <= 0 {
|
||||
t.Error("Expected positive MaxFailuresPerIP")
|
||||
}
|
||||
if config.BlockDurationMinutes <= 0 {
|
||||
t.Error("Expected positive BlockDurationMinutes")
|
||||
}
|
||||
if config.CleanupIntervalMinutes <= 0 {
|
||||
t.Error("Expected positive CleanupIntervalMinutes")
|
||||
}
|
||||
if config.FailureWindowMinutes <= 0 {
|
||||
t.Error("Expected positive FailureWindowMinutes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityMonitorCleanup(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
config.CleanupIntervalMinutes = 1
|
||||
config.BlockDurationMinutes = 1
|
||||
config.RetentionHours = 1
|
||||
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Block an IP
|
||||
for i := 0; i < config.MaxFailuresPerIP; i++ {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.99", "test-agent", "/login", "test", nil)
|
||||
}
|
||||
|
||||
// Verify it's blocked
|
||||
if !monitor.IsIPBlocked("192.168.1.99") {
|
||||
t.Error("IP should be blocked")
|
||||
}
|
||||
|
||||
// Wait a bit and check if it gets unblocked automatically
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// The IP should still be blocked since we haven't waited long enough
|
||||
if !monitor.IsIPBlocked("192.168.1.99") {
|
||||
t.Error("IP should still be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityEventTypes(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Test different event types - just verify they don't panic
|
||||
monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil)
|
||||
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
|
||||
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
|
||||
|
||||
details := map[string]any{"pattern": "test"}
|
||||
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
|
||||
|
||||
// Just verify GetSecurityMetrics doesn't panic
|
||||
_ = monitor.GetSecurityMetrics()
|
||||
}
|
||||
+10
@@ -0,0 +1,10 @@
|
||||
version: 1
|
||||
force:
|
||||
existing: true
|
||||
wording:
|
||||
patch:
|
||||
- patch-release
|
||||
minor:
|
||||
- minor-release
|
||||
major:
|
||||
- breaking
|
||||
+2173
File diff suppressed because it is too large
Load Diff
+691
@@ -0,0 +1,691 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func TestSessionPoolMemoryLeak(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a fake request
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
|
||||
// Test 1: Successful session creation and return
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
}
|
||||
|
||||
// Clear the session which should return it to the pool
|
||||
session.Clear(req, nil)
|
||||
|
||||
// Test 2: ReturnToPool explicit method
|
||||
session, err = sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
}
|
||||
|
||||
// Call ReturnToPool directly
|
||||
session.ReturnToPool()
|
||||
|
||||
// Test 3: Error path in GetSession
|
||||
// Modify the session store to force an error - use a different encryption key
|
||||
badSM, _ := NewSessionManager("different0123456789abcdef0123456789abcdef0123456789", false, logger)
|
||||
|
||||
// Get session using mismatched manager/request to force error
|
||||
_, err = badSM.GetSession(req)
|
||||
if err == nil {
|
||||
// We don't test the exact error since it could vary, just that we get one
|
||||
t.Log("Note: Expected error when using mismatched encryption keys")
|
||||
}
|
||||
|
||||
// Force GC to ensure any objects are cleaned up
|
||||
runtime.GC()
|
||||
|
||||
// Wait a moment for GC to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check if we have objects in the pool
|
||||
// This is just a simple check; in a real scenario, we'd have to
|
||||
// consider that sync.Pool can discard objects at any time.
|
||||
pooledCount := getPooledObjects(sm)
|
||||
t.Logf("Pooled objects count: %d", pooledCount)
|
||||
}
|
||||
|
||||
func TestSessionErrorHandling(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a fake request
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
|
||||
// Call the GetSession method, corrupting the cookie to force an error
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: mainCookieName,
|
||||
Value: "corrupt-value",
|
||||
})
|
||||
|
||||
_, err = sm.GetSession(req)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
// Check that the error message contains our expected prefix
|
||||
if err != nil && !strings.Contains(err.Error(), "failed to get main session:") {
|
||||
t.Fatalf("Unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionClearAlwaysReturnsToPool(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a test request with the special header that will trigger an error
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
req.Header.Set("X-Test-Error", "true") // This will trigger the error in session.Clear
|
||||
|
||||
// Get a session
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
}
|
||||
|
||||
// Create a response writer
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call Clear with the test request (with X-Test-Error header) and response writer
|
||||
// This should trigger the serialization error in Save
|
||||
clearErr := session.Clear(req, w)
|
||||
|
||||
// Verify that Clear returned the error from Save
|
||||
if clearErr == nil {
|
||||
t.Error("Expected an error from Clear with X-Test-Error header, but got nil")
|
||||
} else {
|
||||
t.Logf("Received expected error from Clear: %v", clearErr)
|
||||
}
|
||||
|
||||
// Force GC to ensure any objects are cleaned up
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create and clear another session (without the error header) to verify the pool is still working
|
||||
normalReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session2, err := sm.GetSession(normalReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Second GetSession failed: %v", err)
|
||||
}
|
||||
session2.Clear(normalReq, nil)
|
||||
|
||||
// If we got here without panics, the test is successful
|
||||
t.Log("Session returned to pool despite errors")
|
||||
}
|
||||
|
||||
// This placeholder comment is intentionally left empty since we're removing redundant code
|
||||
|
||||
// Helper function to count objects in the session pool for a given manager
|
||||
func getPooledObjects(sm *SessionManager) int {
|
||||
// Collect objects until we can't get any more from the pool
|
||||
// Set a max limit to avoid potential infinite loops
|
||||
var objects []*SessionData
|
||||
maxAttempts := 100 // Safety limit to prevent infinite loops
|
||||
|
||||
for range maxAttempts {
|
||||
obj := sm.sessionPool.Get()
|
||||
if obj == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Type assertion with validation
|
||||
sessionData, ok := obj.(*SessionData)
|
||||
if !ok {
|
||||
// Return the object even if it's not the right type to avoid leaks
|
||||
sm.sessionPool.Put(obj)
|
||||
break
|
||||
}
|
||||
|
||||
objects = append(objects, sessionData)
|
||||
}
|
||||
|
||||
// Count how many objects we found
|
||||
count := len(objects)
|
||||
|
||||
// Return all objects back to the pool to preserve the pool state
|
||||
for _, obj := range objects {
|
||||
sm.sessionPool.Put(obj)
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// TestSessionObjectTracking verifies that session objects are properly
|
||||
// returned to the pool in various scenarios including normal usage and error paths
|
||||
func TestSessionObjectTracking(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a fake request
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
|
||||
// Test that the session pool is used as expected
|
||||
hasNew := sm.sessionPool.New != nil
|
||||
if !hasNew {
|
||||
t.Error("Expected sessionPool.New function to be set")
|
||||
}
|
||||
|
||||
// Create and discard 5 sessions
|
||||
for range 5 {
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
}
|
||||
session.ReturnToPool()
|
||||
}
|
||||
|
||||
// Create a session and get an error when trying to clear it
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession failed: %v", err)
|
||||
}
|
||||
|
||||
// Deliberately cause bad state in the session object
|
||||
session.mainSession = nil // This will cause an error in Clear
|
||||
|
||||
// Even with an error, the pool should not leak
|
||||
session.ReturnToPool()
|
||||
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Success - if we got here without crashing, the pool is working as expected
|
||||
t.Log("Session pool handling verified")
|
||||
}
|
||||
|
||||
// TestTokenCompressionIntegrity tests that token compression and decompression maintains JWT integrity
|
||||
func TestTokenCompressionIntegrity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantFail bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT - Small",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.signature",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT - Large",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + strings.Repeat("eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9", 100) + ".signature",
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - Wrong dot count",
|
||||
token: "invalid.token",
|
||||
wantFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - No dots",
|
||||
token: "invalidtoken",
|
||||
wantFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - Too many dots",
|
||||
token: "part1.part2.part3.part4",
|
||||
wantFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty token",
|
||||
token: "",
|
||||
wantFail: false, // Empty tokens are handled gracefully
|
||||
},
|
||||
{
|
||||
name: "Oversized token (>50KB)",
|
||||
token: "part1." + strings.Repeat("A", 51*1024) + ".part3",
|
||||
wantFail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
compressed := compressToken(tt.token)
|
||||
|
||||
if tt.wantFail {
|
||||
// For invalid tokens, compression should return original
|
||||
if compressed != tt.token {
|
||||
t.Errorf("Expected compression to return original for invalid token, got different result")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For valid tokens, test round-trip integrity
|
||||
decompressed := decompressToken(compressed)
|
||||
if decompressed != tt.token {
|
||||
t.Errorf("Token integrity lost: original=%q, compressed=%q, decompressed=%q",
|
||||
tt.token, compressed, decompressed)
|
||||
}
|
||||
|
||||
// Test that decompression is idempotent
|
||||
decompressed2 := decompressToken(decompressed)
|
||||
if decompressed2 != tt.token {
|
||||
t.Errorf("Decompression not idempotent: %q != %q", decompressed2, tt.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenCompressionCorruptionDetection tests that gzip corruption is detected and handled
|
||||
func TestTokenCompressionCorruptionDetection(t *testing.T) {
|
||||
validJWT := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.signature"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
corruptedInput string
|
||||
expectOriginal bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid base64",
|
||||
corruptedInput: "!@#$%^&*()",
|
||||
expectOriginal: true,
|
||||
},
|
||||
{
|
||||
name: "Valid base64 but invalid gzip",
|
||||
corruptedInput: base64.StdEncoding.EncodeToString([]byte("not gzip data")),
|
||||
expectOriginal: true,
|
||||
},
|
||||
{
|
||||
name: "Truncated gzip data",
|
||||
corruptedInput: "H4sI", // Incomplete gzip header
|
||||
expectOriginal: true,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
corruptedInput: "",
|
||||
expectOriginal: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := decompressToken(tt.corruptedInput)
|
||||
if tt.expectOriginal && result != tt.corruptedInput {
|
||||
t.Errorf("Expected decompression to return original corrupted input, got: %q", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test that valid compression still works
|
||||
compressed := compressToken(validJWT)
|
||||
decompressed := decompressToken(compressed)
|
||||
if decompressed != validJWT {
|
||||
t.Errorf("Valid compression/decompression failed: %q != %q", decompressed, validJWT)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenChunkingIntegrity tests that large tokens are properly chunked and reassembled
|
||||
func TestTokenChunkingIntegrity(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create tokens of various sizes to test chunking
|
||||
testTokens := NewTestTokens()
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenSize int
|
||||
expectChunked bool
|
||||
}{
|
||||
{
|
||||
name: "Small token (no chunking)",
|
||||
tokenSize: 100,
|
||||
expectChunked: false,
|
||||
},
|
||||
{
|
||||
name: "Medium token (no chunking)",
|
||||
tokenSize: 800, // FIXED: Reduced further to account for new conservative chunk size (1200 bytes)
|
||||
expectChunked: false,
|
||||
},
|
||||
{
|
||||
name: "Large token (chunking required)",
|
||||
tokenSize: 5000,
|
||||
expectChunked: true,
|
||||
},
|
||||
{
|
||||
name: "Very large token (multiple chunks)",
|
||||
tokenSize: 10000,
|
||||
expectChunked: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// FIXED: Use incompressible tokens to ensure chunking occurs
|
||||
var token string
|
||||
if tt.expectChunked {
|
||||
token = testTokens.CreateIncompressibleToken(tt.tokenSize)
|
||||
} else {
|
||||
token = testTokens.CreateLargeValidJWT(tt.tokenSize)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Store the token
|
||||
session.SetAccessToken(token)
|
||||
|
||||
// Retrieve the token
|
||||
retrievedToken := session.GetAccessToken()
|
||||
|
||||
// Verify integrity
|
||||
if retrievedToken != token {
|
||||
t.Errorf("Token integrity lost:\nOriginal: %q\nRetrieved: %q", token, retrievedToken)
|
||||
}
|
||||
|
||||
// Check if chunking occurred as expected
|
||||
hasChunks := len(session.accessTokenChunks) > 0
|
||||
if tt.expectChunked != hasChunks {
|
||||
t.Errorf("Chunking expectation mismatch: expected chunked=%v, has chunks=%v", tt.expectChunked, hasChunks)
|
||||
}
|
||||
|
||||
session.ReturnToPool()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenChunkingCorruptionResistance tests handling of corrupted chunks
|
||||
func TestTokenChunkingCorruptionResistance(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a large token that will be chunked
|
||||
largeToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
|
||||
base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, `{"sub":"test","data":"%s"}`, strings.Repeat("A", 5000))) +
|
||||
".signature"
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Store the token (this should create chunks)
|
||||
session.SetAccessToken(largeToken)
|
||||
if len(session.accessTokenChunks) == 0 {
|
||||
t.Skip("Token was not chunked, skipping corruption test")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
corruptChunk func(chunks map[int]*sessions.Session)
|
||||
name string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Missing chunk in sequence",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Remove a middle chunk
|
||||
if len(chunks) > 1 {
|
||||
delete(chunks, 1)
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Empty chunk data",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Set first chunk to empty
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = ""
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Wrong data type in chunk",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Set chunk data to wrong type
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = 123 // Should be string
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Oversized chunk",
|
||||
corruptChunk: func(chunks map[int]*sessions.Session) {
|
||||
// Set chunk to oversized data
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = strings.Repeat("A", maxCookieSize+200)
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Get a fresh session
|
||||
freshSession, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get fresh session: %v", err)
|
||||
}
|
||||
|
||||
// Store the token again
|
||||
freshSession.SetAccessToken(largeToken)
|
||||
|
||||
// Apply corruption
|
||||
tt.corruptChunk(freshSession.accessTokenChunks)
|
||||
|
||||
// Try to retrieve the token
|
||||
retrievedToken := freshSession.GetAccessToken()
|
||||
|
||||
if tt.expectEmpty {
|
||||
if retrievedToken != "" {
|
||||
t.Errorf("Expected empty token due to corruption, got: %q", retrievedToken)
|
||||
}
|
||||
} else {
|
||||
if retrievedToken != largeToken {
|
||||
t.Errorf("Expected original token despite corruption, got: %q", retrievedToken)
|
||||
}
|
||||
}
|
||||
|
||||
freshSession.ReturnToPool()
|
||||
})
|
||||
}
|
||||
|
||||
session.ReturnToPool()
|
||||
}
|
||||
|
||||
// TestTokenSizeLimits tests that token size limits are enforced
|
||||
func TestTokenSizeLimits(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
testTokens := NewTestTokens()
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenSize int
|
||||
expectStored bool
|
||||
}{
|
||||
{
|
||||
name: "Normal size token",
|
||||
tokenSize: 1000,
|
||||
expectStored: true,
|
||||
},
|
||||
{
|
||||
name: "Large but acceptable token",
|
||||
tokenSize: 30000, // FIXED: 30KB to ensure final size < 100KB limit
|
||||
expectStored: true,
|
||||
},
|
||||
{
|
||||
name: "Oversized token (>100KB)",
|
||||
tokenSize: 120000, // FIXED: 120KB to ensure rejection after compression
|
||||
expectStored: false, // Should be rejected
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// FIXED: Use proper token generation that accounts for base64 encoding
|
||||
var token string
|
||||
if tt.expectStored {
|
||||
token = testTokens.CreateLargeValidJWT(tt.tokenSize)
|
||||
} else {
|
||||
token = testTokens.CreateIncompressibleToken(tt.tokenSize)
|
||||
}
|
||||
|
||||
// Store the token
|
||||
session.SetAccessToken(token)
|
||||
|
||||
// Try to retrieve it
|
||||
retrievedToken := session.GetAccessToken()
|
||||
|
||||
if tt.expectStored {
|
||||
if retrievedToken != token {
|
||||
t.Errorf("Expected token to be stored and retrieved, but got different token")
|
||||
}
|
||||
} else {
|
||||
if retrievedToken == token {
|
||||
t.Errorf("Expected oversized token to be rejected, but it was stored")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentTokenOperations tests thread safety of token operations
|
||||
func TestConcurrentTokenOperations(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
const numGoroutines = 10
|
||||
const numOperations = 100
|
||||
|
||||
// Test concurrent access and refresh token operations
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := range numOperations {
|
||||
// Create unique tokens for each goroutine/operation
|
||||
accessToken := ValidAccessToken
|
||||
refreshToken := fmt.Sprintf("refresh_token_%d_%d", id, j)
|
||||
|
||||
// Concurrent operations
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetRefreshToken(refreshToken)
|
||||
|
||||
retrievedAccess := session.GetAccessToken()
|
||||
retrievedRefresh := session.GetRefreshToken()
|
||||
|
||||
// Verify tokens are still valid (should be one of the tokens set by any goroutine)
|
||||
if retrievedAccess != "" && strings.Count(retrievedAccess, ".") != 2 {
|
||||
t.Errorf("Retrieved access token has invalid format: %q", retrievedAccess)
|
||||
}
|
||||
if retrievedRefresh != "" && len(retrievedRefresh) < 10 {
|
||||
t.Errorf("Retrieved refresh token is too short: %q", retrievedRefresh)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for range numGoroutines {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionValidationAndCleanup tests session validation and orphan cleanup
|
||||
func TestSessionValidationAndCleanup(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set tokens that will create chunks
|
||||
largeToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
|
||||
base64.RawURLEncoding.EncodeToString([]byte(strings.Repeat(`{"data":"large"}`, 500))) +
|
||||
".signature"
|
||||
|
||||
session.SetAccessToken(largeToken)
|
||||
session.SetRefreshToken("refresh_token_test")
|
||||
|
||||
// Save session to create cookies
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify chunks were created
|
||||
if len(session.accessTokenChunks) == 0 {
|
||||
t.Log("No chunks created, large token test may not be applicable")
|
||||
}
|
||||
|
||||
// Test cleanup by clearing session
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.Logf("Clear returned error (may be expected): %v", err)
|
||||
}
|
||||
|
||||
// Verify tokens are cleared
|
||||
if token := session.GetAccessToken(); token != "" {
|
||||
t.Errorf("Access token should be empty after clear, got: %q", token)
|
||||
}
|
||||
if token := session.GetRefreshToken(); token != "" {
|
||||
t.Errorf("Refresh token should be empty after clear, got: %q", token)
|
||||
}
|
||||
}
|
||||
+413
-66
@@ -5,102 +5,398 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TemplatedHeader represents a custom HTTP header with a templated value.
|
||||
// The value can contain template expressions that will be evaluated for each
|
||||
// authenticated request, such as {{.claims.email}} or {{.accessToken}}.
|
||||
type TemplatedHeader struct {
|
||||
// Name is the HTTP header name to set (e.g., "X-Forwarded-Email")
|
||||
Name string `json:"name"`
|
||||
|
||||
// Value is the template string for the header value
|
||||
// Example: "{{.claims.email}}", "Bearer {{.accessToken}}"
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// Config holds the configuration for the OIDC middleware.
|
||||
// It provides all necessary settings to configure OpenID Connect authentication
|
||||
// with various providers like Auth0, Logto, or any standard OIDC provider.
|
||||
type Config struct {
|
||||
HTTPClient *http.Client
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
}
|
||||
|
||||
const (
|
||||
cookieName = "_raczylo_oidc"
|
||||
// DefaultRateLimit defines the default rate limit for requests per second
|
||||
DefaultRateLimit = 100
|
||||
|
||||
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
|
||||
MinRateLimit = 10
|
||||
|
||||
// DefaultLogLevel defines the default logging level
|
||||
DefaultLogLevel = "info"
|
||||
|
||||
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
|
||||
MinSessionEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// Config holds the configuration for the OIDC middleware
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
Scopes []string `json:"scopes"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
var defaultSessionOptions = &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
|
||||
// CreateConfig creates a new Config with default values
|
||||
// CreateConfig creates a new Config with secure default values.
|
||||
// Default values are set for optional fields:
|
||||
// - Scopes: ["openid", "profile", "email"]
|
||||
// - LogLevel: "info"
|
||||
// - LogoutURL: CallbackURL + "/logout"
|
||||
// - RateLimit: 100 requests per second
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
// - EnablePKCE: false (PKCE is opt-in)
|
||||
//
|
||||
// CreateConfig initializes a new Config struct with default values for optional fields.
|
||||
// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the
|
||||
// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret,
|
||||
// CallbackURL, and SessionEncryptionKey must be set explicitly after creation.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a new Config struct with default settings applied.
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{}
|
||||
|
||||
if c.Scopes == nil {
|
||||
c.Scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
if c.LogLevel == "" {
|
||||
c.LogLevel = "info"
|
||||
}
|
||||
|
||||
if c.LogoutURL == "" {
|
||||
c.LogoutURL = c.CallbackURL + "/logout"
|
||||
}
|
||||
|
||||
if c.RateLimit == 0 {
|
||||
c.RateLimit = 100
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
OverrideScopes: false, // Default to appending scopes, not overriding
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate validates the Config
|
||||
// Validate checks the configuration settings for validity.
|
||||
// It ensures that required fields (ProviderURL, CallbackURL, ClientID, ClientSecret, SessionEncryptionKey)
|
||||
// are present and that URLs are well-formed (HTTPS where required). It also validates
|
||||
// the session key length, log level, rate limit, and refresh grace period.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the configuration is valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (c *Config) Validate() error {
|
||||
// Validate provider URL
|
||||
if c.ProviderURL == "" {
|
||||
return fmt.Errorf("providerURL is required")
|
||||
}
|
||||
if !isValidSecureURL(c.ProviderURL) {
|
||||
return fmt.Errorf("providerURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate callback URL
|
||||
if c.CallbackURL == "" {
|
||||
return fmt.Errorf("callbackURL is required")
|
||||
}
|
||||
if !strings.HasPrefix(c.CallbackURL, "/") {
|
||||
return fmt.Errorf("callbackURL must start with /")
|
||||
}
|
||||
|
||||
// Validate client credentials
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("clientID is required")
|
||||
}
|
||||
if c.ClientSecret == "" {
|
||||
return fmt.Errorf("clientSecret is required")
|
||||
}
|
||||
|
||||
// Validate session encryption key
|
||||
if c.SessionEncryptionKey == "" {
|
||||
return fmt.Errorf("sessionEncryptionKey is required")
|
||||
}
|
||||
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
|
||||
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
|
||||
}
|
||||
|
||||
// Validate log level
|
||||
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
|
||||
return fmt.Errorf("logLevel must be one of: debug, info, error")
|
||||
}
|
||||
|
||||
// Validate excluded URLs
|
||||
for _, url := range c.ExcludedURLs {
|
||||
if !strings.HasPrefix(url, "/") {
|
||||
return fmt.Errorf("excluded URL must start with /: %s", url)
|
||||
}
|
||||
if strings.Contains(url, "..") {
|
||||
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
|
||||
}
|
||||
if strings.Contains(url, "*") {
|
||||
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate revocation URL if set
|
||||
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
|
||||
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate end session URL if set
|
||||
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
|
||||
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate post-logout redirect URI if set
|
||||
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
|
||||
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
|
||||
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate rate limit
|
||||
if c.RateLimit < MinRateLimit {
|
||||
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
|
||||
}
|
||||
|
||||
// Validate refresh grace period
|
||||
if c.RefreshGracePeriodSeconds < 0 {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// SECURITY FIX: Validate headers configuration with enhanced template security
|
||||
for _, header := range c.Headers {
|
||||
if header.Name == "" {
|
||||
return fmt.Errorf("header name cannot be empty")
|
||||
}
|
||||
if header.Value == "" {
|
||||
return fmt.Errorf("header value template cannot be empty")
|
||||
}
|
||||
if !strings.Contains(header.Value, "{{") || !strings.Contains(header.Value, "}}") {
|
||||
return fmt.Errorf("header value '%s' does not appear to be a valid template (missing {{ }})", header.Value)
|
||||
}
|
||||
|
||||
// Provide more helpful guidance for common template errors BEFORE security validation
|
||||
if strings.Contains(header.Value, "{{.claims") {
|
||||
return fmt.Errorf("header template '%s' appears to use lowercase 'claims' - use '{{.Claims...' instead (case sensitive)", header.Value)
|
||||
}
|
||||
if strings.Contains(header.Value, "{{.accessToken") {
|
||||
return fmt.Errorf("header template '%s' appears to use lowercase 'accessToken' - use '{{.AccessToken...' instead (case sensitive)", header.Value)
|
||||
}
|
||||
if strings.Contains(header.Value, "{{.idToken") {
|
||||
return fmt.Errorf("header template '%s' appears to use lowercase 'idToken' - use '{{.IdToken...' instead (case sensitive)", header.Value)
|
||||
}
|
||||
if strings.Contains(header.Value, "{{.refreshToken") {
|
||||
return fmt.Errorf("header template '%s' appears to use lowercase 'refreshToken' - use '{{.RefreshToken...' instead (case sensitive)", header.Value)
|
||||
}
|
||||
|
||||
// SECURITY FIX: Implement template sandboxing and validation
|
||||
if err := validateTemplateSecure(header.Value); err != nil {
|
||||
return fmt.Errorf("header template '%s' failed security validation: %w", header.Value, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger is a simple logger with different levels
|
||||
// SECURITY FIX: validateTemplateSecure implements template sandboxing and validation
|
||||
func validateTemplateSecure(templateStr string) error {
|
||||
// SECURITY FIX: Restrict dangerous template functions and patterns
|
||||
dangerousPatterns := []string{
|
||||
"{{call", // Function calls
|
||||
"{{range", // Range over arbitrary data
|
||||
"{{with", // With statements that could access unexpected data
|
||||
"{{define", // Template definitions
|
||||
"{{template", // Template inclusions
|
||||
"{{block", // Block definitions
|
||||
"{{/*", // Comments that could hide malicious code
|
||||
"{{-", // Trim whitespace (could be used to obfuscate)
|
||||
"-}}", // Trim whitespace (could be used to obfuscate)
|
||||
"{{printf", // Printf functions
|
||||
"{{print", // Print functions
|
||||
"{{println", // Println functions
|
||||
"{{html", // HTML functions
|
||||
"{{js", // JavaScript functions
|
||||
"{{urlquery", // URL query functions
|
||||
"{{index", // Index access to arbitrary data
|
||||
"{{slice", // Slice operations
|
||||
"{{len", // Length operations on arbitrary data
|
||||
"{{eq", // Comparison operations
|
||||
"{{ne", // Comparison operations
|
||||
"{{lt", // Comparison operations
|
||||
"{{le", // Comparison operations
|
||||
"{{gt", // Comparison operations
|
||||
"{{ge", // Comparison operations
|
||||
"{{and", // Logical operations
|
||||
"{{or", // Logical operations
|
||||
"{{not", // Logical operations
|
||||
}
|
||||
|
||||
templateLower := strings.ToLower(templateStr)
|
||||
for _, pattern := range dangerousPatterns {
|
||||
if strings.Contains(templateLower, pattern) {
|
||||
return fmt.Errorf("dangerous template pattern detected: %s", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// SECURITY FIX: Whitelist allowed template variables and functions
|
||||
allowedPatterns := []string{
|
||||
"{{.AccessToken}}",
|
||||
"{{.IdToken}}",
|
||||
"{{.RefreshToken}}",
|
||||
"{{.Claims.",
|
||||
}
|
||||
|
||||
// Check if template contains only allowed patterns
|
||||
hasAllowedPattern := false
|
||||
for _, pattern := range allowedPatterns {
|
||||
if strings.Contains(templateStr, pattern) {
|
||||
hasAllowedPattern = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAllowedPattern {
|
||||
return fmt.Errorf("template must use only allowed variables: AccessToken, IdToken, RefreshToken, or Claims.*")
|
||||
}
|
||||
|
||||
// SECURITY FIX: Validate Claims access patterns
|
||||
if strings.Contains(templateStr, "{{.Claims.") {
|
||||
// Simple validation - ensure claims access is to known safe fields
|
||||
safeClaimsFields := map[string]bool{
|
||||
"email": true,
|
||||
"name": true,
|
||||
"given_name": true,
|
||||
"family_name": true,
|
||||
"preferred_username": true,
|
||||
"sub": true,
|
||||
"iss": true,
|
||||
"aud": true,
|
||||
"exp": true,
|
||||
"iat": true,
|
||||
"groups": true,
|
||||
"roles": true,
|
||||
}
|
||||
|
||||
// Extract field names from Claims access
|
||||
start := strings.Index(templateStr, "{{.Claims.")
|
||||
for start != -1 {
|
||||
end := strings.Index(templateStr[start:], "}}")
|
||||
if end == -1 {
|
||||
return fmt.Errorf("malformed Claims template syntax")
|
||||
}
|
||||
|
||||
// Extract the content between "{{.Claims." and "}}"
|
||||
// start+10 skips "{{.Claims." and start+end is the position of "}}"
|
||||
claimsContent := templateStr[start+10 : start+end]
|
||||
|
||||
// Get the field name (first part before any dots)
|
||||
fieldName := strings.Split(claimsContent, ".")[0]
|
||||
|
||||
if !safeClaimsFields[fieldName] {
|
||||
return fmt.Errorf("access to Claims.%s is not allowed for security reasons", fieldName)
|
||||
}
|
||||
|
||||
// Fix the search for next occurrence
|
||||
nextStart := strings.Index(templateStr[start+end+2:], "{{.Claims.")
|
||||
if nextStart != -1 {
|
||||
start = start + end + 2 + nextStart
|
||||
} else {
|
||||
start = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SECURITY FIX: Prevent code injection through template syntax
|
||||
if strings.Contains(templateStr, "{{") && strings.Contains(templateStr, "}}") {
|
||||
// Count opening and closing braces
|
||||
openCount := strings.Count(templateStr, "{{")
|
||||
closeCount := strings.Count(templateStr, "}}")
|
||||
if openCount != closeCount {
|
||||
return fmt.Errorf("unbalanced template braces")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidSecureURL checks if a given string represents a valid, absolute HTTPS URL.
|
||||
// It uses url.Parse and checks for a nil error, an "https" scheme, and a non-empty host.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The URL string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the string is a valid HTTPS URL, false otherwise.
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
|
||||
//
|
||||
// Parameters:
|
||||
// - level: The log level string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the log level is valid, false otherwise.
|
||||
func isValidLogLevel(level string) bool {
|
||||
return level == "debug" || level == "info" || level == "error"
|
||||
}
|
||||
|
||||
// Logger provides structured logging capabilities with different severity levels.
|
||||
// It supports error, info, and debug levels with appropriate output streams
|
||||
// and formatting for each level.
|
||||
type Logger struct {
|
||||
// logError handles error-level messages, writing to stderr
|
||||
logError *log.Logger
|
||||
logInfo *log.Logger
|
||||
// logInfo handles informational messages, writing to stdout
|
||||
logInfo *log.Logger
|
||||
// logDebug handles debug-level messages, writing to stdout when debug is enabled
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger
|
||||
// NewLogger creates and configures a new Logger instance based on the provided log level.
|
||||
// It initializes loggers for ERROR (stderr), INFO (stdout), and DEBUG (stdout) levels,
|
||||
// enabling output based on the specified level:
|
||||
// - "error": Only ERROR messages are output.
|
||||
// - "info": INFO and ERROR messages are output.
|
||||
// - "debug": DEBUG, INFO, and ERROR messages are output.
|
||||
//
|
||||
// If an invalid level is provided, it defaults to behavior similar to "error".
|
||||
//
|
||||
// Parameters:
|
||||
// - logLevel: The desired logging level ("debug", "info", or "error").
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the configured Logger instance.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
|
||||
logError.SetOutput(os.Stderr)
|
||||
logInfo.SetOutput(os.Stdout)
|
||||
|
||||
if logLevel == "debug" || logLevel == "info" {
|
||||
logInfo.SetOutput(os.Stdout)
|
||||
}
|
||||
if logLevel == "debug" {
|
||||
logDebug.SetOutput(os.Stdout)
|
||||
}
|
||||
@@ -112,37 +408,88 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
// Info logs a message at the INFO level using Printf style formatting.
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
// Debug logs a message at the DEBUG level using Printf style formatting.
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
// Error logs a message at the ERROR level using Printf style formatting.
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an info message
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
// Infof logs a message at the INFO level using Printf style formatting.
|
||||
// Equivalent to calling l.Info(format, args...).
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Infof(format string, args ...any) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
// Debugf logs a message at the DEBUG level using Printf style formatting.
|
||||
// Equivalent to calling l.Debug(format, args...).
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debugf(format string, args ...any) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
// Errorf logs a message at the ERROR level using Printf style formatting.
|
||||
// Equivalent to calling l.Error(format, args...).
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Errorf(format string, args ...any) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to the response and logs it
|
||||
// newNoOpLogger creates a silent logger that doesn't output anything.
|
||||
// This is useful for internal components that need a logger instance
|
||||
// but should not produce any output by default.
|
||||
func newNoOpLogger() *Logger {
|
||||
return &Logger{
|
||||
logError: log.New(io.Discard, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
}
|
||||
|
||||
// handleError logs an error message using the provided logger and sends an HTTP error
|
||||
// response to the client with the specified message and status code.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The http.ResponseWriter to send the error response to.
|
||||
// - message: The error message string.
|
||||
// - code: The HTTP status code for the response.
|
||||
// - logger: The Logger instance to use for logging the error.
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
|
||||
@@ -0,0 +1,433 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Helper function to compare string slices
|
||||
func equalSlices(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestCreateConfig(t *testing.T) {
|
||||
t.Run("Default Values", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
|
||||
// Check default scopes
|
||||
expectedScopes := []string{"openid", "profile", "email"}
|
||||
if len(config.Scopes) != len(expectedScopes) {
|
||||
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
|
||||
}
|
||||
for i, scope := range expectedScopes {
|
||||
if config.Scopes[i] != scope {
|
||||
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Check default log level
|
||||
if config.LogLevel != DefaultLogLevel {
|
||||
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
|
||||
}
|
||||
|
||||
// Check default rate limit
|
||||
if config.RateLimit != DefaultRateLimit {
|
||||
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
|
||||
}
|
||||
|
||||
// Check ForceHTTPS default
|
||||
if !config.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true by default")
|
||||
}
|
||||
|
||||
// Check OverrideScopes default
|
||||
if config.OverrideScopes {
|
||||
t.Error("Expected OverrideScopes to be false by default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Config Can Hold Custom Values", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.Scopes = []string{"custom_scope"}
|
||||
config.LogLevel = "debug"
|
||||
config.RateLimit = 50
|
||||
config.ForceHTTPS = false
|
||||
config.OverrideScopes = true
|
||||
|
||||
// Verify config struct can hold custom values
|
||||
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
|
||||
t.Error("Config struct cannot hold custom scopes")
|
||||
}
|
||||
if config.LogLevel != "debug" {
|
||||
t.Error("Config struct cannot hold custom log level")
|
||||
}
|
||||
if config.RateLimit != 50 {
|
||||
t.Error("Config struct cannot hold custom rate limit")
|
||||
}
|
||||
if config.ForceHTTPS {
|
||||
t.Error("Config struct cannot hold custom ForceHTTPS value")
|
||||
}
|
||||
if !config.OverrideScopes {
|
||||
t.Error("Config struct cannot hold custom OverrideScopes value")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Empty Config",
|
||||
config: &Config{},
|
||||
expectedError: "providerURL is required",
|
||||
},
|
||||
{
|
||||
name: "Missing CallbackURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
},
|
||||
expectedError: "callbackURL is required",
|
||||
},
|
||||
{
|
||||
name: "Missing ClientID",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
},
|
||||
expectedError: "clientID is required",
|
||||
},
|
||||
{
|
||||
name: "Missing ClientSecret",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedError: "clientSecret is required",
|
||||
},
|
||||
{
|
||||
name: "Missing SessionEncryptionKey",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
},
|
||||
expectedError: "sessionEncryptionKey is required",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS ProviderURL",
|
||||
config: &Config{
|
||||
ProviderURL: "http://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "encryption-key",
|
||||
},
|
||||
expectedError: "providerURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Invalid CallbackURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "callback", // Missing leading slash
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "encryption-key",
|
||||
},
|
||||
expectedError: "callbackURL must start with /",
|
||||
},
|
||||
{
|
||||
name: "Short SessionEncryptionKey",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "short",
|
||||
},
|
||||
expectedError: "sessionEncryptionKey must be at least 32 characters long",
|
||||
},
|
||||
{
|
||||
name: "Low RateLimit",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
RateLimit: 5,
|
||||
},
|
||||
expectedError: "rateLimit must be at least 10",
|
||||
},
|
||||
{
|
||||
name: "Invalid LogLevel",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
LogLevel: "invalid",
|
||||
},
|
||||
expectedError: "logLevel must be one of: debug, info, error",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS RevocationURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
RevocationURL: "http://revoke.com",
|
||||
},
|
||||
expectedError: "revocationURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS OIDCEndSessionURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
OIDCEndSessionURL: "http://endsession.com",
|
||||
},
|
||||
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Valid Config",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
LogLevel: "debug",
|
||||
RateLimit: 100,
|
||||
RevocationURL: "https://revoke.com",
|
||||
OIDCEndSessionURL: "https://endsession.com",
|
||||
},
|
||||
expectedError: "",
|
||||
},
|
||||
{
|
||||
name: "Valid Config With AllowedUsers",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
LogLevel: "debug",
|
||||
RateLimit: 100,
|
||||
AllowedUsers: []string{"user1@example.com", "user2@example.com"},
|
||||
},
|
||||
expectedError: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.config.Validate()
|
||||
if tc.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
|
||||
} else if err.Error() != tc.expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
// Capture log output
|
||||
var debugBuf, infoBuf, errorBuf bytes.Buffer
|
||||
|
||||
tests := []struct {
|
||||
testFunc func(*Logger)
|
||||
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
|
||||
name string
|
||||
logLevel string
|
||||
}{
|
||||
{
|
||||
name: "Debug Level",
|
||||
logLevel: "debug",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut == "" {
|
||||
t.Error("Expected debug message in output")
|
||||
}
|
||||
if infoOut == "" {
|
||||
t.Error("Expected info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Info Level",
|
||||
logLevel: "info",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut != "" {
|
||||
t.Error("Did not expect debug message in output")
|
||||
}
|
||||
if infoOut == "" {
|
||||
t.Error("Expected info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error Level",
|
||||
logLevel: "error",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut != "" {
|
||||
t.Error("Did not expect debug message in output")
|
||||
}
|
||||
if infoOut != "" {
|
||||
t.Error("Did not expect info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Printf Methods",
|
||||
logLevel: "debug",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debugf("debug %s", "formatted")
|
||||
l.Infof("info %s", "formatted")
|
||||
l.Errorf("error %s", "formatted")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
|
||||
t.Error("Expected formatted debug message")
|
||||
}
|
||||
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
|
||||
t.Error("Expected formatted info message")
|
||||
}
|
||||
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
|
||||
t.Error("Expected formatted error message")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset buffers
|
||||
debugBuf.Reset()
|
||||
infoBuf.Reset()
|
||||
errorBuf.Reset()
|
||||
|
||||
// Create logger with test buffers
|
||||
logger := NewLogger(tc.logLevel)
|
||||
logger.logError.SetOutput(&errorBuf)
|
||||
|
||||
if tc.logLevel == "debug" || tc.logLevel == "info" {
|
||||
logger.logInfo.SetOutput(&infoBuf)
|
||||
}
|
||||
if tc.logLevel == "debug" {
|
||||
logger.logDebug.SetOutput(&debugBuf)
|
||||
}
|
||||
|
||||
// Run test
|
||||
tc.testFunc(logger)
|
||||
|
||||
// Check results
|
||||
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleError(t *testing.T) {
|
||||
// Create a test logger with captured output
|
||||
var errorBuf bytes.Buffer
|
||||
logger := &Logger{
|
||||
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
|
||||
}
|
||||
logger.logError.SetOutput(&errorBuf)
|
||||
|
||||
// Create a test response recorder
|
||||
rr := &testResponseRecorder{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
|
||||
// Test error handling
|
||||
message := "test error message"
|
||||
code := 400
|
||||
handleError(rr, message, code, logger)
|
||||
|
||||
// Check response code
|
||||
if rr.statusCode != code {
|
||||
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
|
||||
}
|
||||
|
||||
// Check response body
|
||||
expectedBody := message + "\n"
|
||||
if rr.body != expectedBody {
|
||||
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
|
||||
}
|
||||
|
||||
// Check error was logged
|
||||
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
|
||||
t.Error("Error message was not logged")
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper types
|
||||
type testResponseRecorder struct {
|
||||
headers map[string][]string
|
||||
body string
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Header() http.Header {
|
||||
return r.headers
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Write(b []byte) (int, error) {
|
||||
r.body = string(b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) WriteHeader(code int) {
|
||||
r.statusCode = code
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestTemplatedHeaderValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header TemplatedHeader
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Empty Name",
|
||||
header: TemplatedHeader{Name: "", Value: "{{.Claims.email}}"},
|
||||
expectedError: "header name cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "Empty Value",
|
||||
header: TemplatedHeader{Name: "X-Email", Value: ""},
|
||||
expectedError: "header value template cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "Not a Template",
|
||||
header: TemplatedHeader{Name: "X-Email", Value: "static-value"},
|
||||
expectedError: "header value 'static-value' does not appear to be a valid template (missing {{ }})",
|
||||
},
|
||||
{
|
||||
name: "Lowercase claims",
|
||||
header: TemplatedHeader{Name: "X-Email", Value: "{{.claims.email}}"},
|
||||
expectedError: "header template '{{.claims.email}}' appears to use lowercase 'claims' - use '{{.Claims...' instead (case sensitive)",
|
||||
},
|
||||
{
|
||||
name: "Lowercase accessToken",
|
||||
header: TemplatedHeader{Name: "X-Token", Value: "Bearer {{.accessToken}}"},
|
||||
expectedError: "header template 'Bearer {{.accessToken}}' appears to use lowercase 'accessToken' - use '{{.AccessToken...' instead (case sensitive)",
|
||||
},
|
||||
{
|
||||
name: "Lowercase idToken",
|
||||
header: TemplatedHeader{Name: "X-Token", Value: "Bearer {{.idToken}}"},
|
||||
expectedError: "header template 'Bearer {{.idToken}}' appears to use lowercase 'idToken' - use '{{.IdToken...' instead (case sensitive)",
|
||||
},
|
||||
{
|
||||
name: "Lowercase refreshToken",
|
||||
header: TemplatedHeader{Name: "X-Refresh", Value: "Bearer {{.refreshToken}}"},
|
||||
expectedError: "header template 'Bearer {{.refreshToken}}' appears to use lowercase 'refreshToken' - use '{{.RefreshToken...' instead (case sensitive)",
|
||||
},
|
||||
{
|
||||
name: "Valid Template",
|
||||
header: TemplatedHeader{Name: "X-Email", Value: "{{.Claims.email}}"},
|
||||
expectedError: "",
|
||||
},
|
||||
{
|
||||
name: "Valid Bearer Token Template",
|
||||
header: TemplatedHeader{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
expectedError: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
RateLimit: 10, // Adding minimum required rate limit
|
||||
Headers: []TemplatedHeader{tc.header},
|
||||
}
|
||||
|
||||
err := config.Validate()
|
||||
if tc.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error: %s, got nil", tc.expectedError)
|
||||
} else if err.Error() != tc.expectedError {
|
||||
t.Errorf("Expected error: %s, got: %s", tc.expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateParsingInNew(t *testing.T) {
|
||||
// Test successful parsing of templates during middleware creation
|
||||
tests := []struct {
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
expectedTemplates int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Single Valid Template",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Email", Value: "{{.Claims.email}}"},
|
||||
},
|
||||
expectedTemplates: 1,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple Valid Templates",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
expectedTemplates: 3,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid Template",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Email", Value: "{{.Claims.email"}, // Missing closing braces
|
||||
},
|
||||
expectedTemplates: 0,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Mix of Valid and Invalid Templates",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-Invalid", Value: "{{if .Claims.admin}}Admin{{end"}, // Invalid template
|
||||
},
|
||||
expectedTemplates: 1, // Only the valid template should be parsed
|
||||
expectError: true, // We expect an error for the invalid template, but we'll handle it
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// For testing template parsing, we'll directly try to parse the templates instead of using New()
|
||||
// This avoids the provider discovery that would fail in tests
|
||||
headerTemplates := make(map[string]*template.Template)
|
||||
|
||||
// Special handling for the mixed valid/invalid templates case
|
||||
if tc.name == "Mix of Valid and Invalid Templates" {
|
||||
// Process templates one at a time so we can still have valid templates
|
||||
for _, header := range tc.headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
if err != nil {
|
||||
// We expect an error for the invalid template
|
||||
if !tc.expectError {
|
||||
t.Errorf("Unexpected error parsing template %s: %v", header.Name, err)
|
||||
}
|
||||
// Skip this template but continue processing others
|
||||
continue
|
||||
}
|
||||
headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
} else {
|
||||
// Normal handling for other test cases
|
||||
var parseErr error
|
||||
for _, header := range tc.headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
if err != nil {
|
||||
parseErr = err
|
||||
break
|
||||
}
|
||||
headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
if tc.expectError {
|
||||
if parseErr == nil {
|
||||
t.Error("Expected error parsing templates but got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if parseErr != nil {
|
||||
t.Fatalf("Unexpected error: %v", parseErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Check the number of parsed templates
|
||||
if len(headerTemplates) != tc.expectedTemplates {
|
||||
t.Errorf("Expected %d parsed templates, got %d", tc.expectedTemplates, len(headerTemplates))
|
||||
}
|
||||
|
||||
// Check each template was parsed
|
||||
for _, header := range tc.headers {
|
||||
// Skip the known invalid templates
|
||||
if header.Value == "{{.Claims.email" || header.Value == "{{if .Claims.admin}}Admin{{end" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := headerTemplates[header.Name]; !ok {
|
||||
t.Errorf("Template for header %s was not parsed", header.Name)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,606 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// TestTemplateExecution tests that templates are executed correctly with different types of claims
|
||||
func TestTemplateExecution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]any
|
||||
expectedValue string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "String Claim",
|
||||
templateText: "{{.Claims.email}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
expectedValue: "user@example.com",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Number Claim",
|
||||
templateText: "{{.Claims.age}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"age": 30,
|
||||
},
|
||||
},
|
||||
expectedValue: "30",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Boolean Claim",
|
||||
templateText: "{{.Claims.admin}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "true",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Array Claim",
|
||||
templateText: "{{index .Claims.roles 0}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"roles": []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
expectedValue: "admin",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nested Object Claim",
|
||||
templateText: "{{.Claims.user.name}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"user": map[string]any{
|
||||
"name": "John Doe",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "John Doe",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Access Token",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
},
|
||||
expectedValue: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ID Token",
|
||||
templateText: "{{.IdToken}}",
|
||||
data: map[string]any{
|
||||
"IdToken": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
},
|
||||
expectedValue: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Refresh Token",
|
||||
templateText: "{{.RefreshToken}}",
|
||||
data: map[string]any{
|
||||
"RefreshToken": "refresh-token-value",
|
||||
},
|
||||
expectedValue: "refresh-token-value",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Conditional Template",
|
||||
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "Admin User",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple Claims",
|
||||
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"email": "john.doe@example.com",
|
||||
},
|
||||
},
|
||||
expectedValue: "John Doe <john.doe@example.com>",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Missing Claim",
|
||||
templateText: "{{.Claims.missing}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{},
|
||||
},
|
||||
expectedValue: "<no value>",
|
||||
expectError: false, // Go templates don't error on missing values
|
||||
},
|
||||
{
|
||||
name: "Invalid Template Syntax",
|
||||
templateText: "{{.Claims.email",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
expectedValue: "",
|
||||
expectError: true, // Parsing should fail
|
||||
},
|
||||
{
|
||||
name: "Custom Claims",
|
||||
templateText: "Role: {{.Claims.role}}, Department: {{.Claims.department}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"department": "engineering",
|
||||
},
|
||||
},
|
||||
expectedValue: "Role: admin, Department: engineering",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nested Custom Claims",
|
||||
templateText: "Org: {{.Claims.metadata.organization}}, Team: {{.Claims.metadata.team}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"metadata": map[string]any{
|
||||
"organization": "company-name",
|
||||
"team": "platform",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "Org: company-name, Team: platform",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Email Claims",
|
||||
templateText: "Email: {{.Claims.email}}, Verified: {{.Claims.email_verified}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "Email: user@example.com, Verified: true",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "User Identity Claims",
|
||||
templateText: "Name: {{.Claims.name}}, Subject: {{.Claims.sub}}, Username: {{.Claims.preferred_username}}",
|
||||
data: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"name": "John Doe",
|
||||
"sub": "user123",
|
||||
"preferred_username": "johndoe",
|
||||
},
|
||||
},
|
||||
expectedValue: "Name: John Doe, Subject: user123, Username: johndoe",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Fatal("Expected template parsing error, but got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTemplateExecutionContext tests the specific template data context used in processAuthorizedRequest
|
||||
func TestTemplateExecutionContext(t *testing.T) {
|
||||
// Test cases for map-based template data, matching the new implementation
|
||||
mapTests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]any
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction with map",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{},
|
||||
"RefreshToken": "refresh-token-value",
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims with map",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token",
|
||||
"IdToken": "id-token",
|
||||
"Claims": map[string]any{
|
||||
"sub": "user123",
|
||||
},
|
||||
"RefreshToken": "refresh-token",
|
||||
},
|
||||
expectedValue: "User: user123 Token: access-token",
|
||||
},
|
||||
{
|
||||
name: "Authorization header with Bearer token",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "jwt-access-token",
|
||||
"IdToken": "id-token",
|
||||
"Claims": map[string]any{},
|
||||
},
|
||||
expectedValue: "Bearer jwt-access-token",
|
||||
},
|
||||
{
|
||||
name: "Boolean template data with AccessToken",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": true, // Test boolean values to ensure they render correctly
|
||||
},
|
||||
expectedValue: "Bearer true",
|
||||
},
|
||||
{
|
||||
name: "Custom non-standard claims in ID token",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"permissions": "read:all,write:own",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
|
||||
},
|
||||
{
|
||||
name: "Deeply nested custom claims",
|
||||
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"Claims": map[string]any{
|
||||
"app_metadata": map[string]any{
|
||||
"organization": map[string]any{
|
||||
"name": "acme-corp",
|
||||
"id": "org-123",
|
||||
},
|
||||
"team": "platform",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Organization: acme-corp, X-Team: platform",
|
||||
},
|
||||
{
|
||||
name: "Email in claims",
|
||||
templateText: "X-User-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Email: user@example.com, X-Email-Verified: true",
|
||||
},
|
||||
{
|
||||
name: "User info from claims",
|
||||
templateText: "X-User-ID: {{.Claims.sub}}, X-User-Name: {{.Claims.name}}, X-Username: {{.Claims.preferred_username}}",
|
||||
data: map[string]any{
|
||||
"AccessToken": "access-token-value",
|
||||
"IdToken": "id-token-value",
|
||||
"Claims": map[string]any{
|
||||
"sub": "user123456",
|
||||
"name": "Jane Doe",
|
||||
"preferred_username": "jane.doe",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-ID: user123456, X-User-Name: Jane Doe, X-Username: jane.doe",
|
||||
},
|
||||
}
|
||||
|
||||
// Run map-based tests (matching the new implementation)
|
||||
for _, tc := range mapTests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// For backward compatibility, also test the original struct-based implementation
|
||||
type templateData struct {
|
||||
Claims map[string]any
|
||||
AccessToken string
|
||||
IdToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// Test cases for struct-based template data (original implementation)
|
||||
structTests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data templateData
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction with struct",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token-value",
|
||||
IdToken: "id-token-value", // Now these should be distinct values
|
||||
Claims: map[string]any{},
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims with struct",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "access-token",
|
||||
Claims: map[string]any{
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
expectedValue: "User: user123 Token: access-token",
|
||||
},
|
||||
{
|
||||
name: "Custom claims with struct",
|
||||
templateText: "X-Custom: {{.Claims.custom_field}}, X-Group: {{.Claims.group}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "id-token",
|
||||
Claims: map[string]any{
|
||||
"sub": "user123",
|
||||
"custom_field": "custom-value",
|
||||
"group": "admins",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Custom: custom-value, X-Group: admins",
|
||||
},
|
||||
{
|
||||
name: "Email claim in struct context",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}",
|
||||
data: templateData{
|
||||
AccessToken: "access-token",
|
||||
IdToken: "id-token",
|
||||
Claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "John Smith",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Email: user@example.com, X-Name: John Smith",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range structTests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegressionBooleanAccessToken specifically tests the regression case where
|
||||
// a boolean value was causing "can't evaluate field AccessToken in type bool" error
|
||||
func TestRegressionBooleanAccessToken(t *testing.T) {
|
||||
// Test the specific case where we execute a template referencing AccessToken
|
||||
// using a boolean context value
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
dataContext any
|
||||
expectedValue string
|
||||
expectError bool // Added to skip the test that demonstrates the error
|
||||
}{
|
||||
{
|
||||
name: "Map with boolean as root",
|
||||
templateText: "{{.AccessToken}}",
|
||||
dataContext: map[string]any{"AccessToken": "token-value"},
|
||||
expectedValue: "token-value",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Boolean as root context",
|
||||
templateText: "{{.AccessToken}}",
|
||||
dataContext: true,
|
||||
expectedValue: "<no value>",
|
||||
expectError: true, // Skip this test as it demonstrates the error we're fixing
|
||||
},
|
||||
{
|
||||
name: "Bearer with map context",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
dataContext: map[string]any{"AccessToken": "token-value"},
|
||||
expectedValue: "Bearer token-value",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Complex nesting with authorization",
|
||||
templateText: "Authorization: Bearer {{.AccessToken}}",
|
||||
dataContext: map[string]any{
|
||||
"AccessToken": "jwt-token-123",
|
||||
"something": true,
|
||||
"anotherField": map[string]any{
|
||||
"nested": "value",
|
||||
},
|
||||
},
|
||||
expectedValue: "Authorization: Bearer jwt-token-123",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Custom claims access",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Groups: {{.Claims.groups}}",
|
||||
dataContext: map[string]any{
|
||||
"AccessToken": "jwt-token-xyz",
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"role": "admin",
|
||||
"groups": "group1,group2,group3",
|
||||
"custom_data": map[string]any{
|
||||
"organization": "company-name",
|
||||
"department": "engineering",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Role: admin, X-User-Groups: group1,group2,group3",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nested custom claims access",
|
||||
templateText: "X-Organization: {{.Claims.custom_data.organization}}, X-Department: {{.Claims.custom_data.department}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"custom_data": map[string]any{
|
||||
"organization": "company-name",
|
||||
"department": "engineering",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Organization: company-name, X-Department: engineering",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Azure AD specific claims",
|
||||
templateText: "X-TenantID: {{.Claims.tid}}, X-Roles: {{.Claims.roles}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"tid": "tenant-id-12345",
|
||||
"roles": "User,Admin,Developer",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-TenantID: tenant-id-12345, X-Roles: User,Admin,Developer",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Auth0 specific claims",
|
||||
templateText: "X-Permissions: {{.Claims.permissions}}, X-AppMetadata: {{.Claims.app_metadata.plan}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"permissions": "read:products,write:orders",
|
||||
"app_metadata": map[string]any{
|
||||
"plan": "premium",
|
||||
"status": "active",
|
||||
"trial_ended": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Permissions: read:products,write:orders, X-AppMetadata: premium",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Standard claims with email",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Name: {{.Claims.name}}, X-Subject: {{.Claims.sub}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "John Doe",
|
||||
"sub": "auth0|12345",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Email: user@example.com, X-Name: John Doe, X-Subject: auth0|12345",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Verified email claim",
|
||||
templateText: "X-Email: {{.Claims.email}}, X-Email-Verified: {{.Claims.email_verified}}",
|
||||
dataContext: map[string]any{
|
||||
"Claims": map[string]any{
|
||||
"email": "user@example.com",
|
||||
"email_verified": true,
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Email: user@example.com, X-Email-Verified: true",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
// Skip tests that demonstrate the error
|
||||
if tc.expectError {
|
||||
t.Skip("Skipping test that demonstrates the error we're fixing")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.dataContext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,593 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestTemplatedHeadersIntegration tests that templated headers are correctly added to requests
|
||||
// in the actual middleware flow
|
||||
func TestTemplatedHeadersIntegration(t *testing.T) {
|
||||
// Create a TestSuite to use its helper methods and fields
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
sessionSetup func(*SessionData)
|
||||
claims map[string]any
|
||||
expectedHeaders map[string]string
|
||||
interceptedHeaders map[string]string
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
}{
|
||||
{
|
||||
name: "Basic Email Header",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Email": "user@example.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple Headers",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
{Name: "X-User-Name", Value: "{{.Claims.given_name}} {{.Claims.family_name}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"sub": "user123",
|
||||
"given_name": "John",
|
||||
"family_name": "Doe",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Email": "user@example.com",
|
||||
"X-User-ID": "user123",
|
||||
"X-User-Name": "John Doe",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Authorization Header with Bearer Token",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update this dynamically after generating the token
|
||||
"Authorization": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ID Token Header",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update this dynamically after generating the token
|
||||
"X-ID-Token": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Both Token Types",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update these dynamically after generating the tokens
|
||||
"X-Access-Token": "",
|
||||
"X-ID-Token": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Missing Claim",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Role", Value: "{{.Claims.role}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
// role claim is missing
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Role": "<no value>", // Go templates provide <no value> for missing fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Conditional Header",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Admin", Value: "{{if .Claims.is_admin}}true{{else}}false{{end}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
"email": "admin@example.com",
|
||||
"is_admin": true,
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Admin": "true",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Combined Token and Claim",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Auth-Info", Value: "User={{.Claims.email}}, Token={{.AccessToken}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
// We'll update this dynamically after generating the token
|
||||
"X-Auth-Info": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Opaque Access Token with AccessTokenField",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-AccessToken", Value: "{{.AccessToken}}"},
|
||||
},
|
||||
claims: map[string]any{ // For ID Token
|
||||
"email": "opaque_user@example.com",
|
||||
"sub": "opaque_sub_for_id_token",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-AccessToken": "this_is_an_opaque_access_token",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with the test claims
|
||||
token := ts.token
|
||||
if len(tc.claims) > 0 {
|
||||
var err error
|
||||
baseClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000), // Far future timestamp
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateRandomString(16),
|
||||
}
|
||||
|
||||
// Add the test-specific claims
|
||||
maps.Copy(baseClaims, tc.claims)
|
||||
|
||||
token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", baseClaims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update expectedHeaders for the token-based tests after token generation
|
||||
if tc.name == "Authorization Header with Bearer Token" {
|
||||
tc.expectedHeaders["Authorization"] = "Bearer " + token
|
||||
}
|
||||
|
||||
if tc.name == "Combined Token and Claim" {
|
||||
// If this test case uses specific ID/Access tokens, 'token' here might be just the ID token.
|
||||
// This part might need adjustment if AccessToken is different and opaque.
|
||||
// For now, assuming 'token' is the one to be used if not overridden later.
|
||||
// The specific test "Opaque Access Token with AccessTokenField" will handle its AccessToken.
|
||||
// This generic 'token' is used as a fallback if specific logic isn't hit.
|
||||
// Let's ensure this test case uses the JWT access token if not otherwise specified.
|
||||
accessTokenForHeader := token // Default to the generated JWT 'token'
|
||||
if sessionVal, ok := tc.claims["_accessToken"]; ok { // Check if a specific access token is provided for this test
|
||||
accessTokenForHeader = sessionVal.(string)
|
||||
}
|
||||
tc.expectedHeaders["X-Auth-Info"] = "User=" + tc.claims["email"].(string) + ", Token=" + accessTokenForHeader
|
||||
}
|
||||
|
||||
// Store intercepted headers for verification
|
||||
interceptedHeaders := make(map[string]string)
|
||||
|
||||
// Create a test next handler that captures the headers
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture headers for verification
|
||||
for name := range tc.expectedHeaders {
|
||||
if value := r.Header.Get(name); value != "" {
|
||||
interceptedHeaders[name] = value
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: NewLogger("debug"),
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}, "opaque_user@example.com": {}}, // Ensure domain for opaque test is allowed
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: ts.sessionManager,
|
||||
extractClaimsFunc: extractClaims,
|
||||
headerTemplates: make(map[string]*template.Template),
|
||||
// Default to true, which means PopulateSessionWithIdTokenClaims is true
|
||||
// UseIdTokenForSession: true, // Explicitly can be set if needed
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenExchanger = tOidc
|
||||
|
||||
// Initialize and parse header templates
|
||||
for _, header := range tc.headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse header template for %s: %v", header.Name, err)
|
||||
}
|
||||
tOidc.headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
close(tOidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
// Set a default email; specific tests might override or rely on ID token population
|
||||
defaultEmail := "user@example.com"
|
||||
if emailClaim, ok := tc.claims["email"].(string); ok {
|
||||
defaultEmail = emailClaim // Use email from claims if available for initial setup
|
||||
}
|
||||
session.SetEmail(defaultEmail)
|
||||
|
||||
// Default token setup (can be overridden by specific test cases below)
|
||||
session.SetIDToken(token)
|
||||
session.SetAccessToken(token)
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
if tc.name == "ID Token Header" || tc.name == "Both Token Types" {
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject",
|
||||
"nonce": "test-nonce", "jti": generateRandomString(16), "type": "id_token",
|
||||
"email": tc.claims["email"], // Ensure email from test case claims is in ID token
|
||||
}
|
||||
// Add other claims from tc.claims to idTokenClaims
|
||||
for k, v := range tc.claims {
|
||||
if _, exists := idTokenClaims[k]; !exists {
|
||||
idTokenClaims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
idTokenForSession, idErr := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", idTokenClaims)
|
||||
if idErr != nil {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", idErr)
|
||||
}
|
||||
|
||||
accessTokenClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject",
|
||||
"jti": generateRandomString(16), "type": "access_token", "scope": "openid email profile",
|
||||
"email": tc.claims["email"], // Include email in access token too for these tests
|
||||
}
|
||||
accessTokenForSession, accessErr := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", accessTokenClaims)
|
||||
if accessErr != nil {
|
||||
t.Fatalf("Failed to create test access JWT: %v", accessErr)
|
||||
}
|
||||
|
||||
session.SetIDToken(idTokenForSession)
|
||||
session.SetAccessToken(accessTokenForSession)
|
||||
|
||||
tOidc.tokenExchanger = &MockTokenExchanger{
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: idTokenForSession, AccessToken: accessTokenForSession,
|
||||
RefreshToken: refreshToken, ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
tOidc.tokenVerifier = &MockTokenVerifier{VerifyFunc: func(token string) error { return nil }}
|
||||
|
||||
if tc.name == "ID Token Header" {
|
||||
tc.expectedHeaders["X-ID-Token"] = idTokenForSession
|
||||
} else if tc.name == "Both Token Types" {
|
||||
tc.expectedHeaders["X-ID-Token"] = idTokenForSession
|
||||
tc.expectedHeaders["X-Access-Token"] = accessTokenForSession
|
||||
}
|
||||
} else if tc.name == "Opaque Access Token with AccessTokenField" {
|
||||
idTokenClaims := map[string]any{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": float64(3000000000),
|
||||
"iat": float64(1000000000), "nbf": float64(1000000000), "sub": "test-subject", // Default sub
|
||||
"nonce": "test-nonce", "jti": generateRandomString(16), "type": "id_token",
|
||||
}
|
||||
// Populate ID token claims from tc.claims
|
||||
maps.Copy(idTokenClaims, tc.claims)
|
||||
// Ensure email from tc.claims is used for the ID token
|
||||
session.SetEmail(tc.claims["email"].(string)) // Also set it directly for initial session state
|
||||
|
||||
idTokenForSession, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", idTokenClaims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test ID JWT for opaque test: %v", err)
|
||||
}
|
||||
|
||||
opaqueAccessToken := "this_is_an_opaque_access_token"
|
||||
|
||||
session.SetIDToken(idTokenForSession)
|
||||
session.SetAccessToken(opaqueAccessToken)
|
||||
|
||||
tOidc.tokenExchanger = &MockTokenExchanger{
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: idTokenForSession,
|
||||
AccessToken: opaqueAccessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
tOidc.tokenVerifier = &MockTokenVerifier{
|
||||
VerifyFunc: func(tokenToVerify string) error {
|
||||
if tokenToVerify == idTokenForSession {
|
||||
return nil // ID token is expected to be verified
|
||||
}
|
||||
if tokenToVerify == opaqueAccessToken {
|
||||
t.Errorf("TokenVerifier was incorrectly called with the opaque access token.")
|
||||
return errors.New("opaque access token should not be verified by this path")
|
||||
}
|
||||
t.Logf("TokenVerifier called with unexpected token: %s", tokenToVerify)
|
||||
return errors.New("unexpected token passed to verifier for this test case")
|
||||
},
|
||||
}
|
||||
// Expected header X-User-AccessToken is already set in tc.expectedHeaders
|
||||
}
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d. Body: %s", http.StatusOK, rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
for name, expectedValue := range tc.expectedHeaders {
|
||||
if value, exists := interceptedHeaders[name]; !exists {
|
||||
// For <no value> case, it might not be set if template resolves to empty and header is omitted.
|
||||
// However, Go templates usually insert "<no value>" string.
|
||||
if expectedValue == "<no value>" && tc.name == "Missing Claim" { // Special handling for <no value>
|
||||
// If the template {{.Claims.role}} results in an empty string because role is missing,
|
||||
// and the header is not set, this is also acceptable for "<no value>".
|
||||
// The current test expects the literal string "<no value>".
|
||||
// Let's assume for now that if it's missing, it's an error unless specifically handled.
|
||||
// The test as written expects "<no value>" to be present.
|
||||
}
|
||||
t.Errorf("Expected header %s was not set", name)
|
||||
|
||||
} else if value != expectedValue {
|
||||
t.Errorf("Header %s expected value %q, got %q", name, expectedValue, value)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.name == "Opaque Access Token with AccessTokenField" {
|
||||
postReq := httptest.NewRequest("GET", "/protected", nil)
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
postReq.AddCookie(cookie)
|
||||
}
|
||||
updatedSession, err := tOidc.sessionManager.GetSession(postReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated session for opaque test: %v", err)
|
||||
}
|
||||
|
||||
expectedEmail := tc.claims["email"].(string)
|
||||
if updatedSession.GetEmail() != expectedEmail {
|
||||
t.Errorf("Expected session email to be %q (from ID token), got %q", expectedEmail, updatedSession.GetEmail())
|
||||
}
|
||||
if !updatedSession.GetAuthenticated() {
|
||||
t.Errorf("Session should be authenticated after successful flow for opaque test")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEdgeCaseTemplatedHeaders tests edge cases for templated headers
|
||||
func TestEdgeCaseTemplatedHeaders(t *testing.T) {
|
||||
// Create a TestSuite to use its helper methods and fields
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
shouldExecuteCheck bool
|
||||
}{
|
||||
{
|
||||
name: "Very Large Template",
|
||||
headers: []TemplatedHeader{
|
||||
{
|
||||
Name: "X-Large-Header",
|
||||
Value: createLargeTemplate(500), // Template with 500 variable references
|
||||
},
|
||||
},
|
||||
claims: createLargeClaims(500), // Map with 500 claims
|
||||
shouldExecuteCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Array Claim Access",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-Roles", Value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"},
|
||||
},
|
||||
claims: map[string]any{
|
||||
"roles": []any{"admin", "user", "manager"},
|
||||
},
|
||||
shouldExecuteCheck: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with the test claims
|
||||
claims := map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000), // Far future timestamp
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateRandomString(16),
|
||||
}
|
||||
|
||||
// Add the test-specific claims
|
||||
maps.Copy(claims, tc.claims)
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Create a test next handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: NewLogger("debug"),
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: ts.sessionManager,
|
||||
extractClaimsFunc: extractClaims,
|
||||
headerTemplates: make(map[string]*template.Template),
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc
|
||||
tOidc.jwtVerifier = tOidc
|
||||
|
||||
// Initialize and parse header templates
|
||||
for _, header := range tc.headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse header template for %s: %v", header.Name, err)
|
||||
}
|
||||
tOidc.headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
close(tOidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetIDToken(token) // Use the new method
|
||||
session.SetAccessToken(token) // Also set access token to match
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
tOidc.extractClaimsFunc = extractClaims
|
||||
tOidc.tokenExchanger = &MockTokenExchanger{
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: token,
|
||||
AccessToken: token,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// The "Array Claim Access" check previously here was problematic as it didn't correctly
|
||||
// intercept headers in TestEdgeCaseTemplatedHeaders. The primary goal of this
|
||||
// function is to test edge cases for panics/errors, and robust header value
|
||||
// checking is already covered in TestTemplatedHeadersIntegration.
|
||||
// Removing the ineffective check to resolve the "declared and not used" error.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for edge case tests
|
||||
|
||||
// createLargeTemplate creates a template with many variable references
|
||||
func createLargeTemplate(size int) string {
|
||||
template := "{{with .Claims}}"
|
||||
for i := range size {
|
||||
if i > 0 {
|
||||
template += ","
|
||||
}
|
||||
template += "{{.field" + string(rune('a'+i%26)) + string(rune('0'+i%10)) + "}}"
|
||||
}
|
||||
template += "{{end}}"
|
||||
return template
|
||||
}
|
||||
|
||||
// createLargeClaims creates a map with many claims for testing large templates
|
||||
func createLargeClaims(size int) map[string]any {
|
||||
claims := make(map[string]any)
|
||||
for i := range size {
|
||||
claims["email"] = "largeclaimsuser@example.com" // Add email claim
|
||||
key := "field" + string(rune('a'+i%26)) + string(rune('0'+i%10))
|
||||
claims[key] = "value" + string(rune('a'+i%26)) + string(rune('0'+i%10))
|
||||
}
|
||||
return claims
|
||||
}
|
||||
+391
@@ -0,0 +1,391 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestTokens provides a comprehensive set of standardized test tokens
|
||||
// for consistent testing across the entire codebase.
|
||||
type TestTokens struct{}
|
||||
|
||||
// NewTestTokens creates a new TestTokens instance
|
||||
func NewTestTokens() *TestTokens {
|
||||
return &TestTokens{}
|
||||
}
|
||||
|
||||
// Valid JWT tokens for testing
|
||||
const (
|
||||
// ValidAccessToken - A properly formatted JWT access token for testing
|
||||
ValidAccessToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiJlNDcxN2RhZDBmZjAyOTNkIiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.bmwp-vk0B7Ir9UiUkzib8L7yJbebJ00o3U9QrB6gP2H9-RfqyCbN8M9Rkx7Rb8Vdh3YzqkBBoLS_G0i414rs2I9uABnTC4E6-63qkGdUrLB7p-XbjcRW2RoIBwXHk7lfumi8eX0uWzBsJ9CY0__UECVsex5XORfBb4Bcqj0LK4y-glxkpI51I7BPySfciWC_PkdaQ1Qe5pCAlxeNs2E9NMGXp-Ox6vAufUzoC2cws1LswGPPP6icQ-Zlzd5WMCIWhdIkN4yTxk8FMqsTC52k2zskRHNSSd4DDVETonfzawZNqDcMpnTyN53sCJ9UHiQTl9mCm61ttYW-W9Gc-ze4Xw"
|
||||
|
||||
// ValidIDToken - A properly formatted JWT ID token for testing
|
||||
ValidIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTc1MDI5NDYyOCwiaWF0IjoxNzUwMjkxMDI4LCJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJqdGkiOiI2YzBjZTZmMTM4Y2EzMzc2IiwibmJmIjoxNzUwMjkxMDI4LCJub25jZSI6Im5vbmNlMTIzIiwic3ViIjoidGVzdC1zdWJqZWN0In0.RBQYejA9vP4lnh2EhFqWerePWaCyDTF0ZE1jlU2xm4g2wWVeaEHpv5SNg92_gwk633N9xx7ugS0UrlEu4qbT7wSb1HBDR00q_andyYnyFk4OoxPpD0AqHkVr-pjS-Z7UCGF3sLgQ4ECmU9695PIys3XvgUGMzEn_mK-PHcpY5AnbBGFsbj7epUld_sb6WfjjjwAa8kKfKObPvaIpuJ4TlxI1Uf0wYOoIA0zh5ipeAn-i8Ud-GErxis1Hp8UQK7IRolXpToiXnFcnf3vI3eCS7Yu3oPl7LRxTxKMCI9h0MCwu25ZNsOg2C9ohyebpU0jbURX9Q74GNOaphv-Lz9rCRA"
|
||||
|
||||
// ValidRefreshToken - A properly formatted refresh token for testing
|
||||
ValidRefreshToken = "valid-refresh-token-12345"
|
||||
|
||||
// MinimalValidJWT - The shortest valid JWT for testing
|
||||
MinimalValidJWT = "h.p.s"
|
||||
|
||||
// ValidRefreshTokenGoogle - A Google-style refresh token for testing
|
||||
ValidRefreshTokenGoogle = "google_refresh_token_12345"
|
||||
)
|
||||
|
||||
// Invalid tokens for testing validation
|
||||
const (
|
||||
// InvalidTokenNoDots - Token with no dots (invalid JWT format)
|
||||
InvalidTokenNoDots = "notajwttoken"
|
||||
|
||||
// InvalidTokenOneDot - Token with one dot (invalid JWT format)
|
||||
InvalidTokenOneDot = "header.payload"
|
||||
|
||||
// InvalidTokenThreeDots - Token with three dots (invalid JWT format)
|
||||
InvalidTokenThreeDots = "header.payload.signature.extra"
|
||||
|
||||
// EmptyToken - Empty token
|
||||
EmptyToken = ""
|
||||
|
||||
// CorruptedBase64Token - Token with invalid base64 data for chunking tests
|
||||
CorruptedBase64Token = "corrupted_base64_!@#$"
|
||||
)
|
||||
|
||||
// CreateLargeValidJWT creates a JWT of approximately the specified size
|
||||
// This replaces the ad-hoc createLargeValidJWT function in tests
|
||||
func (tt *TestTokens) CreateLargeValidJWT(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "signature_" + tt.generateRandomString(32)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 50)
|
||||
|
||||
// Create a payload with realistic JWT claims
|
||||
claims := map[string]any{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": 9999999999,
|
||||
"iat": 1000000000,
|
||||
}
|
||||
|
||||
// FIXED: Calculate data size safely
|
||||
dataSize := max(
|
||||
// Account for other claims and base64 encoding
|
||||
payloadSize-100,
|
||||
// Minimum data size
|
||||
10)
|
||||
|
||||
claims["data"] = tt.generateRandomString(dataSize)
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// CreateLargeRefreshToken creates a refresh token of approximately the specified size
|
||||
func (tt *TestTokens) CreateLargeRefreshToken(targetSize int) string {
|
||||
baseToken := "refresh_token_"
|
||||
padding := tt.generateRandomString(targetSize - len(baseToken))
|
||||
return baseToken + padding
|
||||
}
|
||||
|
||||
// CreateExpiredJWT creates an expired JWT token for testing
|
||||
func (tt *TestTokens) CreateExpiredJWT() string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
// Create claims with expired timestamp
|
||||
claims := map[string]any{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": time.Now().Unix() - 3600, // Expired 1 hour ago
|
||||
"iat": time.Now().Unix() - 7200, // Issued 2 hours ago
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := "expired_signature"
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// CreateUniqueValidJWT creates a unique valid JWT for concurrent testing
|
||||
func (tt *TestTokens) CreateUniqueValidJWT(id string) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
|
||||
claims := map[string]any{
|
||||
"sub": "user_" + id,
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": 9999999999,
|
||||
"iat": 1000000000,
|
||||
"jti": id,
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := "sig_" + id
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// CreateIncompressibleToken creates a token that cannot be compressed effectively
|
||||
// This is useful for testing chunking scenarios where compression doesn't help
|
||||
func (tt *TestTokens) CreateIncompressibleToken(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "incompressible_signature_" + tt.generateRandomString(32)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 100)
|
||||
|
||||
// Generate multiple random fields to prevent compression
|
||||
randomFields := make(map[string]any)
|
||||
randomFields["sub"] = "user123"
|
||||
randomFields["iss"] = "https://example.com"
|
||||
randomFields["aud"] = "client123"
|
||||
randomFields["exp"] = 9999999999
|
||||
randomFields["iat"] = 1000000000
|
||||
|
||||
// Add many random fields with random data to prevent compression
|
||||
remainingSize := payloadSize - 200 // Account for base64 encoding and other fields
|
||||
fieldCount := max(
|
||||
// ~100 bytes per field
|
||||
remainingSize/100, 1)
|
||||
|
||||
for i := range fieldCount {
|
||||
// Generate truly random data for each field
|
||||
randomBytes := make([]byte, 50)
|
||||
rand.Read(randomBytes)
|
||||
fieldName := fmt.Sprintf("random_field_%d_%s", i, tt.generateRandomString(8))
|
||||
randomFields[fieldName] = base64.StdEncoding.EncodeToString(randomBytes)
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(randomFields)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
token := fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
|
||||
// If still too small, pad with more random data
|
||||
if len(token) < targetSize {
|
||||
padding := targetSize - len(token)
|
||||
extraRandomBytes := make([]byte, padding/2)
|
||||
rand.Read(extraRandomBytes)
|
||||
randomFields["padding"] = base64.StdEncoding.EncodeToString(extraRandomBytes)
|
||||
claimsJSON, _ = json.Marshal(randomFields)
|
||||
payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
token = fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// GetValidTokenSet returns a complete set of valid tokens for testing
|
||||
func (tt *TestTokens) GetValidTokenSet() TokenSet {
|
||||
return TokenSet{
|
||||
AccessToken: ValidAccessToken,
|
||||
IDToken: ValidIDToken,
|
||||
RefreshToken: ValidRefreshToken,
|
||||
}
|
||||
}
|
||||
|
||||
// GetGoogleTokenSet returns tokens that simulate Google OIDC provider responses
|
||||
func (tt *TestTokens) GetGoogleTokenSet() TokenSet {
|
||||
return TokenSet{
|
||||
AccessToken: ValidAccessToken,
|
||||
IDToken: ValidIDToken,
|
||||
RefreshToken: ValidRefreshTokenGoogle,
|
||||
}
|
||||
}
|
||||
|
||||
// GetLargeTokenSet returns a set of large tokens for chunking tests
|
||||
func (tt *TestTokens) GetLargeTokenSet() TokenSet {
|
||||
return TokenSet{
|
||||
AccessToken: tt.CreateLargeValidJWT(5000),
|
||||
IDToken: tt.CreateLargeValidJWT(2000),
|
||||
RefreshToken: tt.CreateLargeRefreshToken(3000),
|
||||
}
|
||||
}
|
||||
|
||||
// GetInvalidTokens returns various invalid tokens for validation testing
|
||||
func (tt *TestTokens) GetInvalidTokens() InvalidTokenSet {
|
||||
return InvalidTokenSet{
|
||||
NoDots: InvalidTokenNoDots,
|
||||
OneDot: InvalidTokenOneDot,
|
||||
ThreeDots: InvalidTokenThreeDots,
|
||||
Empty: EmptyToken,
|
||||
Corrupted: CorruptedBase64Token,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomString creates a random string of the specified length
|
||||
func (tt *TestTokens) generateRandomString(length int) string {
|
||||
// FIXED: Handle negative or zero lengths safely
|
||||
if length <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
randomByte := make([]byte, 1)
|
||||
rand.Read(randomByte)
|
||||
b[i] = charset[int(randomByte[0])%len(charset)]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// TokenSet represents a complete set of tokens for testing
|
||||
type TokenSet struct {
|
||||
AccessToken string
|
||||
IDToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// InvalidTokenSet represents various invalid tokens for validation testing
|
||||
type InvalidTokenSet struct {
|
||||
NoDots string // Token with 0 dots
|
||||
OneDot string // Token with 1 dot
|
||||
ThreeDots string // Token with 3 dots
|
||||
Empty string // Empty token
|
||||
Corrupted string // Corrupted/invalid characters
|
||||
}
|
||||
|
||||
// TestScenarios provides predefined test scenarios
|
||||
type TestScenarios struct {
|
||||
tokens *TestTokens
|
||||
}
|
||||
|
||||
// NewTestScenarios creates a new TestScenarios instance
|
||||
func NewTestScenarios() *TestScenarios {
|
||||
return &TestScenarios{
|
||||
tokens: NewTestTokens(),
|
||||
}
|
||||
}
|
||||
|
||||
// NormalFlow returns tokens for normal authentication flow testing
|
||||
func (ts *TestScenarios) NormalFlow() TokenSet {
|
||||
return ts.tokens.GetValidTokenSet()
|
||||
}
|
||||
|
||||
// GoogleFlow returns tokens simulating Google OIDC provider
|
||||
func (ts *TestScenarios) GoogleFlow() TokenSet {
|
||||
return ts.tokens.GetGoogleTokenSet()
|
||||
}
|
||||
|
||||
// ChunkingRequired returns large tokens that require chunking
|
||||
func (ts *TestScenarios) ChunkingRequired() TokenSet {
|
||||
return ts.tokens.GetLargeTokenSet()
|
||||
}
|
||||
|
||||
// CorruptionTest returns tokens and corruption scenarios for testing
|
||||
func (ts *TestScenarios) CorruptionTest() CorruptionTestSet {
|
||||
return CorruptionTestSet{
|
||||
ValidTokens: ts.tokens.GetValidTokenSet(),
|
||||
InvalidTokens: ts.tokens.GetInvalidTokens(),
|
||||
LargeTokens: ts.tokens.GetLargeTokenSet(),
|
||||
CorruptedToken: CorruptedBase64Token,
|
||||
}
|
||||
}
|
||||
|
||||
// ConcurrentTest returns unique tokens for concurrent testing
|
||||
func (ts *TestScenarios) ConcurrentTest(count int) []TokenSet {
|
||||
sets := make([]TokenSet, count)
|
||||
for i := range count {
|
||||
sets[i] = TokenSet{
|
||||
AccessToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("concurrent_%d", i)),
|
||||
IDToken: ts.tokens.CreateUniqueValidJWT(fmt.Sprintf("id_%d", i)),
|
||||
RefreshToken: fmt.Sprintf("refresh_concurrent_%d", i),
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
// CorruptionTestSet represents tokens and scenarios for corruption testing
|
||||
type CorruptionTestSet struct {
|
||||
ValidTokens TokenSet
|
||||
InvalidTokens InvalidTokenSet
|
||||
LargeTokens TokenSet
|
||||
CorruptedToken string
|
||||
}
|
||||
|
||||
// TokenValidationTestCases returns test cases for token validation
|
||||
func (tt *TestTokens) TokenValidationTestCases() []ValidationTestCase {
|
||||
return []ValidationTestCase{
|
||||
{
|
||||
Name: "Empty token",
|
||||
Token: EmptyToken,
|
||||
ExpectStored: true, // Empty tokens are allowed for clearing
|
||||
ExpectRetrieved: false, // But return as empty
|
||||
},
|
||||
{
|
||||
Name: "Single dot",
|
||||
Token: InvalidTokenOneDot,
|
||||
ExpectStored: false, // Invalid JWT format
|
||||
ExpectRetrieved: false,
|
||||
},
|
||||
{
|
||||
Name: "No dots",
|
||||
Token: InvalidTokenNoDots,
|
||||
ExpectStored: false, // Invalid JWT format
|
||||
ExpectRetrieved: false,
|
||||
},
|
||||
{
|
||||
Name: "Too many dots",
|
||||
Token: InvalidTokenThreeDots,
|
||||
ExpectStored: false, // Invalid JWT format
|
||||
ExpectRetrieved: false,
|
||||
},
|
||||
{
|
||||
Name: "Valid minimal JWT",
|
||||
Token: MinimalValidJWT,
|
||||
ExpectStored: true,
|
||||
ExpectRetrieved: true,
|
||||
},
|
||||
{
|
||||
Name: "Valid standard JWT",
|
||||
Token: ValidAccessToken,
|
||||
ExpectStored: true,
|
||||
ExpectRetrieved: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidationTestCase represents a single token validation test case
|
||||
type ValidationTestCase struct {
|
||||
Name string
|
||||
Token string
|
||||
ExpectStored bool
|
||||
ExpectRetrieved bool
|
||||
}
|
||||
|
||||
// Helper functions for common test patterns
|
||||
|
||||
// AssertValidTokenStorage verifies that a valid token can be stored and retrieved
|
||||
func AssertValidTokenStorage(t TestingInterface, session *SessionData, token string) {
|
||||
session.SetAccessToken(token)
|
||||
retrieved := session.GetAccessToken()
|
||||
if retrieved != token {
|
||||
t.Errorf("Token storage failed: expected %q, got %q", token, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertInvalidTokenRejection verifies that an invalid token is rejected
|
||||
func AssertInvalidTokenRejection(t TestingInterface, session *SessionData, token string) {
|
||||
original := session.GetAccessToken()
|
||||
session.SetAccessToken(token)
|
||||
after := session.GetAccessToken()
|
||||
if after != original {
|
||||
t.Errorf("Invalid token was not rejected: expected %q, got %q", original, after)
|
||||
}
|
||||
}
|
||||
|
||||
// TestingInterface provides the minimal interface needed for testing
|
||||
type TestingInterface interface {
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
@@ -0,0 +1,502 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// TestTokenCorruptionScenario reproduces the exact failure pattern from GitHub issue #53:
|
||||
// Token verified successfully multiple times, then fails with "signature verification failed"
|
||||
func TestTokenCorruptionScenario(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a valid JWT token
|
||||
validJWT := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImV4cCI6OTk5OTk5OTk5OX0.signature"
|
||||
|
||||
tests := []struct {
|
||||
corruptionScenario func(*SessionData)
|
||||
name string
|
||||
tokenSize int
|
||||
iterations int
|
||||
expectConsistent bool
|
||||
}{
|
||||
{
|
||||
name: "Small token - multiple retrievals",
|
||||
tokenSize: len(validJWT),
|
||||
iterations: 10,
|
||||
expectConsistent: true,
|
||||
},
|
||||
{
|
||||
name: "Large chunked token - multiple retrievals",
|
||||
tokenSize: 5000,
|
||||
iterations: 10,
|
||||
expectConsistent: true,
|
||||
},
|
||||
{
|
||||
name: "Compression corruption simulation",
|
||||
tokenSize: 2000,
|
||||
iterations: 5,
|
||||
expectConsistent: false, // Will be corrupted intentionally
|
||||
corruptionScenario: func(session *SessionData) {
|
||||
// Simulate corruption by directly modifying session values
|
||||
if session.accessSession != nil {
|
||||
// Simulate corrupted compressed data
|
||||
session.accessSession.Values["token"] = "corrupted_base64_!@#$"
|
||||
session.accessSession.Values["compressed"] = true
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Chunk reassembly corruption simulation",
|
||||
tokenSize: 25000, // Large enough to force chunking even after compression
|
||||
iterations: 5,
|
||||
expectConsistent: false, // Will be corrupted intentionally
|
||||
corruptionScenario: func(session *SessionData) {
|
||||
// Simulate chunk corruption with invalid base64 characters
|
||||
if len(session.accessTokenChunks) > 0 {
|
||||
if chunk, exists := session.accessTokenChunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = "invalid_base64_!@#$%"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Create token of specified size
|
||||
token := createTokenOfSize(validJWT, tt.tokenSize)
|
||||
|
||||
// 1. Store the token
|
||||
session.SetAccessToken(token)
|
||||
t.Logf("Stored token of size %d bytes", len(token))
|
||||
|
||||
// 2. Verify token can be retrieved multiple times successfully
|
||||
var retrievedTokens []string
|
||||
for i := 0; i < tt.iterations; i++ {
|
||||
retrieved := session.GetAccessToken()
|
||||
retrievedTokens = append(retrievedTokens, retrieved)
|
||||
|
||||
if tt.expectConsistent && retrieved != token {
|
||||
t.Errorf("Iteration %d: Token mismatch, expected consistency", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Apply corruption scenario if specified
|
||||
if tt.corruptionScenario != nil {
|
||||
tt.corruptionScenario(session)
|
||||
}
|
||||
|
||||
// 4. Retrieve token after potential corruption
|
||||
finalRetrieved := session.GetAccessToken()
|
||||
|
||||
if tt.expectConsistent {
|
||||
// With fixes, token should still be retrievable correctly
|
||||
if finalRetrieved != token {
|
||||
t.Errorf("Final retrieval failed - corruption not handled correctly")
|
||||
t.Logf("Expected: %q", token)
|
||||
t.Logf("Got: %q", finalRetrieved)
|
||||
}
|
||||
} else {
|
||||
// For corruption scenarios, expect empty string (graceful failure)
|
||||
if finalRetrieved != "" {
|
||||
t.Errorf("Expected corruption to result in empty token, got: %q", finalRetrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Verify all previous retrievals were consistent (if expected)
|
||||
if tt.expectConsistent {
|
||||
for i, retrieved := range retrievedTokens {
|
||||
if retrieved != token {
|
||||
t.Errorf("Iteration %d produced inconsistent result", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompressionIntegrityFailure tests scenarios where compression fails integrity checks
|
||||
func TestCompressionIntegrityFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectSame bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig",
|
||||
expectSame: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT - wrong dots",
|
||||
token: "invalid.token",
|
||||
expectSame: true, // Should return unchanged
|
||||
},
|
||||
{
|
||||
name: "Oversized token",
|
||||
token: "header." + strings.Repeat("A", 60000) + ".sig",
|
||||
expectSame: true, // Should return unchanged due to size limit
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
compressed := compressToken(tt.token)
|
||||
|
||||
if tt.expectSame && compressed != tt.token {
|
||||
// If we expect the token to remain the same but it was compressed,
|
||||
// verify round-trip integrity
|
||||
decompressed := decompressToken(compressed)
|
||||
if decompressed != tt.token {
|
||||
t.Errorf("Compression integrity failed: original=%q, decompressed=%q", tt.token, decompressed)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChunkReassemblyEdgeCases tests edge cases in chunk reassembly that could cause corruption
|
||||
func TestChunkReassemblyEdgeCases(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Create a large token that will definitely be chunked
|
||||
largeToken := createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig", 8000)
|
||||
|
||||
// Store the token to create chunks
|
||||
session.SetAccessToken(largeToken)
|
||||
|
||||
if len(session.accessTokenChunks) == 0 {
|
||||
t.Skip("Token was not chunked, skipping reassembly tests")
|
||||
}
|
||||
|
||||
t.Logf("Token was split into %d chunks", len(session.accessTokenChunks))
|
||||
|
||||
// Test various corruption scenarios
|
||||
corruptionTests := []struct {
|
||||
corruption func(map[int]*sessions.Session)
|
||||
name string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Gap in chunk sequence",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
// Remove chunk 1 if it exists
|
||||
delete(chunks, 1)
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Chunk with nil value",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = nil
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Chunk with wrong type",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = 12345 // Should be string
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Empty chunk data",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
if chunk, exists := chunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = ""
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "Excessive chunk count",
|
||||
corruption: func(chunks map[int]*sessions.Session) {
|
||||
// This test simulates having too many chunks (>50 limit)
|
||||
// We'll create a scenario by adding many fake chunks
|
||||
for i := range 60 {
|
||||
fakeSession := &sessions.Session{Values: make(map[any]any)}
|
||||
fakeSession.Values["token_chunk"] = "fake_chunk_data"
|
||||
chunks[i] = fakeSession
|
||||
}
|
||||
},
|
||||
expectEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, ct := range corruptionTests {
|
||||
t.Run(ct.name, func(t *testing.T) {
|
||||
// Get a fresh session for each test
|
||||
freshReq := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
freshSession, err := sm.GetSession(freshReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get fresh session: %v", err)
|
||||
}
|
||||
defer freshSession.ReturnToPool()
|
||||
|
||||
// Store the large token again
|
||||
freshSession.SetAccessToken(largeToken)
|
||||
|
||||
// Apply corruption
|
||||
ct.corruption(freshSession.accessTokenChunks)
|
||||
|
||||
// Try to retrieve the token
|
||||
retrieved := freshSession.GetAccessToken()
|
||||
|
||||
if ct.expectEmpty {
|
||||
if retrieved != "" {
|
||||
t.Errorf("Expected empty token due to corruption, got: %q", retrieved)
|
||||
}
|
||||
} else {
|
||||
if retrieved != largeToken {
|
||||
t.Errorf("Expected original token, got: %q", retrieved)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRaceConditionProtection tests that concurrent access doesn't cause corruption
|
||||
func TestRaceConditionProtection(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
const numGoroutines = 20
|
||||
const numOperations = 50
|
||||
|
||||
// Create tokens of different sizes
|
||||
tokens := []string{
|
||||
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig1",
|
||||
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig2", 3000),
|
||||
createTokenOfSize("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig3", 6000),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := range numOperations {
|
||||
tokenIndex := (goroutineID + j) % len(tokens)
|
||||
expectedToken := tokens[tokenIndex]
|
||||
|
||||
// Set token
|
||||
session.SetAccessToken(expectedToken)
|
||||
|
||||
// Retrieve token
|
||||
retrieved := session.GetAccessToken()
|
||||
|
||||
// Verify it's a valid JWT (should have exactly 2 dots)
|
||||
if retrieved != "" && strings.Count(retrieved, ".") != 2 {
|
||||
errChan <- fmt.Errorf("goroutine %d, op %d: invalid JWT format in retrieved token: %q",
|
||||
goroutineID, j, retrieved)
|
||||
continue
|
||||
}
|
||||
|
||||
// The retrieved token should be one of the valid tokens we set
|
||||
// (due to concurrent access, it might not be the exact one we just set)
|
||||
isValidToken := slices.Contains(tokens, retrieved)
|
||||
|
||||
if retrieved != "" && !isValidToken {
|
||||
errChan <- fmt.Errorf("goroutine %d, op %d: retrieved unknown token: %q",
|
||||
goroutineID, j, retrieved)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errChan {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMemoryExhaustionProtection tests protection against memory exhaustion attacks
|
||||
func TestMemoryExhaustionProtection(t *testing.T) {
|
||||
tests := []struct {
|
||||
setupCorruption func() string
|
||||
name string
|
||||
expectRejection bool
|
||||
}{
|
||||
{
|
||||
name: "Extremely large compressed data",
|
||||
setupCorruption: func() string {
|
||||
return base64.StdEncoding.EncodeToString(bytes.Repeat([]byte("A"), 200*1024)) // 200KB
|
||||
},
|
||||
expectRejection: true,
|
||||
},
|
||||
{
|
||||
name: "Malformed gzip bomb attempt",
|
||||
setupCorruption: func() string {
|
||||
// Create data that looks like gzip but would decompress to huge size
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
gz.Write(bytes.Repeat([]byte("A"), 10*1024)) // 10KB that compresses well
|
||||
gz.Close()
|
||||
|
||||
compressed := buf.Bytes()
|
||||
// Modify to make it potentially dangerous
|
||||
return base64.StdEncoding.EncodeToString(compressed)
|
||||
},
|
||||
expectRejection: false, // Our decompression has size limits
|
||||
},
|
||||
{
|
||||
name: "Token with excessive chunk simulation",
|
||||
setupCorruption: func() string {
|
||||
// This will be tested in the session layer
|
||||
return strings.Repeat("chunk.", 100) + "final"
|
||||
},
|
||||
expectRejection: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
corruptedData := tt.setupCorruption()
|
||||
|
||||
result := decompressToken(corruptedData)
|
||||
|
||||
if tt.expectRejection {
|
||||
// Should return original corrupted data, not attempt decompression
|
||||
if result != corruptedData {
|
||||
t.Errorf("Expected rejection of dangerous data, but decompression was attempted")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no excessive memory was used (this test would catch OOM in practice)
|
||||
// The fact that we reach this point means memory limits were effective
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackwardCompatibility ensures that sessions created before the fixes still work
|
||||
func TestBackwardCompatibility(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Simulate old-style session data (without new validation fields)
|
||||
oldStyleToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.oldsig"
|
||||
|
||||
// Manually set token without going through new SetAccessToken validation
|
||||
session.accessSession.Values["token"] = oldStyleToken
|
||||
session.accessSession.Values["compressed"] = false
|
||||
|
||||
// Should still be retrievable
|
||||
retrieved := session.GetAccessToken()
|
||||
if retrieved != oldStyleToken {
|
||||
t.Errorf("Backward compatibility failed: expected %q, got %q", oldStyleToken, retrieved)
|
||||
}
|
||||
|
||||
// Test with simulated old compressed token
|
||||
oldCompressed := compressToken(oldStyleToken)
|
||||
session.accessSession.Values["token"] = oldCompressed
|
||||
session.accessSession.Values["compressed"] = true
|
||||
|
||||
retrieved2 := session.GetAccessToken()
|
||||
if retrieved2 != oldStyleToken {
|
||||
t.Errorf("Backward compatibility with compression failed: expected %q, got %q", oldStyleToken, retrieved2)
|
||||
}
|
||||
}
|
||||
|
||||
// createTokenOfSize creates a JWT token of approximately the specified size
|
||||
func createTokenOfSize(baseToken string, targetSize int) string {
|
||||
parts := strings.Split(baseToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return baseToken
|
||||
}
|
||||
|
||||
header, payload, signature := parts[0], parts[1], parts[2]
|
||||
currentSize := len(baseToken)
|
||||
|
||||
if currentSize >= targetSize {
|
||||
return baseToken
|
||||
}
|
||||
|
||||
// Expand the payload to reach target size
|
||||
paddingNeeded := targetSize - len(header) - len(signature) - 2 // Account for dots
|
||||
if paddingNeeded > 0 {
|
||||
// Decode current payload, add padding, re-encode
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// If we can't decode, just pad with random base64-safe characters to resist compression
|
||||
randomBytes := make([]byte, paddingNeeded)
|
||||
rand.Read(randomBytes)
|
||||
// Encode as base64 to make it base64-safe
|
||||
padData := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||||
payload = payload + padData
|
||||
} else {
|
||||
// Add padding to the JSON - use random data to resist compression
|
||||
randomBytes := make([]byte, paddingNeeded/2)
|
||||
rand.Read(randomBytes)
|
||||
// Encode as base64 to make it JSON-safe
|
||||
padData := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
newPayload := fmt.Sprintf(`{"original":%s,"padding":"%s"}`, string(decoded), padData)
|
||||
payload = base64.RawURLEncoding.EncodeToString([]byte(newPayload))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
@@ -0,0 +1,671 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestTokenTypeDistinction tests that AccessToken and IdToken are correctly distinguished in templates
|
||||
func TestTokenTypeDistinction(t *testing.T) {
|
||||
// Define test data where AccessToken and IdToken are deliberately different
|
||||
type templateData struct {
|
||||
Claims map[string]any
|
||||
AccessToken string
|
||||
IdToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
testData := templateData{
|
||||
AccessToken: "test-access-token-abc123",
|
||||
IdToken: "test-id-token-xyz789",
|
||||
RefreshToken: "test-refresh-token",
|
||||
Claims: map[string]any{
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
templateText string
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access Token Only",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
expectedValue: "Bearer test-access-token-abc123",
|
||||
},
|
||||
{
|
||||
name: "ID Token Only",
|
||||
templateText: "ID: {{.IdToken}}",
|
||||
expectedValue: "ID: test-id-token-xyz789",
|
||||
},
|
||||
{
|
||||
name: "Both Tokens",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IdToken}}",
|
||||
expectedValue: "Access: test-access-token-abc123 ID: test-id-token-xyz789",
|
||||
},
|
||||
{
|
||||
name: "Both Tokens in Authorization Format",
|
||||
templateText: "Bearer {{.AccessToken}} and Bearer {{.IdToken}}",
|
||||
expectedValue: "Bearer test-access-token-abc123 and Bearer test-id-token-xyz789",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute template: %v", err)
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
if result != tc.expectedValue {
|
||||
t.Errorf("Expected template output %q, got %q", tc.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenTypeIntegration tests the integration of ID and access tokens with the middleware
|
||||
func TestTokenTypeIntegration(t *testing.T) {
|
||||
// Create a TestSuite to use its helper methods and fields
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create different tokens for ID and access tokens
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateRandomString(16),
|
||||
"token_type": "id_token",
|
||||
"email": "user@example.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test ID JWT: %v", err)
|
||||
}
|
||||
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]any{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": float64(3000000000),
|
||||
"iat": float64(1000000000),
|
||||
"nbf": float64(1000000000),
|
||||
"sub": "test-subject",
|
||||
"jti": generateRandomString(16),
|
||||
"token_type": "access_token",
|
||||
"scope": "openid profile email",
|
||||
"email": "user@example.com", // Add email to access token so it's available in claims
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test access JWT: %v", err)
|
||||
}
|
||||
|
||||
// Define test headers that use both token types
|
||||
headers := []TemplatedHeader{
|
||||
{Name: "X-ID-Token", Value: "{{.IdToken}}"},
|
||||
{Name: "X-Access-Token", Value: "{{.AccessToken}}"},
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-Email-From-Claims", Value: "{{.Claims.email}}"},
|
||||
}
|
||||
|
||||
// Store intercepted headers for verification
|
||||
interceptedHeaders := make(map[string]string)
|
||||
|
||||
// Create a test next handler that captures the headers
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture headers for verification
|
||||
for _, header := range headers {
|
||||
if value := r.Header.Get(header.Name); value != "" {
|
||||
interceptedHeaders[header.Name] = value
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create the TraefikOidc instance
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: NewCache(),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: NewLogger("debug"),
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: ts.sessionManager,
|
||||
extractClaimsFunc: extractClaims,
|
||||
headerTemplates: make(map[string]*template.Template),
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc
|
||||
tOidc.jwtVerifier = tOidc
|
||||
|
||||
// Initialize and parse header templates
|
||||
for _, header := range headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse header template for %s: %v", header.Name, err)
|
||||
}
|
||||
tOidc.headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
// Close the initComplete channel to bypass the waiting
|
||||
close(tOidc.initComplete)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Create a session
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Setup the session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetIDToken(idToken) // Set the ID token
|
||||
session.SetAccessToken(accessToken) // Set the access token
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Add session cookies to the request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset the response recorder for the main test
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Process the request
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check status code
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
// Verify headers were set correctly
|
||||
expectedHeaders := map[string]string{
|
||||
"X-ID-Token": idToken,
|
||||
"X-Access-Token": accessToken,
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
"X-Email-From-Claims": "user@example.com",
|
||||
}
|
||||
|
||||
for name, expectedValue := range expectedHeaders {
|
||||
if value, exists := interceptedHeaders[name]; !exists {
|
||||
t.Errorf("Expected header %s was not set", name)
|
||||
} else if value != expectedValue {
|
||||
t.Errorf("Header %s expected value %q, got %q", name, expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionIDTokenAccessToken tests that the SessionData correctly stores and retrieves
|
||||
// both ID tokens and access tokens separately
|
||||
func TestSessionIDTokenAccessToken(t *testing.T) {
|
||||
// Create a logger for the session manager
|
||||
logger := NewLogger("debug")
|
||||
|
||||
// Create a session manager
|
||||
sessionManager, err := NewSessionManager("test-session-encryption-key-at-least-32-bytes", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get a session
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set test tokens using standardized tokens
|
||||
idToken := ValidIDToken
|
||||
accessToken := ValidAccessToken
|
||||
refreshToken := ValidRefreshToken
|
||||
|
||||
// Store tokens in session
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetRefreshToken(refreshToken)
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies from response
|
||||
cookies := rr.Result().Cookies()
|
||||
|
||||
// Create a new request with those cookies
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the session again
|
||||
session2, err := sessionManager.GetSession(req2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session from request with cookies: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the tokens were correctly stored and retrieved
|
||||
retrievedIDToken := session2.GetIDToken()
|
||||
retrievedAccessToken := session2.GetAccessToken()
|
||||
retrievedRefreshToken := session2.GetRefreshToken()
|
||||
|
||||
if retrievedIDToken != idToken {
|
||||
t.Errorf("ID token mismatch: expected %q, got %q", idToken, retrievedIDToken)
|
||||
}
|
||||
|
||||
if retrievedAccessToken != accessToken {
|
||||
t.Errorf("Access token mismatch: expected %q, got %q", accessToken, retrievedAccessToken)
|
||||
}
|
||||
|
||||
if retrievedRefreshToken != refreshToken {
|
||||
t.Errorf("Refresh token mismatch: expected %q, got %q", refreshToken, retrievedRefreshToken)
|
||||
}
|
||||
|
||||
// Verify that the tokens are distinct
|
||||
if retrievedIDToken == retrievedAccessToken {
|
||||
t.Errorf("ID token and Access token should be different, but both are %q", retrievedIDToken)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenCorruptionIntegrationFlows tests the complete token handling flow with corruption scenarios
|
||||
func TestTokenCorruptionIntegrationFlows(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
corruptAction func(*SessionData)
|
||||
name string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
idToken string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "Normal flow - small tokens",
|
||||
accessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.access_sig",
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "Normal flow - large tokens (chunked)",
|
||||
accessToken: createLargeValidJWT(5000),
|
||||
refreshToken: createLargeRefreshToken(3000),
|
||||
idToken: createLargeValidJWT(2000),
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "Corrupted access token compression",
|
||||
accessToken: createLargeValidJWT(3000),
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
expectSuccess: false,
|
||||
corruptAction: func(session *SessionData) {
|
||||
// Corrupt compressed access token
|
||||
if session.accessSession != nil {
|
||||
session.accessSession.Values["token"] = "corrupted_compressed_data_!@#"
|
||||
session.accessSession.Values["compressed"] = true
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Corrupted chunk in large token",
|
||||
accessToken: createLargeValidJWT(8000), // Force chunking
|
||||
refreshToken: "refresh_token_12345",
|
||||
idToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.id_sig",
|
||||
expectSuccess: false,
|
||||
corruptAction: func(session *SessionData) {
|
||||
// Corrupt first chunk
|
||||
if len(session.accessTokenChunks) > 0 {
|
||||
if chunk, exists := session.accessTokenChunks[0]; exists {
|
||||
chunk.Values["token_chunk"] = "corrupted_chunk_data"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get session
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Store tokens
|
||||
session.SetAccessToken(tt.accessToken)
|
||||
session.SetRefreshToken(tt.refreshToken)
|
||||
session.SetIDToken(tt.idToken)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Apply corruption if specified
|
||||
if tt.corruptAction != nil {
|
||||
tt.corruptAction(session)
|
||||
}
|
||||
|
||||
// Test token retrieval after corruption
|
||||
retrievedAccess := session.GetAccessToken()
|
||||
retrievedRefresh := session.GetRefreshToken()
|
||||
retrievedID := session.GetIDToken()
|
||||
|
||||
if tt.expectSuccess {
|
||||
if retrievedAccess != tt.accessToken {
|
||||
t.Errorf("Access token corruption: expected %q, got %q", tt.accessToken, retrievedAccess)
|
||||
}
|
||||
if retrievedRefresh != tt.refreshToken {
|
||||
t.Errorf("Refresh token corruption: expected %q, got %q", tt.refreshToken, retrievedRefresh)
|
||||
}
|
||||
if retrievedID != tt.idToken {
|
||||
t.Errorf("ID token corruption: expected %q, got %q", tt.idToken, retrievedID)
|
||||
}
|
||||
} else {
|
||||
// For corruption scenarios, access token should be empty (graceful failure)
|
||||
if retrievedAccess != "" {
|
||||
t.Errorf("Expected corrupted access token to return empty, got: %q", retrievedAccess)
|
||||
}
|
||||
// Other tokens should still work
|
||||
if retrievedRefresh != tt.refreshToken {
|
||||
t.Errorf("Refresh token should not be affected by access token corruption: expected %q, got %q",
|
||||
tt.refreshToken, retrievedRefresh)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPersistenceWithCorruption tests that session corruption is handled across requests
|
||||
func TestSessionPersistenceWithCorruption(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
// First request - store tokens
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
rr1 := httptest.NewRecorder()
|
||||
|
||||
session1, err := sm.GetSession(req1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
largeToken := createLargeValidJWT(6000)
|
||||
session1.SetAccessToken(largeToken)
|
||||
session1.SetAuthenticated(true)
|
||||
|
||||
if err := session1.Save(req1, rr1); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies from first response
|
||||
cookies := rr1.Result().Cookies()
|
||||
session1.ReturnToPool()
|
||||
|
||||
// Second request - retrieve tokens with cookies
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
req2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
session2, err := sm.GetSession(req2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session from cookies: %v", err)
|
||||
}
|
||||
defer session2.ReturnToPool()
|
||||
|
||||
// Verify token can be retrieved
|
||||
retrieved := session2.GetAccessToken()
|
||||
if retrieved != largeToken {
|
||||
t.Errorf("Token persistence failed: expected %q, got %q", largeToken, retrieved)
|
||||
}
|
||||
|
||||
// Simulate corruption by modifying chunks
|
||||
if len(session2.accessTokenChunks) > 0 {
|
||||
// Corrupt a middle chunk
|
||||
chunkIndex := len(session2.accessTokenChunks) / 2
|
||||
if chunk, exists := session2.accessTokenChunks[chunkIndex]; exists {
|
||||
chunk.Values["token_chunk"] = "corrupted"
|
||||
}
|
||||
|
||||
// Try to retrieve again - should detect corruption and return empty
|
||||
retrievedAfterCorruption := session2.GetAccessToken()
|
||||
if retrievedAfterCorruption != "" {
|
||||
t.Errorf("Expected corruption to be detected, but got token: %q", retrievedAfterCorruption)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentTokenOperationsWithCorruption tests concurrent access with intentional corruption
|
||||
func TestConcurrentTokenOperationsWithCorruption(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
const numGoroutines = 10
|
||||
const numOperations = 20
|
||||
|
||||
done := make(chan bool, numGoroutines)
|
||||
errorChan := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Start concurrent operations
|
||||
for i := range numGoroutines {
|
||||
go func(goroutineID int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := range numOperations {
|
||||
// Create a unique valid token for each operation
|
||||
token := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwib3AiOiIxMjMifQ.sig_%d_%d",
|
||||
goroutineID, j)
|
||||
|
||||
// Store token
|
||||
session.SetAccessToken(token)
|
||||
|
||||
// Retrieve token
|
||||
retrieved := session.GetAccessToken()
|
||||
|
||||
// Validate retrieved token format
|
||||
if retrieved != "" {
|
||||
if strings.Count(retrieved, ".") != 2 {
|
||||
errorChan <- fmt.Errorf("goroutine %d, op %d: invalid JWT format: %q",
|
||||
goroutineID, j, retrieved)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a reasonable length
|
||||
if len(retrieved) < 10 || len(retrieved) > 100000 {
|
||||
errorChan <- fmt.Errorf("goroutine %d, op %d: suspicious token length %d: %q",
|
||||
goroutineID, j, len(retrieved), retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// Occasionally simulate corruption to test error handling
|
||||
if j%5 == 0 && len(session.accessTokenChunks) > 0 {
|
||||
// Intentionally corrupt a random chunk
|
||||
for chunkID, chunk := range session.accessTokenChunks {
|
||||
if chunkID%2 == 0 {
|
||||
chunk.Values["token_chunk"] = "intentionally_corrupted"
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for range numGoroutines {
|
||||
<-done
|
||||
}
|
||||
close(errorChan)
|
||||
|
||||
// Check for any unexpected errors
|
||||
errorCount := 0
|
||||
for err := range errorChan {
|
||||
t.Logf("Concurrent operation error: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
// We expect some corruption-related "errors" due to intentional corruption,
|
||||
// but not format-related errors which would indicate actual corruption bugs
|
||||
if errorCount > numGoroutines*numOperations/4 { // Allow up to 25% corruption-related issues
|
||||
t.Errorf("Too many errors during concurrent operations: %d", errorCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenValidationEdgeCases tests edge cases in token validation
|
||||
func TestTokenValidationEdgeCases(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Use standardized test tokens
|
||||
testTokens := NewTestTokens()
|
||||
edgeCases := testTokens.TokenValidationTestCases()
|
||||
|
||||
for _, ec := range edgeCases {
|
||||
t.Run(ec.Name, func(t *testing.T) {
|
||||
// Clear any previous token
|
||||
session.SetAccessToken("")
|
||||
|
||||
// Store the test token
|
||||
originalToken := session.GetAccessToken()
|
||||
session.SetAccessToken(ec.Token)
|
||||
afterStoreToken := session.GetAccessToken()
|
||||
|
||||
if ec.ExpectStored {
|
||||
if afterStoreToken != ec.Token {
|
||||
t.Errorf("Expected token to be stored, but got different value")
|
||||
}
|
||||
} else {
|
||||
if afterStoreToken != originalToken {
|
||||
t.Errorf("Expected invalid token to be rejected, but it was stored")
|
||||
}
|
||||
}
|
||||
|
||||
// Test retrieval
|
||||
finalToken := session.GetAccessToken()
|
||||
if ec.ExpectRetrieved {
|
||||
if finalToken != ec.Token {
|
||||
t.Errorf("Expected token to be retrievable: %q, got: %q", ec.Token, finalToken)
|
||||
}
|
||||
} else {
|
||||
if finalToken != "" {
|
||||
t.Errorf("Expected empty token due to invalid format, got: %q", finalToken)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for test data creation
|
||||
|
||||
// createLargeValidJWT creates a JWT of approximately the specified size
|
||||
func createLargeValidJWT(targetSize int) string {
|
||||
header := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
signature := "signature_" + generateRandomString(32)
|
||||
|
||||
// Calculate required payload size
|
||||
usedSize := len(header) + len(signature) + 2 // account for dots
|
||||
payloadSize := max(targetSize-usedSize, 50)
|
||||
|
||||
// Create a payload with realistic JWT claims
|
||||
claims := map[string]any{
|
||||
"sub": "user123",
|
||||
"iss": "https://example.com",
|
||||
"aud": "client123",
|
||||
"exp": 9999999999,
|
||||
"iat": 1000000000,
|
||||
"data": generateRandomString(payloadSize - 100), // Account for other claims
|
||||
}
|
||||
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
|
||||
// createLargeRefreshToken creates a refresh token of approximately the specified size
|
||||
func createLargeRefreshToken(targetSize int) string {
|
||||
baseToken := "refresh_token_"
|
||||
padding := generateRandomString(targetSize - len(baseToken))
|
||||
return baseToken + padding
|
||||
}
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
|
||||
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
|
||||
+1
-5
@@ -1,7 +1,4 @@
|
||||
# Gorilla Sessions
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
|
||||
# sessions
|
||||
|
||||

|
||||
[](https://codecov.io/github/gorilla/sessions)
|
||||
@@ -77,7 +74,6 @@ Other implementations of the `sessions.Store` interface:
|
||||
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
|
||||
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
|
||||
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
|
||||
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
|
||||
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
|
||||
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
|
||||
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
|
||||
|
||||
+9
-12
@@ -1,6 +1,5 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//go:build !go1.11
|
||||
// +build !go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
@@ -9,15 +8,13 @@ import "net/http"
|
||||
// newCookieFromOptions returns an http.Cookie with the options set.
|
||||
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
Partitioned: options.Partitioned,
|
||||
SameSite: options.SameSite,
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
//go:build go1.11
|
||||
// +build go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// newCookieFromOptions returns an http.Cookie with the options set.
|
||||
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
SameSite: options.SameSite,
|
||||
}
|
||||
|
||||
}
|
||||
+5
-10
@@ -1,11 +1,8 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//go:build !go1.11
|
||||
// +build !go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Options stores configuration for a session or session store.
|
||||
//
|
||||
// Fields are a subset of http.Cookie fields.
|
||||
@@ -16,9 +13,7 @@ type Options struct {
|
||||
// deleted after the browser session ends.
|
||||
// MaxAge<0 means delete cookie immediately.
|
||||
// MaxAge>0 means Max-Age attribute present and given in seconds.
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
Partitioned bool
|
||||
SameSite http.SameSite
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
}
|
||||
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
//go:build go1.11
|
||||
// +build go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Options stores configuration for a session or session store.
|
||||
//
|
||||
// Fields are a subset of http.Cookie fields.
|
||||
type Options struct {
|
||||
Path string
|
||||
Domain string
|
||||
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
|
||||
// deleted after the browser session ends.
|
||||
// MaxAge<0 means delete cookie immediately.
|
||||
// MaxAge>0 means Max-Age attribute present and given in seconds.
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
// Defaults to http.SameSiteDefaultMode
|
||||
SameSite http.SameSite
|
||||
}
|
||||
Vendored
+2
-2
@@ -4,8 +4,8 @@ github.com/google/uuid
|
||||
# github.com/gorilla/securecookie v1.1.2
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/securecookie
|
||||
# github.com/gorilla/sessions v1.4.0
|
||||
## explicit; go 1.23
|
||||
# github.com/gorilla/sessions v1.3.0
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/sessions
|
||||
# golang.org/x/time v0.7.0
|
||||
## explicit; go 1.18
|
||||
|
||||
Reference in New Issue
Block a user