mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ce2815123 | |||
| 7d204113ea | |||
| c721913cbe | |||
| 0f8b7f7ab1 | |||
| 2743b0e024 | |||
| e6fc36937b | |||
| df051e0cfb | |||
| 5d5ce8ae5e | |||
| d194cd778a | |||
| 803a1e5e21 | |||
| 3ad8fb4518 | |||
| 9402f1bca5 | |||
| e6205b3a48 | |||
| fdb8e3233e |
+219
-18
@@ -4,28 +4,229 @@ type: middleware
|
||||
import: github.com/lukaszraczylo/traefikoidc
|
||||
|
||||
summary: |
|
||||
Middleware adding OIDC authentication to traefik routes. Does what it says on the tin.
|
||||
Middleware has been tested with Auth0 and Logto. It should work with any OIDC provider.
|
||||
Middleware adding OpenID Connect (OIDC) authentication to Traefik routes.
|
||||
|
||||
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy.
|
||||
It provides a complete OIDC authentication solution with features like domain restrictions,
|
||||
role-based access control, token caching, and more.
|
||||
|
||||
The middleware has been tested with Auth0, Logto, Google, and other standard OIDC providers.
|
||||
It supports various authentication scenarios including:
|
||||
|
||||
- Basic authentication with customizable callback and logout URLs
|
||||
- Email domain restrictions to limit access to specific organizations
|
||||
- Role and group-based access control
|
||||
- Public URLs that bypass authentication
|
||||
- Rate limiting to prevent brute force attacks
|
||||
- Custom post-logout redirect behavior
|
||||
- Secure session management with encrypted cookies
|
||||
- Automatic token validation and refresh
|
||||
|
||||
testData:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
# Required parameters
|
||||
providerURL: https://accounts.google.com # Base URL of the OIDC provider
|
||||
clientID: 1234567890.apps.googleusercontent.com # OAuth 2.0 client identifier
|
||||
clientSecret: secret # OAuth 2.0 client secret
|
||||
callbackURL: /oauth2/callback # Path where the OIDC provider will redirect after authentication
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long # Key used to encrypt session data (must be at least 32 bytes)
|
||||
|
||||
# Optional parameters with defaults
|
||||
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
|
||||
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
|
||||
|
||||
scopes: # OAuth 2.0 scopes to request (default: ["openid", "email", "profile"])
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
|
||||
- raczylo.com
|
||||
allowedRolesAndGroups:
|
||||
- roles # Include this to get role information from the provider
|
||||
|
||||
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
|
||||
allowedRolesAndGroups: # Restricts access to users with specific roles or groups (if not provided, no role/group restrictions)
|
||||
- guest-endpoints
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- admin
|
||||
- developer
|
||||
|
||||
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
|
||||
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
|
||||
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
|
||||
|
||||
excludedURLs: # Lists paths that bypass authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
- /public
|
||||
- /health
|
||||
- /metrics
|
||||
|
||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
|
||||
# Configuration documentation
|
||||
configuration:
|
||||
providerURL:
|
||||
type: string
|
||||
description: |
|
||||
The base URL of the OIDC provider. This is the issuer URL that will be used to discover
|
||||
OIDC endpoints like authorization, token, and JWKS URIs.
|
||||
|
||||
Examples:
|
||||
- https://accounts.google.com
|
||||
- https://login.microsoftonline.com/tenant-id/v2.0
|
||||
- https://your-auth0-domain.auth0.com
|
||||
- https://your-logto-instance.com/oidc
|
||||
required: true
|
||||
|
||||
clientID:
|
||||
type: string
|
||||
description: |
|
||||
The OAuth 2.0 client identifier obtained from your OIDC provider.
|
||||
This is the public identifier for your application.
|
||||
required: true
|
||||
|
||||
clientSecret:
|
||||
type: string
|
||||
description: |
|
||||
The OAuth 2.0 client secret obtained from your OIDC provider.
|
||||
This should be kept confidential and not exposed in client-side code.
|
||||
|
||||
For Kubernetes deployments, you can use the secret reference format:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
required: true
|
||||
|
||||
callbackURL:
|
||||
type: string
|
||||
description: |
|
||||
The path where the OIDC provider will redirect after authentication.
|
||||
This must match one of the redirect URIs configured in your OIDC provider.
|
||||
|
||||
The full redirect URI will be constructed as:
|
||||
[scheme]://[host][callbackURL]
|
||||
|
||||
Example: /oauth2/callback
|
||||
required: true
|
||||
|
||||
sessionEncryptionKey:
|
||||
type: string
|
||||
description: |
|
||||
Key used to encrypt session data stored in cookies.
|
||||
Must be at least 32 bytes long for security.
|
||||
|
||||
Example: potato-secret-is-at-least-32-bytes-long
|
||||
required: true
|
||||
|
||||
logoutURL:
|
||||
type: string
|
||||
description: |
|
||||
The path for handling logout requests.
|
||||
If not provided, it will be set to callbackURL + "/logout".
|
||||
|
||||
Example: /oauth2/logout
|
||||
required: false
|
||||
|
||||
postLogoutRedirectURI:
|
||||
type: string
|
||||
description: |
|
||||
The URL to redirect to after logout.
|
||||
Default: "/"
|
||||
|
||||
Example: /logged-out-page
|
||||
required: false
|
||||
|
||||
scopes:
|
||||
type: array
|
||||
description: |
|
||||
The OAuth 2.0 scopes to request from the OIDC provider.
|
||||
Default: ["openid", "profile", "email"]
|
||||
|
||||
Include "roles" or similar scope if you need role/group information.
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
logLevel:
|
||||
type: string
|
||||
description: |
|
||||
Sets the logging verbosity.
|
||||
Valid values: "debug", "info", "error"
|
||||
Default: "info"
|
||||
required: false
|
||||
enum:
|
||||
- debug
|
||||
- info
|
||||
- error
|
||||
|
||||
forceHTTPS:
|
||||
type: boolean
|
||||
description: |
|
||||
Forces the use of HTTPS for all URLs.
|
||||
This is recommended for security in production environments.
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
rateLimit:
|
||||
type: integer
|
||||
description: |
|
||||
Sets the maximum number of requests per second.
|
||||
This helps prevent brute force attacks.
|
||||
Default: 100
|
||||
Minimum: 10
|
||||
required: false
|
||||
|
||||
excludedURLs:
|
||||
type: array
|
||||
description: |
|
||||
Lists paths that bypass authentication.
|
||||
These paths will be accessible without OIDC authentication.
|
||||
|
||||
The middleware uses prefix matching, so "/public" will match
|
||||
"/public", "/public/page", "/public-data", etc.
|
||||
|
||||
Examples: ["/health", "/metrics", "/public"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
allowedUserDomains:
|
||||
type: array
|
||||
description: |
|
||||
Restricts access to users with email addresses from specific domains.
|
||||
If not provided, the middleware relies entirely on the OIDC provider
|
||||
for authentication decisions.
|
||||
|
||||
Examples: ["company.com", "subsidiary.com"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
allowedRolesAndGroups:
|
||||
type: array
|
||||
description: |
|
||||
Restricts access to users with specific roles or groups.
|
||||
If not provided, no role/group restrictions are applied.
|
||||
|
||||
The middleware checks both the "roles" and "groups" claims in the ID token.
|
||||
|
||||
Examples: ["admin", "developer"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
The endpoint for revoking tokens.
|
||||
If not provided, it will be discovered from provider metadata.
|
||||
|
||||
Example: https://accounts.google.com/revoke
|
||||
required: false
|
||||
|
||||
oidcEndSessionURL:
|
||||
type: string
|
||||
description: |
|
||||
The provider's end session endpoint.
|
||||
If not provided, it will be discovered from provider metadata.
|
||||
|
||||
Example: https://accounts.google.com/logout
|
||||
required: false
|
||||
|
||||
@@ -1,153 +1,283 @@
|
||||
## Traefik OIDC middleware
|
||||
# Traefik OIDC Middleware
|
||||
|
||||
This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy to support the OIDC authentication.
|
||||
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy to support OpenID Connect (OIDC) authentication.
|
||||
|
||||
Middleware has been tested with Auth0 and Logto.
|
||||
## Overview
|
||||
|
||||
### Traefik version compatibility
|
||||
The Traefik OIDC middleware provides a complete OIDC authentication solution with features like:
|
||||
- Token validation and verification
|
||||
- Session management
|
||||
- Domain restrictions
|
||||
- Role-based access control
|
||||
- Token caching and blacklisting
|
||||
- Rate limiting
|
||||
- Excluded paths (public URLs)
|
||||
|
||||
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
|
||||
The middleware has been tested with Auth0 and Logto, but should work with any standard OIDC provider.
|
||||
|
||||
### Configuration options
|
||||
## Traefik Version Compatibility
|
||||
|
||||
Middleware currently supports following scenarios:
|
||||
This middleware follows closely the current Traefik helm chart versions. If the plugin fails to load, it's time to update to the latest version of the Traefik helm chart.
|
||||
|
||||
* Setting custom callback and logout URLs via `callbackURL` and `logoutURL`
|
||||
* Allowing for access only from the listed domains if `allowedUserDomains` is set, otherwise it relies entirely on the OIDC provider
|
||||
* Using excluded URLs which do **NOT** require the OIDC authentication
|
||||
* Rate limiting requests to prevent the bruteforce attacks
|
||||
## Installation
|
||||
|
||||
#### How to configure...
|
||||
### As a Traefik Plugin
|
||||
|
||||
* `sessionEncryptionKey` should be at least 32 bytes long.
|
||||
|
||||
##### Keeping secrets secret
|
||||
|
||||
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
|
||||
1. Enable the plugin in your Traefik static configuration:
|
||||
|
||||
```yaml
|
||||
# traefik.yml
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.2.1 # Use the latest version
|
||||
```
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
|
||||
### Local Development with Docker Compose
|
||||
|
||||
For local development or testing, you can use the provided Docker Compose setup:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
This will start Traefik with the OIDC middleware and two test services.
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The middleware supports the following configuration options:
|
||||
|
||||
### Required Parameters
|
||||
|
||||
| Parameter | Description | Example |
|
||||
|-----------|-------------|---------|
|
||||
| `providerURL` | The base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | The OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | The OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | Key used to encrypt session data (must be at least 32 bytes long) | `potato-secret-is-at-least-32-bytes-long` |
|
||||
| `callbackURL` | The path where the OIDC provider will redirect after authentication | `/oauth2/callback` |
|
||||
|
||||
### Optional Parameters
|
||||
|
||||
| Parameter | Description | Default | Example |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `logoutURL` | The path for handling logout requests | `callbackURL + "/logout"` | `/oauth2/logout` |
|
||||
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
||||
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-basic
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### With Excluded URLs (Public Access Paths)
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs:
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /public-data
|
||||
- /health
|
||||
- /metrics
|
||||
```
|
||||
|
||||
### With Email Domain Restrictions
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-domain-restricted
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
```
|
||||
|
||||
### With Role-Based Access Control
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-rbac
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # Include this to get role information from the provider
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
```
|
||||
|
||||
### With Custom Logging and Rate Limiting
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-custom-settings
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
logLevel: debug # Options: debug, info, error (default: info)
|
||||
rateLimit: 500 # Requests per second (default: 100)
|
||||
forceHTTPS: false # Default is true for security
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### With Custom Post-Logout Redirect
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-custom-logout
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /logged-out-page # Where to redirect after logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### Keeping Secrets Secret in Kubernetes
|
||||
|
||||
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-secrets
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
|
||||
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```
|
||||
|
||||
##### Excluded URLs with open access
|
||||
Don't forget to create the secret:
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```bash
|
||||
kubectl create secret generic traefik-middleware-oidc \
|
||||
--from-literal=ISSUER=https://accounts.google.com \
|
||||
--from-literal=CLIENT_ID=1234567890.apps.googleusercontent.com \
|
||||
--from-literal=SECRET=your-client-secret \
|
||||
-n traefik
|
||||
```
|
||||
|
||||
## Complete Docker Compose Example
|
||||
|
||||
##### Allowed email domains
|
||||
|
||||
Assuming that your OIDC provider allows anyone to log in, you may want to limit the access to people using emains in specific domain.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-only-my-users
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /new-oidc/callback
|
||||
logoutURL: /new-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
##### Allowed groups and roles
|
||||
|
||||
In case of multiple roles / groups and access separation for various endpoints you will need to create multiple traefik middlewares.
|
||||
Following example allows access for users who have additional role `guest-endpoints` assigned.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-guest-endpoints
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: xxx
|
||||
clientID: yyy
|
||||
clientSecret: zzz
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /my-oidc/callback
|
||||
logoutURL: /my-oidc/logout
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
- roles # This line queries the OIDC provider for roles
|
||||
forceHTTPS: true
|
||||
allowedRolesAndGroups:
|
||||
- guest-endpoints # This line specifies the roles or groups allowed to access content
|
||||
allowedUserDomains:
|
||||
- raczylo.com
|
||||
```
|
||||
|
||||
|
||||
#### Docker compose example
|
||||
|
||||
`docker-compose.yaml`
|
||||
Here's a complete example of using the middleware with Docker Compose:
|
||||
|
||||
```yaml
|
||||
version: "3.7"
|
||||
|
||||
services:
|
||||
traefik:
|
||||
image: traefik:v3.0.1
|
||||
image: traefik:v3.2.1
|
||||
command:
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.2.1"
|
||||
@@ -158,7 +288,6 @@ services:
|
||||
labels:
|
||||
- "traefik.http.routers.dash.rule=Host(`dash.localhost`)"
|
||||
- "traefik.http.routers.dash.service=api@internal"
|
||||
|
||||
ports:
|
||||
- "80:80"
|
||||
|
||||
@@ -181,8 +310,7 @@ services:
|
||||
- traefik.http.routers.whoami.middlewares=my-plugin@file
|
||||
```
|
||||
|
||||
`traefik-config/traefik.yaml`
|
||||
|
||||
`traefik-config/traefik.yml`:
|
||||
```yaml
|
||||
log:
|
||||
level: INFO
|
||||
@@ -211,7 +339,7 @@ providers:
|
||||
filename: /etc/traefik/dynamic-configuration.yml
|
||||
```
|
||||
|
||||
`traefik-config/dynamic-configuration.yaml`
|
||||
`traefik-config/dynamic-configuration.yml`:
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
@@ -220,20 +348,78 @@ http:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: secret
|
||||
clientSecret: your-client-secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
postLogoutRedirectURI: /logged-out-page
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
|
||||
- raczylo.com
|
||||
sessionEncryptionKey: potato-secret
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
forceHTTPS: false
|
||||
logLevel: debug # debug, info, warn, error
|
||||
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
logLevel: debug
|
||||
rateLimit: 100
|
||||
excludedURLs:
|
||||
- /login
|
||||
- /public
|
||||
- /health
|
||||
- /metrics
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Session Management
|
||||
|
||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||
|
||||
### Token Caching and Blacklisting
|
||||
|
||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||
|
||||
### Headers Set for Downstream Services
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
|
||||
### Security Headers
|
||||
|
||||
The middleware also sets the following security headers:
|
||||
|
||||
- `X-Frame-Options: DENY`
|
||||
- `X-Content-Type-Options: nosniff`
|
||||
- `X-XSS-Protection: 1; mode=block`
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Logging
|
||||
|
||||
Set the `logLevel` to `debug` to get more detailed logs:
|
||||
|
||||
```yaml
|
||||
logLevel: debug
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Token verification failed**: Check that your `providerURL` is correct and accessible.
|
||||
2. **Session encryption key too short**: Ensure your `sessionEncryptionKey` is at least 32 bytes long.
|
||||
3. **No matching public key found**: The JWKS endpoint might be unavailable or the token's key ID (kid) doesn't match any key in the JWKS.
|
||||
4. **Access denied: Your email domain is not allowed**: The user's email domain is not in the `allowedUserDomains` list.
|
||||
5. **Access denied: You do not have any of the allowed roles or groups**: The user doesn't have any of the roles or groups specified in `allowedRolesAndGroups`.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package traefikoidc
|
||||
|
||||
import "time"
|
||||
|
||||
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
|
||||
// It stops when a value is received on the stop channel.
|
||||
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanup()
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAutoCleanupRoutine(t *testing.T) {
|
||||
var counter int32
|
||||
cleanupFunc := func() {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
}
|
||||
stop := make(chan struct{})
|
||||
go autoCleanupRoutine(50*time.Millisecond, stop, cleanupFunc)
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
close(stop)
|
||||
|
||||
if atomic.LoadInt32(&counter) < 3 {
|
||||
t.Errorf("Expected cleanup to be called at least 3 times, got %d", counter)
|
||||
}
|
||||
}
|
||||
+110
@@ -0,0 +1,110 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
|
||||
type TokenBlacklist struct {
|
||||
tokens map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new token blacklist instance.
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
tokens: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist with an expiration time.
|
||||
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
// Clean up expired tokens if we're at capacity
|
||||
if len(b.tokens) >= 1000 {
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
for t, exp := range b.tokens {
|
||||
if now.After(exp) || futureThreshold.After(exp) {
|
||||
delete(b.tokens, t)
|
||||
}
|
||||
}
|
||||
|
||||
// If still at capacity, remove oldest token
|
||||
if len(b.tokens) >= 1000 {
|
||||
var oldestToken string
|
||||
var oldestTime time.Time
|
||||
first := true
|
||||
for t, exp := range b.tokens {
|
||||
if first || exp.Before(oldestTime) {
|
||||
oldestToken = t
|
||||
oldestTime = exp
|
||||
first = false
|
||||
}
|
||||
}
|
||||
if oldestToken != "" {
|
||||
delete(b.tokens, oldestToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.tokens[token] = expiry
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is in the blacklist and not expired.
|
||||
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
|
||||
expiry, exists := b.tokens[token]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// If token is expired, remove it and return false
|
||||
if time.Now().After(expiry) {
|
||||
// Switch to write lock to remove expired token
|
||||
b.mutex.RUnlock()
|
||||
b.mutex.Lock()
|
||||
delete(b.tokens, token)
|
||||
b.mutex.Unlock()
|
||||
b.mutex.RLock()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist.
|
||||
// Also removes tokens that will expire within the next minute to prevent edge cases.
|
||||
func (b *TokenBlacklist) Cleanup() {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
|
||||
for token, expiry := range b.tokens {
|
||||
// Remove tokens that are expired or will expire soon
|
||||
if now.After(expiry) || futureThreshold.After(expiry) {
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove removes a token from the blacklist regardless of its expiration.
|
||||
func (b *TokenBlacklist) Remove(token string) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
|
||||
// Count returns the current number of tokens in the blacklist.
|
||||
func (b *TokenBlacklist) Count() int {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
return len(b.tokens)
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBlacklist_Add(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
|
||||
if blacklist.IsBlacklisted("nonExistentToken") {
|
||||
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Cleanup(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(-time.Hour) // Expired token
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Cleanup()
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Remove(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Remove(token)
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be removed, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Count(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token1 := "token1"
|
||||
token2 := "token2"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token1, expiry)
|
||||
blacklist.Add(token2, expiry)
|
||||
|
||||
if blacklist.Count() != 2 {
|
||||
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
|
||||
}
|
||||
}
|
||||
@@ -37,19 +37,27 @@ type Cache struct {
|
||||
|
||||
// maxSize is the maximum number of items allowed in the cache.
|
||||
maxSize int
|
||||
// autoCleanupInterval defines how often Cleanup is called automatically.
|
||||
autoCleanupInterval time.Duration
|
||||
// stopCleanup channel to terminate the auto cleanup goroutine.
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 1000
|
||||
const DefaultMaxSize = 500
|
||||
|
||||
// NewCache creates a new empty cache instance that is ready for use.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
go c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
@@ -128,7 +136,8 @@ func (c *Cache) Cleanup() {
|
||||
|
||||
now := time.Now()
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiresAt) {
|
||||
// Remove items that are expired or within 10% of expiration
|
||||
if now.After(item.ExpiresAt) || now.Add(time.Duration(float64(item.ExpiresAt.Sub(now))*0.1)).After(item.ExpiresAt) {
|
||||
c.removeItem(key)
|
||||
}
|
||||
}
|
||||
@@ -136,8 +145,23 @@ func (c *Cache) Cleanup() {
|
||||
|
||||
// evictOldest removes the least recently used item from the cache.
|
||||
func (c *Cache) evictOldest() {
|
||||
now := time.Now()
|
||||
elem := c.order.Front()
|
||||
if elem != nil {
|
||||
|
||||
// First try to find an expired item from the front
|
||||
for elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
if item, exists := c.items[entry.key]; exists {
|
||||
if now.After(item.ExpiresAt) {
|
||||
c.removeItem(entry.key)
|
||||
return
|
||||
}
|
||||
}
|
||||
elem = elem.Next()
|
||||
}
|
||||
|
||||
// If no expired items found, remove the oldest item
|
||||
if elem = c.order.Front(); elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
c.removeItem(entry.key)
|
||||
}
|
||||
@@ -151,3 +175,13 @@ func (c *Cache) removeItem(key string) {
|
||||
delete(c.elems, key)
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
|
||||
func (c *Cache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close terminates the auto cleanup goroutine.
|
||||
func (c *Cache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
-74
@@ -11,7 +11,6 @@ import (
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -283,79 +282,6 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// TokenBlacklist maintains a thread-safe list of revoked tokens.
|
||||
// It stores tokens with their expiration times and automatically
|
||||
// removes expired entries during cleanup operations.
|
||||
type TokenBlacklist struct {
|
||||
// blacklist maps token IDs to their expiration times
|
||||
blacklist map[string]time.Time
|
||||
|
||||
// mutex protects concurrent access to the blacklist
|
||||
mutex sync.RWMutex
|
||||
|
||||
// maxSize is the maximum number of tokens in the blacklist
|
||||
maxSize int
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new TokenBlacklist instance.
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
maxSize: 1000, // Limit the size to prevent unbounded growth
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist with an expiration time.
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
|
||||
// Clean up expired tokens if we're at capacity
|
||||
if len(tb.blacklist) >= tb.maxSize {
|
||||
now := time.Now()
|
||||
for token, exp := range tb.blacklist {
|
||||
if now.After(exp) {
|
||||
delete(tb.blacklist, token)
|
||||
}
|
||||
}
|
||||
// If still at capacity after cleanup, remove oldest token
|
||||
if len(tb.blacklist) >= tb.maxSize {
|
||||
var oldestToken string
|
||||
var oldestTime time.Time
|
||||
first := true
|
||||
for token, exp := range tb.blacklist {
|
||||
if first || exp.Before(oldestTime) {
|
||||
oldestToken = token
|
||||
oldestTime = exp
|
||||
first = false
|
||||
}
|
||||
}
|
||||
delete(tb.blacklist, oldestToken)
|
||||
}
|
||||
}
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is in the blacklist and not expired.
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
expiration, exists := tb.blacklist[tokenID]
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist.
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for tokenID, expiration := range tb.blacklist {
|
||||
if now.After(expiration) {
|
||||
delete(tb.blacklist, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCache provides a caching mechanism for validated tokens.
|
||||
// It stores token claims to avoid repeated validation of the
|
||||
// same token, improving performance for frequently used tokens.
|
||||
|
||||
+26
-24
@@ -9,34 +9,34 @@ import (
|
||||
|
||||
func TestTokenBlacklistSizeLimit(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
|
||||
// Add tokens up to maxSize
|
||||
for i := 0; i < 1000; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
// Verify size is at max
|
||||
if len(tb.blacklist) != 1000 {
|
||||
t.Errorf("Expected blacklist size to be 1000, got %d", len(tb.blacklist))
|
||||
if tb.Count() != 1000 {
|
||||
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Add one more token, should trigger cleanup/eviction
|
||||
tb.Add("newtoken", time.Now().Add(time.Hour))
|
||||
|
||||
// Size should still be at max
|
||||
if len(tb.blacklist) > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
|
||||
// Add some expired tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
|
||||
}
|
||||
|
||||
|
||||
// Add some valid tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
|
||||
@@ -46,12 +46,14 @@ func TestTokenBlacklistExpiredCleanup(t *testing.T) {
|
||||
tb.Cleanup()
|
||||
|
||||
// Only valid tokens should remain
|
||||
if len(tb.blacklist) != 500 {
|
||||
t.Errorf("Expected 500 valid tokens after cleanup, got %d", len(tb.blacklist))
|
||||
if tb.Count() != 500 {
|
||||
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Verify only valid tokens remain
|
||||
for token, expiry := range tb.blacklist {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
for token, expiry := range tb.tokens {
|
||||
if time.Now().After(expiry) {
|
||||
t.Errorf("Found expired token after cleanup: %s", token)
|
||||
}
|
||||
@@ -60,14 +62,14 @@ func TestTokenBlacklistExpiredCleanup(t *testing.T) {
|
||||
|
||||
func TestTokenBlacklistOldestEviction(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
|
||||
// Add tokens at capacity with different expiration times
|
||||
baseTime := time.Now()
|
||||
oldestToken := "oldest"
|
||||
|
||||
|
||||
// Add oldest token first
|
||||
tb.Add(oldestToken, baseTime.Add(time.Hour))
|
||||
|
||||
|
||||
// Fill up to capacity with newer tokens
|
||||
for i := 0; i < 999; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
|
||||
@@ -94,7 +96,7 @@ func TestTokenBlacklistMemoryUsage(t *testing.T) {
|
||||
|
||||
// Force initial GC
|
||||
runtime.GC()
|
||||
|
||||
|
||||
// Record initial memory stats
|
||||
var m1, m2 runtime.MemStats
|
||||
runtime.ReadMemStats(&m1)
|
||||
@@ -103,12 +105,12 @@ func TestTokenBlacklistMemoryUsage(t *testing.T) {
|
||||
for i := 0; i < iterations; i++ {
|
||||
// Add new token
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
|
||||
|
||||
// Periodically check blacklisted status
|
||||
if i%100 == 0 {
|
||||
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
|
||||
}
|
||||
|
||||
|
||||
// Periodically cleanup
|
||||
if i%1000 == 0 {
|
||||
tb.Cleanup()
|
||||
@@ -130,8 +132,8 @@ func TestTokenBlacklistMemoryUsage(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify size stayed within limits
|
||||
if len(tb.blacklist) > tb.maxSize {
|
||||
t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist))
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,8 +169,8 @@ func TestConcurrentTokenBlacklistOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify size constraints were maintained
|
||||
if len(tb.blacklist) > tb.maxSize {
|
||||
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", len(tb.blacklist))
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,7 +180,7 @@ func TestTokenCacheMemoryUsage(t *testing.T) {
|
||||
|
||||
// Force initial GC
|
||||
runtime.GC()
|
||||
|
||||
|
||||
// Record initial memory stats
|
||||
var m1, m2 runtime.MemStats
|
||||
runtime.ReadMemStats(&m1)
|
||||
@@ -189,15 +191,15 @@ func TestTokenCacheMemoryUsage(t *testing.T) {
|
||||
"sub": fmt.Sprintf("user%d", i),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
|
||||
// Add to cache
|
||||
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
|
||||
|
||||
|
||||
// Periodically retrieve
|
||||
if i%100 == 0 {
|
||||
tc.Get(fmt.Sprintf("token%d", i-50))
|
||||
}
|
||||
|
||||
|
||||
// Periodically cleanup
|
||||
if i%1000 == 0 {
|
||||
tc.Cleanup()
|
||||
|
||||
@@ -1,91 +1,51 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It contains the cryptographic key information used for token verification.
|
||||
type JWK struct {
|
||||
// Kty is the key type (e.g., "RSA", "EC")
|
||||
Kty string `json:"kty"`
|
||||
|
||||
// Kid is the unique key identifier
|
||||
Kid string `json:"kid"`
|
||||
|
||||
// Use specifies the intended use of the key (e.g., "sig" for signature)
|
||||
Use string `json:"use"`
|
||||
|
||||
// N is the modulus for RSA keys
|
||||
N string `json:"n"`
|
||||
|
||||
// E is the exponent for RSA keys
|
||||
E string `json:"e"`
|
||||
|
||||
// Alg is the algorithm intended for use with the key
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Alg string `json:"alg"`
|
||||
|
||||
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
|
||||
Crv string `json:"crv"`
|
||||
|
||||
// X is the x-coordinate for EC keys
|
||||
X string `json:"x"`
|
||||
|
||||
// Y is the y-coordinate for EC keys
|
||||
Y string `json:"y"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
|
||||
// OIDC providers typically expose multiple keys to support key rotation.
|
||||
type JWKSet struct {
|
||||
// Keys is the array of JSON Web Keys
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache provides a thread-safe caching mechanism for JWK sets.
|
||||
// It caches the keys for a configurable duration to reduce load on the OIDC provider
|
||||
// while ensuring keys are refreshed periodically to handle key rotation.
|
||||
type JWKCache struct {
|
||||
// jwks holds the cached set of JSON Web Keys
|
||||
jwks *JWKSet
|
||||
|
||||
// expiresAt is the timestamp when the cached keys should be refreshed
|
||||
jwks *JWKSet
|
||||
expiresAt time.Time
|
||||
|
||||
// mutex protects concurrent access to the cache
|
||||
mutex sync.RWMutex
|
||||
mutex sync.RWMutex
|
||||
// CacheLifetime is configurable to determine how long the JWKS is cached.
|
||||
CacheLifetime time.Duration
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the interface for JWK caching operations.
|
||||
// This interface allows for different caching implementations while
|
||||
// maintaining consistent behavior in the token verification process.
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
|
||||
// from the OIDC provider. It implements a thread-safe double-checked locking
|
||||
// pattern to prevent multiple simultaneous fetches of the same keys.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for fetching keys
|
||||
//
|
||||
// Returns:
|
||||
// - The JSON Web Key Set
|
||||
// - An error if the keys cannot be retrieved or parsed
|
||||
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
defer c.mutex.RUnlock()
|
||||
@@ -95,33 +55,43 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
return c.jwks, nil
|
||||
}
|
||||
|
||||
jwks, err := fetchJWKS(jwksURL, httpClient)
|
||||
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.jwks = jwks
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
lifetime := c.CacheLifetime
|
||||
if lifetime == 0 {
|
||||
lifetime = 1 * time.Hour
|
||||
}
|
||||
c.expiresAt = time.Now().Add(lifetime)
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
|
||||
// It handles HTTP communication and JSON parsing of the response.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for the request
|
||||
//
|
||||
// Returns:
|
||||
// - The parsed JSON Web Key Set
|
||||
// - An error if the request fails or the response is invalid
|
||||
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
resp, err := httpClient.Get(jwksURL)
|
||||
func (c *JWKCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if c.jwks != nil && now.After(c.expiresAt) {
|
||||
c.jwks = nil
|
||||
}
|
||||
}
|
||||
|
||||
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Create a request with context to enforce timeout
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
|
||||
}
|
||||
@@ -139,9 +109,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
|
||||
// cryptographic functions. It supports both RSA and EC keys, delegating
|
||||
// to the appropriate converter based on the key type.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -157,9 +124,6 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
|
||||
// It handles base64url decoding of the modulus and exponent,
|
||||
// constructs an RSA public key, and encodes it in PEM format.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -191,10 +155,6 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
|
||||
// It supports the P-256, P-384, and P-521 curves as defined in the
|
||||
// OIDC specification, decoding the x and y coordinates and encoding
|
||||
// the resulting public key in PEM format.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,44 +4,41 @@ import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
||||
"math/big"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var replayCacheMu sync.Mutex
|
||||
var replayCache = make(map[string]time.Time)
|
||||
|
||||
func cleanupReplayCache() {
|
||||
now := time.Now()
|
||||
for token, expiry := range replayCache {
|
||||
if expiry.Before(now) {
|
||||
delete(replayCache, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClockSkewTolerance is configurable to adjust time-based validations.
|
||||
var ClockSkewTolerance = 2 * time.Minute
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
// It contains the three parts of a JWT: header, claims (payload),
|
||||
// and signature, along with the original token string.
|
||||
type JWT struct {
|
||||
// Header contains the token metadata (algorithm, key ID, etc.)
|
||||
Header map[string]interface{}
|
||||
|
||||
// Claims contains the token claims (subject, expiration, etc.)
|
||||
Claims map[string]interface{}
|
||||
|
||||
// Signature contains the raw signature bytes
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
Signature []byte
|
||||
|
||||
// Token is the original JWT string
|
||||
Token string
|
||||
Token string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
// It validates the token format and decodes the three parts
|
||||
// (header, claims, signature) using base64url decoding.
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT token string
|
||||
//
|
||||
// Returns:
|
||||
// - A parsed JWT struct
|
||||
// - An error if the token format is invalid or parsing fails
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -52,7 +49,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
Token: tokenString,
|
||||
}
|
||||
|
||||
// Decode and unmarshal the header
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
@@ -61,7 +57,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
|
||||
// Decode and unmarshal the claims
|
||||
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
@@ -70,7 +65,6 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
// Decode the signature
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
@@ -81,28 +75,13 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
}
|
||||
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// It checks:
|
||||
// - issuer (iss) matches the expected issuer URL
|
||||
// - audience (aud) includes the client ID
|
||||
// - expiration time (exp) is in the future (with clock skew tolerance)
|
||||
// - issued at time (iat) is in the past (with clock skew tolerance)
|
||||
// - not before time (nbf) is in the past (with clock skew tolerance)
|
||||
// - subject (sub) is present and not empty
|
||||
// - algorithm matches expected value to prevent algorithm switching attacks
|
||||
//
|
||||
// Returns an error if any validation fails.
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Debug logging of validation parameters
|
||||
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
|
||||
// Debug logging of token header
|
||||
fmt.Printf("Token header: %+v\n", j.Header)
|
||||
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'alg' header")
|
||||
}
|
||||
// List of supported algorithms - should match those in verifySignature
|
||||
supportedAlgs := map[string]bool{
|
||||
"RS256": true, "RS384": true, "RS512": true,
|
||||
"PS256": true, "PS384": true, "PS512": true,
|
||||
@@ -114,9 +93,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
|
||||
claims := j.Claims
|
||||
|
||||
// Debug logging of all claims
|
||||
fmt.Printf("Token claims: %+v\n", claims)
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'iss' claim")
|
||||
@@ -149,17 +125,36 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate nbf (not before) claim if present
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if err := verifyNotBefore(nbf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Validate jti (JWT ID) claim if present
|
||||
// Implement replay protection by checking the jti (JWT ID)
|
||||
if jti, ok := claims["jti"].(string); ok {
|
||||
// Could add replay detection here if needed
|
||||
_ = jti
|
||||
// Skip replay detection for tokens that are being verified from the cache
|
||||
if j.Token == "" {
|
||||
// This is a parsed JWT without the original token string,
|
||||
// which means it's likely from a cached token verification
|
||||
return nil
|
||||
}
|
||||
|
||||
replayCacheMu.Lock()
|
||||
cleanupReplayCache()
|
||||
if _, exists := replayCache[jti]; exists {
|
||||
replayCacheMu.Unlock()
|
||||
return fmt.Errorf("token replay detected")
|
||||
}
|
||||
expFloat, ok := claims["exp"].(float64)
|
||||
var expTime time.Time
|
||||
if ok {
|
||||
expTime = time.Unix(int64(expFloat), 0)
|
||||
} else {
|
||||
expTime = time.Now().Add(10 * time.Minute)
|
||||
}
|
||||
replayCache[jti] = expTime
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
|
||||
sub, ok := claims["sub"].(string)
|
||||
@@ -169,20 +164,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience validates the token's audience claim.
|
||||
// The audience can be either a single string or an array of strings.
|
||||
// For array audiences, the expected audience must match any one value.
|
||||
// Parameters:
|
||||
// - tokenAudience: The audience claim from the token
|
||||
// - expectedAudience: The expected audience value
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
|
||||
tokenAudience, expectedAudience)
|
||||
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
@@ -205,165 +187,80 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer validates the token's issuer claim.
|
||||
// The issuer URL must exactly match the expected issuer.
|
||||
// Parameters:
|
||||
// - tokenIssuer: The issuer claim from the token
|
||||
// - expectedIssuer: The expected issuer URL
|
||||
//
|
||||
// Returns an error if validation fails.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
|
||||
tokenIssuer, expectedIssuer)
|
||||
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
|
||||
tokenIssuer, expectedIssuer)
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clock skew tolerance for time-based validations
|
||||
const clockSkewTolerance = 2 * time.Minute
|
||||
// verifyTimeConstraint is a generic function to verify time-based claims
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
|
||||
// For expiration (future=true), we add skew to now (making now later)
|
||||
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
|
||||
skewDirection := 1
|
||||
if !future {
|
||||
skewDirection = -1
|
||||
}
|
||||
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
|
||||
|
||||
if claimTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For expiration: if skewedNow (later) is after expiration, token expired
|
||||
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
|
||||
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
|
||||
var reason string
|
||||
if future {
|
||||
reason = "has expired"
|
||||
} else {
|
||||
if claimName == "iat" {
|
||||
reason = "used before issued"
|
||||
} else {
|
||||
reason = "not yet valid"
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyExpiration checks if the token's expiration time has passed.
|
||||
// The expiration time is compared against the current time with clock skew tolerance.
|
||||
// Parameters:
|
||||
// - expiration: The expiration timestamp from the token
|
||||
//
|
||||
// Returns an error if the token has expired.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
expirationTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens that expire exactly now
|
||||
if expirationTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.After(expirationTime) {
|
||||
return fmt.Errorf("token has expired (exp: %v, now: %v)",
|
||||
expirationTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(expiration, "exp", true)
|
||||
}
|
||||
|
||||
// verifyIssuedAt validates the token's issued-at time.
|
||||
// Ensures the token wasn't issued in the future, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - issuedAt: The issued-at timestamp from the token
|
||||
//
|
||||
// Returns an error if the token was issued in the future.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
issuedAtTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens issued in the same second as current time
|
||||
if issuedAtTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
|
||||
issuedAtTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(issuedAt, "iat", false)
|
||||
}
|
||||
|
||||
// verifyNotBefore validates the token's not-before time if present.
|
||||
// Ensures the token is not used before its valid time period, accounting for clock skew.
|
||||
// Parameters:
|
||||
// - notBefore: The not-before timestamp from the token
|
||||
//
|
||||
// Returns an error if the token is not yet valid.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
notBeforeTime := time.Unix(int64(notBefore), 0)
|
||||
// Truncate current time to seconds for consistent comparison
|
||||
now := time.Now().Truncate(time.Second)
|
||||
skewedNow := now.Add(-clockSkewTolerance)
|
||||
|
||||
// Debug logging
|
||||
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
|
||||
notBeforeTime.UTC(),
|
||||
now.UTC(),
|
||||
skewedNow.UTC(),
|
||||
clockSkewTolerance)
|
||||
|
||||
// Allow tokens that become valid exactly now
|
||||
if notBeforeTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if skewedNow.Before(notBeforeTime) {
|
||||
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
|
||||
notBeforeTime.UTC(), now.UTC())
|
||||
}
|
||||
return nil
|
||||
return verifyTimeConstraint(notBefore, "nbf", false)
|
||||
}
|
||||
|
||||
// verifySignature validates the token's cryptographic signature.
|
||||
// Supports multiple signature algorithms:
|
||||
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
|
||||
// - RSA-PSS: PS256, PS384, PS512
|
||||
// - ECDSA: ES256, ES384, ES512
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The complete JWT token string
|
||||
// - publicKeyPEM: The PEM-encoded public key for verification
|
||||
// - alg: The signature algorithm identifier
|
||||
//
|
||||
// Returns an error if signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Debug logging
|
||||
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
|
||||
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
|
||||
// Decode the signature from the token
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
// Decode the PEM-encoded public key
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
// Parse the public key
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
// Determine the hash function to use based on the algorithm
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
hashFunc = crypto.SHA256
|
||||
@@ -374,27 +271,20 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// Hash the signed content
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
|
||||
// Verify the signature based on the key type and algorithm
|
||||
switch pubKey := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
// RSA PKCS#1 v1.5 signature
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
// RSA PSS signature
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature
|
||||
var r, s big.Int
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
|
||||
@@ -18,6 +18,40 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC
|
||||
func createDefaultHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Timeout: time.Second * 15, // Reduced timeout
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const ConstSessionTimeout = 86400 // Session timeout in seconds
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
@@ -39,6 +73,7 @@ type TraefikOidc struct {
|
||||
issuerURL string
|
||||
revocationURL string
|
||||
jwkCache JWKCacheInterface
|
||||
metadataCache *MetadataCache
|
||||
tokenBlacklist *TokenBlacklist
|
||||
jwksURL string
|
||||
clientID string
|
||||
@@ -81,11 +116,50 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon": {},
|
||||
}
|
||||
|
||||
// VerifyToken verifies the provided JWT token
|
||||
// VerifyToken implements the TokenVerifier interface to verify an OIDC token.
|
||||
// It performs a complete verification process including:
|
||||
// 1. Checking the token cache to avoid redundant verifications
|
||||
// 2. Performing rate limiting and blacklist checks
|
||||
// 3. Parsing the JWT structure
|
||||
// 4. Verifying the JWT signature against the JWKS from the provider
|
||||
// 5. Validating standard JWT claims (iss, aud, exp, etc.)
|
||||
// 6. Caching the verified token for future requests
|
||||
//
|
||||
// Returns nil if the token is valid, or an error describing the validation failure.
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
// Check cache first
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
t.logger.Debugf("Token found in cache with valid claims; skipping verification")
|
||||
return nil
|
||||
}
|
||||
|
||||
t.logger.Debugf("Verifying token")
|
||||
|
||||
// Rate limiting
|
||||
// Perform pre-verification checks
|
||||
if err := t.performPreVerificationChecks(token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and standard claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the verified token
|
||||
t.cacheVerifiedToken(token, jwt.Claims)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// performPreVerificationChecks performs rate limiting and blacklist checks
|
||||
func (t *TraefikOidc) performPreVerificationChecks(token string) error {
|
||||
// Enforce rate limiting
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
@@ -95,30 +169,15 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
return fmt.Errorf("token is blacklisted")
|
||||
}
|
||||
|
||||
// Check if token is cached
|
||||
if _, exists := t.tokenCache.Get(token); exists {
|
||||
t.logger.Debugf("Token is valid and cached")
|
||||
return nil // Token is valid and cached
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the token until it expires
|
||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||
// cacheVerifiedToken caches a verified token until its expiration time
|
||||
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
|
||||
expirationTime := time.Unix(int64(claims["exp"].(float64)), 0)
|
||||
now := time.Now()
|
||||
duration := expirationTime.Sub(now)
|
||||
t.tokenCache.Set(token, jwt.Claims, duration)
|
||||
|
||||
return nil
|
||||
t.tokenCache.Set(token, claims, duration)
|
||||
}
|
||||
|
||||
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
|
||||
@@ -126,7 +185,7 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
t.logger.Debugf("Verifying JWT signature and claims")
|
||||
|
||||
// Get JWKS
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
@@ -172,7 +231,24 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new instance of the OIDC middleware
|
||||
// New creates a new instance of the OIDC middleware.
|
||||
// This is the main entry point for the middleware and is called by Traefik when loading the plugin.
|
||||
// It initializes all components needed for OIDC authentication:
|
||||
// - Session management for storing user state
|
||||
// - Token caching and blacklisting
|
||||
// - JWK caching for signature verification
|
||||
// - Rate limiting to prevent abuse
|
||||
// - Metadata discovery for OIDC provider endpoints
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for initialization operations
|
||||
// - next: The next handler in the middleware chain
|
||||
// - config: Configuration options for the middleware
|
||||
// - name: Identifier for this middleware instance
|
||||
//
|
||||
// Returns:
|
||||
// - An http.Handler that implements the middleware
|
||||
// - An error if initialization fails
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
if config == nil {
|
||||
config = CreateConfig()
|
||||
@@ -186,7 +262,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
// Initialize logger
|
||||
logger := NewLogger(config.LogLevel)
|
||||
|
||||
// Ensure key meets minimum length requirement
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
@@ -197,42 +272,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup HTTP client
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 15 * time.Second, // Reduced timeout
|
||||
KeepAlive: 15 * time.Second, // Reduced keepalive
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 30, // Reduced from 100
|
||||
MaxIdleConnsPerHost: 10, // Reduced from 100
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 50, // Limit max connections
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
if config.HTTPClient != nil {
|
||||
httpClient = config.HTTPClient
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Timeout: time.Second * 15, // Reduced timeout
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Always follow redirects for OIDC endpoints
|
||||
if len(via) >= 50 {
|
||||
return fmt.Errorf("stopped after 50 redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
httpClient = createDefaultHTTPClient()
|
||||
}
|
||||
|
||||
t := &TraefikOidc{
|
||||
@@ -253,6 +298,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}(),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
metadataCache: NewMetadataCache(),
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
@@ -292,40 +338,55 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Debug("Starting provider metadata discovery")
|
||||
|
||||
// Keep retrying until successful
|
||||
backoff := time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
for {
|
||||
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
|
||||
// Get metadata from cache or fetch it
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to get provider metadata: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
t.logger.Debug("Successfully initialized provider metadata")
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
|
||||
// Start metadata refresh goroutine
|
||||
go t.startMetadataRefresh(providerURL)
|
||||
|
||||
// Only close channel on success
|
||||
close(t.initComplete)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Error("Received nil metadata")
|
||||
}
|
||||
|
||||
// updateMetadataEndpoints updates the middleware with metadata endpoints
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
}
|
||||
|
||||
// startMetadataRefresh periodically refreshes the OIDC metadata
|
||||
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
t.logger.Debug("Refreshing OIDC metadata")
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
|
||||
time.Sleep(backoff)
|
||||
|
||||
// Exponential backoff with max
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
t.logger.Errorf("Failed to refresh metadata: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
t.logger.Debug("Successfully initialized provider metadata")
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
|
||||
// Only close channel on success
|
||||
close(t.initComplete)
|
||||
return
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
t.logger.Debug("Successfully refreshed metadata")
|
||||
}
|
||||
|
||||
t.logger.Error("Received nil metadata, retrying")
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,7 +452,18 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the main handler for the middleware
|
||||
// ServeHTTP is the main handler for the middleware that processes all HTTP requests.
|
||||
// It implements the http.Handler interface and performs the following operations:
|
||||
// 1. Waits for OIDC provider metadata initialization to complete
|
||||
// 2. Checks if the requested URL is in the excluded list (bypassing authentication)
|
||||
// 3. Retrieves or creates a user session
|
||||
// 4. Handles special paths like callback and logout URLs
|
||||
// 5. Verifies authentication status and token validity
|
||||
// 6. Refreshes tokens that are about to expire
|
||||
// 7. Validates user email domains, roles, and groups against configured restrictions
|
||||
// 8. Sets appropriate headers for downstream services
|
||||
// 9. Applies security headers to responses
|
||||
// 10. Forwards the authenticated request to the next handler
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
@@ -517,6 +589,34 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Set user information in headers
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
// Set OIDC-specific headers
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetAccessToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
|
||||
// Set security headers
|
||||
rw.Header().Set("X-Frame-Options", "DENY")
|
||||
rw.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
rw.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Set CORS headers
|
||||
origin := req.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
rw.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
|
||||
// Handle preflight requests
|
||||
if req.Method == "OPTIONS" {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Process the request
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
@@ -552,7 +652,17 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// isUserAuthenticated checks if the user is authenticated
|
||||
// isUserAuthenticated checks if the user is authenticated by validating their session and token.
|
||||
// It performs a comprehensive check of the authentication state including:
|
||||
// 1. Verifying the session's authenticated flag
|
||||
// 2. Checking for the presence of an access token
|
||||
// 3. Validating the token's signature and claims
|
||||
// 4. Checking the token's expiration time
|
||||
//
|
||||
// Returns three boolean values:
|
||||
// - authenticated: Whether the user is currently authenticated
|
||||
// - needsRefresh: Whether the token is valid but will expire soon (within grace period)
|
||||
// - expired: Whether the token has expired or is otherwise invalid
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("User is not authenticated according to session")
|
||||
@@ -600,7 +710,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// defaultInitiateAuthentication initiates the authentication process
|
||||
// defaultInitiateAuthentication initiates the OIDC authentication process.
|
||||
// This function prepares and starts a new authentication flow by:
|
||||
// 1. Generating security tokens (CSRF token and nonce) to prevent attacks
|
||||
// 2. Clearing any existing session data to avoid state conflicts
|
||||
// 3. Storing the original request path to redirect back after authentication
|
||||
// 4. Building the authorization URL with all required OIDC parameters
|
||||
// 5. Redirecting the user to the OIDC provider's authorization endpoint
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer for sending the redirect
|
||||
// - req: The original HTTP request that triggered authentication
|
||||
// - session: The user's session data for storing authentication state
|
||||
// - redirectURL: The callback URL where the OIDC provider will redirect after authentication
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Generate CSRF token and nonce
|
||||
csrfToken := uuid.NewString()
|
||||
@@ -630,7 +752,10 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// verifyToken verifies the token using the token verifier
|
||||
// verifyToken verifies the token using the token verifier interface.
|
||||
// This function delegates to the configured token verifier implementation,
|
||||
// which by default is the TraefikOidc instance itself (implementing the VerifyToken method).
|
||||
// This design allows for easy mocking in tests and potential future extension.
|
||||
func (t *TraefikOidc) verifyToken(token string) error {
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
@@ -647,59 +772,37 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
params.Set("scope", strings.Join(t.scopes, " "))
|
||||
}
|
||||
|
||||
// Ensure authURL is absolute
|
||||
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
|
||||
return t.buildURLWithParams(t.authURL, params)
|
||||
}
|
||||
|
||||
// buildURLWithParams ensures a URL is absolute and appends query parameters
|
||||
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
// Ensure URL is absolute
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
// Extract issuer base URL
|
||||
issuerURL, err := url.Parse(t.issuerURL)
|
||||
if err == nil {
|
||||
return fmt.Sprintf("%s://%s%s?%s",
|
||||
issuerURL.Scheme,
|
||||
issuerURL.Host,
|
||||
t.authURL,
|
||||
return fmt.Sprintf("%s://%s%s?%s",
|
||||
issuerURL.Scheme,
|
||||
issuerURL.Host,
|
||||
baseURL,
|
||||
params.Encode())
|
||||
}
|
||||
}
|
||||
return t.authURL + "?" + params.Encode()
|
||||
return baseURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts the token cleanup goroutine
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ticker := time.NewTicker(30 * time.Second) // Increased frequency to prevent memory buildup
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
t.logger.Debug("Starting token cleanup cycle")
|
||||
|
||||
// Run cleanup in a separate goroutine with timeout
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
t.tokenCache.Cleanup()
|
||||
t.tokenBlacklist.Cleanup()
|
||||
}()
|
||||
|
||||
// Wait for cleanup to complete or timeout
|
||||
select {
|
||||
case <-cleanupCtx.Done():
|
||||
if cleanupCtx.Err() == context.DeadlineExceeded {
|
||||
t.logger.Error("Token cleanup cycle timed out")
|
||||
}
|
||||
case <-done:
|
||||
t.logger.Debug("Token cleanup cycle completed successfully")
|
||||
}
|
||||
|
||||
cleanupCancel()
|
||||
}
|
||||
for range ticker.C {
|
||||
t.logger.Debug("Starting token cleanup cycle")
|
||||
t.tokenCache.Cleanup()
|
||||
t.tokenBlacklist.Cleanup()
|
||||
t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go
|
||||
// Removed runtime.GC() call
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
+25
-8
@@ -131,10 +131,16 @@ type MockJWKCache struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) Cleanup() {
|
||||
// Mock cleanup implementation
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
}
|
||||
|
||||
// Helper function to create a JWT token
|
||||
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
||||
header := map[string]interface{}{
|
||||
@@ -221,7 +227,7 @@ func TestVerifyToken(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset token blacklist and cache
|
||||
// Reset token blacklist and cache for each test
|
||||
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
|
||||
ts.tOidc.tokenCache = NewTokenCache()
|
||||
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
||||
@@ -237,9 +243,20 @@ func TestVerifyToken(t *testing.T) {
|
||||
}
|
||||
|
||||
if tc.cacheToken {
|
||||
// Use more realistic claims for cached token
|
||||
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
|
||||
"empty": "claim",
|
||||
}, 60)
|
||||
"iss": "https://test-issuer.com",
|
||||
"sub": "test-subject",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
|
||||
}, time.Minute)
|
||||
|
||||
// Verify the token is actually in the cache
|
||||
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
|
||||
t.Logf("Token found in cache with claims: %v", claims)
|
||||
} else {
|
||||
t.Logf("Token NOT found in cache despite cacheToken=true")
|
||||
}
|
||||
}
|
||||
|
||||
err := ts.tOidc.VerifyToken(tc.token)
|
||||
@@ -1776,7 +1793,7 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
issuerURL string
|
||||
redirectURL string
|
||||
state string
|
||||
nonce string
|
||||
nonce string
|
||||
expectedPrefix string
|
||||
}{
|
||||
{
|
||||
@@ -1785,7 +1802,7 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
issuerURL: "https://auth.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||
},
|
||||
{
|
||||
@@ -1794,7 +1811,7 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
issuerURL: "https://logto.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
||||
},
|
||||
{
|
||||
@@ -1803,7 +1820,7 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
issuerURL: "https://auth.example.com:8443",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
nonce: "test-nonce",
|
||||
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MetadataCache struct {
|
||||
metadata *ProviderMetadata
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
autoCleanupInterval time.Duration
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
go c.startAutoCleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Cleanup removes expired metadata from the cache.
|
||||
func (c *MetadataCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if c.metadata != nil && now.After(c.expiresAt) {
|
||||
c.metadata = nil
|
||||
}
|
||||
}
|
||||
func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the metadata from cache or fetches it if expired
|
||||
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
if c.isCacheValid() {
|
||||
defer c.mutex.RUnlock()
|
||||
return c.metadata, nil
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if c.isCacheValid() {
|
||||
return c.metadata, nil
|
||||
}
|
||||
|
||||
metadata, err := discoverProviderMetadata(providerURL, httpClient, logger)
|
||||
if err != nil {
|
||||
if c.metadata != nil {
|
||||
// On error, extend current cache by 5 minutes to prevent thundering herd
|
||||
c.expiresAt = time.Now().Add(5 * time.Minute)
|
||||
logger.Errorf("Failed to refresh metadata, using cached version for 5 more minutes: %v", err)
|
||||
return c.metadata, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||
}
|
||||
|
||||
c.metadata = metadata
|
||||
// Calculate expiration time based on usage patterns
|
||||
usageCount := 0 // This should be replaced with actual usage tracking logic
|
||||
if usageCount < 10 {
|
||||
c.expiresAt = time.Now().Add(30 * time.Minute)
|
||||
} else if usageCount < 50 {
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
} else {
|
||||
c.expiresAt = time.Now().Add(2 * time.Hour)
|
||||
}
|
||||
|
||||
// End of GetMetadata
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
func (c *MetadataCache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsCacheValid(t *testing.T) {
|
||||
// Setup with a dummy ProviderMetadata.
|
||||
pm := &ProviderMetadata{}
|
||||
mc := &MetadataCache{
|
||||
metadata: pm,
|
||||
expiresAt: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
if !mc.isCacheValid() {
|
||||
t.Errorf("Expected cache to be valid")
|
||||
}
|
||||
mc.expiresAt = time.Now().Add(-1 * time.Hour)
|
||||
if mc.isCacheValid() {
|
||||
t.Errorf("Expected cache to be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup(t *testing.T) {
|
||||
pm := &ProviderMetadata{}
|
||||
mc := &MetadataCache{
|
||||
metadata: pm,
|
||||
expiresAt: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
mc.Cleanup()
|
||||
if mc.metadata != nil {
|
||||
t.Errorf("Expected metadata to be nil after cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetadata_Cached(t *testing.T) {
|
||||
dummyData := &ProviderMetadata{}
|
||||
// Construct MetadataCache manually to avoid interference from auto cleanup.
|
||||
mc := &MetadataCache{
|
||||
metadata: dummyData,
|
||||
expiresAt: time.Now().Add(1 * time.Hour),
|
||||
stopCleanup: make(chan struct{}),
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
}
|
||||
// Use NewLogger to create a logger that writes errors only.
|
||||
logger := NewLogger("error")
|
||||
result, err := mc.GetMetadata("http://example.com", http.DefaultClient, logger)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if result != dummyData {
|
||||
t.Errorf("Expected cached metadata to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataCacheAutoCleanup(t *testing.T) {
|
||||
mc := &MetadataCache{
|
||||
autoCleanupInterval: 50 * time.Millisecond,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
// Start auto cleanup.
|
||||
go mc.startAutoCleanup()
|
||||
mc.mutex.Lock()
|
||||
mc.metadata = &ProviderMetadata{}
|
||||
mc.expiresAt = time.Now().Add(-50 * time.Millisecond)
|
||||
mc.mutex.Unlock()
|
||||
|
||||
// Wait enough time for the auto cleanup to run.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
mc.Close()
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
if mc.metadata != nil {
|
||||
t.Errorf("Expected metadata to be cleared by auto cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
type errorRoundTripper struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return nil, e.err
|
||||
}
|
||||
|
||||
func TestGetMetadata_FetchError(t *testing.T) {
|
||||
// Create an HTTP client that always returns an error.
|
||||
errorClient := &http.Client{
|
||||
Transport: errorRoundTripper{err: fmt.Errorf("fake fetch error")},
|
||||
}
|
||||
|
||||
// Case 1: Cache is empty.
|
||||
mc := &MetadataCache{
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
logger := NewLogger("error")
|
||||
metadata, err := mc.GetMetadata("http://example.com", errorClient, logger)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
if metadata != nil {
|
||||
t.Errorf("Expected nil metadata, got %v", metadata)
|
||||
}
|
||||
|
||||
// Case 2: Cache has old metadata.
|
||||
dummy := &ProviderMetadata{}
|
||||
mc.metadata = dummy
|
||||
mc.expiresAt = time.Now().Add(-1 * time.Minute)
|
||||
logger2 := NewLogger("error")
|
||||
metadata, err = mc.GetMetadata("http://example.com", errorClient, logger2)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error when cached metadata exists, got %v", err)
|
||||
}
|
||||
if metadata != dummy {
|
||||
t.Errorf("Expected cached metadata to be returned")
|
||||
}
|
||||
}
|
||||
@@ -326,6 +326,9 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
err = sd.Save(r, w)
|
||||
}
|
||||
|
||||
// Clear transient per-request fields.
|
||||
sd.request = nil
|
||||
|
||||
// Return session to pool.
|
||||
sd.manager.sessionPool.Put(sd)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user