Compare commits

..

40 Commits

Author SHA1 Message Date
lukaszraczylo 4ce2815123 Update the documentation. 2025-02-25 14:02:08 +00:00
lukaszraczylo 7d204113ea Cleanup the codebase, DRY and abstract functions, increase the test coverage. 2025-02-25 12:53:52 +00:00
lukaszraczylo c721913cbe Increase tests coverage. 2025-02-24 12:25:32 +00:00
lukaszraczylo 0f8b7f7ab1 Abstract the cleanup logic and add helper for cache valid. 2025-02-24 12:02:12 +00:00
lukaszraczylo 2743b0e024 Ensure cleanups actually happen. 2025-02-24 00:19:44 +00:00
lukaszraczylo e6fc36937b Clear per-request reference to stop leaking contexts. 2025-02-24 00:04:04 +00:00
lukaszraczylo df051e0cfb Improve expiration logic. 2025-02-19 20:33:26 +00:00
lukaszraczylo 5d5ce8ae5e Additional tests for the blacklists 2025-02-19 12:08:37 +00:00
lukaszraczylo d194cd778a gofmt the updated files. 2025-02-19 11:56:31 +00:00
lukaszraczylo 803a1e5e21 Clean the caches properly to avoid memleak 2025-02-19 11:55:32 +00:00
lukaszraczylo 3ad8fb4518 Optimise cache cleanup run to avoid the GC which causes CPU usage to go higher than necessary. 2025-02-10 09:30:56 +00:00
lukaszraczylo 9402f1bca5 Token blacklist, cache and metadata improvements
TokenBlacklist Improvements:
Fixed size limit enforcement to properly maintain max size of 1000 tokens
Improved eviction strategy to remove expired tokens first before removing oldest
Added proper cleanup of tokens during Add operation to prevent size overflow
Fixed oldest token eviction logic to ensure correct token removal
Added proper locking mechanisms to prevent race conditions
Cache Improvements:
Fixed cleanup mechanism to only remove truly expired items
Improved eviction strategy in LRU cache to prioritize expired items
Added smarter eviction in evictOldest to scan for expired items first
Fixed aggressive cleanup that was removing valid items
Maintained proper LRU ordering while handling evictions
MetadataCache:
Verified proper implementation of metadata caching with hourly refresh
Confirmed proper handling of cache extension on fetch failures
Validated thread-safe operations with proper RWMutex usage
2025-02-09 23:53:05 +00:00
lukaszraczylo e6205b3a48 Add metadata caching capability to avoid unnecesary API calls 2025-02-09 23:37:50 +00:00
lukaszraczylo fdb8e3233e Testing (could be unstable) additional headers.
This adds additional headers to control the access origin and control allow headers.
2025-02-06 23:46:08 +00:00
lukaszraczylo 33c71fd6fe Enhance test suite. 2025-02-06 23:38:22 +00:00
lukaszraczylo 241cb1c209 Deal with the memory growth issue.
* TokenBlacklist limit is set to 1000
* Increased token cleanup frequency
2025-02-06 23:34:05 +00:00
lukaszraczylo 09daa1025c Follow multiple redirects during the OIDC flow. 2025-02-06 23:31:13 +00:00
lukaszraczylo c09e7a9228 Add additional test cases to cover it. 2025-02-06 21:50:35 +00:00
lukaszraczylo e5da5d4fe9 Fix redirection to the provider when session expires 2025-02-06 21:48:56 +00:00
lukaszraczylo 31db701dda Trigger build and release. 2025-02-05 19:04:44 +00:00
lukaszraczylo 16481afd36 Add todo: Improve test coverage. 2025-02-01 12:20:01 +00:00
lukaszraczylo 751933ffa0 Multiple improvements.
* Add todo list.

* fixup! Add todo list.

* fixup! fixup! Add todo list.

* fixup! fixup! fixup! Add todo list.

* Improve the session handling and cache.

* Fix an issue where expired session can cause infinite redirect loop

* fixup! Fix an issue where expired session can cause infinite redirect loop

* Add semver setup for automatic releases.

* fixup! Add semver setup for automatic releases.

* fixup! fixup! Add semver setup for automatic releases.

* fixup! fixup! fixup! Add semver setup for automatic releases.
2025-02-01 12:16:50 +00:00
lukaszraczylo e74153b107 Merge pull request #28 from lukaszraczylo/additional-improvements
additional improvements
2025-01-21 19:34:01 +00:00
lukaszraczylo 025107fe3e Well, release it finally. 2025-01-21 19:31:51 +00:00
lukaszraczylo dfb9c0771e Fix session handling and the redirection to the original URL incl. get parameters 2025-01-21 17:49:54 +00:00
lukaszraczylo 1107df40e7 Merge pull request #26 from lukaszraczylo/additional-improvements
Cleanup old cookies properly.
2025-01-21 17:34:16 +00:00
lukaszraczylo bf294569eb Cleanup old cookies properly. 2025-01-21 17:09:48 +00:00
lukaszraczylo 482c346840 Merge pull request #24 from lukaszraczylo/additional-improvements
additional improvements
2025-01-21 00:19:49 +00:00
lukaszraczylo a462e44896 Fix remaining issues with session handling and add additional tests. 2025-01-21 00:18:10 +00:00
lukaszraczylo 5eff0dc866 Clean up old cookies. 2025-01-21 00:03:13 +00:00
lukaszraczylo dfc534a400 Merge pull request #23 from lukaszraczylo/additional-improvements
Add useful defaults allowing traefik hub to pass.
2025-01-20 23:57:51 +00:00
lukaszraczylo 061c12d0a3 Add useful defaults allowing traefik hub to pass. 2025-01-20 23:55:58 +00:00
lukaszraczylo 4c4fff3613 Merge pull request #22 from lukaszraczylo/additional-improvements
Quite important fix
2025-01-20 23:50:35 +00:00
lukaszraczylo 0dcb44c187 Quite important fix
When user session expires, reauthentication fails as CSRF token disappears.
This commit fixes the issue by initiating new authentication flow.
2025-01-20 23:48:31 +00:00
lukaszraczylo cbe773d96a Merge pull request #20 from lukaszraczylo/additional-improvements
Provide default session encryption key if not specified.
2025-01-18 11:00:07 +00:00
lukaszraczylo 40254888d7 Provide default session encryption key if not specified. 2025-01-18 10:54:30 +00:00
lukaszraczylo ef41870c81 Merge pull request #18 from lukaszraczylo/additional-improvements
additional improvements
2025-01-18 02:28:29 +00:00
lukaszraczylo 081c32925a fixup! Security improvements have been implemented and verified across four main areas: 2025-01-14 11:47:49 +00:00
lukaszraczylo 17dea67229 Security improvements have been implemented and verified across four main areas:
JWT Token Security:
Protected against algorithm switching attacks by validating and whitelisting algorithms (RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512)
Added 2-minute clock skew tolerance for time-based validations
Added "not before" (nbf) claim validation with clock skew tolerance
Required JWT ID (jti) claim to prevent replay attacks
Added strict algorithm validation to prevent downgrade attacks
Session Management Security:
Implemented cryptographically secure random cookie names to prevent targeting
Added automatic session ID rotation after successful login to prevent session fixation
Enforced 24-hour absolute session timeout
Added strict encryption key length validation (minimum 32 bytes)
Added comprehensive session validation including timeout checks
Implemented session pooling for secure resource management
Added secure session cleanup on expiration
Configuration and URL Security:
Enforced HTTPS for all provider URLs and external endpoints
Added minimum rate limit (10 req/sec) to prevent DOS attacks
Added strict validation for excluded URLs:
Must start with "/"
No path traversal (..)
No wildcards (*)
Made ForceHTTPS true by default for secure cookies
Added validation for secure redirect URIs
Added validation for all OIDC endpoints (must be HTTPS)
Added secure defaults in configuration
Test Coverage:
Added comprehensive test cases verifying all security validations
Added test cases for HTTPS enforcement on all endpoints
Added test cases for minimum rate limits
Added test cases for secure session management
Added test cases for token validation with clock skew
Added test cases for secure configuration defaults
All security improvements have been verified through passing test cases, protecting against:

