mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ce2815123 | |||
| 7d204113ea | |||
| c721913cbe | |||
| 0f8b7f7ab1 | |||
| 2743b0e024 | |||
| e6fc36937b | |||
| df051e0cfb | |||
| 5d5ce8ae5e | |||
| d194cd778a | |||
| 803a1e5e21 | |||
| 3ad8fb4518 | |||
| 9402f1bca5 | |||
| e6205b3a48 | |||
| fdb8e3233e | |||
| 33c71fd6fe | |||
| 241cb1c209 | |||
| 09daa1025c | |||
| c09e7a9228 | |||
| e5da5d4fe9 | |||
| 31db701dda | |||
| 16481afd36 | |||
| 751933ffa0 | |||
| e74153b107 | |||
| 025107fe3e | |||
| dfb9c0771e | |||
| 1107df40e7 | |||
| bf294569eb | |||
| 482c346840 | |||
| a462e44896 | |||
| 5eff0dc866 | |||
| dfc534a400 | |||
| 061c12d0a3 | |||
| 4c4fff3613 | |||
| 0dcb44c187 | |||
| cbe773d96a | |||
| 40254888d7 | |||
| ef41870c81 | |||
| 081c32925a | |||
| 17dea67229 | |||
| 8512ad6d68 | |||
| 5aa838c669 | |||
| 6f359e5ef1 | |||
| bd18d6041c | |||
| 74c620ad51 | |||
| 7e3dc46b6e | |||
| 147aa0b169 | |||
| eecb7dfc92 | |||
| a8d65688c4 | |||
| bef4212c57 | |||
| 1fee2f9e9a | |||
| 11bc6f3e31 |
+219
-18
@@ -4,28 +4,229 @@ type: middleware
|
||||
import: github.com/lukaszraczylo/traefikoidc
|
||||
|
||||
summary: |
|
||||
Middleware adding OIDC authentication to traefik routes. Does what it says on the tin.
|
||||
Middleware has been tested with Auth0 and Logto. It should work with any OIDC provider.
|
||||
Middleware adding OpenID Connect (OIDC) authentication to Traefik routes.
|
||||
|
||||
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy.
|
||||
It provides a complete OIDC authentication solution with features like domain restrictions,
|
||||
role-based access control, token caching, and more.
|
||||
|
||||
The middleware has been tested with Auth0, Logto, Google, and other standard OIDC providers.
|
||||
It supports various authentication scenarios including:
|
||||
|
||||
- Basic authentication with customizable callback and logout URLs
|
||||
- Email domain restrictions to limit access to specific organizations
|
||||
- Role and group-based access control
|
||||
- Public URLs that bypass authentication
|
||||
- Rate limiting to prevent brute force attacks
|
||||
- Custom post-logout redirect behavior
|
||||
- Secure session management with encrypted cookies
|
||||
- Automatic token validation and refresh
|
||||
|
||||
testData:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
# Required parameters
|
||||
providerURL: https://accounts.google.com # Base URL of the OIDC provider
|
||||
clientID: 1234567890.apps.googleusercontent.com # OAuth 2.0 client identifier
|
||||
clientSecret: secret # OAuth 2.0 client secret
|
||||
callbackURL: /oauth2/callback # Path where the OIDC provider will redirect after authentication
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long # Key used to encrypt session data (must be at least 32 bytes)
|
||||
|
||||
# Optional parameters with defaults
|
||||
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
|
||||
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
|
||||
|
||||
scopes: # OAuth 2.0 scopes to request (default: ["openid", "email", "profile"])
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
|
||||
- raczylo.com
|
||||
allowedRolesAndGroups:
|
||||
- roles # Include this to get role information from the provider
|
||||
|
||||
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
|
||||
allowedRolesAndGroups: # Restricts access to users with specific roles or groups (if not provided, no role/group restrictions)
|
||||
- guest-endpoints
|
||||
sessionEncryptionKey: potato-secret
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- admin
|
||||
- developer
|
||||
|
||||
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
|
||||
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
|
||||
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
|
||||
|
||||
excludedURLs: # Lists paths that bypass authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
- /public
|
||||
- /health
|
||||
- /metrics
|
||||
|
||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
|
||||
# Configuration documentation
|
||||
configuration:
|
||||
providerURL:
|
||||
type: string
|
||||
description: |
|
||||
The base URL of the OIDC provider. This is the issuer URL that will be used to discover
|
||||
OIDC endpoints like authorization, token, and JWKS URIs.
|
||||
|
||||
Examples:
|
||||
- https://accounts.google.com
|
||||
- https://login.microsoftonline.com/tenant-id/v2.0
|
||||
- https://your-auth0-domain.auth0.com
|
||||
- https://your-logto-instance.com/oidc
|
||||
required: true
|
||||
|
||||
clientID:
|
||||
type: string
|
||||
description: |
|
||||
The OAuth 2.0 client identifier obtained from your OIDC provider.
|
||||
This is the public identifier for your application.
|
||||
required: true
|
||||
|
||||
clientSecret:
|
||||
type: string
|
||||
description: |
|
||||
The OAuth 2.0 client secret obtained from your OIDC provider.
|
||||
This should be kept confidential and not exposed in client-side code.
|
||||
|
||||
For Kubernetes deployments, you can use the secret reference format:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
required: true
|
||||
|
||||
callbackURL:
|
||||
type: string
|
||||
description: |
|
||||
The path where the OIDC provider will redirect after authentication.
|
||||
This must match one of the redirect URIs configured in your OIDC provider.
|
||||
|
||||
The full redirect URI will be constructed as:
|
||||
[scheme]://[host][callbackURL]
|
||||
|
||||
Example: /oauth2/callback
|
||||
required: true
|
||||
|
||||
sessionEncryptionKey:
|
||||
type: string
|
||||
description: |
|
||||
Key used to encrypt session data stored in cookies.
|
||||
Must be at least 32 bytes long for security.
|
||||
|
||||
Example: potato-secret-is-at-least-32-bytes-long
|
||||
required: true
|
||||
|
||||
logoutURL:
|
||||
type: string
|
||||
description: |
|
||||
The path for handling logout requests.
|
||||
If not provided, it will be set to callbackURL + "/logout".
|
||||
|
||||
Example: /oauth2/logout
|
||||
required: false
|
||||
|
||||
postLogoutRedirectURI:
|
||||
type: string
|
||||
description: |
|
||||
The URL to redirect to after logout.
|
||||
Default: "/"
|
||||
|
||||
Example: /logged-out-page
|
||||
required: false
|
||||
|
||||
scopes:
|
||||
type: array
|
||||
description: |
|
||||
The OAuth 2.0 scopes to request from the OIDC provider.
|
||||
Default: ["openid", "profile", "email"]
|
||||
|
||||
Include "roles" or similar scope if you need role/group information.
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
logLevel:
|
||||
type: string
|
||||
description: |
|
||||
Sets the logging verbosity.
|
||||
Valid values: "debug", "info", "error"
|
||||
Default: "info"
|
||||
required: false
|
||||
enum:
|
||||
- debug
|
||||
- info
|
||||
- error
|
||||
|
||||
forceHTTPS:
|
||||
type: boolean
|
||||
description: |
|
||||
Forces the use of HTTPS for all URLs.
|
||||
This is recommended for security in production environments.
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
rateLimit:
|
||||
type: integer
|
||||
description: |
|
||||
Sets the maximum number of requests per second.
|
||||
This helps prevent brute force attacks.
|
||||
Default: 100
|
||||
Minimum: 10
|
||||
required: false
|
||||
|
||||
excludedURLs:
|
||||
type: array
|
||||
description: |
|
||||
Lists paths that bypass authentication.
|
||||
These paths will be accessible without OIDC authentication.
|
||||
|
||||
The middleware uses prefix matching, so "/public" will match
|
||||
"/public", "/public/page", "/public-data", etc.
|
||||
|
||||
Examples: ["/health", "/metrics", "/public"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
allowedUserDomains:
|
||||
type: array
|
||||
description: |
|
||||
Restricts access to users with email addresses from specific domains.
|
||||
If not provided, the middleware relies entirely on the OIDC provider
|
||||
for authentication decisions.
|
||||
|
||||
Examples: ["company.com", "subsidiary.com"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
allowedRolesAndGroups:
|
||||
type: array
|
||||
description: |
|
||||
Restricts access to users with specific roles or groups.
|
||||
If not provided, no role/group restrictions are applied.
|
||||
|
||||
The middleware checks both the "roles" and "groups" claims in the ID token.
|
||||
|
||||
Examples: ["admin", "developer"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
The endpoint for revoking tokens.
|
||||
If not provided, it will be discovered from provider metadata.
|
||||
|
||||
Example: https://accounts.google.com/revoke
|
||||
required: false
|
||||
|
||||
oidcEndSessionURL:
|
||||
type: string
|
||||
description: |
|
||||
The provider's end session endpoint.
|
||||
If not provided, it will be discovered from provider metadata.
|
||||
|
||||
Example: https://accounts.google.com/logout
|
||||
required: false
|
||||
|
||||
@@ -1,151 +1,283 @@
|
||||
## Traefik OIDC middleware
|
||||
# Traefik OIDC Middleware
|
||||
|
||||
This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy to support the OIDC authentication.
|
||||
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy to support OpenID Connect (OIDC) authentication.
|
||||
|
||||
Middleware has been tested with Auth0 and Logto.
|
||||
## Overview
|
||||
|
||||
### Traefik version compatibility
|
||||
The Traefik OIDC middleware provides a complete OIDC authentication solution with features like:
|
||||
- Token validation and verification
|
||||
- Session management
|
||||
- Domain restrictions
|
||||
- Role-based access control
|
||||
- Token caching and blacklisting
|
||||
- Rate limiting
|
||||
- Excluded paths (public URLs)
|
||||
|
||||
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
|
||||
The middleware has been tested with Auth0 and Logto, but should work with any standard OIDC provider.
|
||||
|
||||
### Configuration options
|
||||
## Traefik Version Compatibility
|
||||
|
||||
Middleware currently supports following scenarios:
|
||||
This middleware follows closely the current Traefik helm chart versions. If the plugin fails to load, it's time to update to the latest version of the Traefik helm chart.
|
||||
|
||||
* Setting custom callback and logout URLs via `callbackURL` and `logoutURL`
|
||||
* Allowing for access only from the listed domains if `allowedUserDomains` is set, otherwise it relies entirely on the OIDC provider
|
||||
* Using excluded URLs which do **NOT** require the OIDC authentication
|
||||
* Rate limiting requests to prevent the bruteforce attacks
|
||||
## Installation
|
||||
|
||||
#### How to configure...
|
||||
### As a Traefik Plugin
|
||||
|
||||
##### Keeping secrets secret
|
||||
|
||||
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
|
||||
1. Enable the plugin in your Traefik static configuration:
|
||||
|
||||
```yaml
|
||||
# traefik.yml
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.2.1 # Use the latest version
|
||||
```
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
|
||||
### Local Development with Docker Compose
|
||||
|
||||
For local development or testing, you can use the provided Docker Compose setup:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
This will start Traefik with the OIDC middleware and two test services.
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The middleware supports the following configuration options:
|
||||
|
||||
### Required Parameters
|
||||
|
||||
| Parameter | Description | Example |
|
||||
|-----------|-------------|---------|
|
||||
| `providerURL` | The base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | The OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | The OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | Key used to encrypt session data (must be at least 32 bytes long) | `potato-secret-is-at-least-32-bytes-long` |
|
||||
| `callbackURL` | The path where the OIDC provider will redirect after authentication | `/oauth2/callback` |
|
||||
|
||||
### Optional Parameters
|
||||
|
||||
| Parameter | Description | Default | Example |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `logoutURL` | The path for handling logout requests | `callbackURL + "/logout"` | `/oauth2/logout` |
|
||||
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
||||
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-basic
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### With Excluded URLs (Public Access Paths)
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs:
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /public-data
|
||||
- /health
|
||||
- /metrics
|
||||
```
|
||||
|
||||
### With Email Domain Restrictions
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-domain-restricted
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
```
|
||||
|
||||
### With Role-Based Access Control
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-rbac
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Include this to get role information from the provider
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
```
|
||||
|
||||
### With Custom Logging and Rate Limiting
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-custom-settings
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
logLevel: debug # Options: debug, info, error (default: info)
|
||||
rateLimit: 500 # Requests per second (default: 100)
|
||||
forceHTTPS: false # Default is true for security
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### With Custom Post-Logout Redirect
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-custom-logout
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /logged-out-page # Where to redirect after logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### Keeping Secrets Secret in Kubernetes
|
||||
|
||||
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-secrets
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
|
||||
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```
|
||||
|
||||
##### Excluded URLs with open access
|
||||
Don't forget to create the secret:
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```bash
|
||||
kubectl create secret generic traefik-middleware-oidc \
|
||||
--from-literal=ISSUER=https://accounts.google.com \
|
||||
--from-literal=CLIENT_ID=1234567890.apps.googleusercontent.com \
|
||||
--from-literal=SECRET=your-client-secret \
|
||||
-n traefik
|
||||
```
|
||||
|
||||
## Complete Docker Compose Example
|
||||
|
||||
##### Allowed email domains
|
||||
|
||||
Assuming that your OIDC provider allows anyone to log in, you may want to limit the access to people using emains in specific domain.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-only-my-users
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /new-oidc/callback
|
||||
logoutURL: /new-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
##### Allowed groups and roles
|
||||
|
||||
In case of multiple roles / groups and access separation for various endpoints you will need to create multiple traefik middlewares.
|
||||
Following example allows access for users who have additional role `guest-endpoints` assigned.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-guest-endpoints
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /my-oidc/callback
|
||||
logoutURL: /my-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # This line queries the OIDC provider for roles
|
||||
forceHTTPS: true
|
||||
allowedRolesAndGroups:
|
||||
- guest-endpoints # This line specifies the roles or groups allowed to access content
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
#### Docker compose example
|
||||
|
||||
`docker-compose.yaml`
|
||||
Here's a complete example of using the middleware with Docker Compose:
|
||||
|
||||
```yaml
|
||||
version: "3.7"
|
||||
|
||||
services:
|
||||
traefik:
|
||||
image: traefik:v3.0.1
|
||||
image: traefik:v3.2.1
|
||||
command:
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.2.1"
|
||||
@@ -156,7 +288,6 @@ services:
|
||||
labels:
|
||||
- "traefik.http.routers.dash.rule=Host(`dash.localhost`)"
|
||||
- "traefik.http.routers.dash.service=api@internal"
|
||||
|
||||
ports:
|
||||
- "80:80"
|
||||
|
||||
@@ -179,8 +310,7 @@ services:
|
||||
- traefik.http.routers.whoami.middlewares=my-plugin@file
|
||||
```
|
||||
|
||||
`traefik-config/traefik.yaml`
|
||||
|
||||
`traefik-config/traefik.yml`:
|
||||
```yaml
|
||||
log:
|
||||
level: INFO
|
||||
@@ -209,7 +339,7 @@ providers:
|
||||
filename: /etc/traefik/dynamic-configuration.yml
|
||||
```
|
||||
|
||||
`traefik-config/dynamic-configuration.yaml`
|
||||
`traefik-config/dynamic-configuration.yml`:
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
@@ -218,20 +348,78 @@ http:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: secret
|
||||
clientSecret: your-client-secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
postLogoutRedirectURI: /logged-out-page
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
|
||||
- raczylo.com
|
||||
sessionEncryptionKey: potato-secret
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
logLevel: debug
|
||||
rateLimit: 100
|
||||
excludedURLs:
|
||||
- /login
|
||||
- /public
|
||||
- /health
|
||||
- /metrics
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Session Management
|
||||
|
||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||
|
||||
### Token Caching and Blacklisting
|
||||
|
||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||
|
||||
### Headers Set for Downstream Services
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
|
||||
### Security Headers
|
||||
|
||||
The middleware also sets the following security headers:
|
||||
|
||||
- `X-Frame-Options: DENY`
|
||||
- `X-Content-Type-Options: nosniff`
|
||||
- `X-XSS-Protection: 1; mode=block`
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Logging
|
||||
|
||||
Set the `logLevel` to `debug` to get more detailed logs:
|
||||
|
||||
```yaml
|
||||
logLevel: debug
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Token verification failed**: Check that your `providerURL` is correct and accessible.
|
||||
2. **Session encryption key too short**: Ensure your `sessionEncryptionKey` is at least 32 bytes long.
|
||||
3. **No matching public key found**: The JWKS endpoint might be unavailable or the token's key ID (kid) doesn't match any key in the JWKS.
|
||||
4. **Access denied: Your email domain is not allowed**: The user's email domain is not in the `allowedUserDomains` list.
|
||||
5. **Access denied: You do not have any of the allowed roles or groups**: The user doesn't have any of the roles or groups specified in `allowedRolesAndGroups`.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
### TODO / wishlist
|
||||
|
||||
- [] Improve test coverage
|
||||
- [x] Improve caching mechanism
|
||||
- [x] Add automatic release and semver generation
|
||||
@@ -0,0 +1,18 @@
|
||||
package traefikoidc
|
||||
|
||||
import "time"
|
||||
|
||||
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
|
||||
// It stops when a value is received on the stop channel.
|
||||
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanup()
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAutoCleanupRoutine(t *testing.T) {
|
||||
var counter int32
|
||||
cleanupFunc := func() {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
}
|
||||
stop := make(chan struct{})
|
||||
go autoCleanupRoutine(50*time.Millisecond, stop, cleanupFunc)
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
close(stop)
|
||||
|
||||
if atomic.LoadInt32(&counter) < 3 {
|
||||
t.Errorf("Expected cleanup to be called at least 3 times, got %d", counter)
|
||||
}
|
||||
}
|
||||
+110
@@ -0,0 +1,110 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
|
||||
type TokenBlacklist struct {
|
||||
tokens map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new token blacklist instance.
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
tokens: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist with an expiration time.
|
||||
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
// Clean up expired tokens if we're at capacity
|
||||
if len(b.tokens) >= 1000 {
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
for t, exp := range b.tokens {
|
||||
if now.After(exp) || futureThreshold.After(exp) {
|
||||
delete(b.tokens, t)
|
||||
}
|
||||
}
|
||||
|
||||
// If still at capacity, remove oldest token
|
||||
if len(b.tokens) >= 1000 {
|
||||
var oldestToken string
|
||||
var oldestTime time.Time
|
||||
first := true
|
||||
for t, exp := range b.tokens {
|
||||
if first || exp.Before(oldestTime) {
|
||||
oldestToken = t
|
||||
oldestTime = exp
|
||||
first = false
|
||||
}
|
||||
}
|
||||
if oldestToken != "" {
|
||||
delete(b.tokens, oldestToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.tokens[token] = expiry
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is in the blacklist and not expired.
|
||||
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
|
||||
expiry, exists := b.tokens[token]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// If token is expired, remove it and return false
|
||||
if time.Now().After(expiry) {
|
||||
// Switch to write lock to remove expired token
|
||||
b.mutex.RUnlock()
|
||||
b.mutex.Lock()
|
||||
delete(b.tokens, token)
|
||||
b.mutex.Unlock()
|
||||
b.mutex.RLock()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist.
|
||||
// Also removes tokens that will expire within the next minute to prevent edge cases.
|
||||
func (b *TokenBlacklist) Cleanup() {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
|
||||
for token, expiry := range b.tokens {
|
||||
// Remove tokens that are expired or will expire soon
|
||||
if now.After(expiry) || futureThreshold.After(expiry) {
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove removes a token from the blacklist regardless of its expiration.
|
||||
func (b *TokenBlacklist) Remove(token string) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
|
||||
// Count returns the current number of tokens in the blacklist.
|
||||
func (b *TokenBlacklist) Count() int {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
return len(b.tokens)
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBlacklist_Add(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
|
||||
if blacklist.IsBlacklisted("nonExistentToken") {
|
||||
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Cleanup(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(-time.Hour) // Expired token
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Cleanup()
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Remove(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Remove(token)
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be removed, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Count(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token1 := "token1"
|
||||
token2 := "token2"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token1, expiry)
|
||||
blacklist.Add(token2, expiry)
|
||||
|
||||
if blacklist.Count() != 2 {
|
||||
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
|
||||
}
|
||||
}
|
||||
@@ -1,69 +1,187 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheItem represents an item in the cache
|
||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
// Value is the cached data of any type.
|
||||
Value interface{}
|
||||
|
||||
// ExpiresAt is the timestamp when this item should be considered expired.
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Cache is a simple in-memory cache
|
||||
// lruEntry represents an entry in the LRU list.
|
||||
type lruEntry struct {
|
||||
key string
|
||||
}
|
||||
|
||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
|
||||
type Cache struct {
|
||||
// items stores the cached data with string keys.
|
||||
items map[string]CacheItem
|
||||
|
||||
// order maintains the usage order; most recently used items are at the back.
|
||||
order *list.List
|
||||
|
||||
// elems maps keys to their corresponding list elements for O(1) access.
|
||||
elems map[string]*list.Element
|
||||
|
||||
// mutex protects concurrent access to the cache.
|
||||
mutex sync.RWMutex
|
||||
|
||||
// maxSize is the maximum number of items allowed in the cache.
|
||||
maxSize int
|
||||
// autoCleanupInterval defines how often Cleanup is called automatically.
|
||||
autoCleanupInterval time.Duration
|
||||
// stopCleanup channel to terminate the auto cleanup goroutine.
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// NewCache creates a new Cache
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 500
|
||||
|
||||
// NewCache creates a new empty cache instance that is ready for use.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
items: make(map[string]CacheItem),
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
go c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set adds an item to the cache
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
// It moves the item to the most recently used position.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
expTime := now.Add(expiration)
|
||||
|
||||
// Update existing item.
|
||||
if _, exists := c.items[key]; exists {
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: expTime,
|
||||
}
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.MoveToBack(elem)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest item if cache is full.
|
||||
if len(c.items) >= c.maxSize {
|
||||
c.evictOldest()
|
||||
}
|
||||
|
||||
// Add new item.
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
ExpiresAt: expTime,
|
||||
}
|
||||
elem := c.order.PushBack(lruEntry{key: key})
|
||||
c.elems[key] = elem
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||
// Moving the accessed item to the most recently used position.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
item, found := c.items[key]
|
||||
if !found {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check for expiration.
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
c.removeItem(key)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Move item to the back (most recently used).
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.MoveToBack(elem)
|
||||
}
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
// Delete removes an item from the cache.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.items, key)
|
||||
|
||||
c.removeItem(key)
|
||||
}
|
||||
|
||||
// Cleanup removes expired items from the cache
|
||||
// Cleanup removes all expired items from the cache. This should be called periodically
|
||||
// to prevent memory bloat from expired entries.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
// Remove items that are expired or within 10% of expiration
|
||||
if now.After(item.ExpiresAt) || now.Add(time.Duration(float64(item.ExpiresAt.Sub(now))*0.1)).After(item.ExpiresAt) {
|
||||
c.removeItem(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used item from the cache.
|
||||
func (c *Cache) evictOldest() {
|
||||
now := time.Now()
|
||||
elem := c.order.Front()
|
||||
|
||||
// First try to find an expired item from the front
|
||||
for elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
if item, exists := c.items[entry.key]; exists {
|
||||
if now.After(item.ExpiresAt) {
|
||||
c.removeItem(entry.key)
|
||||
return
|
||||
}
|
||||
}
|
||||
elem = elem.Next()
|
||||
}
|
||||
|
||||
// If no expired items found, remove the oldest item
|
||||
if elem = c.order.Front(); elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
c.removeItem(entry.key)
|
||||
}
|
||||
}
|
||||
|
||||
// removeItem removes an item from both the cache and the LRU tracking structures.
|
||||
func (c *Cache) removeItem(key string) {
|
||||
delete(c.items, key)
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
c.order.Remove(elem)
|
||||
delete(c.elems, key)
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
|
||||
func (c *Cache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close terminates the auto cleanup goroutine.
|
||||
func (c *Cache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
+306
@@ -0,0 +1,306 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
t.Run("Basic Set and Get", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Test Set
|
||||
cache.Set(key, value, expiration)
|
||||
|
||||
// Test Get
|
||||
got, found := cache.Get(key)
|
||||
if !found {
|
||||
t.Error("Expected to find key in cache")
|
||||
}
|
||||
if got != value {
|
||||
t.Errorf("Expected value %v, got %v", value, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Expiration", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
expiration := 10 * time.Millisecond
|
||||
|
||||
// Set with short expiration
|
||||
cache.Set(key, value, expiration)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Should not find expired key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key to be expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Set and then delete
|
||||
cache.Set(key, value, expiration)
|
||||
cache.Delete(key)
|
||||
|
||||
// Should not find deleted key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Cleanup", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
// Add multiple items with different expirations
|
||||
cache.Set("expired1", "value1", 10*time.Millisecond)
|
||||
cache.Set("expired2", "value2", 10*time.Millisecond)
|
||||
cache.Set("valid", "value3", 1*time.Second)
|
||||
|
||||
// Wait for some items to expire
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Run cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Check expired items are removed
|
||||
_, found1 := cache.Get("expired1")
|
||||
_, found2 := cache.Get("expired2")
|
||||
_, found3 := cache.Get("valid")
|
||||
|
||||
if found1 {
|
||||
t.Error("Expected expired1 to be cleaned up")
|
||||
}
|
||||
if found2 {
|
||||
t.Error("Expected expired2 to be cleaned up")
|
||||
}
|
||||
if !found3 {
|
||||
t.Error("Expected valid item to remain in cache")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Concurrent Access", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
done := make(chan bool)
|
||||
|
||||
// Start multiple goroutines to access cache concurrently
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
key := "key"
|
||||
value := "value"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Perform multiple operations
|
||||
cache.Set(key, value, expiration)
|
||||
cache.Get(key)
|
||||
cache.Delete(key)
|
||||
cache.Cleanup()
|
||||
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Zero Expiration", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
|
||||
// Set with zero expiration
|
||||
cache.Set(key, value, 0)
|
||||
|
||||
// Should not find the key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key with zero expiration to be immediately expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Negative Expiration", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
|
||||
// Set with negative expiration
|
||||
cache.Set(key, value, -1*time.Second)
|
||||
|
||||
// Should not find the key
|
||||
_, found := cache.Get(key)
|
||||
if found {
|
||||
t.Error("Expected key with negative expiration to be immediately expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Update Existing Key", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
key := "test-key"
|
||||
value1 := "value1"
|
||||
value2 := "value2"
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Set initial value
|
||||
cache.Set(key, value1, expiration)
|
||||
|
||||
// Update value
|
||||
cache.Set(key, value2, expiration)
|
||||
|
||||
// Check updated value
|
||||
got, found := cache.Get(key)
|
||||
if !found {
|
||||
t.Error("Expected to find key in cache")
|
||||
}
|
||||
if got != value2 {
|
||||
t.Errorf("Expected updated value %v, got %v", value2, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Different Value Types", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Test with different value types
|
||||
testCases := []struct {
|
||||
key string
|
||||
value interface{}
|
||||
}{
|
||||
{"string", "test"},
|
||||
{"int", 42},
|
||||
{"float", 3.14},
|
||||
{"bool", true},
|
||||
{"slice", []string{"a", "b", "c"}},
|
||||
{"map", map[string]int{"a": 1, "b": 2}},
|
||||
{"struct", struct{ Name string }{"test"}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.key, func(t *testing.T) {
|
||||
cache.Set(tc.key, tc.value, expiration)
|
||||
got, found := cache.Get(tc.key)
|
||||
if !found {
|
||||
t.Error("Expected to find key in cache")
|
||||
}
|
||||
// Use reflect.DeepEqual for comparing complex types like slices and maps
|
||||
if !reflect.DeepEqual(got, tc.value) {
|
||||
t.Errorf("Expected value %v, got %v", tc.value, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenCache(t *testing.T) {
|
||||
t.Run("Basic Operations", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
token := "test-token"
|
||||
claims := map[string]interface{}{
|
||||
"sub": "1234567890",
|
||||
"name": "John Doe",
|
||||
"admin": true,
|
||||
}
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Test Set and Get
|
||||
tc.Set(token, claims, expiration)
|
||||
gotClaims, found := tc.Get(token)
|
||||
if !found {
|
||||
t.Error("Expected to find token in cache")
|
||||
}
|
||||
if len(gotClaims) != len(claims) {
|
||||
t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims))
|
||||
}
|
||||
for k, v := range claims {
|
||||
if gotClaims[k] != v {
|
||||
t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Test Delete
|
||||
tc.Delete(token)
|
||||
_, found = tc.Get(token)
|
||||
if found {
|
||||
t.Error("Expected token to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Expiration", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
token := "test-token"
|
||||
claims := map[string]interface{}{"sub": "1234567890"}
|
||||
expiration := 10 * time.Millisecond
|
||||
|
||||
// Set with short expiration
|
||||
tc.Set(token, claims, expiration)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Should not find expired token
|
||||
_, found := tc.Get(token)
|
||||
if found {
|
||||
t.Error("Expected token to be expired")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Cleanup", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
|
||||
// Add multiple tokens with different expirations
|
||||
tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond)
|
||||
tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond)
|
||||
tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second)
|
||||
|
||||
// Wait for some tokens to expire
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Run cleanup
|
||||
tc.Cleanup()
|
||||
|
||||
// Check expired tokens are removed
|
||||
_, found1 := tc.Get("expired1")
|
||||
_, found2 := tc.Get("expired2")
|
||||
_, found3 := tc.Get("valid")
|
||||
|
||||
if found1 {
|
||||
t.Error("Expected expired1 to be cleaned up")
|
||||
}
|
||||
if found2 {
|
||||
t.Error("Expected expired2 to be cleaned up")
|
||||
}
|
||||
if !found3 {
|
||||
t.Error("Expected valid token to remain in cache")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Token Prefix", func(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
token := "test-token"
|
||||
claims := map[string]interface{}{"sub": "1234567890"}
|
||||
expiration := 1 * time.Second
|
||||
|
||||
// Set token
|
||||
tc.Set(token, claims, expiration)
|
||||
|
||||
// Verify internal storage uses prefix
|
||||
_, found := tc.cache.Get("t-" + token)
|
||||
if !found {
|
||||
t.Error("Expected to find prefixed token in underlying cache")
|
||||
}
|
||||
})
|
||||
}
|
||||
+93
-98
@@ -8,25 +8,16 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// generateNonce generates a random nonce
|
||||
// generateNonce creates a cryptographically secure random nonce
|
||||
// for use in the OIDC authentication flow. The nonce is used to
|
||||
// prevent replay attacks by ensuring the token received matches
|
||||
// the authentication request.
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -36,7 +27,33 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// exchangeTokens exchanges a code or refresh token for tokens
|
||||
// TokenResponse represents the response from the OIDC token endpoint.
|
||||
// It contains the various tokens and metadata returned after successful
|
||||
// code exchange or token refresh operations.
|
||||
type TokenResponse struct {
|
||||
// IDToken is the OIDC ID token containing user claims
|
||||
IDToken string `json:"id_token"`
|
||||
|
||||
// AccessToken is the OAuth 2.0 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
|
||||
// ExpiresIn is the lifetime in seconds of the access token
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// TokenType is the type of token, typically "Bearer"
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
|
||||
// It supports both authorization code and refresh token grant types.
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||
// - codeOrToken: Either the authorization code or refresh token
|
||||
// - redirectURL: The callback URL for authorization code grant
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
@@ -51,13 +68,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
data.Set("refresh_token", codeOrToken)
|
||||
}
|
||||
|
||||
// Create a cookie jar for this request to handle redirects with cookies
|
||||
jar, _ := cookiejar.New(nil)
|
||||
client := &http.Client{
|
||||
Transport: t.httpClient.Transport,
|
||||
Timeout: t.httpClient.Timeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := t.httpClient.Do(req)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
|
||||
}
|
||||
@@ -76,16 +108,8 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from the token endpoint
|
||||
type TokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken refreshes the token using the refresh token
|
||||
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
|
||||
// This is used to refresh access tokens before they expire.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
@@ -94,24 +118,31 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token response: %+v", tokenResponse)
|
||||
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// handleExpiredToken handles the case when a token has expired
|
||||
// handleExpiredToken manages token expiration by clearing the session
|
||||
// and initiating a new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear the existing session
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to clear session: %v", err)
|
||||
// Clear authentication data but preserve CSRF state
|
||||
session.SetAuthenticated(false)
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
|
||||
// Save the cleared session state
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save cleared session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize new authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback handles the callback from the OIDC provider
|
||||
// handleCallback processes the authentication callback from the OIDC provider.
|
||||
// It validates the callback parameters, exchanges the authorization code for
|
||||
// tokens, verifies the tokens, and establishes the user's session.
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -122,7 +153,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the query parameters
|
||||
// Check for errors in the callback
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
@@ -130,7 +161,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state parameter matches the session's CSRF token
|
||||
// Validate CSRF state
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
@@ -166,7 +197,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Verify and process tokens
|
||||
// Verify tokens and claims
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
@@ -180,7 +211,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce
|
||||
// Verify nonce to prevent replay attacks
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
@@ -201,7 +232,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Process email
|
||||
// Validate user's email domain
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
@@ -209,13 +240,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Update session with new values
|
||||
// Update session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
@@ -231,7 +261,8 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// extractClaims extracts claims from a JWT token
|
||||
// extractClaims parses a JWT token and extracts its claims.
|
||||
// It handles base64url decoding and JSON parsing of the token payload.
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -251,65 +282,29 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// TokenBlacklist maintains a blacklist of tokens
|
||||
type TokenBlacklist struct {
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new TokenBlacklist
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is blacklisted
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
expiration, exists := tb.blacklist[tokenID]
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for tokenID, expiration := range tb.blacklist {
|
||||
if now.After(expiration) {
|
||||
delete(tb.blacklist, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCache caches tokens
|
||||
// TokenCache provides a caching mechanism for validated tokens.
|
||||
// It stores token claims to avoid repeated validation of the
|
||||
// same token, improving performance for frequently used tokens.
|
||||
type TokenCache struct {
|
||||
// cache is the underlying cache implementation
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache
|
||||
// NewTokenCache creates a new TokenCache instance.
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a token in the cache
|
||||
// Set stores a token's claims in the cache with an expiration time.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
|
||||
// Get retrieves a token from the cache
|
||||
// Get retrieves a token's claims from the cache.
|
||||
// Returns the claims and a boolean indicating if the token was found.
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
@@ -320,18 +315,18 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache
|
||||
// Delete removes a token from the cache.
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
}
|
||||
|
||||
// Cleanup cleans up expired tokens from the cache
|
||||
// Cleanup removes expired tokens from the cache.
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens
|
||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
@@ -341,7 +336,8 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*To
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings
|
||||
// createStringMap creates a map from a slice of strings.
|
||||
// Used for efficient lookups in allowed domains and roles.
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
@@ -350,7 +346,9 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout handles the logout request
|
||||
// handleLogout manages the OIDC logout process.
|
||||
// It clears the session and redirects either to the OIDC provider's
|
||||
// end session endpoint (if available) or to the configured post-logout URL.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -359,22 +357,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the access token before clearing session
|
||||
accessToken := session.GetAccessToken()
|
||||
|
||||
// Clear all session data
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the base URL for redirects
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
// Determine post logout redirect URI
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
@@ -382,7 +376,6 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
// If we have an end session endpoint and an access token, use OIDC end session
|
||||
if t.endSessionURL != "" && accessToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
@@ -394,11 +387,14 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, redirect to post logout URI
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL
|
||||
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
|
||||
// Parameters:
|
||||
// - endSessionURL: The OIDC provider's end session endpoint
|
||||
// - idToken: The ID token to be invalidated
|
||||
// - postLogoutRedirectURI: Where to redirect after logout completes
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
@@ -408,7 +404,6 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
if postLogoutRedirectURI != "" {
|
||||
// Ensure postLogoutRedirectURI is properly URL encoded
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
+227
@@ -0,0 +1,227 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBlacklistSizeLimit(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add tokens up to maxSize
|
||||
for i := 0; i < 1000; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
// Verify size is at max
|
||||
if tb.Count() != 1000 {
|
||||
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Add one more token, should trigger cleanup/eviction
|
||||
tb.Add("newtoken", time.Now().Add(time.Hour))
|
||||
|
||||
// Size should still be at max
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add some expired tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
|
||||
}
|
||||
|
||||
// Add some valid tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
// Force cleanup
|
||||
tb.Cleanup()
|
||||
|
||||
// Only valid tokens should remain
|
||||
if tb.Count() != 500 {
|
||||
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Verify only valid tokens remain
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
for token, expiry := range tb.tokens {
|
||||
if time.Now().After(expiry) {
|
||||
t.Errorf("Found expired token after cleanup: %s", token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistOldestEviction(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add tokens at capacity with different expiration times
|
||||
baseTime := time.Now()
|
||||
oldestToken := "oldest"
|
||||
|
||||
// Add oldest token first
|
||||
tb.Add(oldestToken, baseTime.Add(time.Hour))
|
||||
|
||||
// Fill up to capacity with newer tokens
|
||||
for i := 0; i < 999; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
|
||||
}
|
||||
|
||||
// Add a new token that should evict the oldest
|
||||
newToken := "newest"
|
||||
tb.Add(newToken, baseTime.Add(time.Hour*3))
|
||||
|
||||
// Verify oldest token was evicted
|
||||
if tb.IsBlacklisted(oldestToken) {
|
||||
t.Error("Oldest token should have been evicted")
|
||||
}
|
||||
|
||||
// Verify newest token is present
|
||||
if !tb.IsBlacklisted(newToken) {
|
||||
t.Error("Newest token should be present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistMemoryUsage(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
iterations := 10000
|
||||
|
||||
// Force initial GC
|
||||
runtime.GC()
|
||||
|
||||
// Record initial memory stats
|
||||
var m1, m2 runtime.MemStats
|
||||
runtime.ReadMemStats(&m1)
|
||||
|
||||
// Simulate heavy usage
|
||||
for i := 0; i < iterations; i++ {
|
||||
// Add new token
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
|
||||
// Periodically check blacklisted status
|
||||
if i%100 == 0 {
|
||||
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
|
||||
}
|
||||
|
||||
// Periodically cleanup
|
||||
if i%1000 == 0 {
|
||||
tb.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// Force GC and wait for it to complete
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
runtime.ReadMemStats(&m2)
|
||||
|
||||
// Check memory growth (using HeapAlloc for more accurate measurement)
|
||||
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
|
||||
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
|
||||
|
||||
if memoryGrowth > maxAllowedGrowth {
|
||||
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
|
||||
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
|
||||
}
|
||||
|
||||
// Verify size stayed within limits
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
iterations := 1000
|
||||
concurrency := 10
|
||||
done := make(chan bool)
|
||||
|
||||
// Start multiple goroutines performing operations
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Add tokens
|
||||
token := fmt.Sprintf("token%d-%d", id, j)
|
||||
tb.Add(token, time.Now().Add(time.Hour))
|
||||
|
||||
// Check blacklist status
|
||||
tb.IsBlacklisted(token)
|
||||
|
||||
// Periodic cleanup
|
||||
if j%100 == 0 {
|
||||
tb.Cleanup()
|
||||
}
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify size constraints were maintained
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCacheMemoryUsage(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
iterations := 10000
|
||||
|
||||
// Force initial GC
|
||||
runtime.GC()
|
||||
|
||||
// Record initial memory stats
|
||||
var m1, m2 runtime.MemStats
|
||||
runtime.ReadMemStats(&m1)
|
||||
|
||||
// Simulate heavy cache usage
|
||||
for i := 0; i < iterations; i++ {
|
||||
claims := map[string]interface{}{
|
||||
"sub": fmt.Sprintf("user%d", i),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
// Add to cache
|
||||
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
|
||||
|
||||
// Periodically retrieve
|
||||
if i%100 == 0 {
|
||||
tc.Get(fmt.Sprintf("token%d", i-50))
|
||||
}
|
||||
|
||||
// Periodically cleanup
|
||||
if i%1000 == 0 {
|
||||
tc.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// Force GC and wait for it to complete
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
runtime.ReadMemStats(&m2)
|
||||
|
||||
// Check memory growth (using HeapAlloc for more accurate measurement)
|
||||
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
|
||||
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
|
||||
|
||||
if memoryGrowth > maxAllowedGrowth {
|
||||
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
|
||||
t.Errorf("Excessive cache memory growth: %d bytes", memoryGrowth)
|
||||
}
|
||||
|
||||
// Verify cache size stayed within limits
|
||||
if len(tc.cache.items) > tc.cache.maxSize {
|
||||
t.Errorf("Cache exceeded max size: %d", len(tc.cache.items))
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,21 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
@@ -29,25 +28,24 @@ type JWK struct {
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JWKs
|
||||
type JWKSet struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache caches the JWKs
|
||||
type JWKCache struct {
|
||||
jwks *JWKSet
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
// CacheLifetime is configurable to determine how long the JWKS is cached.
|
||||
CacheLifetime time.Duration
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the interface for the JWK cache
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// GetJWKS gets the JWKS, either from cache or by fetching it
|
||||
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
defer c.mutex.RUnlock()
|
||||
@@ -57,25 +55,43 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
return c.jwks, nil
|
||||
}
|
||||
|
||||
jwks, err := fetchJWKS(jwksURL, httpClient)
|
||||
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.jwks = jwks
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
lifetime := c.CacheLifetime
|
||||
if lifetime == 0 {
|
||||
lifetime = 1 * time.Hour
|
||||
}
|
||||
c.expiresAt = time.Now().Add(lifetime)
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// fetchJWKS fetches the JWKS from the provider
|
||||
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
resp, err := httpClient.Get(jwksURL)
|
||||
func (c *JWKCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if c.jwks != nil && now.After(c.expiresAt) {
|
||||
c.jwks = nil
|
||||
}
|
||||
}
|
||||
|
||||
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Create a request with context to enforce timeout
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||
}
|
||||
@@ -93,7 +109,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JWK to PEM format
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -109,7 +124,6 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JWK to PEM
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -141,7 +155,6 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC JWK to PEM
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,19 +4,33 @@ import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
||||
"math/big"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token
|
||||
var replayCacheMu sync.Mutex
|
||||
var replayCache = make(map[string]time.Time)
|
||||
|
||||
func cleanupReplayCache() {
|
||||
now := time.Now()
|
||||
for token, expiry := range replayCache {
|
||||
if expiry.Before(now) {
|
||||
delete(replayCache, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClockSkewTolerance is configurable to adjust time-based validations.
|
||||
var ClockSkewTolerance = 2 * time.Minute
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
type JWT struct {
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
@@ -24,7 +38,7 @@ type JWT struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -35,7 +49,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
Token: tokenString,
|
||||
}
|
||||
|
||||
// Decode and unmarshal the header
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
@@ -44,7 +57,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
|
||||
// Decode and unmarshal the claims
|
||||
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
@@ -53,7 +65,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
// Decode the signature
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
@@ -63,8 +74,23 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify verifies the standard claims in the JWT
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'alg' header")
|
||||
}
|
||||
supportedAlgs := map[string]bool{
|
||||
"RS256": true, "RS384": true, "RS512": true,
|
||||
"PS256": true, "PS384": true, "PS512": true,
|
||||
"ES256": true, "ES384": true, "ES512": true,
|
||||
}
|
||||
if !supportedAlgs[alg] {
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
claims := j.Claims
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
@@ -99,6 +125,38 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if err := verifyNotBefore(nbf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Implement replay protection by checking the jti (JWT ID)
|
||||
if jti, ok := claims["jti"].(string); ok {
|
||||
// Skip replay detection for tokens that are being verified from the cache
|
||||
if j.Token == "" {
|
||||
// This is a parsed JWT without the original token string,
|
||||
// which means it's likely from a cached token verification
|
||||
return nil
|
||||
}
|
||||
|
||||
replayCacheMu.Lock()
|
||||
cleanupReplayCache()
|
||||
if _, exists := replayCache[jti]; exists {
|
||||
replayCacheMu.Unlock()
|
||||
return fmt.Errorf("token replay detected")
|
||||
}
|
||||
expFloat, ok := claims["exp"].(float64)
|
||||
var expTime time.Time
|
||||
if ok {
|
||||
expTime = time.Unix(int64(expFloat), 0)
|
||||
} else {
|
||||
expTime = time.Now().Add(10 * time.Minute)
|
||||
}
|
||||
replayCache[jti] = expTime
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
|
||||
sub, ok := claims["sub"].(string)
|
||||
if !ok || sub == "" {
|
||||
return fmt.Errorf("missing or empty 'sub' claim")
|
||||
@@ -106,8 +164,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience verifies the audience claim
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
@@ -131,62 +187,80 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer verifies the issuer claim
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyExpiration checks if the token has expired
|
||||
// verifyTimeConstraint is a generic function to verify time-based claims
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
|
||||
// For expiration (future=true), we add skew to now (making now later)
|
||||
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
|
||||
skewDirection := 1
|
||||
if !future {
|
||||
skewDirection = -1
|
||||
}
|
||||
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
|
||||
|
||||
if claimTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For expiration: if skewedNow (later) is after expiration, token expired
|
||||
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
|
||||
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
|
||||
var reason string
|
||||
if future {
|
||||
reason = "has expired"
|
||||
} else {
|
||||
if claimName == "iat" {
|
||||
reason = "used before issued"
|
||||
} else {
|
||||
reason = "not yet valid"
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
if time.Now().After(expirationTime) {
|
||||
return fmt.Errorf("token has expired")
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(expiration, "exp", true)
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks if the token was issued in the future
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
if time.Now().Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued")
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(issuedAt, "iat", false)
|
||||
}
|
||||
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
return verifyTimeConstraint(notBefore, "nbf", false)
|
||||
}
|
||||
|
||||
// verifySignature verifies the token signature using the provided public key and algorithm
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
|
||||
// Decode the signature from the token
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
// Decode the PEM-encoded public key
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
// Parse the public key
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
// Determine the hash function to use based on the algorithm
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
hashFunc = crypto.SHA256
|
||||
@@ -197,27 +271,20 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// Hash the signed content
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
|
||||
// Verify the signature based on the key type and algorithm
|
||||
switch pubKey := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
// RSA PKCS#1 v1.5 signature
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
// RSA PSS signature
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature
|
||||
var r, s big.Int
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
|
||||
@@ -10,13 +10,48 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"runtime"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC
|
||||
func createDefaultHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Timeout: time.Second * 15, // Reduced timeout
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const ConstSessionTimeout = 86400 // Session timeout in seconds
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
@@ -38,6 +73,7 @@ type TraefikOidc struct {
|
||||
issuerURL string
|
||||
revocationURL string
|
||||
jwkCache JWKCacheInterface
|
||||
metadataCache *MetadataCache
|
||||
tokenBlacklist *TokenBlacklist
|
||||
jwksURL string
|
||||
clientID string
|
||||
@@ -61,7 +97,6 @@ type TraefikOidc struct {
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initComplete chan struct{}
|
||||
endSessionURL string
|
||||
baseURL string
|
||||
postLogoutRedirectURI string
|
||||
sessionManager *SessionManager
|
||||
}
|
||||
@@ -81,21 +116,50 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon": {},
|
||||
}
|
||||
|
||||
var newTicker = time.NewTicker
|
||||
|
||||
var (
|
||||
globalMetadataCache struct {
|
||||
sync.Once
|
||||
metadata *ProviderMetadata
|
||||
err error
|
||||
}
|
||||
)
|
||||
|
||||
// VerifyToken verifies the provided JWT token
|
||||
// VerifyToken implements the TokenVerifier interface to verify an OIDC token.
|
||||
// It performs a complete verification process including:
|
||||
// 1. Checking the token cache to avoid redundant verifications
|
||||
// 2. Performing rate limiting and blacklist checks
|
||||
// 3. Parsing the JWT structure
|
||||
// 4. Verifying the JWT signature against the JWKS from the provider
|
||||
// 5. Validating standard JWT claims (iss, aud, exp, etc.)
|
||||
// 6. Caching the verified token for future requests
|
||||
//
|
||||
// Returns nil if the token is valid, or an error describing the validation failure.
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
// Check cache first
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
t.logger.Debugf("Token found in cache with valid claims; skipping verification")
|
||||
return nil
|
||||
}
|
||||
|
||||
t.logger.Debugf("Verifying token")
|
||||
|
||||
// Rate limiting
|
||||
// Perform pre-verification checks
|
||||
if err := t.performPreVerificationChecks(token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and standard claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the verified token
|
||||
t.cacheVerifiedToken(token, jwt.Claims)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// performPreVerificationChecks performs rate limiting and blacklist checks
|
||||
func (t *TraefikOidc) performPreVerificationChecks(token string) error {
|
||||
// Enforce rate limiting
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
@@ -105,30 +169,15 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
return fmt.Errorf("token is blacklisted")
|
||||
}
|
||||
|
||||
// Check if token is cached
|
||||
if _, exists := t.tokenCache.Get(token); exists {
|
||||
t.logger.Debugf("Token is valid and cached")
|
||||
return nil // Token is valid and cached
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the token until it expires
|
||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||
// cacheVerifiedToken caches a verified token until its expiration time
|
||||
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
|
||||
expirationTime := time.Unix(int64(claims["exp"].(float64)), 0)
|
||||
now := time.Now()
|
||||
duration := expirationTime.Sub(now)
|
||||
t.tokenCache.Set(token, jwt.Claims, duration)
|
||||
|
||||
return nil
|
||||
t.tokenCache.Set(token, claims, duration)
|
||||
}
|
||||
|
||||
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
|
||||
@@ -136,7 +185,7 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
t.logger.Debugf("Verifying JWT signature and claims")
|
||||
|
||||
// Get JWKS
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
@@ -182,34 +231,53 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new instance of the OIDC middleware
|
||||
// New creates a new instance of the OIDC middleware.
|
||||
// This is the main entry point for the middleware and is called by Traefik when loading the plugin.
|
||||
// It initializes all components needed for OIDC authentication:
|
||||
// - Session management for storing user state
|
||||
// - Token caching and blacklisting
|
||||
// - JWK caching for signature verification
|
||||
// - Rate limiting to prevent abuse
|
||||
// - Metadata discovery for OIDC provider endpoints
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for initialization operations
|
||||
// - next: The next handler in the middleware chain
|
||||
// - config: Configuration options for the middleware
|
||||
// - name: Identifier for this middleware instance
|
||||
//
|
||||
// Returns:
|
||||
// - An http.Handler that implements the middleware
|
||||
// - An error if initialization fails
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
// Setup HTTP client
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
if config == nil {
|
||||
config = CreateConfig()
|
||||
}
|
||||
|
||||
// Generate default session encryption key if not provided
|
||||
if config.SessionEncryptionKey == "" {
|
||||
// Generate a fixed key for Traefik Hub testing
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
logger := NewLogger(config.LogLevel)
|
||||
// Ensure key meets minimum length requirement
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
logger.Infof("Session encryption key is too short; using default key for analyzer")
|
||||
} else {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
}
|
||||
// Setup HTTP client
|
||||
var httpClient *http.Client
|
||||
if config.HTTPClient != nil {
|
||||
httpClient = config.HTTPClient
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Timeout: time.Second * 30,
|
||||
Transport: transport,
|
||||
}
|
||||
httpClient = createDefaultHTTPClient()
|
||||
}
|
||||
|
||||
t := &TraefikOidc{
|
||||
@@ -230,6 +298,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}(),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
metadataCache: NewMetadataCache(),
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
@@ -237,14 +306,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
httpClient: httpClient,
|
||||
logger: NewLogger(config.LogLevel),
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
}
|
||||
// Assign the initialized logger
|
||||
t.logger = logger
|
||||
|
||||
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
@@ -266,26 +336,58 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
// initializeMetadata discovers and initializes the provider metadata
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
globalMetadataCache.Once.Do(func() {
|
||||
t.logger.Debug("Starting global provider metadata discovery")
|
||||
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
|
||||
globalMetadataCache.metadata = metadata
|
||||
globalMetadataCache.err = err
|
||||
})
|
||||
t.logger.Debug("Starting provider metadata discovery")
|
||||
|
||||
if globalMetadataCache.err != nil {
|
||||
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
|
||||
} else if globalMetadataCache.metadata != nil {
|
||||
t.logger.Debug("Using cached provider metadata")
|
||||
t.jwksURL = globalMetadataCache.metadata.JWKSURL
|
||||
t.authURL = globalMetadataCache.metadata.AuthURL
|
||||
t.tokenURL = globalMetadataCache.metadata.TokenURL
|
||||
t.issuerURL = globalMetadataCache.metadata.Issuer
|
||||
t.revocationURL = globalMetadataCache.metadata.RevokeURL
|
||||
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
|
||||
// Get metadata from cache or fetch it
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to get provider metadata: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
close(t.initComplete)
|
||||
if metadata != nil {
|
||||
t.logger.Debug("Successfully initialized provider metadata")
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
|
||||
// Start metadata refresh goroutine
|
||||
go t.startMetadataRefresh(providerURL)
|
||||
|
||||
// Only close channel on success
|
||||
close(t.initComplete)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Error("Received nil metadata")
|
||||
}
|
||||
|
||||
// updateMetadataEndpoints updates the middleware with metadata endpoints
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
}
|
||||
|
||||
// startMetadataRefresh periodically refreshes the OIDC metadata
|
||||
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
t.logger.Debug("Refreshing OIDC metadata")
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to refresh metadata: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
t.logger.Debug("Successfully refreshed metadata")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// discoverProviderMetadata fetches the OIDC provider metadata
|
||||
@@ -350,19 +452,34 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the main handler for the middleware
|
||||
// ServeHTTP is the main handler for the middleware that processes all HTTP requests.
|
||||
// It implements the http.Handler interface and performs the following operations:
|
||||
// 1. Waits for OIDC provider metadata initialization to complete
|
||||
// 2. Checks if the requested URL is in the excluded list (bypassing authentication)
|
||||
// 3. Retrieves or creates a user session
|
||||
// 4. Handles special paths like callback and logout URLs
|
||||
// 5. Verifies authentication status and token validity
|
||||
// 6. Refreshes tokens that are about to expire
|
||||
// 7. Validates user email domains, roles, and groups against configured restrictions
|
||||
// 8. Sets appropriate headers for downstream services
|
||||
// 9. Applies security headers to responses
|
||||
// 10. Forwards the authenticated request to the next handler
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
if t.issuerURL == "" {
|
||||
t.logger.Debug("OIDC middleware not yet initialized")
|
||||
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
|
||||
t.logger.Error("OIDC provider metadata initialization failed")
|
||||
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
case <-req.Context().Done():
|
||||
t.logger.Debug("Request cancelled")
|
||||
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
t.logger.Error("Timeout waiting for OIDC initialization")
|
||||
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if URL is excluded
|
||||
@@ -375,7 +492,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
|
||||
// Obtain a new session and clear any residual session cookies
|
||||
session, _ = t.sessionManager.GetSession(req)
|
||||
session.Clear(req, rw)
|
||||
|
||||
// Build redirect URL
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Initiate authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -430,9 +558,65 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed roles and groups
|
||||
if len(t.allowedRolesAndGroups) > 0 {
|
||||
allowed := false
|
||||
for _, roleOrGroup := range append(groups, roles...) {
|
||||
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set user information in headers
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
// Set OIDC-specific headers
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetAccessToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
|
||||
// Set security headers
|
||||
rw.Header().Set("X-Frame-Options", "DENY")
|
||||
rw.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
rw.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Set CORS headers
|
||||
origin := req.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
rw.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
|
||||
// Handle preflight requests
|
||||
if req.Method == "OPTIONS" {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Process the request
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
@@ -451,9 +635,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
|
||||
// determineScheme determines the scheme (http or https) of the request
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
if t.forceHTTPS {
|
||||
return "https"
|
||||
}
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
@@ -471,7 +652,17 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// isUserAuthenticated checks if the user is authenticated
|
||||
// isUserAuthenticated checks if the user is authenticated by validating their session and token.
|
||||
// It performs a comprehensive check of the authentication state including:
|
||||
// 1. Verifying the session's authenticated flag
|
||||
// 2. Checking for the presence of an access token
|
||||
// 3. Validating the token's signature and claims
|
||||
// 4. Checking the token's expiration time
|
||||
//
|
||||
// Returns three boolean values:
|
||||
// - authenticated: Whether the user is currently authenticated
|
||||
// - needsRefresh: Whether the token is valid but will expire soon (within grace period)
|
||||
// - expired: Whether the token has expired or is otherwise invalid
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("User is not authenticated according to session")
|
||||
@@ -519,20 +710,35 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// defaultInitiateAuthentication initiates the authentication process
|
||||
// defaultInitiateAuthentication initiates the OIDC authentication process.
|
||||
// This function prepares and starts a new authentication flow by:
|
||||
// 1. Generating security tokens (CSRF token and nonce) to prevent attacks
|
||||
// 2. Clearing any existing session data to avoid state conflicts
|
||||
// 3. Storing the original request path to redirect back after authentication
|
||||
// 4. Building the authorization URL with all required OIDC parameters
|
||||
// 5. Redirecting the user to the OIDC provider's authorization endpoint
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer for sending the redirect
|
||||
// - req: The original HTTP request that triggered authentication
|
||||
// - session: The user's session data for storing authentication state
|
||||
// - redirectURL: The callback URL where the OIDC provider will redirect after authentication
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Generate CSRF token and nonce
|
||||
csrfToken := uuid.New().String()
|
||||
csrfToken := uuid.NewString()
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session values
|
||||
// Clear any existing session data to avoid stale state causing redirect loops
|
||||
session.Clear(req, rw)
|
||||
|
||||
// Set new session values
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
session.SetIncomingPath(req.URL.Path)
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
@@ -541,12 +747,15 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
return
|
||||
}
|
||||
|
||||
// Build and redirect to auth URL
|
||||
// Build and redirect to authentication URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// verifyToken verifies the token using the token verifier
|
||||
// verifyToken verifies the token using the token verifier interface.
|
||||
// This function delegates to the configured token verifier implementation,
|
||||
// which by default is the TraefikOidc instance itself (implementing the VerifyToken method).
|
||||
// This design allows for easy mocking in tests and potential future extension.
|
||||
func (t *TraefikOidc) verifyToken(token string) error {
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
@@ -562,17 +771,38 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
if len(t.scopes) > 0 {
|
||||
params.Set("scope", strings.Join(t.scopes, " "))
|
||||
}
|
||||
return t.authURL + "?" + params.Encode()
|
||||
|
||||
return t.buildURLWithParams(t.authURL, params)
|
||||
}
|
||||
|
||||
// buildURLWithParams ensures a URL is absolute and appends query parameters
|
||||
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
// Ensure URL is absolute
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
// Extract issuer base URL
|
||||
issuerURL, err := url.Parse(t.issuerURL)
|
||||
if err == nil {
|
||||
return fmt.Sprintf("%s://%s%s?%s",
|
||||
issuerURL.Scheme,
|
||||
issuerURL.Host,
|
||||
baseURL,
|
||||
params.Encode())
|
||||
}
|
||||
}
|
||||
return baseURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts the token cleanup goroutine
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := newTicker(1 * time.Minute)
|
||||
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
t.logger.Debug("Cleaning up token cache")
|
||||
t.logger.Debug("Starting token cleanup cycle")
|
||||
t.tokenCache.Cleanup()
|
||||
t.tokenBlacklist.Cleanup()
|
||||
t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go
|
||||
// Removed runtime.GC() call
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
+554
-10
@@ -1,6 +1,7 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -66,21 +68,29 @@ func (ts *TestSuite) Setup() {
|
||||
}
|
||||
|
||||
// Create a test JWT token signed with the RSA private key
|
||||
// Create timestamps with proper clock skew
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
||||
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
||||
|
||||
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
ts.t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
logger := NewLogger("info")
|
||||
ts.sessionManager = NewSessionManager("test-secret-key", false, logger)
|
||||
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
|
||||
// Common TraefikOidc instance
|
||||
ts.tOidc = &TraefikOidc{
|
||||
@@ -121,10 +131,16 @@ type MockJWKCache struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) Cleanup() {
|
||||
// Mock cleanup implementation
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
}
|
||||
|
||||
// Helper function to create a JWT token
|
||||
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
||||
header := map[string]interface{}{
|
||||
@@ -211,7 +227,7 @@ func TestVerifyToken(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset token blacklist and cache
|
||||
// Reset token blacklist and cache for each test
|
||||
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
|
||||
ts.tOidc.tokenCache = NewTokenCache()
|
||||
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
||||
@@ -227,9 +243,20 @@ func TestVerifyToken(t *testing.T) {
|
||||
}
|
||||
|
||||
if tc.cacheToken {
|
||||
// Use more realistic claims for cached token
|
||||
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
|
||||
"empty": "claim",
|
||||
}, 60)
|
||||
"iss": "https://test-issuer.com",
|
||||
"sub": "test-subject",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
|
||||
}, time.Minute)
|
||||
|
||||
// Verify the token is actually in the cache
|
||||
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
|
||||
t.Logf("Token found in cache with claims: %v", claims)
|
||||
} else {
|
||||
t.Logf("Token NOT found in cache despite cacheToken=true")
|
||||
}
|
||||
}
|
||||
|
||||
err := ts.tOidc.VerifyToken(tc.token)
|
||||
@@ -610,7 +637,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key", false, logger)
|
||||
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
|
||||
// Create a new instance for each test to avoid state carryover
|
||||
tOidc := &TraefikOidc{
|
||||
@@ -915,7 +942,7 @@ func TestHandleLogout(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key", false, logger)
|
||||
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
tOidc := &TraefikOidc{
|
||||
revocationURL: mockRevocationServer.URL,
|
||||
endSessionURL: tc.endSessionURL,
|
||||
@@ -1204,7 +1231,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key", false, logger)
|
||||
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
sessionManager: sessionManager,
|
||||
@@ -1342,7 +1369,407 @@ func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleMiddlewareInstances verifies that multiple middleware instances
|
||||
// can be created and initialized properly for different routes
|
||||
func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
// Create mock provider metadata server
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
RevokeURL: "https://test-issuer.com/revoke",
|
||||
EndSessionURL: "https://test-issuer.com/end-session",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Create base config
|
||||
config := &Config{
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
}
|
||||
|
||||
// Create multiple middleware instances
|
||||
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
|
||||
var middlewares []*TraefikOidc
|
||||
|
||||
for _, route := range routes {
|
||||
config.CallbackURL = route + "/callback"
|
||||
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}), config, "test")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
|
||||
}
|
||||
|
||||
// Type assert to access internal fields
|
||||
if m, ok := middleware.(*TraefikOidc); ok {
|
||||
middlewares = append(middlewares, m)
|
||||
} else {
|
||||
t.Fatalf("Middleware is not of type *TraefikOidc")
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all instances to initialize
|
||||
for i, m := range middlewares {
|
||||
select {
|
||||
case <-m.initComplete:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("Middleware instance %d failed to initialize", i)
|
||||
}
|
||||
|
||||
// Verify each instance has its own unique configuration
|
||||
if m.issuerURL != "https://test-issuer.com" {
|
||||
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
|
||||
}
|
||||
if m.authURL != "https://test-issuer.com/auth" {
|
||||
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
|
||||
}
|
||||
if m.tokenURL != "https://test-issuer.com/token" {
|
||||
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
|
||||
}
|
||||
if m.jwksURL != "https://test-issuer.com/jwks" {
|
||||
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
|
||||
}
|
||||
if m.redirURLPath != routes[i]+"/callback" {
|
||||
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that each instance can handle requests independently
|
||||
for i, m := range middlewares {
|
||||
req := httptest.NewRequest("GET", routes[i]+"/protected", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(rr, req)
|
||||
|
||||
// Should redirect to auth URL since not authenticated
|
||||
if rr.Code != http.StatusFound {
|
||||
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
|
||||
}
|
||||
|
||||
location := rr.Header().Get("Location")
|
||||
if !strings.Contains(location, "https://test-issuer.com/auth") {
|
||||
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create consistent timestamps for all test cases
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
||||
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
claims map[string]interface{}
|
||||
setupSession func(*SessionData)
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "User with allowed role",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"admin", "user"},
|
||||
"groups": []interface{}{"group1"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "admin,user",
|
||||
"X-User-Groups": "group1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User with allowed group",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"allowed-group"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "allowed-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User without allowed roles or groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "No role/group restrictions",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "regular-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Claims without roles and groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"jti": generateRandomString(16),
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with claims
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
// Create test handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Configure OIDC middleware
|
||||
tOidc := ts.tOidc
|
||||
tOidc.next = nextHandler
|
||||
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Set up session
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
tc.setupSession(session)
|
||||
session.SetAccessToken(token)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset response recorder
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Serve request
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check status code
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
}
|
||||
|
||||
// Check headers if status is OK
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
for header, expectedValue := range tc.expectedHeaders {
|
||||
if value := req.Header.Get(header); value != expectedValue {
|
||||
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare string slices
|
||||
|
||||
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
|
||||
func TestExchangeTokensWithRedirects(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupServer func() *httptest.Server
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Successful token exchange with redirects",
|
||||
setupServer: func() *httptest.Server {
|
||||
redirectCount := 0
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if redirectCount < 3 {
|
||||
// Set a cookie before redirecting
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
|
||||
Value: "test-value",
|
||||
})
|
||||
redirectCount++
|
||||
w.Header().Set("Location", r.URL.String())
|
||||
w.WriteHeader(http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify all cookies from previous redirects are present
|
||||
cookies := r.Cookies()
|
||||
if len(cookies) != 3 {
|
||||
t.Errorf("Expected 3 cookies, got %d", len(cookies))
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
found := false
|
||||
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == expectedName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Cookie %s not found", expectedName)
|
||||
}
|
||||
}
|
||||
|
||||
// Return successful token response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
IDToken: "test.id.token",
|
||||
AccessToken: "test-access-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "test-refresh-token",
|
||||
})
|
||||
}))
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Too many redirects",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", r.URL.String())
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}))
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "stopped after 50 redirects",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := tc.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
// Configure the test instance
|
||||
tOidc := ts.tOidc
|
||||
tOidc.tokenURL = server.URL
|
||||
|
||||
// Test token exchange
|
||||
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
||||
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if response == nil {
|
||||
t.Error("Expected token response but got nil")
|
||||
} else if response.IDToken != "test.id.token" {
|
||||
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
@@ -1354,3 +1781,120 @@ func stringSliceEqual(a, b []string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
|
||||
func TestBuildAuthURL(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authURL string
|
||||
issuerURL string
|
||||
redirectURL string
|
||||
state string
|
||||
nonce string
|
||||
expectedPrefix string
|
||||
}{
|
||||
{
|
||||
name: "Absolute Auth URL",
|
||||
authURL: "https://auth.example.com/oauth/authorize",
|
||||
issuerURL: "https://auth.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||
},
|
||||
{
|
||||
name: "Relative Auth URL",
|
||||
authURL: "/oidc/auth",
|
||||
issuerURL: "https://logto.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
||||
},
|
||||
{
|
||||
name: "Relative Auth URL with Different Issuer",
|
||||
authURL: "/sign-in",
|
||||
issuerURL: "https://auth.example.com:8443",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Configure the test instance
|
||||
tOidc := ts.tOidc
|
||||
tOidc.authURL = tc.authURL
|
||||
tOidc.issuerURL = tc.issuerURL
|
||||
|
||||
// Call buildAuthURL
|
||||
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
|
||||
|
||||
// Verify the URL starts with the expected prefix
|
||||
if !strings.HasPrefix(result, tc.expectedPrefix) {
|
||||
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
|
||||
}
|
||||
|
||||
// Parse the resulting URL to verify query parameters
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse resulting URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
expectedParams := map[string]string{
|
||||
"client_id": tOidc.clientID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": tc.redirectURL,
|
||||
"state": tc.state,
|
||||
"nonce": tc.nonce,
|
||||
}
|
||||
|
||||
for key, expected := range expectedParams {
|
||||
if got := query.Get(key); got != expected {
|
||||
t.Errorf("Expected %s=%q, got %q", key, expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify scopes are present and correct
|
||||
if len(tOidc.scopes) > 0 {
|
||||
expectedScopes := strings.Join(tOidc.scopes, " ")
|
||||
if got := query.Get("scope"); got != expectedScopes {
|
||||
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
|
||||
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create a request with query parameters
|
||||
req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Get session
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Call defaultInitiateAuthentication
|
||||
redirectURL := "http://example.com/callback"
|
||||
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
|
||||
// Verify that the incoming path includes query parameters
|
||||
incomingPath := session.GetIncomingPath()
|
||||
expectedPath := "/protected/resource?param1=value1¶m2=value2"
|
||||
if incomingPath != expectedPath {
|
||||
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MetadataCache struct {
|
||||
metadata *ProviderMetadata
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
autoCleanupInterval time.Duration
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
go c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Cleanup removes expired metadata from the cache.
|
||||
func (c *MetadataCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if c.metadata != nil && now.After(c.expiresAt) {
|
||||
c.metadata = nil
|
||||
}
|
||||
}
|
||||
func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the metadata from cache or fetches it if expired
|
||||
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
if c.isCacheValid() {
|
||||
defer c.mutex.RUnlock()
|
||||
return c.metadata, nil
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if c.isCacheValid() {
|
||||
return c.metadata, nil
|
||||
}
|
||||
|
||||
metadata, err := discoverProviderMetadata(providerURL, httpClient, logger)
|
||||
if err != nil {
|
||||
if c.metadata != nil {
|
||||
// On error, extend current cache by 5 minutes to prevent thundering herd
|
||||
c.expiresAt = time.Now().Add(5 * time.Minute)
|
||||
logger.Errorf("Failed to refresh metadata, using cached version for 5 more minutes: %v", err)
|
||||
return c.metadata, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||
}
|
||||
|
||||
c.metadata = metadata
|
||||
// Calculate expiration time based on usage patterns
|
||||
usageCount := 0 // This should be replaced with actual usage tracking logic
|
||||
if usageCount < 10 {
|
||||
c.expiresAt = time.Now().Add(30 * time.Minute)
|
||||
} else if usageCount < 50 {
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
} else {
|
||||
c.expiresAt = time.Now().Add(2 * time.Hour)
|
||||
}
|
||||
|
||||
// End of GetMetadata
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
func (c *MetadataCache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsCacheValid(t *testing.T) {
|
||||
// Setup with a dummy ProviderMetadata.
|
||||
pm := &ProviderMetadata{}
|
||||
mc := &MetadataCache{
|
||||
metadata: pm,
|
||||
expiresAt: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
if !mc.isCacheValid() {
|
||||
t.Errorf("Expected cache to be valid")
|
||||
}
|
||||
mc.expiresAt = time.Now().Add(-1 * time.Hour)
|
||||
if mc.isCacheValid() {
|
||||
t.Errorf("Expected cache to be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup(t *testing.T) {
|
||||
pm := &ProviderMetadata{}
|
||||
mc := &MetadataCache{
|
||||
metadata: pm,
|
||||
expiresAt: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
mc.Cleanup()
|
||||
if mc.metadata != nil {
|
||||
t.Errorf("Expected metadata to be nil after cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetadata_Cached(t *testing.T) {
|
||||
dummyData := &ProviderMetadata{}
|
||||
// Construct MetadataCache manually to avoid interference from auto cleanup.
|
||||
mc := &MetadataCache{
|
||||
metadata: dummyData,
|
||||
expiresAt: time.Now().Add(1 * time.Hour),
|
||||
stopCleanup: make(chan struct{}),
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
}
|
||||
// Use NewLogger to create a logger that writes errors only.
|
||||
logger := NewLogger("error")
|
||||
result, err := mc.GetMetadata("http://example.com", http.DefaultClient, logger)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if result != dummyData {
|
||||
t.Errorf("Expected cached metadata to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataCacheAutoCleanup(t *testing.T) {
|
||||
mc := &MetadataCache{
|
||||
autoCleanupInterval: 50 * time.Millisecond,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
// Start auto cleanup.
|
||||
go mc.startAutoCleanup()
|
||||
mc.mutex.Lock()
|
||||
mc.metadata = &ProviderMetadata{}
|
||||
mc.expiresAt = time.Now().Add(-50 * time.Millisecond)
|
||||
mc.mutex.Unlock()
|
||||
|
||||
// Wait enough time for the auto cleanup to run.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
mc.Close()
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
if mc.metadata != nil {
|
||||
t.Errorf("Expected metadata to be cleared by auto cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
type errorRoundTripper struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return nil, e.err
|
||||
}
|
||||
|
||||
func TestGetMetadata_FetchError(t *testing.T) {
|
||||
// Create an HTTP client that always returns an error.
|
||||
errorClient := &http.Client{
|
||||
Transport: errorRoundTripper{err: fmt.Errorf("fake fetch error")},
|
||||
}
|
||||
|
||||
// Case 1: Cache is empty.
|
||||
mc := &MetadataCache{
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
logger := NewLogger("error")
|
||||
metadata, err := mc.GetMetadata("http://example.com", errorClient, logger)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
if metadata != nil {
|
||||
t.Errorf("Expected nil metadata, got %v", metadata)
|
||||
}
|
||||
|
||||
// Case 2: Cache has old metadata.
|
||||
dummy := &ProviderMetadata{}
|
||||
mc.metadata = dummy
|
||||
mc.expiresAt = time.Now().Add(-1 * time.Minute)
|
||||
logger2 := NewLogger("error")
|
||||
metadata, err = mc.GetMetadata("http://example.com", errorClient, logger2)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error when cached metadata exists, got %v", err)
|
||||
}
|
||||
if metadata != dummy {
|
||||
t.Errorf("Expected cached metadata to be returned")
|
||||
}
|
||||
}
|
||||
+10
@@ -0,0 +1,10 @@
|
||||
version: 1
|
||||
force:
|
||||
existing: true
|
||||
wording:
|
||||
patch:
|
||||
- patch-release
|
||||
minor:
|
||||
- minor-release
|
||||
major:
|
||||
- breaking
|
||||
+460
-57
@@ -1,106 +1,308 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
mainCookieName = "_raczylo_oidc" // Main session cookie
|
||||
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
|
||||
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
|
||||
)
|
||||
|
||||
// SessionManager handles multiple session cookies
|
||||
type SessionManager struct {
|
||||
store sessions.Store
|
||||
forceHTTPS bool
|
||||
logger *Logger
|
||||
// generateSecureRandomString creates a cryptographically secure random string of specified length.
|
||||
// It returns the generated string or an error if random generation fails.
|
||||
func generateSecureRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
|
||||
return &SessionManager{
|
||||
// Cookie names and configuration constants used for session management
|
||||
const (
|
||||
// Using fixed prefixes for consistent cookie naming across restarts
|
||||
mainCookieName = "_oidc_raczylo_m"
|
||||
accessTokenCookie = "_oidc_raczylo_a"
|
||||
refreshTokenCookie = "_oidc_raczylo_r"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxCookieSize is the maximum size for each cookie chunk.
|
||||
// This value is calculated to ensure the final cookie size stays within browser limits:
|
||||
// 1. Browser cookie size limit is typically 4096 bytes
|
||||
// 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio)
|
||||
// 3. Calculation:
|
||||
// - Let x be the chunk size
|
||||
// - After encryption: x + 28 bytes
|
||||
// - After base64: ((x + 28) * 4/3) bytes
|
||||
// - Must satisfy: ((x + 28) * 4/3) ≤ 4096
|
||||
// - Solving for x: x ≤ 3044
|
||||
// 4. We use 2000 as a conservative limit to account for cookie metadata
|
||||
maxCookieSize = 2000
|
||||
|
||||
// absoluteSessionTimeout defines the maximum lifetime of a session
|
||||
// regardless of activity (24 hours)
|
||||
absoluteSessionTimeout = 24 * time.Hour
|
||||
|
||||
// minEncryptionKeyLength defines the minimum length for the encryption key
|
||||
minEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// compressToken compresses a token using gzip and base64 encodes it.
|
||||
func compressToken(token string) string {
|
||||
var b bytes.Buffer
|
||||
gz := gzip.NewWriter(&b)
|
||||
if _, err := gz.Write([]byte(token)); err != nil {
|
||||
return token // fallback to uncompressed on error
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
return token
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b.Bytes())
|
||||
}
|
||||
|
||||
// decompressToken decompresses a base64 encoded gzipped token.
|
||||
func decompressToken(compressed string) string {
|
||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||
if err != nil {
|
||||
return compressed // return as-is if not base64
|
||||
}
|
||||
|
||||
gz, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return compressed
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(gz)
|
||||
if err != nil {
|
||||
return compressed
|
||||
}
|
||||
|
||||
return string(decompressed)
|
||||
}
|
||||
|
||||
// SessionManager handles the management of multiple session cookies for OIDC authentication.
|
||||
// It provides functionality for storing and retrieving authentication state, tokens,
|
||||
// and other session-related data across multiple cookies.
|
||||
type SessionManager struct {
|
||||
// store is the underlying session store for cookie management.
|
||||
store sessions.Store
|
||||
|
||||
// forceHTTPS enforces secure cookie attributes regardless of request scheme.
|
||||
forceHTTPS bool
|
||||
|
||||
// logger provides structured logging capabilities.
|
||||
logger *Logger
|
||||
|
||||
// sessionPool is a sync.Pool for reusing SessionData objects.
|
||||
sessionPool sync.Pool
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager with the specified configuration.
|
||||
// Parameters:
|
||||
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
|
||||
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
|
||||
// - logger: Logger instance for recording session-related events
|
||||
//
|
||||
// Returns an error if the encryption key does not meet minimum length requirements.
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) {
|
||||
// Validate encryption key length.
|
||||
if len(encryptionKey) < minEncryptionKeyLength {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
|
||||
sm := &SessionManager{
|
||||
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
||||
forceHTTPS: forceHTTPS,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Initialize session pool.
|
||||
sm.sessionPool.New = func() interface{} {
|
||||
return &SessionData{
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// getSessionOptions returns session options based on scheme
|
||||
// getSessionOptions returns secure session options configured for the current request.
|
||||
// Parameters:
|
||||
// - isSecure: Whether the current request is using HTTPS.
|
||||
//
|
||||
// The options ensure cookies are:
|
||||
// - HTTP-only (not accessible via JavaScript)
|
||||
// - Secure when using HTTPS or when forceHTTPS is enabled
|
||||
// - Using SameSite=Lax for CSRF protection
|
||||
// - Set with appropriate timeout and path settings
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure || sm.forceHTTPS,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
MaxAge: int(absoluteSessionTimeout.Seconds()),
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession retrieves all session data
|
||||
// GetSession retrieves all session data for the current request.
|
||||
// It loads the main session and token sessions, including any chunked token data,
|
||||
// and combines them into a single SessionData structure for easy access.
|
||||
// Returns an error if any session component cannot be loaded.
|
||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
mainSession, err := sm.store.Get(r, mainCookieName)
|
||||
// Get session from pool.
|
||||
sessionData := sm.sessionPool.Get().(*SessionData)
|
||||
sessionData.request = r
|
||||
|
||||
var err error
|
||||
sessionData.mainSession, err = sm.store.Get(r, mainCookieName)
|
||||
if err != nil {
|
||||
sm.sessionPool.Put(sessionData)
|
||||
return nil, fmt.Errorf("failed to get main session: %w", err)
|
||||
}
|
||||
|
||||
accessSession, err := sm.store.Get(r, accessTokenCookie)
|
||||
// Check for absolute session timeout.
|
||||
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
||||
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
|
||||
sessionData.Clear(r, nil)
|
||||
return nil, fmt.Errorf("session expired")
|
||||
}
|
||||
}
|
||||
|
||||
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
|
||||
if err != nil {
|
||||
sm.sessionPool.Put(sessionData)
|
||||
return nil, fmt.Errorf("failed to get access token session: %w", err)
|
||||
}
|
||||
|
||||
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
|
||||
sessionData.refreshSession, err = sm.store.Get(r, refreshTokenCookie)
|
||||
if err != nil {
|
||||
sm.sessionPool.Put(sessionData)
|
||||
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
|
||||
}
|
||||
|
||||
sessionData := &SessionData{
|
||||
manager: sm,
|
||||
mainSession: mainSession,
|
||||
accessSession: accessSession,
|
||||
refreshSession: refreshSession,
|
||||
// Clear and reuse chunk maps.
|
||||
for k := range sessionData.accessTokenChunks {
|
||||
delete(sessionData.accessTokenChunks, k)
|
||||
}
|
||||
for k := range sessionData.refreshTokenChunks {
|
||||
delete(sessionData.refreshTokenChunks, k)
|
||||
}
|
||||
|
||||
// Retrieve chunked token sessions.
|
||||
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
|
||||
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// SessionData holds all session information
|
||||
type SessionData struct {
|
||||
manager *SessionManager
|
||||
mainSession *sessions.Session
|
||||
accessSession *sessions.Session
|
||||
refreshSession *sessions.Session
|
||||
// getTokenChunkSessions retrieves all session chunks for a given token type.
|
||||
// Parameters:
|
||||
// - r: The HTTP request
|
||||
// - baseName: The base name for the token's session cookies
|
||||
// - chunks: Map to store the chunks in
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
session, err := sm.store.Get(r, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
chunks[i] = session
|
||||
}
|
||||
}
|
||||
|
||||
// Save saves all session data
|
||||
// SessionData holds all session information for an authenticated user.
|
||||
// It manages multiple session cookies to handle the main session state
|
||||
// and potentially large access and refresh tokens that may need to be
|
||||
// split across multiple cookies due to browser size limitations.
|
||||
type SessionData struct {
|
||||
// manager is the SessionManager that created this SessionData.
|
||||
manager *SessionManager
|
||||
|
||||
// request is the current HTTP request associated with this session.
|
||||
request *http.Request
|
||||
|
||||
// mainSession stores authentication state and basic user info.
|
||||
mainSession *sessions.Session
|
||||
|
||||
// accessSession stores the primary access token cookie.
|
||||
accessSession *sessions.Session
|
||||
|
||||
// refreshSession stores the primary refresh token cookie.
|
||||
refreshSession *sessions.Session
|
||||
|
||||
// accessTokenChunks stores additional chunks of the access token
|
||||
// when it exceeds the maximum cookie size.
|
||||
accessTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshTokenChunks stores additional chunks of the refresh token
|
||||
// when it exceeds the maximum cookie size.
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
}
|
||||
|
||||
// Save persists all session data to cookies in the HTTP response.
|
||||
// It saves the main session, token sessions, and any token chunks,
|
||||
// applying appropriate security options to each cookie. All cookies
|
||||
// are saved with consistent security settings based on the request scheme.
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
// Set options for all sessions
|
||||
sd.mainSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
sd.accessSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
sd.refreshSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
// Set options for all sessions.
|
||||
options := sd.manager.getSessionOptions(isSecure)
|
||||
sd.mainSession.Options = options
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
|
||||
// Save main session.
|
||||
if err := sd.mainSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save main session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token session.
|
||||
if err := sd.accessSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token session: %w", err)
|
||||
}
|
||||
|
||||
// Save refresh token session.
|
||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token chunks.
|
||||
for _, session := range sd.accessTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save refresh token chunks.
|
||||
for _, session := range sd.refreshTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear clears all session data
|
||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions
|
||||
// Clear and expire all sessions.
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
sd.accessSession.Options.MaxAge = -1
|
||||
sd.refreshSession.Options.MaxAge = -1
|
||||
@@ -115,82 +317,283 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
|
||||
return sd.Save(r, w)
|
||||
// Clear chunk sessions.
|
||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||
|
||||
var err error
|
||||
if w != nil {
|
||||
err = sd.Save(r, w)
|
||||
}
|
||||
|
||||
// Clear transient per-request fields.
|
||||
sd.request = nil
|
||||
|
||||
// Return session to pool.
|
||||
sd.manager.sessionPool.Put(sd)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAuthenticated returns authentication status
|
||||
// clearTokenChunks removes all session chunks for a given token type.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns whether the current session is authenticated.
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
return auth
|
||||
if !auth {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check session expiration.
|
||||
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
|
||||
}
|
||||
|
||||
// SetAuthenticated sets authentication status
|
||||
func (sd *SessionData) SetAuthenticated(value bool) {
|
||||
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
||||
// Returns an error if generating a new session ID fails.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
if value {
|
||||
id, err := generateSecureRandomString(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate secure session id: %w", err)
|
||||
}
|
||||
sd.mainSession.ID = id
|
||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||
}
|
||||
sd.mainSession.Values["authenticated"] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccessToken returns the access token
|
||||
// GetAccessToken retrieves the complete access token from the session.
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
compressed, _ := sd.accessSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks.
|
||||
if len(sd.accessTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.accessTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
token = strings.Join(chunks, "")
|
||||
compressed, _ := sd.accessSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// SetAccessToken sets the access token
|
||||
// SetAccessToken stores the access token in the session.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
sd.accessSession.Values["token"] = token
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
sd.accessSession.Values["token"] = compressed
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks.
|
||||
sd.accessSession.Values["token"] = ""
|
||||
sd.accessSession.Values["compressed"] = true
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.accessTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken returns the refresh token
|
||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks.
|
||||
if len(sd.refreshTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.refreshTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
token = strings.Join(chunks, "")
|
||||
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
||||
if compressed {
|
||||
return decompressToken(token)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// SetRefreshToken sets the refresh token
|
||||
// SetRefreshToken stores the refresh token in the session.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
sd.refreshSession.Values["token"] = token
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
|
||||
}
|
||||
|
||||
// Clear and prepare chunks map for new token.
|
||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
// Compress token.
|
||||
compressed := compressToken(token)
|
||||
|
||||
if len(compressed) <= maxCookieSize {
|
||||
sd.refreshSession.Values["token"] = compressed
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
} else {
|
||||
// Split compressed token into chunks.
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = true
|
||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.refreshTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetCSRF returns the CSRF token
|
||||
// expireAccessTokenChunks expires any existing access token chunk cookies.
|
||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if w != nil {
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
|
||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
break
|
||||
}
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if w != nil {
|
||||
if err := session.Save(sd.request, w); err != nil {
|
||||
sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size.
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
if len(s) > chunkSize {
|
||||
chunks = append(chunks, s[:chunkSize])
|
||||
s = s[chunkSize:]
|
||||
} else {
|
||||
chunks = append(chunks, s)
|
||||
break
|
||||
}
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF retrieves the CSRF token from the session.
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRF token
|
||||
// SetCSRF stores a new CSRF token in the session.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce returns the nonce
|
||||
// GetNonce retrieves the nonce value from the session.
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce sets the nonce
|
||||
// SetNonce stores a new nonce value in the session.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetEmail returns the user's email
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail sets the user's email
|
||||
// SetEmail stores the user's email address in the session.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath returns the original incoming path
|
||||
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath sets the original incoming path
|
||||
// SetIncomingPath stores the original request path that triggered the authentication flow.
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
|
||||
+345
-23
@@ -1,60 +1,382 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
manager := NewSessionManager("test-secret-key", false, logger)
|
||||
// generateRandomString creates a random string of specified length
|
||||
func generateRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// TestTokenCompression tests the token compression functionality
|
||||
func TestTokenCompression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantSize int // Expected size after compression (approximate)
|
||||
}{
|
||||
{
|
||||
name: "Short token",
|
||||
token: "shorttoken",
|
||||
wantSize: 50, // Base64 encoded gzip has overhead for small content
|
||||
},
|
||||
{
|
||||
name: "Repeating content",
|
||||
token: strings.Repeat("abcdef", 1000),
|
||||
wantSize: 100, // Should compress well due to repetition
|
||||
},
|
||||
{
|
||||
name: "Random content",
|
||||
token: generateRandomString(1000),
|
||||
wantSize: 2000, // Random content won't compress much
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
compressed := compressToken(tt.token)
|
||||
decompressed := decompressToken(compressed)
|
||||
|
||||
// Only verify compression ratio for non-short tokens
|
||||
if len(tt.token) > 100 {
|
||||
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
|
||||
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
|
||||
|
||||
if compressionRatio > 1.1 { // Allow up to 10% size increase
|
||||
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
|
||||
len(tt.token), len(compressed), compressionRatio)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify decompression restores original
|
||||
if decompressed != tt.token {
|
||||
t.Error("Decompression failed to restore original token")
|
||||
}
|
||||
|
||||
// Verify approximate compression ratio
|
||||
if len(compressed) > tt.wantSize*2 {
|
||||
t.Errorf("Compression ratio worse than expected: got=%d, want<%d", len(compressed), tt.wantSize*2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionManager tests the SessionManager functionality
|
||||
|
||||
func TestCookiePrefix(t *testing.T) {
|
||||
// Create a session and verify cookie names
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := manager.GetSession(req)
|
||||
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Test setting and getting values
|
||||
// Set some data to ensure cookies are created
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetAccessToken("test.access.token")
|
||||
session.SetRefreshToken("test.refresh.token")
|
||||
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Set new tokens
|
||||
session.SetAccessToken("test_token")
|
||||
session.SetRefreshToken("test_refresh_token")
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are set
|
||||
// Check cookie prefixes
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != 3 {
|
||||
t.Errorf("Expected 3 cookies, got %d", len(cookies))
|
||||
for _, cookie := range cookies {
|
||||
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
|
||||
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRefreshCleanup(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Create a new request with the cookies
|
||||
// Set a large token that will be split into chunks
|
||||
largeToken := strings.Repeat("x", 5000)
|
||||
session.SetAccessToken(largeToken)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get initial cookies
|
||||
initialCookies := rr.Result().Cookies()
|
||||
|
||||
// Create a new request with the initial cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
for _, cookie := range initialCookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
newRr := httptest.NewRecorder()
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := manager.GetSession(newReq)
|
||||
// Get session with cookies and set a new token
|
||||
newSession, err := sm.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
if !newSession.GetAuthenticated() {
|
||||
t.Error("Authentication status not preserved")
|
||||
// Create a response recorder for expired cookies
|
||||
expiredRr := httptest.NewRecorder()
|
||||
|
||||
// Expire old chunk cookies
|
||||
newSession.expireAccessTokenChunks(expiredRr)
|
||||
|
||||
// Set a smaller token that won't need chunks
|
||||
newSession.SetAccessToken("small_token")
|
||||
|
||||
// Save session with new token
|
||||
if err := newSession.Save(newReq, newRr); err != nil {
|
||||
t.Fatalf("Failed to save new session: %v", err)
|
||||
}
|
||||
if email := newSession.GetEmail(); email != "test@example.com" {
|
||||
t.Errorf("Expected email test@example.com, got %s", email)
|
||||
|
||||
// Check cookies in response where old cookies are expired
|
||||
intermediateResponse := expiredRr.Result()
|
||||
intermediateCount := 0
|
||||
chunkCount := 0
|
||||
expiredCount := 0
|
||||
|
||||
for _, cookie := range intermediateResponse.Cookies() {
|
||||
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
|
||||
chunkCount++
|
||||
if cookie.MaxAge < 0 {
|
||||
expiredCount++
|
||||
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
} else if cookie.MaxAge >= 0 {
|
||||
intermediateCount++
|
||||
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
|
||||
}
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != "test.access.token" {
|
||||
t.Errorf("Expected access token test.access.token, got %s", token)
|
||||
|
||||
// All chunk cookies should be expired
|
||||
if chunkCount > 0 && chunkCount != expiredCount {
|
||||
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != "test.refresh.token" {
|
||||
t.Errorf("Expected refresh token test.refresh.token, got %s", token)
|
||||
|
||||
// Should have fewer active cookies after setting smaller token
|
||||
if intermediateCount >= len(initialCookies) {
|
||||
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authenticated bool
|
||||
email string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
expectedCookieCount int
|
||||
wantCompressed bool // Whether tokens should be compressed
|
||||
}{
|
||||
{
|
||||
name: "Short tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: "shortaccesstoken",
|
||||
refreshToken: "shortrefreshtoken",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "Long tokens exceeding 4096 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 5000),
|
||||
refreshToken: strings.Repeat("y", 6000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "REALLY long tokens, exceeding 25000 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 25000),
|
||||
refreshToken: strings.Repeat("y", 25000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
|
||||
wantCompressed: true,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated session",
|
||||
authenticated: false,
|
||||
email: "",
|
||||
accessToken: "",
|
||||
refreshToken: "",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
wantCompressed: false,
|
||||
},
|
||||
{
|
||||
name: "Random content tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: generateRandomString(5000),
|
||||
refreshToken: generateRandomString(5000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
|
||||
wantCompressed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set session values
|
||||
session.SetAuthenticated(tc.authenticated)
|
||||
session.SetEmail(tc.email)
|
||||
|
||||
// Expire any existing cookies
|
||||
session.expireAccessTokenChunks(rr)
|
||||
session.expireRefreshTokenChunks(rr)
|
||||
|
||||
// Set new tokens
|
||||
session.SetAccessToken(tc.accessToken)
|
||||
session.SetRefreshToken(tc.refreshToken)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are set and compression is used when appropriate
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != tc.expectedCookieCount {
|
||||
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
|
||||
}
|
||||
|
||||
// Verify compression is working by checking token sizes
|
||||
for _, cookie := range cookies {
|
||||
if strings.Contains(cookie.Name, accessTokenCookie) {
|
||||
// Get original and stored sizes
|
||||
originalSize := len(tc.accessToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
// For large tokens, verify some compression occurred
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
|
||||
originalSize := len(tc.refreshToken)
|
||||
storedSize := len(cookie.Value)
|
||||
|
||||
if originalSize > 100 && tc.wantCompressed {
|
||||
compressionRatio := float64(storedSize) / float64(originalSize)
|
||||
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
|
||||
compressionRatio, originalSize, storedSize)
|
||||
|
||||
if compressionRatio > 0.9 {
|
||||
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
|
||||
cookie.Name, compressionRatio)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := ts.sessionManager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
// Verify session values
|
||||
if newSession.GetAuthenticated() != tc.authenticated {
|
||||
t.Errorf("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != tc.email {
|
||||
t.Errorf("Expected email %s, got %s", tc.email, email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != tc.accessToken {
|
||||
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
|
||||
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
|
||||
}
|
||||
|
||||
// Verify session pooling by checking if the session is reused
|
||||
session2, _ := ts.sessionManager.GetSession(newReq)
|
||||
if session2 == newSession {
|
||||
t.Error("Session not properly pooled")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
|
||||
count := 3 // main, access, refresh
|
||||
|
||||
// Helper to calculate chunks for compressed token
|
||||
calculateChunks := func(token string) int {
|
||||
// Compress token (matching the actual implementation)
|
||||
compressed := compressToken(token)
|
||||
|
||||
// If compressed token fits in one cookie, no additional chunks needed
|
||||
if len(compressed) <= maxCookieSize {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate chunks needed for compressed token
|
||||
return len(splitIntoChunks(compressed, maxCookieSize))
|
||||
}
|
||||
|
||||
// Add chunks for access token if needed
|
||||
accessChunks := calculateChunks(accessToken)
|
||||
if accessChunks > 0 {
|
||||
count += accessChunks
|
||||
}
|
||||
|
||||
// Add chunks for refresh token if needed
|
||||
refreshChunks := calculateChunks(refreshToken)
|
||||
if refreshChunks > 0 {
|
||||
count += refreshChunks
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
+206
-50
@@ -5,93 +5,235 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieName = "_raczylo_oidc"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the OIDC middleware
|
||||
// Config holds the configuration for the OIDC middleware.
|
||||
// It provides all necessary settings to configure OpenID Connect authentication
|
||||
// with various providers like Auth0, Logto, or any standard OIDC provider.
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
Scopes []string `json:"scopes"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
// ProviderURL is the base URL of the OIDC provider (required)
|
||||
// Example: https://accounts.google.com
|
||||
ProviderURL string `json:"providerURL"`
|
||||
|
||||
// RevocationURL is the endpoint for revoking tokens (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
|
||||
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
||||
// Example: /oauth2/callback
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
|
||||
// LogoutURL is the path for handling logout requests (optional)
|
||||
// If not provided, it will be set to CallbackURL + "/logout"
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
|
||||
// ClientID is the OAuth 2.0 client identifier (required)
|
||||
ClientID string `json:"clientID"`
|
||||
|
||||
// ClientSecret is the OAuth 2.0 client secret (required)
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
|
||||
// Scopes defines the OAuth 2.0 scopes to request (optional)
|
||||
// Defaults to ["openid", "profile", "email"] if not provided
|
||||
Scopes []string `json:"scopes"`
|
||||
|
||||
// LogLevel sets the logging verbosity (optional)
|
||||
// Valid values: "debug", "info", "error"
|
||||
// Default: "info"
|
||||
LogLevel string `json:"logLevel"`
|
||||
|
||||
// SessionEncryptionKey is used to encrypt session data (required)
|
||||
// Must be a secure random string
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
|
||||
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
|
||||
// Default: false
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
|
||||
// RateLimit sets the maximum number of requests per second (optional)
|
||||
// Default: 100
|
||||
RateLimit int `json:"rateLimit"`
|
||||
|
||||
// ExcludedURLs lists paths that bypass authentication (optional)
|
||||
// Example: ["/health", "/metrics"]
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
|
||||
// AllowedUserDomains restricts access to specific email domains (optional)
|
||||
// Example: ["company.com", "subsidiary.com"]
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
|
||||
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
|
||||
// Example: ["admin", "developer"]
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
HTTPClient *http.Client
|
||||
|
||||
// OIDCEndSessionURL is the provider's end session endpoint (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
|
||||
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
|
||||
// Default: "/"
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
|
||||
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// CreateConfig creates a new Config with default values
|
||||
const (
|
||||
// DefaultRateLimit defines the default rate limit for requests per second
|
||||
DefaultRateLimit = 100
|
||||
|
||||
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
|
||||
MinRateLimit = 10
|
||||
|
||||
// DefaultLogLevel defines the default logging level
|
||||
DefaultLogLevel = "info"
|
||||
|
||||
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
|
||||
MinSessionEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// CreateConfig creates a new Config with secure default values.
|
||||
// Default values are set for optional fields:
|
||||
// - Scopes: ["openid", "profile", "email"]
|
||||
// - LogLevel: "info"
|
||||
// - LogoutURL: CallbackURL + "/logout"
|
||||
// - RateLimit: 100 requests per second
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{}
|
||||
|
||||
if c.Scopes == nil {
|
||||
c.Scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
if c.LogLevel == "" {
|
||||
c.LogLevel = "info"
|
||||
}
|
||||
|
||||
if c.LogoutURL == "" {
|
||||
c.LogoutURL = c.CallbackURL + "/logout"
|
||||
}
|
||||
|
||||
if c.RateLimit == 0 {
|
||||
c.RateLimit = 100
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate validates the Config
|
||||
// Validate performs validation checks on the Config.
|
||||
// It ensures all required fields are set and have valid values.
|
||||
// Returns an error if any validation check fails.
|
||||
func (c *Config) Validate() error {
|
||||
// Validate provider URL
|
||||
if c.ProviderURL == "" {
|
||||
return fmt.Errorf("providerURL is required")
|
||||
}
|
||||
if !isValidSecureURL(c.ProviderURL) {
|
||||
return fmt.Errorf("providerURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate callback URL
|
||||
if c.CallbackURL == "" {
|
||||
return fmt.Errorf("callbackURL is required")
|
||||
}
|
||||
if !strings.HasPrefix(c.CallbackURL, "/") {
|
||||
return fmt.Errorf("callbackURL must start with /")
|
||||
}
|
||||
|
||||
// Validate client credentials
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("clientID is required")
|
||||
}
|
||||
if c.ClientSecret == "" {
|
||||
return fmt.Errorf("clientSecret is required")
|
||||
}
|
||||
|
||||
// Validate session encryption key
|
||||
if c.SessionEncryptionKey == "" {
|
||||
return fmt.Errorf("sessionEncryptionKey is required")
|
||||
}
|
||||
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
|
||||
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
|
||||
}
|
||||
|
||||
// Validate log level
|
||||
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
|
||||
return fmt.Errorf("logLevel must be one of: debug, info, error")
|
||||
}
|
||||
|
||||
// Validate excluded URLs
|
||||
for _, url := range c.ExcludedURLs {
|
||||
if !strings.HasPrefix(url, "/") {
|
||||
return fmt.Errorf("excluded URL must start with /: %s", url)
|
||||
}
|
||||
if strings.Contains(url, "..") {
|
||||
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
|
||||
}
|
||||
if strings.Contains(url, "*") {
|
||||
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate revocation URL if set
|
||||
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
|
||||
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate end session URL if set
|
||||
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
|
||||
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
|
||||
}
|
||||
|
||||
// Validate post-logout redirect URI if set
|
||||
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
|
||||
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
|
||||
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate rate limit
|
||||
if c.RateLimit < MinRateLimit {
|
||||
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger is a simple logger with different levels
|
||||
// isValidSecureURL checks if the provided string is a valid HTTPS URL
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level is valid
|
||||
func isValidLogLevel(level string) bool {
|
||||
return level == "debug" || level == "info" || level == "error"
|
||||
}
|
||||
|
||||
// Logger provides structured logging capabilities with different severity levels.
|
||||
// It supports error, info, and debug levels with appropriate output streams
|
||||
// and formatting for each level.
|
||||
type Logger struct {
|
||||
// logError handles error-level messages, writing to stderr
|
||||
logError *log.Logger
|
||||
logInfo *log.Logger
|
||||
// logInfo handles informational messages, writing to stdout
|
||||
logInfo *log.Logger
|
||||
// logDebug handles debug-level messages, writing to stdout when debug is enabled
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger
|
||||
// NewLogger creates a new Logger with the specified log level.
|
||||
// The log level determines which messages are output:
|
||||
// - "debug": Outputs all messages (debug, info, error)
|
||||
// - "info": Outputs info and error messages
|
||||
// - "error": Outputs only error messages
|
||||
//
|
||||
// Error messages are always written to stderr, while info and debug
|
||||
// messages are written to stdout when enabled.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
|
||||
logError.SetOutput(os.Stderr)
|
||||
logInfo.SetOutput(os.Stdout)
|
||||
|
||||
if logLevel == "debug" || logLevel == "info" {
|
||||
logInfo.SetOutput(os.Stdout)
|
||||
}
|
||||
if logLevel == "debug" {
|
||||
logDebug.SetOutput(os.Stdout)
|
||||
}
|
||||
@@ -103,37 +245,51 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
// Info logs an informational message.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
// Debug logs a debug message.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
// Error logs an error message.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an info message
|
||||
// Infof logs an informational message using Printf formatting.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message
|
||||
// Debugf logs a debug message using Printf formatting.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message
|
||||
// Errorf logs an error message using Printf formatting.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to the response and logs it
|
||||
// handleError writes an error message to both the HTTP response and the error log.
|
||||
// It ensures consistent error handling across the middleware by logging the error
|
||||
// and sending an appropriate HTTP response to the client.
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
|
||||
@@ -0,0 +1,397 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateConfig(t *testing.T) {
|
||||
t.Run("Default Values", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
|
||||
// Check default scopes
|
||||
expectedScopes := []string{"openid", "profile", "email"}
|
||||
if len(config.Scopes) != len(expectedScopes) {
|
||||
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
|
||||
}
|
||||
for i, scope := range expectedScopes {
|
||||
if config.Scopes[i] != scope {
|
||||
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Check default log level
|
||||
if config.LogLevel != DefaultLogLevel {
|
||||
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
|
||||
}
|
||||
|
||||
// Check default rate limit
|
||||
if config.RateLimit != DefaultRateLimit {
|
||||
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
|
||||
}
|
||||
|
||||
// Check ForceHTTPS default
|
||||
if !config.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true by default")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Custom Values Preserved", func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.Scopes = []string{"custom_scope"}
|
||||
config.LogLevel = "debug"
|
||||
config.RateLimit = 50
|
||||
config.ForceHTTPS = false
|
||||
|
||||
// Verify custom values are not overwritten
|
||||
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
|
||||
t.Error("Custom scopes were overwritten")
|
||||
}
|
||||
if config.LogLevel != "debug" {
|
||||
t.Error("Custom log level was overwritten")
|
||||
}
|
||||
if config.RateLimit != 50 {
|
||||
t.Error("Custom rate limit was overwritten")
|
||||
}
|
||||
if config.ForceHTTPS {
|
||||
t.Error("Custom ForceHTTPS value was overwritten")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Empty Config",
|
||||
config: &Config{},
|
||||
expectedError: "providerURL is required",
|
||||
},
|
||||
{
|
||||
name: "Missing CallbackURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
},
|
||||
expectedError: "callbackURL is required",
|
||||
},
|
||||
{
|
||||
name: "Missing ClientID",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
},
|
||||
expectedError: "clientID is required",
|
||||
},
|
||||
{
|
||||
name: "Missing ClientSecret",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedError: "clientSecret is required",
|
||||
},
|
||||
{
|
||||
name: "Missing SessionEncryptionKey",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
},
|
||||
expectedError: "sessionEncryptionKey is required",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS ProviderURL",
|
||||
config: &Config{
|
||||
ProviderURL: "http://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "encryption-key",
|
||||
},
|
||||
expectedError: "providerURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Invalid CallbackURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "callback", // Missing leading slash
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "encryption-key",
|
||||
},
|
||||
expectedError: "callbackURL must start with /",
|
||||
},
|
||||
{
|
||||
name: "Short SessionEncryptionKey",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "short",
|
||||
},
|
||||
expectedError: "sessionEncryptionKey must be at least 32 characters long",
|
||||
},
|
||||
{
|
||||
name: "Low RateLimit",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
RateLimit: 5,
|
||||
},
|
||||
expectedError: "rateLimit must be at least 10",
|
||||
},
|
||||
{
|
||||
name: "Invalid LogLevel",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
LogLevel: "invalid",
|
||||
},
|
||||
expectedError: "logLevel must be one of: debug, info, error",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS RevocationURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
RevocationURL: "http://revoke.com",
|
||||
},
|
||||
expectedError: "revocationURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Non-HTTPS OIDCEndSessionURL",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
OIDCEndSessionURL: "http://endsession.com",
|
||||
},
|
||||
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
|
||||
},
|
||||
{
|
||||
name: "Valid Config",
|
||||
config: &Config{
|
||||
ProviderURL: "https://provider.com",
|
||||
CallbackURL: "/callback",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
|
||||
LogLevel: "debug",
|
||||
RateLimit: 100,
|
||||
RevocationURL: "https://revoke.com",
|
||||
OIDCEndSessionURL: "https://endsession.com",
|
||||
},
|
||||
expectedError: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.config.Validate()
|
||||
if tc.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
|
||||
} else if err.Error() != tc.expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
// Capture log output
|
||||
var debugBuf, infoBuf, errorBuf bytes.Buffer
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logLevel string
|
||||
testFunc func(*Logger)
|
||||
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
|
||||
}{
|
||||
{
|
||||
name: "Debug Level",
|
||||
logLevel: "debug",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut == "" {
|
||||
t.Error("Expected debug message in output")
|
||||
}
|
||||
if infoOut == "" {
|
||||
t.Error("Expected info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Info Level",
|
||||
logLevel: "info",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut != "" {
|
||||
t.Error("Did not expect debug message in output")
|
||||
}
|
||||
if infoOut == "" {
|
||||
t.Error("Expected info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error Level",
|
||||
logLevel: "error",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debug("debug message")
|
||||
l.Info("info message")
|
||||
l.Error("error message")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if debugOut != "" {
|
||||
t.Error("Did not expect debug message in output")
|
||||
}
|
||||
if infoOut != "" {
|
||||
t.Error("Did not expect info message in output")
|
||||
}
|
||||
if errorOut == "" {
|
||||
t.Error("Expected error message in output")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Printf Methods",
|
||||
logLevel: "debug",
|
||||
testFunc: func(l *Logger) {
|
||||
l.Debugf("debug %s", "formatted")
|
||||
l.Infof("info %s", "formatted")
|
||||
l.Errorf("error %s", "formatted")
|
||||
},
|
||||
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
|
||||
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
|
||||
t.Error("Expected formatted debug message")
|
||||
}
|
||||
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
|
||||
t.Error("Expected formatted info message")
|
||||
}
|
||||
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
|
||||
t.Error("Expected formatted error message")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset buffers
|
||||
debugBuf.Reset()
|
||||
infoBuf.Reset()
|
||||
errorBuf.Reset()
|
||||
|
||||
// Create logger with test buffers
|
||||
logger := NewLogger(tc.logLevel)
|
||||
logger.logError.SetOutput(&errorBuf)
|
||||
|
||||
if tc.logLevel == "debug" || tc.logLevel == "info" {
|
||||
logger.logInfo.SetOutput(&infoBuf)
|
||||
}
|
||||
if tc.logLevel == "debug" {
|
||||
logger.logDebug.SetOutput(&debugBuf)
|
||||
}
|
||||
|
||||
// Run test
|
||||
tc.testFunc(logger)
|
||||
|
||||
// Check results
|
||||
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleError(t *testing.T) {
|
||||
// Create a test logger with captured output
|
||||
var errorBuf bytes.Buffer
|
||||
logger := &Logger{
|
||||
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
|
||||
}
|
||||
logger.logError.SetOutput(&errorBuf)
|
||||
|
||||
// Create a test response recorder
|
||||
rr := &testResponseRecorder{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
|
||||
// Test error handling
|
||||
message := "test error message"
|
||||
code := 400
|
||||
handleError(rr, message, code, logger)
|
||||
|
||||
// Check response code
|
||||
if rr.statusCode != code {
|
||||
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
|
||||
}
|
||||
|
||||
// Check response body
|
||||
expectedBody := message + "\n"
|
||||
if rr.body != expectedBody {
|
||||
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
|
||||
}
|
||||
|
||||
// Check error was logged
|
||||
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
|
||||
t.Error("Error message was not logged")
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper types
|
||||
type testResponseRecorder struct {
|
||||
statusCode int
|
||||
body string
|
||||
headers map[string][]string
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Header() http.Header {
|
||||
return r.headers
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Write(b []byte) (int, error) {
|
||||
r.body = string(b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) WriteHeader(code int) {
|
||||
r.statusCode = code
|
||||
}
|
||||
Reference in New Issue
Block a user