Session fixation attacks
Token replay attacks
Algorithm switching attacks
Path traversal attacks
Session hijacking
Timing attacks
DOS attacks
Man-in-the-middle attacks through enforced HTTPS
2025-01-14 11:33:48 +00:00
lukaszraczylo 8512ad6d68 Revert "Update vendored modules."
This reverts commit 5aa838c669.
2025-01-07 13:19:41 +00:00
30 changed files with 2665 additions and 917 deletions
+219 -18
View File
@@ -4,28 +4,229 @@ type: middleware
import: github.com/lukaszraczylo/traefikoidc
summary: |
Middleware adding OIDC authentication to traefik routes. Does what it says on the tin.
Middleware has been tested with Auth0 and Logto. It should work with any OIDC provider.
Middleware adding OpenID Connect (OIDC) authentication to Traefik routes.
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy.
It provides a complete OIDC authentication solution with features like domain restrictions,
role-based access control, token caching, and more.
The middleware has been tested with Auth0, Logto, Google, and other standard OIDC providers.
It supports various authentication scenarios including:
- Basic authentication with customizable callback and logout URLs
- Email domain restrictions to limit access to specific organizations
- Role and group-based access control
- Public URLs that bypass authentication
- Rate limiting to prevent brute force attacks
- Custom post-logout redirect behavior
- Secure session management with encrypted cookies
- Automatic token validation and refresh
testData:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
# Required parameters
providerURL: https://accounts.google.com # Base URL of the OIDC provider
clientID: 1234567890.apps.googleusercontent.com # OAuth 2.0 client identifier
clientSecret: secret # OAuth 2.0 client secret
callbackURL: /oauth2/callback # Path where the OIDC provider will redirect after authentication
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long # Key used to encrypt session data (must be at least 32 bytes)
# Optional parameters with defaults
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
scopes: # OAuth 2.0 scopes to request (default: ["openid", "email", "profile"])
- openid
- email
- profile
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
- raczylo.com
allowedRolesAndGroups:
- roles # Include this to get role information from the provider
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
- company.com
- subsidiary.com
allowedRolesAndGroups: # Restricts access to users with specific roles or groups (if not provided, no role/group restrictions)
- guest-endpoints
sessionEncryptionKey: potato-secret
forceHTTPS: false
logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- admin
- developer
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
excludedURLs: # Lists paths that bypass authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
- /public
- /health
- /metrics
# Advanced parameters (usually discovered automatically from provider metadata)
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
# Configuration documentation
configuration:
providerURL:
type: string
description: |
The base URL of the OIDC provider. This is the issuer URL that will be used to discover
OIDC endpoints like authorization, token, and JWKS URIs.
Examples:
- https://accounts.google.com
- https://login.microsoftonline.com/tenant-id/v2.0
- https://your-auth0-domain.auth0.com
- https://your-logto-instance.com/oidc
required: true
clientID:
type: string
description: |
The OAuth 2.0 client identifier obtained from your OIDC provider.
This is the public identifier for your application.
required: true
clientSecret:
type: string
description: |
The OAuth 2.0 client secret obtained from your OIDC provider.
This should be kept confidential and not exposed in client-side code.
For Kubernetes deployments, you can use the secret reference format:
urn:k8s:secret:namespace:secret-name:key
required: true
callbackURL:
type: string
description: |
The path where the OIDC provider will redirect after authentication.
This must match one of the redirect URIs configured in your OIDC provider.
The full redirect URI will be constructed as:
[scheme]://[host][callbackURL]
Example: /oauth2/callback
required: true
sessionEncryptionKey:
type: string
description: |
Key used to encrypt session data stored in cookies.
Must be at least 32 bytes long for security.
Example: potato-secret-is-at-least-32-bytes-long
required: true
logoutURL:
type: string
description: |
The path for handling logout requests.
If not provided, it will be set to callbackURL + "/logout".
Example: /oauth2/logout
required: false
postLogoutRedirectURI:
type: string
description: |
The URL to redirect to after logout.
Default: "/"
Example: /logged-out-page
required: false
scopes:
type: array
description: |
The OAuth 2.0 scopes to request from the OIDC provider.
Default: ["openid", "profile", "email"]
Include "roles" or similar scope if you need role/group information.
required: false
items:
type: string
logLevel:
type: string
description: |
Sets the logging verbosity.
Valid values: "debug", "info", "error"
Default: "info"
required: false
enum:
- debug
- info
- error
forceHTTPS:
type: boolean
description: |
Forces the use of HTTPS for all URLs.
This is recommended for security in production environments.
Default: true
required: false
rateLimit:
type: integer
description: |
Sets the maximum number of requests per second.
This helps prevent brute force attacks.
Default: 100
Minimum: 10
required: false
excludedURLs:
type: array
description: |
Lists paths that bypass authentication.
These paths will be accessible without OIDC authentication.
The middleware uses prefix matching, so "/public" will match
"/public", "/public/page", "/public-data", etc.
Examples: ["/health", "/metrics", "/public"]
required: false
items:
type: string
allowedUserDomains:
type: array
description: |
Restricts access to users with email addresses from specific domains.
If not provided, the middleware relies entirely on the OIDC provider
for authentication decisions.
Examples: ["company.com", "subsidiary.com"]
required: false
items:
type: string
allowedRolesAndGroups:
type: array
description: |
Restricts access to users with specific roles or groups.
If not provided, no role/group restrictions are applied.
The middleware checks both the "roles" and "groups" claims in the ID token.
Examples: ["admin", "developer"]
required: false
items:
type: string
revocationURL:
type: string
description: |
The endpoint for revoking tokens.
If not provided, it will be discovered from provider metadata.
Example: https://accounts.google.com/revoke
required: false
oidcEndSessionURL:
type: string
description: |
The provider's end session endpoint.
If not provided, it will be discovered from provider metadata.
Example: https://accounts.google.com/logout
required: false
+312 -124
View File
@@ -1,151 +1,283 @@
## Traefik OIDC middleware
# Traefik OIDC Middleware
This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy to support the OIDC authentication.
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy to support OpenID Connect (OIDC) authentication.
Middleware has been tested with Auth0 and Logto.
## Overview
### Traefik version compatibility
The Traefik OIDC middleware provides a complete OIDC authentication solution with features like:
- Token validation and verification
- Session management
- Domain restrictions
- Role-based access control
- Token caching and blacklisting
- Rate limiting
- Excluded paths (public URLs)
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
The middleware has been tested with Auth0 and Logto, but should work with any standard OIDC provider.
### Configuration options
## Traefik Version Compatibility
Middleware currently supports following scenarios:
This middleware follows closely the current Traefik helm chart versions. If the plugin fails to load, it's time to update to the latest version of the Traefik helm chart.
* Setting custom callback and logout URLs via `callbackURL` and `logoutURL`
* Allowing for access only from the listed domains if `allowedUserDomains` is set, otherwise it relies entirely on the OIDC provider
* Using excluded URLs which do **NOT** require the OIDC authentication
* Rate limiting requests to prevent the bruteforce attacks
## Installation
#### How to configure...
### As a Traefik Plugin
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
1. Enable the plugin in your Traefik static configuration:
```yaml
# traefik.yml
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.2.1 # Use the latest version
```
2. Configure the middleware in your dynamic configuration (see examples below).
### Local Development with Docker Compose
For local development or testing, you can use the provided Docker Compose setup:
```bash
cd docker
docker-compose up -d
```
This will start Traefik with the OIDC middleware and two test services.
## Configuration Options
The middleware supports the following configuration options:
### Required Parameters
| Parameter | Description | Example |
|-----------|-------------|---------|
| `providerURL` | The base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | The OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | The OAuth 2.0 client secret | `your-client-secret` |
| `sessionEncryptionKey` | Key used to encrypt session data (must be at least 32 bytes long) | `potato-secret-is-at-least-32-bytes-long` |
| `callbackURL` | The path where the OIDC provider will redirect after authentication | `/oauth2/callback` |
### Optional Parameters
| Parameter | Description | Default | Example |
|-----------|-------------|---------|---------|
| `logoutURL` | The path for handling logout requests | `callbackURL + "/logout"` | `/oauth2/logout` |
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
## Usage Examples
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-basic
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
```
### With Excluded URLs (Public Access Paths)
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-open-urls
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
excludedURLs:
- /login # covers /login, /login/me, /login/reminder etc.
- /public-data
- /health
- /metrics
```
### With Email Domain Restrictions
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-domain-restricted
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
allowedUserDomains:
- company.com
- subsidiary.com
```
### With Role-Based Access Control
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-rbac
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Include this to get role information from the provider
allowedRolesAndGroups:
- admin
- developer
```
### With Custom Logging and Rate Limiting
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-custom-settings
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
logLevel: debug # Options: debug, info, error (default: info)
rateLimit: 500 # Requests per second (default: 100)
forceHTTPS: false # Default is true for security
scopes:
- openid
- email
- profile
```
### With Custom Post-Logout Redirect
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-custom-logout
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /logged-out-page # Where to redirect after logout
scopes:
- openid
- email
- profile
```
### Keeping Secrets Secret in Kubernetes
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-secrets
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
```
##### Excluded URLs with open access
Don't forget to create the secret:
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-open-urls
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: xxx
clientID: yyy
clientSecret: zzz
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
scopes:
- openid
- email
- profile
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
```bash
kubectl create secret generic traefik-middleware-oidc \
--from-literal=ISSUER=https://accounts.google.com \
--from-literal=CLIENT_ID=1234567890.apps.googleusercontent.com \
--from-literal=SECRET=your-client-secret \
-n traefik
```
## Complete Docker Compose Example
##### Allowed email domains
Assuming that your OIDC provider allows anyone to log in, you may want to limit the access to people using emains in specific domain.
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-only-my-users
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: xxx
clientID: yyy
clientSecret: zzz
sessionEncryptionKey: vvv
callbackURL: /new-oidc/callback
logoutURL: /new-oidc/logout
scopes:
- openid
- email
- profile
allowedUserDomains:
- raczylo.com
```
##### Allowed groups and roles
In case of multiple roles / groups and access separation for various endpoints you will need to create multiple traefik middlewares.
Following example allows access for users who have additional role `guest-endpoints` assigned.
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-guest-endpoints
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: xxx
clientID: yyy
clientSecret: zzz
sessionEncryptionKey: vvv
callbackURL: /my-oidc/callback
logoutURL: /my-oidc/logout
scopes:
- openid
- email
- profile
- roles # This line queries the OIDC provider for roles
forceHTTPS: true
allowedRolesAndGroups:
- guest-endpoints # This line specifies the roles or groups allowed to access content
allowedUserDomains:
- raczylo.com
```
#### Docker compose example
`docker-compose.yaml`
Here's a complete example of using the middleware with Docker Compose:
```yaml
version: "3.7"
services:
traefik:
image: traefik:v3.0.1
image: traefik:v3.2.1
command:
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
- "--experimental.plugins.traefikoidc.version=v0.2.1"
@@ -156,7 +288,6 @@ services:
labels:
- "traefik.http.routers.dash.rule=Host(`dash.localhost`)"
- "traefik.http.routers.dash.service=api@internal"
ports:
- "80:80"
@@ -179,8 +310,7 @@ services:
- traefik.http.routers.whoami.middlewares=my-plugin@file
```
`traefik-config/traefik.yaml`
`traefik-config/traefik.yml`:
```yaml
log:
level: INFO
@@ -209,7 +339,7 @@ providers:
filename: /etc/traefik/dynamic-configuration.yml
```
`traefik-config/dynamic-configuration.yaml`
`traefik-config/dynamic-configuration.yml`:
```yaml
http:
middlewares:
@@ -218,20 +348,78 @@ http:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: secret
clientSecret: your-client-secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes: # If not provided, default scopes will be used (openid, email, profile)
postLogoutRedirectURI: /logged-out-page
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
scopes:
- openid
- email
- profile
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
- raczylo.com
sessionEncryptionKey: potato-secret
allowedUserDomains:
- company.com
allowedRolesAndGroups:
- admin
- developer
forceHTTPS: false
logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
logLevel: debug
rateLimit: 100
excludedURLs:
- /login
- /public
- /health
- /metrics
```
## Advanced Configuration
### Session Management
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
### Token Caching and Blacklisting
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
### Headers Set for Downstream Services
When a user is authenticated, the middleware sets the following headers for downstream services:
- `X-Forwarded-User`: The user's email address
- `X-User-Groups`: Comma-separated list of user groups (if available)
- `X-User-Roles`: Comma-separated list of user roles (if available)
- `X-Auth-Request-Redirect`: The original request URI
- `X-Auth-Request-User`: The user's email address
- `X-Auth-Request-Token`: The user's access token
### Security Headers
The middleware also sets the following security headers:
- `X-Frame-Options: DENY`
- `X-Content-Type-Options: nosniff`
- `X-XSS-Protection: 1; mode=block`
- `Referrer-Policy: strict-origin-when-cross-origin`
## Troubleshooting
### Logging
Set the `logLevel` to `debug` to get more detailed logs:
```yaml
logLevel: debug
```
### Common Issues
1. **Token verification failed**: Check that your `providerURL` is correct and accessible.
2. **Session encryption key too short**: Ensure your `sessionEncryptionKey` is at least 32 bytes long.
3. **No matching public key found**: The JWKS endpoint might be unavailable or the token's key ID (kid) doesn't match any key in the JWKS.
4. **Access denied: Your email domain is not allowed**: The user's email domain is not in the `allowedUserDomains` list.
5. **Access denied: You do not have any of the allowed roles or groups**: The user doesn't have any of the roles or groups specified in `allowedRolesAndGroups`.
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
+5
View File
@@ -0,0 +1,5 @@
### TODO / wishlist
- [] Improve test coverage
- [x] Improve caching mechanism
- [x] Add automatic release and semver generation
+18
View File
@@ -0,0 +1,18 @@
package traefikoidc
import "time"
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
// It stops when a value is received on the stop channel.
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cleanup()
case <-stop:
return
}
}
}
+22
View File
@@ -0,0 +1,22 @@
package traefikoidc
import (
"sync/atomic"
"testing"
"time"
)
func TestAutoCleanupRoutine(t *testing.T) {
var counter int32
cleanupFunc := func() {
atomic.AddInt32(&counter, 1)
}
stop := make(chan struct{})
go autoCleanupRoutine(50*time.Millisecond, stop, cleanupFunc)
time.Sleep(250 * time.Millisecond)
close(stop)
if atomic.LoadInt32(&counter) < 3 {
t.Errorf("Expected cleanup to be called at least 3 times, got %d", counter)
}
}
+110
View File
@@ -0,0 +1,110 @@
package traefikoidc
import (
"sync"
"time"
)
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
type TokenBlacklist struct {
tokens map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new token blacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
tokens: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
b.mutex.Lock()
defer b.mutex.Unlock()
// Clean up expired tokens if we're at capacity
if len(b.tokens) >= 1000 {
now := time.Now()
futureThreshold := now.Add(time.Minute)
for t, exp := range b.tokens {
if now.After(exp) || futureThreshold.After(exp) {
delete(b.tokens, t)
}
}
// If still at capacity, remove oldest token
if len(b.tokens) >= 1000 {
var oldestToken string
var oldestTime time.Time
first := true
for t, exp := range b.tokens {
if first || exp.Before(oldestTime) {
oldestToken = t
oldestTime = exp
first = false
}
}
if oldestToken != "" {
delete(b.tokens, oldestToken)
}
}
}
b.tokens[token] = expiry
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
b.mutex.RLock()
defer b.mutex.RUnlock()
expiry, exists := b.tokens[token]
if !exists {
return false
}
// If token is expired, remove it and return false
if time.Now().After(expiry) {
// Switch to write lock to remove expired token
b.mutex.RUnlock()
b.mutex.Lock()
delete(b.tokens, token)
b.mutex.Unlock()
b.mutex.RLock()
return false
}
return true
}
// Cleanup removes expired tokens from the blacklist.
// Also removes tokens that will expire within the next minute to prevent edge cases.
func (b *TokenBlacklist) Cleanup() {
b.mutex.Lock()
defer b.mutex.Unlock()
now := time.Now()
futureThreshold := now.Add(time.Minute)
for token, expiry := range b.tokens {
// Remove tokens that are expired or will expire soon
if now.After(expiry) || futureThreshold.After(expiry) {
delete(b.tokens, token)
}
}
}
// Remove removes a token from the blacklist regardless of its expiration.
func (b *TokenBlacklist) Remove(token string) {
b.mutex.Lock()
defer b.mutex.Unlock()
delete(b.tokens, token)
}
// Count returns the current number of tokens in the blacklist.
func (b *TokenBlacklist) Count() int {
b.mutex.RLock()
defer b.mutex.RUnlock()
return len(b.tokens)
}
+74
View File
@@ -0,0 +1,74 @@
package traefikoidc
import (
"testing"
"time"
)
func TestTokenBlacklist_Add(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
}
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
if blacklist.IsBlacklisted("nonExistentToken") {
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
}
}
func TestTokenBlacklist_Cleanup(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(-time.Hour) // Expired token
blacklist.Add(token, expiry)
blacklist.Cleanup()
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
}
}
func TestTokenBlacklist_Remove(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
blacklist.Remove(token)
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be removed, but it was not")
}
}
func TestTokenBlacklist_Count(t *testing.T) {
blacklist := NewTokenBlacklist()
token1 := "token1"
token2 := "token2"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token1, expiry)
blacklist.Add(token2, expiry)
if blacklist.Count() != 2 {
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
}
}
+114 -96
View File
@@ -1,169 +1,187 @@
package traefikoidc
import (
"container/list"
"sync"
"time"
)
// CacheItem represents an item stored in the cache with its associated metadata.
type CacheItem struct {
// Value is the cached data of any type
// Value is the cached data of any type.
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
// ExpiresAt is the timestamp when this item should be considered expired.
ExpiresAt time.Time
}
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache
maxSize int
// accessList maintains the order of item access for eviction
accessList []string
// lruEntry represents an entry in the LRU list.
type lruEntry struct {
key string
}
// DefaultMaxSize is the default maximum number of items in the cache
const DefaultMaxSize = 1000
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
type Cache struct {
// items stores the cached data with string keys.
items map[string]CacheItem
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
// order maintains the usage order; most recently used items are at the back.
order *list.List
// elems maps keys to their corresponding list elements for O(1) access.
elems map[string]*list.Element
// mutex protects concurrent access to the cache.
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache.
maxSize int
// autoCleanupInterval defines how often Cleanup is called automatically.
autoCleanupInterval time.Duration
// stopCleanup channel to terminate the auto cleanup goroutine.
stopCleanup chan struct{}
}
// DefaultMaxSize is the default maximum number of items in the cache.
const DefaultMaxSize = 500
// NewCache creates a new empty cache instance that is ready for use.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
maxSize: DefaultMaxSize,
accessList: make([]string, 0, DefaultMaxSize),
c := &Cache{
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
autoCleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
}
go c.startAutoCleanup()
return c
}
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
// Thread-safe: Uses write locking to ensure safe concurrent access.
// It moves the item to the most recently used position.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// If key exists, update it
now := time.Now()
expTime := now.Add(expiration)
// Update existing item.
if _, exists := c.items[key]; exists {
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: expTime,
}
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return
}
// If cache is full, remove oldest item
// Evict oldest item if cache is full.
if len(c.items) >= c.maxSize {
c.evictOldest()
}
// Add new item
// Add new item.
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: expTime,
}
c.accessList = append(c.accessList, key)
elem := c.order.PushBack(lruEntry{key: key})
c.elems[key] = elem
}
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
// Thread-safe: Uses read locking to ensure safe concurrent access.
// Moving the accessed item to the most recently used position.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
item, found := c.items[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
if time.Now().After(item.ExpiresAt) {
c.mutex.Lock()
c.removeItem(key)
c.mutex.Unlock()
return nil, false
}
// Update access order
c.mutex.Lock()
c.updateAccessOrder(key)
c.mutex.Unlock()
defer c.mutex.Unlock()
item, exists := c.items[key]
if !exists {
return nil, false
}
// Check for expiration.
if time.Now().After(item.ExpiresAt) {
c.removeItem(key)
return nil, false
}
// Move item to the back (most recently used).
if elem, ok := c.elems[key]; ok {
c.order.MoveToBack(elem)
}
return item.Value, true
}
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Delete removes an item from the cache.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.removeItem(key)
}
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
// Cleanup removes all expired items from the cache. This should be called periodically
// to prevent memory bloat from expired entries.
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
var newAccessList []string
for _, key := range c.accessList {
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
newAccessList = append(newAccessList, key)
} else {
delete(c.items, key)
for key, item := range c.items {
// Remove items that are expired or within 10% of expiration
if now.After(item.ExpiresAt) || now.Add(time.Duration(float64(item.ExpiresAt.Sub(now))*0.1)).After(item.ExpiresAt) {
c.removeItem(key)
}
}
c.accessList = newAccessList
}
// evictOldest removes the least recently used item from the cache
// evictOldest removes the least recently used item from the cache.
func (c *Cache) evictOldest() {
if len(c.accessList) > 0 {
oldest := c.accessList[0]
c.removeItem(oldest)
now := time.Now()
elem := c.order.Front()
// First try to find an expired item from the front
for elem != nil {
entry := elem.Value.(lruEntry)
if item, exists := c.items[entry.key]; exists {
if now.After(item.ExpiresAt) {
c.removeItem(entry.key)
return
}
}
elem = elem.Next()
}
// If no expired items found, remove the oldest item
if elem = c.order.Front(); elem != nil {
entry := elem.Value.(lruEntry)
c.removeItem(entry.key)
}
}
// removeItem removes an item from both the cache and access list
// removeItem removes an item from both the cache and the LRU tracking structures.
func (c *Cache) removeItem(key string) {
delete(c.items, key)
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
if elem, ok := c.elems[key]; ok {
c.order.Remove(elem)
delete(c.elems, key)
}
}
// updateAccessOrder moves the accessed key to the end of the access list
func (c *Cache) updateAccessOrder(key string) {
for i, k := range c.accessList {
if k == key {
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
break
}
}
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
func (c *Cache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
}
// Close terminates the auto cleanup goroutine.
func (c *Cache) Close() {
close(c.stopCleanup)
}
+2 -2
View File
@@ -6,8 +6,8 @@ toolchain go1.23.1
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.4.0
golang.org/x/time v0.9.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.7.0
)
require github.com/gorilla/securecookie v1.1.2 // indirect
+4 -4
View File
@@ -4,7 +4,7 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+27 -68
View File
@@ -8,31 +8,12 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
@@ -87,13 +68,28 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
data.Set("refresh_token", codeOrToken)
}
// Create a cookie jar for this request to handle redirects with cookies
jar, _ := cookiejar.New(nil)
client := &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := t.httpClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
@@ -128,11 +124,19 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
// Clear authentication data but preserve CSRF state
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Save the cleared session state
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -278,51 +282,6 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a thread-safe list of revoked tokens.
// It stores tokens with their expiration times and automatically
// removes expired entries during cleanup operations.
type TokenBlacklist struct {
// blacklist maps token IDs to their expiration times
blacklist map[string]time.Time
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist.
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
now := time.Now()
for tokenID, expiration := range tb.blacklist {
if now.After(expiration) {
delete(tb.blacklist, tokenID)
}
}
}
// TokenCache provides a caching mechanism for validated tokens.
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
+227
View File
@@ -0,0 +1,227 @@
package traefikoidc
import (
"fmt"
"runtime"
"testing"
"time"
)
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
}
// Verify size is at max
if tb.Count() != 1000 {
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
}
// Add one more token, should trigger cleanup/eviction
tb.Add("newtoken", time.Now().Add(time.Hour))
// Size should still be at max
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
}
// Force cleanup
tb.Cleanup()
// Only valid tokens should remain
if tb.Count() != 500 {
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
}
// Verify only valid tokens remain
tb.mutex.RLock()
defer tb.mutex.RUnlock()
for token, expiry := range tb.tokens {
if time.Now().After(expiry) {
t.Errorf("Found expired token after cleanup: %s", token)
}
}
}
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
}
// Add a new token that should evict the oldest
newToken := "newest"
tb.Add(newToken, baseTime.Add(time.Hour*3))
// Verify oldest token was evicted
if tb.IsBlacklisted(oldestToken) {
t.Error("Oldest token should have been evicted")
}
// Verify newest token is present
if !tb.IsBlacklisted(newToken) {
t.Error("Newest token should be present")
}
}
func TestTokenBlacklistMemoryUsage(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy usage
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
}
// Verify size stayed within limits
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 1000
concurrency := 10
done := make(chan bool)
// Start multiple goroutines performing operations
for i := 0; i < concurrency; i++ {
go func(id int) {
for j := 0; j < iterations; j++ {
// Add tokens
token := fmt.Sprintf("token%d-%d", id, j)
tb.Add(token, time.Now().Add(time.Hour))
// Check blacklist status
tb.IsBlacklisted(token)
// Periodic cleanup
if j%100 == 0 {
tb.Cleanup()
}
}
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < concurrency; i++ {
<-done
}
// Verify size constraints were maintained
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
}
}
func TestTokenCacheMemoryUsage(t *testing.T) {
tc := NewTokenCache()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy cache usage
for i := 0; i < iterations; i++ {
claims := map[string]interface{}{
"sub": fmt.Sprintf("user%d", i),
"exp": time.Now().Add(time.Hour).Unix(),
}
// Add to cache
tc.Set(fmt.Sprintf("token%d", i), claims, time.Hour)
// Periodically retrieve
if i%100 == 0 {
tc.Get(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tc.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive cache memory growth: %d bytes", memoryGrowth)
}
// Verify cache size stayed within limits
if len(tc.cache.items) > tc.cache.maxSize {
t.Errorf("Cache exceeded max size: %d", len(tc.cache.items))
}
}
+37 -75
View File
@@ -1,90 +1,51 @@
package traefikoidc
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
)
// JWK represents a JSON Web Key as defined in RFC 7517.
// It contains the cryptographic key information used for token verification.
type JWK struct {
// Kty is the key type (e.g., "RSA", "EC")
Kty string `json:"kty"`
// Kid is the unique key identifier
Kid string `json:"kid"`
// Use specifies the intended use of the key (e.g., "sig" for signature)
Use string `json:"use"`
// N is the modulus for RSA keys
N string `json:"n"`
// E is the exponent for RSA keys
E string `json:"e"`
// Alg is the algorithm intended for use with the key
N string `json:"n"`
E string `json:"e"`
Alg string `json:"alg"`
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
Crv string `json:"crv"`
// X is the x-coordinate for EC keys
X string `json:"x"`
// Y is the y-coordinate for EC keys
Y string `json:"y"`
X string `json:"x"`
Y string `json:"y"`
}
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
// OIDC providers typically expose multiple keys to support key rotation.
type JWKSet struct {
// Keys is the array of JSON Web Keys
Keys []JWK `json:"keys"`
}
// JWKCache provides a thread-safe caching mechanism for JWK sets.
// It caches the keys for a configurable duration to reduce load on the OIDC provider
// while ensuring keys are refreshed periodically to handle key rotation.
type JWKCache struct {
// jwks holds the cached set of JSON Web Keys
jwks *JWKSet
// expiresAt is the timestamp when the cached keys should be refreshed
jwks *JWKSet
expiresAt time.Time
// mutex protects concurrent access to the cache
mutex sync.RWMutex
mutex sync.RWMutex
// CacheLifetime is configurable to determine how long the JWKS is cached.
CacheLifetime time.Duration
}
// JWKCacheInterface defines the interface for JWK caching operations.
// This interface allows for different caching implementations while
// maintaining consistent behavior in the token verification process.
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
Cleanup()
}
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
// from the OIDC provider. It implements a thread-safe double-checked locking
// pattern to prevent multiple simultaneous fetches of the same keys.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
defer c.mutex.RUnlock()
@@ -94,32 +55,43 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
c.mutex.Lock()
defer c.mutex.Unlock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil
}
jwks, err := fetchJWKS(jwksURL, httpClient)
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
if err != nil {
return nil, err
}
c.jwks = jwks
c.expiresAt = time.Now().Add(1 * time.Hour)
lifetime := c.CacheLifetime
if lifetime == 0 {
lifetime = 1 * time.Hour
}
c.expiresAt = time.Now().Add(lifetime)
return jwks, nil
}
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
func (c *JWKCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
if c.jwks != nil && now.After(c.expiresAt) {
c.jwks = nil
}
}
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// Create a request with context to enforce timeout
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
}
@@ -137,9 +109,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
// cryptographic functions. It supports both RSA and EC keys, delegating
// to the appropriate converter based on the key type.
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
@@ -155,9 +124,6 @@ var jwkConverters = map[string]jwkToPEMConverter{
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
// It handles base64url decoding of the modulus and exponent,
// constructs an RSA public key, and encodes it in PEM format.
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
@@ -189,10 +155,6 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
// It supports the P-256, P-384, and P-521 curves as defined in the
// OIDC specification, decoding the x and y coordinates and encoding
// the resulting public key in PEM format.
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
+111 -96
View File
@@ -4,43 +4,41 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strings"
"sync"
"time"
)
var replayCacheMu sync.Mutex
var replayCache = make(map[string]time.Time)
func cleanupReplayCache() {
now := time.Now()
for token, expiry := range replayCache {
if expiry.Before(now) {
delete(replayCache, token)
}
}
}
// ClockSkewTolerance is configurable to adjust time-based validations.
var ClockSkewTolerance = 2 * time.Minute
// JWT represents a JSON Web Token as defined in RFC 7519.
// It contains the three parts of a JWT: header, claims (payload),
// and signature, along with the original token string.
type JWT struct {
// Header contains the token metadata (algorithm, key ID, etc.)
Header map[string]interface{}
// Claims contains the token claims (subject, expiration, etc.)
Claims map[string]interface{}
// Signature contains the raw signature bytes
Header map[string]interface{}
Claims map[string]interface{}
Signature []byte
// Token is the original JWT string
Token string
Token string
}
// parseJWT parses a JWT token string into a JWT struct.
// It validates the token format and decodes the three parts
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -51,7 +49,6 @@ func parseJWT(tokenString string) (*JWT, error) {
Token: tokenString,
}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
@@ -60,7 +57,6 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
}
// Decode and unmarshal the claims
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
@@ -69,7 +65,6 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
@@ -80,14 +75,22 @@ func parseJWT(tokenString string) (*JWT, error) {
}
// Verify validates the standard JWT claims as defined in RFC 7519.
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future
// - issued at time (iat) is in the past
// - subject (sub) is present and not empty
// Returns an error if any validation fails.
// Verify validates the standard JWT claims as defined in RFC 7519.
func (j *JWT) Verify(issuerURL, clientID string) error {
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing 'alg' header")
}
supportedAlgs := map[string]bool{
"RS256": true, "RS384": true, "RS512": true,
"PS256": true, "PS384": true, "PS512": true,
"ES256": true, "ES384": true, "ES512": true,
}
if !supportedAlgs[alg] {
return fmt.Errorf("unsupported algorithm: %s", alg)
}
claims := j.Claims
iss, ok := claims["iss"].(string)
@@ -122,6 +125,38 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return err
}
if nbf, ok := claims["nbf"].(float64); ok {
if err := verifyNotBefore(nbf); err != nil {
return err
}
}
// Implement replay protection by checking the jti (JWT ID)
if jti, ok := claims["jti"].(string); ok {
// Skip replay detection for tokens that are being verified from the cache
if j.Token == "" {
// This is a parsed JWT without the original token string,
// which means it's likely from a cached token verification
return nil
}
replayCacheMu.Lock()
cleanupReplayCache()
if _, exists := replayCache[jti]; exists {
replayCacheMu.Unlock()
return fmt.Errorf("token replay detected")
}
expFloat, ok := claims["exp"].(float64)
var expTime time.Time
if ok {
expTime = time.Unix(int64(expFloat), 0)
} else {
expTime = time.Now().Add(10 * time.Minute)
}
replayCache[jti] = expTime
replayCacheMu.Unlock()
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
@@ -129,14 +164,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience validates the token's audience claim.
// The audience can be either a single string or an array of strings.
// For array audiences, the expected audience must match any one value.
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
// Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
@@ -160,85 +187,80 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
return nil
}
// verifyIssuer validates the token's issuer claim.
// The issuer URL must exactly match the expected issuer.
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
// Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
}
return nil
}
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time.
// Parameters:
// - expiration: The expiration timestamp from the token
// Returns an error if the token has expired.
// verifyTimeConstraint is a generic function to verify time-based claims
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
claimTime := time.Unix(int64(unixTime), 0)
now := time.Now().Truncate(time.Second)
// For expiration (future=true), we add skew to now (making now later)
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
skewDirection := 1
if !future {
skewDirection = -1
}
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
if claimTime.Equal(now) {
return nil
}
// For expiration: if skewedNow (later) is after expiration, token expired
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
var reason string
if future {
reason = "has expired"
} else {
if claimName == "iat" {
reason = "used before issued"
} else {
reason = "not yet valid"
}
}
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
}
return nil
}
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
}
return nil
return verifyTimeConstraint(expiration, "exp", true)
}
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, which could
// indicate clock skew or a malicious token.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
// Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
return verifyTimeConstraint(issuedAt, "iat", false)
}
func verifyNotBefore(notBefore float64) error {
return verifyTimeConstraint(notBefore, "nbf", false)
}
// verifySignature validates the token's cryptographic signature.
// Supports multiple signature algorithms:
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
// Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Decode the PEM-encoded public key
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
@@ -249,27 +271,20 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
+285 -97
View File
@@ -12,10 +12,46 @@ import (
"strings"
"time"
"runtime"
"github.com/google/uuid"
"golang.org/x/time/rate"
)
// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC
func createDefaultHTTPClient() *http.Client {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
}
return &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
}
const ConstSessionTimeout = 86400 // Session timeout in seconds
// TokenVerifier interface for token verification
@@ -37,6 +73,7 @@ type TraefikOidc struct {
issuerURL string
revocationURL string
jwkCache JWKCacheInterface
metadataCache *MetadataCache
tokenBlacklist *TokenBlacklist
jwksURL string
clientID string
@@ -60,7 +97,6 @@ type TraefikOidc struct {
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
@@ -80,13 +116,50 @@ var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
// VerifyToken implements the TokenVerifier interface to verify an OIDC token.
// It performs a complete verification process including:
// 1. Checking the token cache to avoid redundant verifications
// 2. Performing rate limiting and blacklist checks
// 3. Parsing the JWT structure
// 4. Verifying the JWT signature against the JWKS from the provider
// 5. Validating standard JWT claims (iss, aud, exp, etc.)
// 6. Caching the verified token for future requests
//
// Returns nil if the token is valid, or an error describing the validation failure.
func (t *TraefikOidc) VerifyToken(token string) error {
// Check cache first
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
t.logger.Debugf("Token found in cache with valid claims; skipping verification")
return nil
}
t.logger.Debugf("Verifying token")
// Rate limiting
// Perform pre-verification checks
if err := t.performPreVerificationChecks(token); err != nil {
return err
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and standard claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the verified token
t.cacheVerifiedToken(token, jwt.Claims)
return nil
}
// performPreVerificationChecks performs rate limiting and blacklist checks
func (t *TraefikOidc) performPreVerificationChecks(token string) error {
// Enforce rate limiting
if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
@@ -96,30 +169,15 @@ func (t *TraefikOidc) VerifyToken(token string) error {
return fmt.Errorf("token is blacklisted")
}
// Check if token is cached
if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached
}
return nil
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the token until it expires
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
// cacheVerifiedToken caches a verified token until its expiration time
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
expirationTime := time.Unix(int64(claims["exp"].(float64)), 0)
now := time.Now()
duration := expirationTime.Sub(now)
t.tokenCache.Set(token, jwt.Claims, duration)
return nil
t.tokenCache.Set(token, claims, duration)
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
@@ -127,7 +185,7 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
t.logger.Debugf("Verifying JWT signature and claims")
// Get JWKS
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
@@ -173,36 +231,53 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
return nil
}
// New creates a new instance of the OIDC middleware
// New creates a new instance of the OIDC middleware.
// This is the main entry point for the middleware and is called by Traefik when loading the plugin.
// It initializes all components needed for OIDC authentication:
// - Session management for storing user state
// - Token caching and blacklisting
// - JWK caching for signature verification
// - Rate limiting to prevent abuse
// - Metadata discovery for OIDC provider endpoints
//
// Parameters:
// - ctx: Context for initialization operations
// - next: The next handler in the middleware chain
// - config: Configuration options for the middleware
// - name: Identifier for this middleware instance
//
// Returns:
// - An http.Handler that implements the middleware
// - An error if initialization fails
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
if config == nil {
config = CreateConfig()
}
// Generate default session encryption key if not provided
if config.SessionEncryptionKey == "" {
// Generate a fixed key for Traefik Hub testing
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
// Initialize logger
logger := NewLogger(config.LogLevel)
// Ensure key meets minimum length requirement
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
// Set default encryption key for Yaegi (Traefik Plugin Analyzer)
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
}
// Setup HTTP client
var httpClient *http.Client
if config.HTTPClient != nil {
httpClient = config.HTTPClient
} else {
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
}
httpClient = createDefaultHTTPClient()
}
t := &TraefikOidc{
@@ -223,6 +298,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
metadataCache: NewMetadataCache(),
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
@@ -230,14 +306,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
// Assign the initialized logger
t.logger = logger
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
@@ -260,41 +337,56 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Starting provider metadata discovery")
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
// Get metadata from cache or fetch it
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to get provider metadata: %v", err)
return
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.updateMetadataEndpoints(metadata)
// Start metadata refresh goroutine
go t.startMetadataRefresh(providerURL)
// Only close channel on success
close(t.initComplete)
return
}
t.logger.Error("Received nil metadata")
}
// updateMetadataEndpoints updates the middleware with metadata endpoints
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
}
// startMetadataRefresh periodically refreshes the OIDC metadata
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
t.logger.Debug("Refreshing OIDC metadata")
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
t.logger.Errorf("Failed to refresh metadata: %v", err)
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
t.updateMetadataEndpoints(metadata)
t.logger.Debug("Successfully refreshed metadata")
}
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
}
}
@@ -360,7 +452,18 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware
// ServeHTTP is the main handler for the middleware that processes all HTTP requests.
// It implements the http.Handler interface and performs the following operations:
// 1. Waits for OIDC provider metadata initialization to complete
// 2. Checks if the requested URL is in the excluded list (bypassing authentication)
// 3. Retrieves or creates a user session
// 4. Handles special paths like callback and logout URLs
// 5. Verifies authentication status and token validity
// 6. Refreshes tokens that are about to expire
// 7. Validates user email domains, roles, and groups against configured restrictions
// 8. Sets appropriate headers for downstream services
// 9. Applies security headers to responses
// 10. Forwards the authenticated request to the next handler
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
@@ -389,7 +492,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
// Obtain a new session and clear any residual session cookies
session, _ = t.sessionManager.GetSession(req)
session.Clear(req, rw)
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Initiate authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -475,6 +589,34 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Set OIDC-specific headers
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetAccessToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
// Set security headers
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
rw.Header().Set("X-XSS-Protection", "1; mode=block")
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Set CORS headers
origin := req.Header.Get("Origin")
if origin != "" {
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Access-Control-Allow-Credentials", "true")
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Handle preflight requests
if req.Method == "OPTIONS" {
rw.WriteHeader(http.StatusOK)
return
}
}
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -493,9 +635,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
@@ -513,7 +652,17 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated
// isUserAuthenticated checks if the user is authenticated by validating their session and token.
// It performs a comprehensive check of the authentication state including:
// 1. Verifying the session's authenticated flag
// 2. Checking for the presence of an access token
// 3. Validating the token's signature and claims
// 4. Checking the token's expiration time
//
// Returns three boolean values:
// - authenticated: Whether the user is currently authenticated
// - needsRefresh: Whether the token is valid but will expire soon (within grace period)
// - expired: Whether the token has expired or is otherwise invalid
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
@@ -561,20 +710,35 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return true, false, false
}
// defaultInitiateAuthentication initiates the authentication process
// defaultInitiateAuthentication initiates the OIDC authentication process.
// This function prepares and starts a new authentication flow by:
// 1. Generating security tokens (CSRF token and nonce) to prevent attacks
// 2. Clearing any existing session data to avoid state conflicts
// 3. Storing the original request path to redirect back after authentication
// 4. Building the authorization URL with all required OIDC parameters
// 5. Redirecting the user to the OIDC provider's authorization endpoint
//
// Parameters:
// - rw: The HTTP response writer for sending the redirect
// - req: The original HTTP request that triggered authentication
// - session: The user's session data for storing authentication state
// - redirectURL: The callback URL where the OIDC provider will redirect after authentication
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.New().String()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Set session values
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
session.SetIncomingPath(req.URL.RequestURI())
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -583,12 +747,15 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build and redirect to auth URL
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier
// verifyToken verifies the token using the token verifier interface.
// This function delegates to the configured token verifier implementation,
// which by default is the TraefikOidc instance itself (implementing the VerifyToken method).
// This design allows for easy mocking in tests and potential future extension.
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
@@ -604,17 +771,38 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
return t.authURL + "?" + params.Encode()
return t.buildURLWithParams(t.authURL, params)
}
// buildURLWithParams ensures a URL is absolute and appends query parameters
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
// Ensure URL is absolute
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
baseURL,
params.Encode())
}
}
return baseURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
go func() {
defer ticker.Stop()
for range ticker.C {
t.logger.Debug("Cleaning up token cache")
t.logger.Debug("Starting token cleanup cycle")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go
// Removed runtime.GC() call
}
}()
}
+291 -26
View File
@@ -13,6 +13,7 @@ import (
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -67,21 +68,29 @@ func (ts *TestSuite) Setup() {
}
// Create a test JWT token signed with the RSA private key
// Create timestamps with proper clock skew
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
"jti": generateRandomString(16),
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
}
logger := NewLogger("info")
ts.sessionManager = NewSessionManager("test-secret-key", false, logger)
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
@@ -122,10 +131,16 @@ type MockJWKCache struct {
Err error
}
func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return m.JWKS, m.Err
}
func (m *MockJWKCache) Cleanup() {
// Mock cleanup implementation
m.JWKS = nil
m.Err = nil
}
// Helper function to create a JWT token
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
header := map[string]interface{}{
@@ -212,7 +227,7 @@ func TestVerifyToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
// Reset token blacklist and cache for each test
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
@@ -228,9 +243,20 @@ func TestVerifyToken(t *testing.T) {
}
if tc.cacheToken {
// Use more realistic claims for cached token
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
"empty": "claim",
}, 60)
"iss": "https://test-issuer.com",
"sub": "test-subject",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
}, time.Minute)
// Verify the token is actually in the cache
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
t.Logf("Token found in cache with claims: %v", claims)
} else {
t.Logf("Token NOT found in cache despite cacheToken=true")
}
}
err := ts.tOidc.VerifyToken(tc.token)
@@ -611,7 +637,7 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
@@ -916,7 +942,7 @@ func TestHandleLogout(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
@@ -1205,7 +1231,7 @@ func TestHandleExpiredToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
sessionManager: sessionManager,
@@ -1362,23 +1388,23 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
// Create base config
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
}
// Create multiple middleware instances
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
var middlewares []*TraefikOidc
for _, route := range routes {
config.CallbackURL = route + "/callback"
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
}
@@ -1440,6 +1466,12 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create consistent timestamps for all test cases
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
tests := []struct {
name string
allowedRolesAndGroups map[string]struct{}
@@ -1456,11 +1488,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"admin", "user"},
"groups": []interface{}{"group1"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1480,11 +1514,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"allowed-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1505,11 +1541,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1523,11 +1561,13 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1545,9 +1585,11 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"jti": generateRandomString(16),
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
@@ -1622,6 +1664,112 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
}
// Helper function to compare string slices
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupServer func() *httptest.Server
expectError bool
errorContains string
}{
{
name: "Successful token exchange with redirects",
setupServer: func() *httptest.Server {
redirectCount := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 3 {
// Set a cookie before redirecting
http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
Value: "test-value",
})
redirectCount++
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
return
}
// Verify all cookies from previous redirects are present
cookies := r.Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
for i := 0; i < 3; i++ {
found := false
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
for _, cookie := range cookies {
if cookie.Name == expectedName {
found = true
break
}
}
if !found {
t.Errorf("Cookie %s not found", expectedName)
}
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
expectError: false,
},
{
name: "Too many redirects",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", r.URL.String())
w.WriteHeader(http.StatusFound)
}))
},
expectError: true,
errorContains: "stopped after 50 redirects",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupServer()
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
} else if !strings.Contains(err.Error(), tc.errorContains) {
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
}
})
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -1633,3 +1781,120 @@ func stringSliceEqual(a, b []string) bool {
}
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authURL string
issuerURL string
redirectURL string
state string
nonce string
expectedPrefix string
}{
{
name: "Absolute Auth URL",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
},
{
name: "Relative Auth URL",
authURL: "/oidc/auth",
issuerURL: "https://logto.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://logto.example.com/oidc/auth?",
},
{
name: "Relative Auth URL with Different Issuer",
authURL: "/sign-in",
issuerURL: "https://auth.example.com:8443",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure the test instance
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
}
// Parse the resulting URL to verify query parameters
parsedURL, err := url.Parse(result)
if err != nil {
t.Fatalf("Failed to parse resulting URL: %v", err)
}
query := parsedURL.Query()
expectedParams := map[string]string{
"client_id": tOidc.clientID,
"response_type": "code",
"redirect_uri": tc.redirectURL,
"state": tc.state,
"nonce": tc.nonce,
}
for key, expected := range expectedParams {
if got := query.Get(key); got != expected {
t.Errorf("Expected %s=%q, got %q", key, expected, got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
if got := query.Get("scope"); got != expectedScopes {
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
}
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
rw := httptest.NewRecorder()
// Get session
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath()
expectedPath := "/protected/resource?param1=value1&param2=value2"
if incomingPath != expectedPath {
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
}
}
+90
View File
@@ -0,0 +1,90 @@
package traefikoidc
import (
"fmt"
"net/http"
"sync"
"time"
)
type MetadataCache struct {
metadata *ProviderMetadata
expiresAt time.Time
mutex sync.RWMutex
autoCleanupInterval time.Duration
stopCleanup chan struct{}
}
func NewMetadataCache() *MetadataCache {
c := &MetadataCache{
autoCleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
}
go c.startAutoCleanup()
return c
}
// Cleanup removes expired metadata from the cache.
func (c *MetadataCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
if c.metadata != nil && now.After(c.expiresAt) {
c.metadata = nil
}
}
func (c *MetadataCache) isCacheValid() bool {
return c.metadata != nil && time.Now().Before(c.expiresAt)
}
// GetMetadata retrieves the metadata from cache or fetches it if expired
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
c.mutex.RLock()
if c.isCacheValid() {
defer c.mutex.RUnlock()
return c.metadata, nil
}
c.mutex.RUnlock()
c.mutex.Lock()
defer c.mutex.Unlock()
// Double-check after acquiring write lock
if c.isCacheValid() {
return c.metadata, nil
}
metadata, err := discoverProviderMetadata(providerURL, httpClient, logger)
if err != nil {
if c.metadata != nil {
// On error, extend current cache by 5 minutes to prevent thundering herd
c.expiresAt = time.Now().Add(5 * time.Minute)
logger.Errorf("Failed to refresh metadata, using cached version for 5 more minutes: %v", err)
return c.metadata, nil
}
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
}
c.metadata = metadata
// Calculate expiration time based on usage patterns
usageCount := 0 // This should be replaced with actual usage tracking logic
if usageCount < 10 {
c.expiresAt = time.Now().Add(30 * time.Minute)
} else if usageCount < 50 {
c.expiresAt = time.Now().Add(1 * time.Hour)
} else {
c.expiresAt = time.Now().Add(2 * time.Hour)
}
// End of GetMetadata
return metadata, nil
}
func (c *MetadataCache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
}
func (c *MetadataCache) Close() {
close(c.stopCleanup)
}
+119
View File
@@ -0,0 +1,119 @@
package traefikoidc
import (
"fmt"
"net/http"
"testing"
"time"
)
func TestIsCacheValid(t *testing.T) {
// Setup with a dummy ProviderMetadata.
pm := &ProviderMetadata{}
mc := &MetadataCache{
metadata: pm,
expiresAt: time.Now().Add(1 * time.Hour),
}
if !mc.isCacheValid() {
t.Errorf("Expected cache to be valid")
}
mc.expiresAt = time.Now().Add(-1 * time.Hour)
if mc.isCacheValid() {
t.Errorf("Expected cache to be invalid")
}
}
func TestCleanup(t *testing.T) {
pm := &ProviderMetadata{}
mc := &MetadataCache{
metadata: pm,
expiresAt: time.Now().Add(-1 * time.Hour),
}
mc.Cleanup()
if mc.metadata != nil {
t.Errorf("Expected metadata to be nil after cleanup")
}
}
func TestGetMetadata_Cached(t *testing.T) {
dummyData := &ProviderMetadata{}
// Construct MetadataCache manually to avoid interference from auto cleanup.
mc := &MetadataCache{
metadata: dummyData,
expiresAt: time.Now().Add(1 * time.Hour),
stopCleanup: make(chan struct{}),
autoCleanupInterval: 5 * time.Minute,
}
// Use NewLogger to create a logger that writes errors only.
logger := NewLogger("error")
result, err := mc.GetMetadata("http://example.com", http.DefaultClient, logger)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if result != dummyData {
t.Errorf("Expected cached metadata to be returned")
}
}
func TestMetadataCacheAutoCleanup(t *testing.T) {
mc := &MetadataCache{
autoCleanupInterval: 50 * time.Millisecond,
stopCleanup: make(chan struct{}),
}
// Start auto cleanup.
go mc.startAutoCleanup()
mc.mutex.Lock()
mc.metadata = &ProviderMetadata{}
mc.expiresAt = time.Now().Add(-50 * time.Millisecond)
mc.mutex.Unlock()
// Wait enough time for the auto cleanup to run.
time.Sleep(200 * time.Millisecond)
mc.Close()
mc.mutex.RLock()
defer mc.mutex.RUnlock()
if mc.metadata != nil {
t.Errorf("Expected metadata to be cleared by auto cleanup")
}
}
type errorRoundTripper struct {
err error
}
func (e errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, e.err
}
func TestGetMetadata_FetchError(t *testing.T) {
// Create an HTTP client that always returns an error.
errorClient := &http.Client{
Transport: errorRoundTripper{err: fmt.Errorf("fake fetch error")},
}
// Case 1: Cache is empty.
mc := &MetadataCache{
stopCleanup: make(chan struct{}),
}
logger := NewLogger("error")
metadata, err := mc.GetMetadata("http://example.com", errorClient, logger)
if err == nil {
t.Errorf("Expected error, got nil")
}
if metadata != nil {
t.Errorf("Expected nil metadata, got %v", metadata)
}
// Case 2: Cache has old metadata.
dummy := &ProviderMetadata{}
mc.metadata = dummy
mc.expiresAt = time.Now().Add(-1 * time.Minute)
logger2 := NewLogger("error")
metadata, err = mc.GetMetadata("http://example.com", errorClient, logger2)
if err != nil {
t.Errorf("Expected no error when cached metadata exists, got %v", err)
}
if metadata != dummy {
t.Errorf("Expected cached metadata to be returned")
}
}
+10
View File
@@ -0,0 +1,10 @@
version: 1
force:
existing: true
wording:
patch:
- patch-release
minor:
- minor-release
major:
- breaking
+165 -112
View File
@@ -3,30 +3,38 @@ package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// generateSecureRandomString creates a cryptographically secure random string of specified length.
// It returns the generated string or an error if random generation fails.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return hex.EncodeToString(bytes), nil
}
// Cookie names and configuration constants used for session management
const (
// mainCookieName is the name of the main session cookie that stores authentication state
// and basic user information like email and CSRF tokens
mainCookieName = "_raczylo_oidc"
// accessTokenCookie is the name of the cookie that stores the OIDC access token
// This may be split into multiple cookies if the token is large
accessTokenCookie = "_raczylo_oidc_access"
// refreshTokenCookie is the name of the cookie that stores the OIDC refresh token
// This may be split into multiple cookies if the token is large
refreshTokenCookie = "_raczylo_oidc_refresh"
// Using fixed prefixes for consistent cookie naming across restarts
mainCookieName = "_oidc_raczylo_m"
accessTokenCookie = "_oidc_raczylo_a"
refreshTokenCookie = "_oidc_raczylo_r"
)
const (
// maxCookieSize is the maximum size for each cookie chunk.
// This value is calculated to ensure the final cookie size stays within browser limits:
// 1. Browser cookie size limit is typically 4096 bytes
@@ -39,9 +47,16 @@ const (
// - Solving for x: x ≤ 3044
// 4. We use 2000 as a conservative limit to account for cookie metadata
maxCookieSize = 2000
// absoluteSessionTimeout defines the maximum lifetime of a session
// regardless of activity (24 hours)
absoluteSessionTimeout = 24 * time.Hour
// minEncryptionKeyLength defines the minimum length for the encryption key
minEncryptionKeyLength = 32
)
// compressToken compresses a token using gzip and base64 encodes it
// compressToken compresses a token using gzip and base64 encodes it.
func compressToken(token string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
@@ -54,41 +69,41 @@ func compressToken(token string) string {
return base64.StdEncoding.EncodeToString(b.Bytes())
}
// decompressToken decompresses a base64 encoded gzipped token
// decompressToken decompresses a base64 encoded gzipped token.
func decompressToken(compressed string) string {
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return compressed // return as-is if not base64
}
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return compressed
}
defer gz.Close()
decompressed, err := io.ReadAll(gz)
if err != nil {
return compressed
}
return string(decompressed)
}
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
// and other session-related data across multiple cookies.
type SessionManager struct {
// store is the underlying session store for cookie management
// store is the underlying session store for cookie management.
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
// forceHTTPS enforces secure cookie attributes regardless of request scheme.
forceHTTPS bool
// logger provides structured logging capabilities
// logger provides structured logging capabilities.
logger *Logger
// sessionPool is a sync.Pool for reusing SessionData objects
// sessionPool is a sync.Pool for reusing SessionData objects.
sessionPool sync.Pool
}
@@ -97,29 +112,36 @@ type SessionManager struct {
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
//
// Returns an error if the encryption key does not meet minimum length requirements.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) {
// Validate encryption key length.
if len(encryptionKey) < minEncryptionKeyLength {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
// Initialize session pool
// Initialize session pool.
sm.sessionPool.New = func() interface{} {
return &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
}
}
return sm
return sm, nil
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
// - isSecure: Whether the current request is using HTTPS.
//
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
@@ -130,7 +152,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
MaxAge: int(absoluteSessionTimeout.Seconds()),
Path: "/",
}
}
@@ -140,7 +162,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
// Get session from pool
// Get session from pool.
sessionData := sm.sessionPool.Get().(*SessionData)
sessionData.request = r
@@ -151,6 +173,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
// Check for absolute session timeout.
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
sessionData.Clear(r, nil)
return nil, fmt.Errorf("session expired")
}
}
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
@@ -163,7 +193,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
// Clear and reuse chunk maps
// Clear and reuse chunk maps.
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
@@ -171,7 +201,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
delete(sessionData.refreshTokenChunks, k)
}
// Retrieve chunked token sessions
// Retrieve chunked token sessions.
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
@@ -188,7 +218,6 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
@@ -200,27 +229,27 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
// manager is the SessionManager that created this SessionData.
manager *SessionManager
// request is the current HTTP request associated with this session
// request is the current HTTP request associated with this session.
request *http.Request
// mainSession stores authentication state and basic user info
// mainSession stores authentication state and basic user info.
mainSession *sessions.Session
// accessSession stores the primary access token cookie
// accessSession stores the primary access token cookie.
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
// refreshSession stores the primary refresh token cookie.
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
// when it exceeds the maximum cookie size.
refreshTokenChunks map[int]*sessions.Session
}
@@ -231,28 +260,28 @@ type SessionData struct {
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
// Set options for all sessions.
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
// Save main session.
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
// Save access token session.
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
// Save refresh token session.
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
// Save access token chunks.
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -260,7 +289,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
}
// Save refresh token chunks
// Save refresh token chunks.
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
@@ -272,10 +301,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
// Clear and expire all sessions.
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
@@ -290,21 +317,25 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
// Clear chunk sessions.
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
err := sd.Save(r, w)
// Return session to pool
var err error
if w != nil {
err = sd.Save(r, w)
}
// Clear transient per-request fields.
sd.request = nil
// Return session to pool.
sd.manager.sessionPool.Put(sd)
return err
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
@@ -315,23 +346,36 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication,
// false otherwise or if the authentication status cannot be determined.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
if !auth {
return false
}
// Check session expiration.
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
}
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
}
// SetAuthenticated updates the session's authentication status.
// This should be called after successful OIDC authentication or during logout.
func (sd *SessionData) SetAuthenticated(value bool) {
// SetAuthenticated updates the session's authentication status and rotates session ID.
// Returns an error if generating a new session ID fails.
func (sd *SessionData) SetAuthenticated(value bool) error {
if value {
id, err := generateSecureRandomString(32)
if err != nil {
return fmt.Errorf("failed to generate secure session id: %w", err)
}
sd.mainSession.ID = id
sd.mainSession.Values["created_at"] = time.Now().Unix()
}
sd.mainSession.Values["authenticated"] = value
return nil
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
@@ -342,7 +386,7 @@ func (sd *SessionData) GetAccessToken() string {
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.accessTokenChunks) == 0 {
return ""
}
@@ -366,23 +410,23 @@ func (sd *SessionData) GetAccessToken() string {
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetAccessToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token.
sd.accessTokenChunks = make(map[int]*sessions.Session)
// Compress token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
// Split compressed token into chunks.
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
@@ -396,9 +440,6 @@ func (sd *SessionData) SetAccessToken(token string) {
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
@@ -409,7 +450,7 @@ func (sd *SessionData) GetRefreshToken() string {
return token
}
// Reassemble token from chunks
// Reassemble token from chunks.
if len(sd.refreshTokenChunks) == 0 {
return ""
}
@@ -433,23 +474,23 @@ func (sd *SessionData) GetRefreshToken() string {
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetRefreshToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
// Expire any existing chunk cookies first.
if sd.request != nil {
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
}
// Clear and prepare chunks map for new token.
sd.refreshTokenChunks = make(map[int]*sessions.Session)
// Compress token
// Compress token.
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
// Split compressed token into chunks.
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
@@ -462,12 +503,43 @@ func (sd *SessionData) SetRefreshToken(token string) {
}
}
// expireAccessTokenChunks expires any existing access token chunk cookies.
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err)
}
}
}
}
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, err := sd.manager.store.Get(sd.request, sessionName)
if err != nil || session.IsNew {
break
}
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if w != nil {
if err := session.Save(sd.request, w); err != nil {
sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err)
}
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
@@ -483,64 +555,45 @@ func splitIntoChunks(s string, chunkSize int) []string {
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
// SetIncomingPath stores the original request path that triggered the authentication flow.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+236 -122
View File
@@ -5,14 +5,8 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
)
func init() {
// Initialize random seed
rand.Seed(time.Now().UnixNano())
}
// generateRandomString creates a random string of specified length
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
@@ -56,9 +50,9 @@ func TestTokenCompression(t *testing.T) {
if len(tt.token) > 100 {
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
if compressionRatio > 1.1 { // Allow up to 10% size increase
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
len(tt.token), len(compressed), compressionRatio)
}
}
@@ -77,6 +71,120 @@ func TestTokenCompression(t *testing.T) {
}
// TestSessionManager tests the SessionManager functionality
func TestCookiePrefix(t *testing.T) {
// Create a session and verify cookie names
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set some data to ensure cookies are created
session.SetAuthenticated(true)
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Set new tokens
session.SetAccessToken("test_token")
session.SetRefreshToken("test_refresh_token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Check cookie prefixes
cookies := rr.Result().Cookies()
for _, cookie := range cookies {
if !strings.HasPrefix(cookie.Name, "_oidc_raczylo_") {
t.Errorf("Cookie %s does not have expected prefix '_oidc_raczylo_'", cookie.Name)
}
}
}
func TestTokenRefreshCleanup(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set a large token that will be split into chunks
largeToken := strings.Repeat("x", 5000)
session.SetAccessToken(largeToken)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get initial cookies
initialCookies := rr.Result().Cookies()
// Create a new request with the initial cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range initialCookies {
newReq.AddCookie(cookie)
}
newRr := httptest.NewRecorder()
// Get session with cookies and set a new token
newSession, err := sm.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Create a response recorder for expired cookies
expiredRr := httptest.NewRecorder()
// Expire old chunk cookies
newSession.expireAccessTokenChunks(expiredRr)
// Set a smaller token that won't need chunks
newSession.SetAccessToken("small_token")
// Save session with new token
if err := newSession.Save(newReq, newRr); err != nil {
t.Fatalf("Failed to save new session: %v", err)
}
// Check cookies in response where old cookies are expired
intermediateResponse := expiredRr.Result()
intermediateCount := 0
chunkCount := 0
expiredCount := 0
for _, cookie := range intermediateResponse.Cookies() {
if strings.Contains(cookie.Name, "_oidc_raczylo_a_") && strings.Count(cookie.Name, "_") > 3 {
chunkCount++
if cookie.MaxAge < 0 {
expiredCount++
t.Logf("Found expired chunk cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
} else if cookie.MaxAge >= 0 {
intermediateCount++
t.Logf("Found active cookie: %s (MaxAge=%d)", cookie.Name, cookie.MaxAge)
}
}
// All chunk cookies should be expired
if chunkCount > 0 && chunkCount != expiredCount {
t.Errorf("Not all chunk cookies are expired: %d chunks, %d expired", chunkCount, expiredCount)
}
// Should have fewer active cookies after setting smaller token
if intermediateCount >= len(initialCookies) {
t.Errorf("Expected fewer active cookies after token refresh, got %d, want less than %d", intermediateCount, len(initialCookies))
}
}
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
@@ -84,154 +192,160 @@ func TestSessionManager(t *testing.T) {
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
email string
accessToken string
refreshToken string
expectedCookieCount int
wantCompressed bool // Whether tokens should be compressed
wantCompressed bool // Whether tokens should be compressed
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: true,
wantCompressed: true,
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
wantCompressed: true,
wantCompressed: true,
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
wantCompressed: true,
wantCompressed: true,
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: false,
wantCompressed: false,
},
{
name: "Random content tokens",
authenticated: true,
email: "test@example.com",
accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000),
name: "Random content tokens",
authenticated: true,
email: "test@example.com",
accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000),
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
wantCompressed: true,
wantCompressed: true,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Expire any existing cookies
session.expireAccessTokenChunks(rr)
session.expireRefreshTokenChunks(rr)
// Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Set new tokens
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Verify compression is working by checking token sizes
for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes
originalSize := len(tc.accessToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Verify compression is working by checking token sizes
for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes
originalSize := len(tc.accessToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
}
}
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
}
}
@@ -242,12 +356,12 @@ func calculateExpectedCookieCount(accessToken, refreshToken string) int {
calculateChunks := func(token string) int {
// Compress token (matching the actual implementation)
compressed := compressToken(token)
// If compressed token fits in one cookie, no additional chunks needed
if len(compressed) <= maxCookieSize {
return 0
}
// Calculate chunks needed for compressed token
return len(splitIntoChunks(compressed, maxCookieSize))
}
+75 -32
View File
@@ -10,10 +10,6 @@ import (
"strings"
)
const (
cookieName = "_raczylo_oidc"
)
// Config holds the configuration for the OIDC middleware.
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
@@ -85,30 +81,34 @@ type Config struct {
HTTPClient *http.Client
}
// CreateConfig creates a new Config with sensible default values.
const (
// DefaultRateLimit defines the default rate limit for requests per second
DefaultRateLimit = 100
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
MinRateLimit = 10
// DefaultLogLevel defines the default logging level
DefaultLogLevel = "info"
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
MinSessionEncryptionKeyLength = 32
)
// CreateConfig creates a new Config with secure default values.
// Default values are set for optional fields:
// - Scopes: ["openid", "profile", "email"]
// - LogLevel: "info"
// - LogoutURL: CallbackURL + "/logout"
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
func CreateConfig() *Config {
c := &Config{}
if c.Scopes == nil {
c.Scopes = []string{"openid", "profile", "email"}
}
if c.LogLevel == "" {
c.LogLevel = "info"
}
if c.LogoutURL == "" {
c.LogoutURL = c.CallbackURL + "/logout"
}
if c.RateLimit == 0 {
c.RateLimit = 100
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
}
return c
@@ -118,43 +118,85 @@ func CreateConfig() *Config {
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
func (c *Config) Validate() error {
// Validate provider URL
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if !isValidURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid URL")
if !isValidSecureURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid HTTPS URL")
}
// Validate callback URL
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if !strings.HasPrefix(c.CallbackURL, "/") {
return fmt.Errorf("callbackURL must start with /")
}
// Validate client credentials
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required")
}
// Validate session encryption key
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
if len(c.SessionEncryptionKey) < 32 {
return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long")
}
if c.RateLimit < 0 {
return fmt.Errorf("rateLimit must be non-negative")
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
}
// Validate log level
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
return fmt.Errorf("logLevel must be one of: debug, info, error")
}
// Validate excluded URLs
for _, url := range c.ExcludedURLs {
if !strings.HasPrefix(url, "/") {
return fmt.Errorf("excluded URL must start with /: %s", url)
}
if strings.Contains(url, "..") {
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
}
if strings.Contains(url, "*") {
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
}
}
// Validate revocation URL if set
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
}
// Validate end session URL if set
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
}
// Validate post-logout redirect URI if set
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
}
}
// Validate rate limit
if c.RateLimit < MinRateLimit {
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
}
return nil
}
// isValidURL checks if the provided string is a valid URL
func isValidURL(s string) bool {
// isValidSecureURL checks if the provided string is a valid HTTPS URL
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme != "" && u.Host != ""
return err == nil && u.Scheme == "https" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
@@ -179,6 +221,7 @@ type Logger struct {
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
//
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
func NewLogger(logLevel string) *Logger {
@@ -187,7 +230,7 @@ func NewLogger(logLevel string) *Logger {
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
+49 -14
View File
@@ -23,13 +23,18 @@ func TestCreateConfig(t *testing.T) {
}
// Check default log level
if config.LogLevel != "info" {
t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel)
if config.LogLevel != DefaultLogLevel {
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
}
// Check default rate limit
if config.RateLimit != 100 {
t.Errorf("Expected default rate limit 100, got %d", config.RateLimit)
if config.RateLimit != DefaultRateLimit {
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
}
// Check ForceHTTPS default
if !config.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true by default")
}
})
@@ -38,6 +43,7 @@ func TestCreateConfig(t *testing.T) {
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
config.ForceHTTPS = false
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
@@ -49,6 +55,9 @@ func TestCreateConfig(t *testing.T) {
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
if config.ForceHTTPS {
t.Error("Custom ForceHTTPS value was overwritten")
}
})
}
@@ -98,15 +107,15 @@ func TestConfigValidate(t *testing.T) {
expectedError: "sessionEncryptionKey is required",
},
{
name: "Invalid ProviderURL",
name: "Non-HTTPS ProviderURL",
config: &Config{
ProviderURL: "not-a-url",
ProviderURL: "http://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid URL",
expectedError: "providerURL must be a valid HTTPS URL",
},
{
name: "Invalid CallbackURL",
@@ -131,16 +140,16 @@ func TestConfigValidate(t *testing.T) {
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Negative RateLimit",
name: "Low RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: -1,
RateLimit: 5,
},
expectedError: "rateLimit must be non-negative",
expectedError: "rateLimit must be at least 10",
},
{
name: "Invalid LogLevel",
@@ -154,6 +163,30 @@ func TestConfigValidate(t *testing.T) {
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Non-HTTPS RevocationURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RevocationURL: "http://revoke.com",
},
expectedError: "revocationURL must be a valid HTTPS URL",
},
{
name: "Non-HTTPS OIDCEndSessionURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
OIDCEndSessionURL: "http://endsession.com",
},
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
},
{
name: "Valid Config",
config: &Config{
@@ -164,6 +197,8 @@ func TestConfigValidate(t *testing.T) {
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
RevocationURL: "https://revoke.com",
OIDCEndSessionURL: "https://endsession.com",
},
expectedError: "",
},
@@ -192,9 +227,9 @@ func TestLogger(t *testing.T) {
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
@@ -289,7 +324,7 @@ func TestLogger(t *testing.T) {
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+1 -5
View File
@@ -1,7 +1,4 @@
# Gorilla Sessions
> [!IMPORTANT]
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
# sessions
![testing](https://github.com/gorilla/sessions/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/sessions/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/sessions)
@@ -77,7 +74,6 @@ Other implementations of the `sessions.Store` interface:
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
+9 -12
View File
@@ -1,6 +1,5 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
@@ -9,15 +8,13 @@ import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Partitioned: options.Partitioned,
SameSite: options.SameSite,
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
}
}
+21
View File
@@ -0,0 +1,21 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
+5 -10
View File
@@ -1,11 +1,8 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
@@ -16,9 +13,7 @@ type Options struct {
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
Partitioned bool
SameSite http.SameSite
MaxAge int
Secure bool
HttpOnly bool
}
+23
View File
@@ -0,0 +1,23 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
+3 -3
View File
@@ -4,9 +4,9 @@ github.com/google/uuid
# github.com/gorilla/securecookie v1.1.2
## explicit; go 1.20
github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.4.0
## explicit; go 1.23
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# golang.org/x/time v0.9.0
# golang.org/x/time v0.7.0
## explicit; go 1.18
golang.org/x/time/rate