Compare commits

...

8 Commits

Author SHA1 Message Date
lukaszraczylo f3bec6cf68 fixup! Add support for the anchor redirects 2025-03-18 01:59:31 +00:00
lukaszraczylo a3b8cbf9f3 Add support for the anchor redirects 2025-03-18 01:57:59 +00:00
lukaszraczylo 4322407129 Add support for PKCE (#31)
* Add PKCE support.
* Add option to toggle PKCE checks feature.
* GoFMT
2025-03-18 01:09:14 +00:00
lukaszraczylo 4ce2815123 Update the documentation. 2025-02-25 14:02:08 +00:00
lukaszraczylo 7d204113ea Cleanup the codebase, DRY and abstract functions, increase the test coverage. 2025-02-25 12:53:52 +00:00
lukaszraczylo c721913cbe Increase tests coverage. 2025-02-24 12:25:32 +00:00
lukaszraczylo 0f8b7f7ab1 Abstract the cleanup logic and add helper for cache valid. 2025-02-24 12:02:12 +00:00
lukaszraczylo 2743b0e024 Ensure cleanups actually happen. 2025-02-24 00:19:44 +00:00
14 changed files with 1516 additions and 565 deletions
+232 -18
View File
@@ -4,28 +4,242 @@ type: middleware
import: github.com/lukaszraczylo/traefikoidc
summary: |
Middleware adding OIDC authentication to traefik routes. Does what it says on the tin.
Middleware has been tested with Auth0 and Logto. It should work with any OIDC provider.
Middleware adding OpenID Connect (OIDC) authentication to Traefik routes.
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy.
It provides a complete OIDC authentication solution with features like domain restrictions,
role-based access control, token caching, and more.
The middleware has been tested with Auth0, Logto, Google, and other standard OIDC providers.
It supports various authentication scenarios including:
- Basic authentication with customizable callback and logout URLs
- Email domain restrictions to limit access to specific organizations
- Role and group-based access control
- Public URLs that bypass authentication
- Rate limiting to prevent brute force attacks
- Custom post-logout redirect behavior
- Secure session management with encrypted cookies
- Automatic token validation and refresh
testData:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
# Required parameters
providerURL: https://accounts.google.com # Base URL of the OIDC provider
clientID: 1234567890.apps.googleusercontent.com # OAuth 2.0 client identifier
clientSecret: secret # OAuth 2.0 client secret
callbackURL: /oauth2/callback # Path where the OIDC provider will redirect after authentication
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long # Key used to encrypt session data (must be at least 32 bytes)
# Optional parameters with defaults
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
scopes: # OAuth 2.0 scopes to request (default: ["openid", "email", "profile"])
- openid
- email
- profile
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
- raczylo.com
allowedRolesAndGroups:
- roles # Include this to get role information from the provider
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
- company.com
- subsidiary.com
allowedRolesAndGroups: # Restricts access to users with specific roles or groups (if not provided, no role/group restrictions)
- guest-endpoints
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
forceHTTPS: false
logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- admin
- developer
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
excludedURLs: # Lists paths that bypass authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
- /public
- /health
- /metrics
# Advanced parameters (usually discovered automatically from provider metadata)
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
# Configuration documentation
configuration:
providerURL:
type: string
description: |
The base URL of the OIDC provider. This is the issuer URL that will be used to discover
OIDC endpoints like authorization, token, and JWKS URIs.
Examples:
- https://accounts.google.com
- https://login.microsoftonline.com/tenant-id/v2.0
- https://your-auth0-domain.auth0.com
- https://your-logto-instance.com/oidc
required: true
clientID:
type: string
description: |
The OAuth 2.0 client identifier obtained from your OIDC provider.
This is the public identifier for your application.
required: true
clientSecret:
type: string
description: |
The OAuth 2.0 client secret obtained from your OIDC provider.
This should be kept confidential and not exposed in client-side code.
For Kubernetes deployments, you can use the secret reference format:
urn:k8s:secret:namespace:secret-name:key
required: true
callbackURL:
type: string
description: |
The path where the OIDC provider will redirect after authentication.
This must match one of the redirect URIs configured in your OIDC provider.
The full redirect URI will be constructed as:
[scheme]://[host][callbackURL]
Example: /oauth2/callback
required: true
sessionEncryptionKey:
type: string
description: |
Key used to encrypt session data stored in cookies.
Must be at least 32 bytes long for security.
Example: potato-secret-is-at-least-32-bytes-long
required: true
logoutURL:
type: string
description: |
The path for handling logout requests.
If not provided, it will be set to callbackURL + "/logout".
Example: /oauth2/logout
required: false
postLogoutRedirectURI:
type: string
description: |
The URL to redirect to after logout.
Default: "/"
Example: /logged-out-page
required: false
scopes:
type: array
description: |
The OAuth 2.0 scopes to request from the OIDC provider.
Default: ["openid", "profile", "email"]
Include "roles" or similar scope if you need role/group information.
required: false
items:
type: string
logLevel:
type: string
description: |
Sets the logging verbosity.
Valid values: "debug", "info", "error"
Default: "info"
required: false
enum:
- debug
- info
- error
forceHTTPS:
type: boolean
description: |
Forces the use of HTTPS for all URLs.
This is recommended for security in production environments.
Default: true
required: false
rateLimit:
type: integer
description: |
Sets the maximum number of requests per second.
This helps prevent brute force attacks.
Default: 100
Minimum: 10
required: false
excludedURLs:
type: array
description: |
Lists paths that bypass authentication.
These paths will be accessible without OIDC authentication.
The middleware uses prefix matching, so "/public" will match
"/public", "/public/page", "/public-data", etc.
Examples: ["/health", "/metrics", "/public"]
required: false
items:
type: string
allowedUserDomains:
type: array
description: |
Restricts access to users with email addresses from specific domains.
If not provided, the middleware relies entirely on the OIDC provider
for authentication decisions.
Examples: ["company.com", "subsidiary.com"]
required: false
items:
type: string
allowedRolesAndGroups:
type: array
description: |
Restricts access to users with specific roles or groups.
If not provided, no role/group restrictions are applied.
The middleware checks both the "roles" and "groups" claims in the ID token.
Examples: ["admin", "developer"]
required: false
items:
type: string
revocationURL:
type: string
description: |
The endpoint for revoking tokens.
If not provided, it will be discovered from provider metadata.
Example: https://accounts.google.com/revoke
required: false
oidcEndSessionURL:
type: string
description: |
The provider's end session endpoint.
If not provided, it will be discovered from provider metadata.
Example: https://accounts.google.com/logout
required: false
enablePKCE:
type: boolean
description: |
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
PKCE adds an extra layer of security to protect against authorization code interception attacks.
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
Default: false
required: false
+348 -126
View File
@@ -1,153 +1,308 @@
## Traefik OIDC middleware
# Traefik OIDC Middleware
This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy to support the OIDC authentication.
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy to support OpenID Connect (OIDC) authentication.
Middleware has been tested with Auth0 and Logto.
## Overview
### Traefik version compatibility
The Traefik OIDC middleware provides a complete OIDC authentication solution with features like:
- Token validation and verification
- Session management
- Domain restrictions
- Role-based access control
- Token caching and blacklisting
- Rate limiting
- Excluded paths (public URLs)
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
The middleware has been tested with Auth0 and Logto, but should work with any standard OIDC provider.
### Configuration options
## Traefik Version Compatibility
Middleware currently supports following scenarios:
This middleware follows closely the current Traefik helm chart versions. If the plugin fails to load, it's time to update to the latest version of the Traefik helm chart.
* Setting custom callback and logout URLs via `callbackURL` and `logoutURL`
* Allowing for access only from the listed domains if `allowedUserDomains` is set, otherwise it relies entirely on the OIDC provider
* Using excluded URLs which do **NOT** require the OIDC authentication
* Rate limiting requests to prevent the bruteforce attacks
## Installation
#### How to configure...
### As a Traefik Plugin
* `sessionEncryptionKey` should be at least 32 bytes long.
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
1. Enable the plugin in your Traefik static configuration:
```yaml
# traefik.yml
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.2.1 # Use the latest version
```
2. Configure the middleware in your dynamic configuration (see examples below).
### Local Development with Docker Compose
For local development or testing, you can use the provided Docker Compose setup:
```bash
cd docker
docker-compose up -d
```
This will start Traefik with the OIDC middleware and two test services.
## Configuration Options
The middleware supports the following configuration options:
### Required Parameters
| Parameter | Description | Example |
|-----------|-------------|---------|
| `providerURL` | The base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | The OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | The OAuth 2.0 client secret | `your-client-secret` |
| `sessionEncryptionKey` | Key used to encrypt session data (must be at least 32 bytes long) | `potato-secret-is-at-least-32-bytes-long` |
| `callbackURL` | The path where the OIDC provider will redirect after authentication | `/oauth2/callback` |
### Optional Parameters
| Parameter | Description | Default | Example |
|-----------|-------------|---------|---------|
| `logoutURL` | The path for handling logout requests | `callbackURL + "/logout"` | `/oauth2/logout` |
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
## Usage Examples
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-basic
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
```
### With Excluded URLs (Public Access Paths)
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-open-urls
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
excludedURLs:
- /login # covers /login, /login/me, /login/reminder etc.
- /public-data
- /health
- /metrics
```
### With Email Domain Restrictions
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-domain-restricted
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
allowedUserDomains:
- company.com
- subsidiary.com
```
### With Role-Based Access Control
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-rbac
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
- roles # Include this to get role information from the provider
allowedRolesAndGroups:
- admin
- developer
```
### With Custom Logging and Rate Limiting
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-custom-settings
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
logLevel: debug # Options: debug, info, error (default: info)
rateLimit: 500 # Requests per second (default: 100)
forceHTTPS: false # Default is true for security
scopes:
- openid
- email
- profile
```
### With Custom Post-Logout Redirect
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-custom-logout
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /logged-out-page # Where to redirect after logout
scopes:
- openid
- email
- profile
```
### With PKCE Enabled
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-pkce
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
enablePKCE: true # Enables PKCE for added security
scopes:
- openid
- email
- profile
```
### Keeping Secrets Secret in Kubernetes
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-secrets
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes:
- openid
- email
- profile
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
```
##### Excluded URLs with open access
Don't forget to create the secret:
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-open-urls
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: xxx
clientID: yyy
clientSecret: zzz
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
scopes:
- openid
- email
- profile
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
```bash
kubectl create secret generic traefik-middleware-oidc \
--from-literal=ISSUER=https://accounts.google.com \
--from-literal=CLIENT_ID=1234567890.apps.googleusercontent.com \
--from-literal=SECRET=your-client-secret \
-n traefik
```
## Complete Docker Compose Example
##### Allowed email domains
Assuming that your OIDC provider allows anyone to log in, you may want to limit the access to people using emains in specific domain.
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-only-my-users
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: xxx
clientID: yyy
clientSecret: zzz
sessionEncryptionKey: vvv
callbackURL: /new-oidc/callback
logoutURL: /new-oidc/logout
scopes:
- openid
- email
- profile
allowedUserDomains:
- raczylo.com
```
##### Allowed groups and roles
In case of multiple roles / groups and access separation for various endpoints you will need to create multiple traefik middlewares.
Following example allows access for users who have additional role `guest-endpoints` assigned.
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-guest-endpoints
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: xxx
clientID: yyy
clientSecret: zzz
sessionEncryptionKey: vvv
callbackURL: /my-oidc/callback
logoutURL: /my-oidc/logout
scopes:
- openid
- email
- profile
- roles # This line queries the OIDC provider for roles
forceHTTPS: true
allowedRolesAndGroups:
- guest-endpoints # This line specifies the roles or groups allowed to access content
allowedUserDomains:
- raczylo.com
```
#### Docker compose example
`docker-compose.yaml`
Here's a complete example of using the middleware with Docker Compose:
```yaml
version: "3.7"
services:
traefik:
image: traefik:v3.0.1
image: traefik:v3.2.1
command:
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
- "--experimental.plugins.traefikoidc.version=v0.2.1"
@@ -158,7 +313,6 @@ services:
labels:
- "traefik.http.routers.dash.rule=Host(`dash.localhost`)"
- "traefik.http.routers.dash.service=api@internal"
ports:
- "80:80"
@@ -181,8 +335,7 @@ services:
- traefik.http.routers.whoami.middlewares=my-plugin@file
```
`traefik-config/traefik.yaml`
`traefik-config/traefik.yml`:
```yaml
log:
level: INFO
@@ -211,7 +364,7 @@ providers:
filename: /etc/traefik/dynamic-configuration.yml
```
`traefik-config/dynamic-configuration.yaml`
`traefik-config/dynamic-configuration.yml`:
```yaml
http:
middlewares:
@@ -220,20 +373,89 @@ http:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: secret
clientSecret: your-client-secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
scopes: # If not provided, default scopes will be used (openid, email, profile)
postLogoutRedirectURI: /logged-out-page
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
scopes:
- openid
- email
- profile
allowedUserDomains: # If not provided - will rely entirely on the OIDC yes/no
- raczylo.com
sessionEncryptionKey: potato-secret
allowedUserDomains:
- company.com
allowedRolesAndGroups:
- admin
- developer
forceHTTPS: false
logLevel: debug # debug, info, warn, error
rateLimit: 100 # Simple rate limiter to prevent brute force attacks
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
logLevel: debug
rateLimit: 100
excludedURLs:
- /login
- /public
- /health
- /metrics
```
## Advanced Configuration
### Session Management
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
### PKCE Support
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
PKCE is recommended when:
- Your OIDC provider supports it (most modern providers do)
- You need an additional layer of security for the authorization code flow
- You're concerned about potential authorization code interception attacks
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
### Token Caching and Blacklisting
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
### Headers Set for Downstream Services
When a user is authenticated, the middleware sets the following headers for downstream services:
- `X-Forwarded-User`: The user's email address
- `X-User-Groups`: Comma-separated list of user groups (if available)
- `X-User-Roles`: Comma-separated list of user roles (if available)
- `X-Auth-Request-Redirect`: The original request URI
- `X-Auth-Request-User`: The user's email address
- `X-Auth-Request-Token`: The user's access token
### Security Headers
The middleware also sets the following security headers:
- `X-Frame-Options: DENY`
- `X-Content-Type-Options: nosniff`
- `X-XSS-Protection: 1; mode=block`
- `Referrer-Policy: strict-origin-when-cross-origin`
## Troubleshooting
### Logging
Set the `logLevel` to `debug` to get more detailed logs:
```yaml
logLevel: debug
```
### Common Issues
1. **Token verification failed**: Check that your `providerURL` is correct and accessible.
2. **Session encryption key too short**: Ensure your `sessionEncryptionKey` is at least 32 bytes long.
3. **No matching public key found**: The JWKS endpoint might be unavailable or the token's key ID (kid) doesn't match any key in the JWKS.
4. **Access denied: Your email domain is not allowed**: The user's email domain is not in the `allowedUserDomains` list.
5. **Access denied: You do not have any of the allowed roles or groups**: The user doesn't have any of the roles or groups specified in `allowedRolesAndGroups`.
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
+18
View File
@@ -0,0 +1,18 @@
package traefikoidc
import "time"
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
// It stops when a value is received on the stop channel.
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cleanup()
case <-stop:
return
}
}
}
+22
View File
@@ -0,0 +1,22 @@
package traefikoidc
import (
"sync/atomic"
"testing"
"time"
)
func TestAutoCleanupRoutine(t *testing.T) {
var counter int32
cleanupFunc := func() {
atomic.AddInt32(&counter, 1)
}
stop := make(chan struct{})
go autoCleanupRoutine(50*time.Millisecond, stop, cleanupFunc)
time.Sleep(250 * time.Millisecond)
close(stop)
if atomic.LoadInt32(&counter) < 3 {
t.Errorf("Expected cleanup to be called at least 3 times, got %d", counter)
}
}
+23 -5
View File
@@ -37,6 +37,10 @@ type Cache struct {
// maxSize is the maximum number of items allowed in the cache.
maxSize int
// autoCleanupInterval defines how often Cleanup is called automatically.
autoCleanupInterval time.Duration
// stopCleanup channel to terminate the auto cleanup goroutine.
stopCleanup chan struct{}
}
// DefaultMaxSize is the default maximum number of items in the cache.
@@ -44,12 +48,16 @@ const DefaultMaxSize = 500
// NewCache creates a new empty cache instance that is ready for use.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
c := &Cache{
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
autoCleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
}
go c.startAutoCleanup()
return c
}
// Set adds or updates an item in the cache with the specified expiration duration.
@@ -167,3 +175,13 @@ func (c *Cache) removeItem(key string) {
delete(c.elems, key)
}
}
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
func (c *Cache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
}
// Close terminates the auto cleanup goroutine.
func (c *Cache) Close() {
close(c.stopCleanup)
}
+78 -6
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
@@ -27,6 +28,31 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// generateCodeVerifier creates a cryptographically secure random string
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
// characters long, per the PKCE spec (RFC 7636).
func generateCodeVerifier() (string, error) {
// Using 32 bytes (256 bits) will produce a 43 character base64url string
verifierBytes := make([]byte, 32)
_, err := rand.Read(verifierBytes)
if err != nil {
return "", fmt.Errorf("could not generate code verifier: %w", err)
}
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
}
// deriveCodeChallenge creates a code challenge from a code verifier
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
func deriveCodeChallenge(codeVerifier string) string {
// Calculate SHA-256 hash of the code verifier
hasher := sha256.New()
hasher.Write([]byte(codeVerifier))
hash := hasher.Sum(nil)
// Base64url encode the hash to get the code challenge
return base64.RawURLEncoding.EncodeToString(hash)
}
// TokenResponse represents the response from the OIDC token endpoint.
// It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations.
@@ -54,7 +80,8 @@ type TokenResponse struct {
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
// - codeVerifier: Optional PKCE code verifier for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
@@ -64,6 +91,11 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
if grantType == "authorization_code" {
data.Set("code", codeOrToken)
data.Set("redirect_uri", redirectURL)
// Add code_verifier if PKCE is being used
if codeVerifier != "" {
data.Set("code_verifier", codeVerifier)
}
} else if grantType == "refresh_token" {
data.Set("refresh_token", codeOrToken)
}
@@ -112,7 +144,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
// This is used to refresh access tokens before they expire.
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
@@ -190,7 +222,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
// Get the code verifier from the session for PKCE flow
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
@@ -251,13 +286,41 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Redirect to original path or root
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
// For redirecting, we need to ensure URL fragments are preserved
// To do this, we'll use a small JavaScript snippet that preserves any URL fragments
// This is necessary because URL fragments are not sent to the server
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, `<!DOCTYPE html>
<html>
<head>
<title>Authentication Complete</title>
<script>
// Preserve URL fragments by combining the redirectPath with any fragment in the current URL
(function() {
var redirectPath = %q;
var redirectUrl = new URL(redirectPath, window.location.href);
// If we have a hash in the current URL, and the redirect path doesn't already have one,
// append the hash to the redirect URL to preserve anchors
if (window.location.hash && !redirectPath.includes('#')) {
redirectUrl.hash = window.location.hash;
}
window.location.replace(redirectUrl.toString());
})();
</script>
</head>
<body>
<p>Authentication successful. Redirecting...</p>
</body>
</html>`, redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
@@ -327,9 +390,18 @@ func (tc *TokenCache) Cleanup() {
}
// exchangeCodeForToken exchanges an authorization code for tokens.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
// The code verifier is only included in the token request if PKCE is enabled.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
// Only include code verifier if PKCE is enabled
effectiveCodeVerifier := ""
if t.enablePKCE && codeVerifier != "" {
effectiveCodeVerifier = codeVerifier
}
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
+27 -79
View File
@@ -1,92 +1,51 @@
package traefikoidc
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
)
// JWK represents a JSON Web Key as defined in RFC 7517.
// It contains the cryptographic key information used for token verification.
type JWK struct {
// Kty is the key type (e.g., "RSA", "EC")
Kty string `json:"kty"`
// Kid is the unique key identifier
Kid string `json:"kid"`
// Use specifies the intended use of the key (e.g., "sig" for signature)
Use string `json:"use"`
// N is the modulus for RSA keys
N string `json:"n"`
// E is the exponent for RSA keys
E string `json:"e"`
// Alg is the algorithm intended for use with the key
N string `json:"n"`
E string `json:"e"`
Alg string `json:"alg"`
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
Crv string `json:"crv"`
// X is the x-coordinate for EC keys
X string `json:"x"`
// Y is the y-coordinate for EC keys
Y string `json:"y"`
X string `json:"x"`
Y string `json:"y"`
}
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
// OIDC providers typically expose multiple keys to support key rotation.
type JWKSet struct {
// Keys is the array of JSON Web Keys
Keys []JWK `json:"keys"`
}
// JWKCache provides a thread-safe caching mechanism for JWK sets.
// It caches the keys for a configurable duration to reduce load on the OIDC provider
// while ensuring keys are refreshed periodically to handle key rotation.
type JWKCache struct {
// jwks holds the cached set of JSON Web Keys
jwks *JWKSet
// expiresAt is the timestamp when the cached keys should be refreshed
jwks *JWKSet
expiresAt time.Time
// mutex protects concurrent access to the cache
mutex sync.RWMutex
mutex sync.RWMutex
// CacheLifetime is configurable to determine how long the JWKS is cached.
CacheLifetime time.Duration
}
// JWKCacheInterface defines the interface for JWK caching operations.
// This interface allows for different caching implementations while
// maintaining consistent behavior in the token verification process.
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
Cleanup() // Add Cleanup method to the interface
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
Cleanup()
}
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
// from the OIDC provider. It implements a thread-safe double-checked locking
// pattern to prevent multiple simultaneous fetches of the same keys.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
//
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
defer c.mutex.RUnlock()
@@ -96,23 +55,25 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
c.mutex.Lock()
defer c.mutex.Unlock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil
}
jwks, err := fetchJWKS(jwksURL, httpClient)
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
if err != nil {
return nil, err
}
c.jwks = jwks
c.expiresAt = time.Now().Add(1 * time.Hour)
lifetime := c.CacheLifetime
if lifetime == 0 {
lifetime = 1 * time.Hour
}
c.expiresAt = time.Now().Add(lifetime)
return jwks, nil
}
// Cleanup removes expired JWKs from the cache.
func (c *JWKCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -123,17 +84,14 @@ func (c *JWKCache) Cleanup() {
}
}
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
//
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// Create a request with context to enforce timeout
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create JWKS request: %w", err)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
}
@@ -151,9 +109,6 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
// cryptographic functions. It supports both RSA and EC keys, delegating
// to the appropriate converter based on the key type.
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
@@ -169,9 +124,6 @@ var jwkConverters = map[string]jwkToPEMConverter{
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
// It handles base64url decoding of the modulus and exponent,
// constructs an RSA public key, and encodes it in PEM format.
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
@@ -203,10 +155,6 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
// It supports the P-256, P-384, and P-521 curves as defined in the
// OIDC specification, decoding the x and y coordinates and encoding
// the resulting public key in PEM format.
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
+84 -194
View File
@@ -4,44 +4,41 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strings"
"sync"
"time"
)
var replayCacheMu sync.Mutex
var replayCache = make(map[string]time.Time)
func cleanupReplayCache() {
now := time.Now()
for token, expiry := range replayCache {
if expiry.Before(now) {
delete(replayCache, token)
}
}
}
// ClockSkewTolerance is configurable to adjust time-based validations.
var ClockSkewTolerance = 2 * time.Minute
// JWT represents a JSON Web Token as defined in RFC 7519.
// It contains the three parts of a JWT: header, claims (payload),
// and signature, along with the original token string.
type JWT struct {
// Header contains the token metadata (algorithm, key ID, etc.)
Header map[string]interface{}
// Claims contains the token claims (subject, expiration, etc.)
Claims map[string]interface{}
// Signature contains the raw signature bytes
Header map[string]interface{}
Claims map[string]interface{}
Signature []byte
// Token is the original JWT string
Token string
Token string
}
// parseJWT parses a JWT token string into a JWT struct.
// It validates the token format and decodes the three parts
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
//
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -52,7 +49,6 @@ func parseJWT(tokenString string) (*JWT, error) {
Token: tokenString,
}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
@@ -61,7 +57,6 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
}
// Decode and unmarshal the claims
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
@@ -70,7 +65,6 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
@@ -81,28 +75,13 @@ func parseJWT(tokenString string) (*JWT, error) {
}
// Verify validates the standard JWT claims as defined in RFC 7519.
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future (with clock skew tolerance)
// - issued at time (iat) is in the past (with clock skew tolerance)
// - not before time (nbf) is in the past (with clock skew tolerance)
// - subject (sub) is present and not empty
// - algorithm matches expected value to prevent algorithm switching attacks
//
// Returns an error if any validation fails.
// Verify validates the standard JWT claims as defined in RFC 7519.
func (j *JWT) Verify(issuerURL, clientID string) error {
// Debug logging of validation parameters
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
// Debug logging of token header
fmt.Printf("Token header: %+v\n", j.Header)
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing 'alg' header")
}
// List of supported algorithms - should match those in verifySignature
supportedAlgs := map[string]bool{
"RS256": true, "RS384": true, "RS512": true,
"PS256": true, "PS384": true, "PS512": true,
@@ -114,9 +93,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
// Debug logging of all claims
fmt.Printf("Token claims: %+v\n", claims)
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
@@ -149,17 +125,36 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return err
}
// Validate nbf (not before) claim if present
if nbf, ok := claims["nbf"].(float64); ok {
if err := verifyNotBefore(nbf); err != nil {
return err
}
}
// Validate jti (JWT ID) claim if present
// Implement replay protection by checking the jti (JWT ID)
if jti, ok := claims["jti"].(string); ok {
// Could add replay detection here if needed
_ = jti
// Skip replay detection for tokens that are being verified from the cache
if j.Token == "" {
// This is a parsed JWT without the original token string,
// which means it's likely from a cached token verification
return nil
}
replayCacheMu.Lock()
cleanupReplayCache()
if _, exists := replayCache[jti]; exists {
replayCacheMu.Unlock()
return fmt.Errorf("token replay detected")
}
expFloat, ok := claims["exp"].(float64)
var expTime time.Time
if ok {
expTime = time.Unix(int64(expFloat), 0)
} else {
expTime = time.Now().Add(10 * time.Minute)
}
replayCache[jti] = expTime
replayCacheMu.Unlock()
}
sub, ok := claims["sub"].(string)
@@ -169,20 +164,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience validates the token's audience claim.
// The audience can be either a single string or an array of strings.
// For array audiences, the expected audience must match any one value.
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
//
// Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Debug logging
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
tokenAudience, expectedAudience)
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
@@ -205,165 +187,80 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
return nil
}
// verifyIssuer validates the token's issuer claim.
// The issuer URL must exactly match the expected issuer.
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
//
// Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
// Debug logging
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
tokenIssuer, expectedIssuer)
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
tokenIssuer, expectedIssuer)
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
}
return nil
}
// Clock skew tolerance for time-based validations
const clockSkewTolerance = 2 * time.Minute
// verifyTimeConstraint is a generic function to verify time-based claims
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
claimTime := time.Unix(int64(unixTime), 0)
now := time.Now().Truncate(time.Second)
// For expiration (future=true), we add skew to now (making now later)
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
skewDirection := 1
if !future {
skewDirection = -1
}
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
if claimTime.Equal(now) {
return nil
}
// For expiration: if skewedNow (later) is after expiration, token expired
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
var reason string
if future {
reason = "has expired"
} else {
if claimName == "iat" {
reason = "used before issued"
} else {
reason = "not yet valid"
}
}
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
}
return nil
}
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time with clock skew tolerance.
// Parameters:
// - expiration: The expiration timestamp from the token
//
// Returns an error if the token has expired.
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(clockSkewTolerance)
// Debug logging
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
expirationTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that expire exactly now
if expirationTime.Equal(now) {
return nil
}
if skewedNow.After(expirationTime) {
return fmt.Errorf("token has expired (exp: %v, now: %v)",
expirationTime.UTC(), now.UTC())
}
return nil
return verifyTimeConstraint(expiration, "exp", true)
}
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, accounting for clock skew.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
//
// Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
issuedAtTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens issued in the same second as current time
if issuedAtTime.Equal(now) {
return nil
}
if skewedNow.Before(issuedAtTime) {
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
issuedAtTime.UTC(), now.UTC())
}
return nil
return verifyTimeConstraint(issuedAt, "iat", false)
}
// verifyNotBefore validates the token's not-before time if present.
// Ensures the token is not used before its valid time period, accounting for clock skew.
// Parameters:
// - notBefore: The not-before timestamp from the token
//
// Returns an error if the token is not yet valid.
func verifyNotBefore(notBefore float64) error {
notBeforeTime := time.Unix(int64(notBefore), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
notBeforeTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that become valid exactly now
if notBeforeTime.Equal(now) {
return nil
}
if skewedNow.Before(notBeforeTime) {
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
notBeforeTime.UTC(), now.UTC())
}
return nil
return verifyTimeConstraint(notBefore, "nbf", false)
}
// verifySignature validates the token's cryptographic signature.
// Supports multiple signature algorithms:
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
//
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
//
// Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Debug logging
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Decode the PEM-encoded public key
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
@@ -374,27 +271,20 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
+196 -84
View File
@@ -18,6 +18,40 @@ import (
"golang.org/x/time/rate"
)
// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC
func createDefaultHTTPClient() *http.Client {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
}
return &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
}
const ConstSessionTimeout = 86400 // Session timeout in seconds
// TokenVerifier interface for token verification
@@ -49,6 +83,7 @@ type TraefikOidc struct {
scopes []string
limiter *rate.Limiter
forceHTTPS bool
enablePKCE bool
scheme string
tokenCache *TokenCache
httpClient *http.Client
@@ -59,7 +94,7 @@ type TraefikOidc struct {
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
@@ -82,11 +117,50 @@ var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
// VerifyToken verifies the provided JWT token
// VerifyToken implements the TokenVerifier interface to verify an OIDC token.
// It performs a complete verification process including:
// 1. Checking the token cache to avoid redundant verifications
// 2. Performing rate limiting and blacklist checks
// 3. Parsing the JWT structure
// 4. Verifying the JWT signature against the JWKS from the provider
// 5. Validating standard JWT claims (iss, aud, exp, etc.)
// 6. Caching the verified token for future requests
//
// Returns nil if the token is valid, or an error describing the validation failure.
func (t *TraefikOidc) VerifyToken(token string) error {
// Check cache first
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
t.logger.Debugf("Token found in cache with valid claims; skipping verification")
return nil
}
t.logger.Debugf("Verifying token")
// Rate limiting
// Perform pre-verification checks
if err := t.performPreVerificationChecks(token); err != nil {
return err
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and standard claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the verified token
t.cacheVerifiedToken(token, jwt.Claims)
return nil
}
// performPreVerificationChecks performs rate limiting and blacklist checks
func (t *TraefikOidc) performPreVerificationChecks(token string) error {
// Enforce rate limiting
if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
@@ -96,30 +170,15 @@ func (t *TraefikOidc) VerifyToken(token string) error {
return fmt.Errorf("token is blacklisted")
}
// Check if token is cached
if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached
}
return nil
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the token until it expires
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
// cacheVerifiedToken caches a verified token until its expiration time
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
expirationTime := time.Unix(int64(claims["exp"].(float64)), 0)
now := time.Now()
duration := expirationTime.Sub(now)
t.tokenCache.Set(token, jwt.Claims, duration)
return nil
t.tokenCache.Set(token, claims, duration)
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
@@ -127,7 +186,7 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
t.logger.Debugf("Verifying JWT signature and claims")
// Get JWKS
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
@@ -173,7 +232,24 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
return nil
}
// New creates a new instance of the OIDC middleware
// New creates a new instance of the OIDC middleware.
// This is the main entry point for the middleware and is called by Traefik when loading the plugin.
// It initializes all components needed for OIDC authentication:
// - Session management for storing user state
// - Token caching and blacklisting
// - JWK caching for signature verification
// - Rate limiting to prevent abuse
// - Metadata discovery for OIDC provider endpoints
//
// Parameters:
// - ctx: Context for initialization operations
// - next: The next handler in the middleware chain
// - config: Configuration options for the middleware
// - name: Identifier for this middleware instance
//
// Returns:
// - An http.Handler that implements the middleware
// - An error if initialization fails
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
if config == nil {
config = CreateConfig()
@@ -187,7 +263,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// Initialize logger
logger := NewLogger(config.LogLevel)
// Ensure key meets minimum length requirement
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
@@ -198,44 +273,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
}
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
}
var httpClient *http.Client
if config.HTTPClient != nil {
httpClient = config.HTTPClient
} else {
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
}
httpClient = createDefaultHTTPClient()
}
t := &TraefikOidc{
next: next,
name: name,
@@ -258,6 +302,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
enablePKCE: config.EnablePKCE,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
@@ -266,9 +311,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
logger: logger,
}
// Assign the initialized logger
t.logger = logger
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
@@ -303,12 +347,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
t.updateMetadataEndpoints(metadata)
// Start metadata refresh goroutine
go t.startMetadataRefresh(providerURL)
@@ -321,6 +360,16 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Error("Received nil metadata")
}
// updateMetadataEndpoints updates the middleware with metadata endpoints
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
}
// startMetadataRefresh periodically refreshes the OIDC metadata
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
ticker := time.NewTicker(1 * time.Hour)
@@ -335,12 +384,7 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
}
if metadata != nil {
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
t.updateMetadataEndpoints(metadata)
t.logger.Debug("Successfully refreshed metadata")
}
}
@@ -408,7 +452,18 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware
// ServeHTTP is the main handler for the middleware that processes all HTTP requests.
// It implements the http.Handler interface and performs the following operations:
// 1. Waits for OIDC provider metadata initialization to complete
// 2. Checks if the requested URL is in the excluded list (bypassing authentication)
// 3. Retrieves or creates a user session
// 4. Handles special paths like callback and logout URLs
// 5. Verifies authentication status and token validity
// 6. Refreshes tokens that are about to expire
// 7. Validates user email domains, roles, and groups against configured restrictions
// 8. Sets appropriate headers for downstream services
// 9. Applies security headers to responses
// 10. Forwards the authenticated request to the next handler
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
@@ -597,7 +652,17 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated
// isUserAuthenticated checks if the user is authenticated by validating their session and token.
// It performs a comprehensive check of the authentication state including:
// 1. Verifying the session's authenticated flag
// 2. Checking for the presence of an access token
// 3. Validating the token's signature and claims
// 4. Checking the token's expiration time
//
// Returns three boolean values:
// - authenticated: Whether the user is currently authenticated
// - needsRefresh: Whether the token is valid but will expire soon (within grace period)
// - expired: Whether the token has expired or is otherwise invalid
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
@@ -645,7 +710,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return true, false, false
}
// defaultInitiateAuthentication initiates the authentication process
// defaultInitiateAuthentication initiates the OIDC authentication process.
// This function prepares and starts a new authentication flow by:
// 1. Generating security tokens (CSRF token and nonce) to prevent attacks
// 2. Clearing any existing session data to avoid state conflicts
// 3. Storing the original request path to redirect back after authentication
// 4. Building the authorization URL with all required OIDC parameters
// 5. Redirecting the user to the OIDC provider's authorization endpoint
//
// Parameters:
// - rw: The HTTP response writer for sending the redirect
// - req: The original HTTP request that triggered authentication
// - session: The user's session data for storing authentication state
// - redirectURL: The callback URL where the OIDC provider will redirect after authentication
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.NewString()
@@ -655,12 +732,32 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Generate PKCE code verifier and challenge if PKCE is enabled
var codeVerifier, codeChallenge string
if t.enablePKCE {
var err error
codeVerifier, err = generateCodeVerifier()
if err != nil {
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
// Derive code challenge from verifier
codeChallenge = deriveCodeChallenge(codeVerifier)
}
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
// Only set code verifier if PKCE is enabled
if t.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(req.URL.RequestURI())
// Save the session
@@ -671,40 +768,55 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
}
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier
// verifyToken verifies the token using the token verifier interface.
// This function delegates to the configured token verifier implementation,
// which by default is the TraefikOidc instance itself (implementing the VerifyToken method).
// This design allows for easy mocking in tests and potential future extension.
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
// buildAuthURL constructs the authentication URL with optional PKCE support
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
params.Set("response_type", "code")
params.Set("redirect_uri", redirectURL)
params.Set("state", state)
params.Set("nonce", nonce)
// Add PKCE parameters only if PKCE is enabled and we have a code challenge
if t.enablePKCE && codeChallenge != "" {
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
// Ensure authURL is absolute
if !strings.HasPrefix(t.authURL, "http://") && !strings.HasPrefix(t.authURL, "https://") {
return t.buildURLWithParams(t.authURL, params)
}
// buildURLWithParams ensures a URL is absolute and appends query parameters
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
// Ensure URL is absolute
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
t.authURL,
baseURL,
params.Encode())
}
}
return t.authURL + "?" + params.Encode()
return baseURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
+317 -36
View File
@@ -118,7 +118,7 @@ func (ts *TestSuite) Setup() {
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -131,7 +131,7 @@ type MockJWKCache struct {
Err error
}
func (m *MockJWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return m.JWKS, m.Err
}
@@ -227,7 +227,7 @@ func TestVerifyToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
// Reset token blacklist and cache for each test
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
@@ -243,9 +243,20 @@ func TestVerifyToken(t *testing.T) {
}
if tc.cacheToken {
// Use more realistic claims for cached token
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
"empty": "claim",
}, 60)
"iss": "https://test-issuer.com",
"sub": "test-subject",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
}, time.Minute)
// Verify the token is actually in the cache
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
t.Logf("Token found in cache with claims: %v", claims)
} else {
t.Logf("Token NOT found in cache despite cacheToken=true")
}
}
err := ts.tOidc.VerifyToken(tc.token)
@@ -478,7 +489,7 @@ func TestHandleCallback(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(*SessionData)
expectedStatus int
@@ -486,7 +497,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Success",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -502,7 +513,7 @@ func TestHandleCallback(t *testing.T) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusFound,
expectedStatus: http.StatusOK, // Changed from StatusFound since we now return HTML instead of redirect
},
{
name: "Missing Code",
@@ -516,7 +527,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Exchange Code Error",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *SessionData) {
@@ -528,7 +539,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Missing ID Token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *SessionData) {
@@ -540,7 +551,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Disallowed Email",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -561,7 +572,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -582,7 +593,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -603,7 +614,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -719,7 +730,7 @@ func TestOIDCHandler(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
@@ -735,7 +746,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -759,7 +770,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -782,7 +793,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -806,7 +817,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -1653,6 +1664,17 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
@@ -1737,7 +1759,7 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
if tc.expectError {
if err == nil {
@@ -1759,18 +1781,6 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
@@ -1783,7 +1793,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL string
state string
nonce string
enablePKCE bool
codeChallenge string
expectedPrefix string
checkPKCE bool
}{
{
name: "Absolute Auth URL",
@@ -1792,7 +1805,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
{
name: "Relative Auth URL",
@@ -1801,7 +1817,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://logto.example.com/oidc/auth?",
checkPKCE: false,
},
{
name: "Relative Auth URL with Different Issuer",
@@ -1810,7 +1829,46 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
checkPKCE: false,
},
{
name: "With PKCE Enabled",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: true,
},
{
name: "With PKCE Enabled but No Challenge",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
{
name: "With PKCE Disabled but Challenge Provided",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
}
@@ -1820,9 +1878,10 @@ func TestBuildAuthURL(t *testing.T) {
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
tOidc.enablePKCE = tc.enablePKCE
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Call buildAuthURL with code challenge
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
@@ -1850,6 +1909,23 @@ func TestBuildAuthURL(t *testing.T) {
}
}
// Verify PKCE parameters
if tc.checkPKCE {
if got := query.Get("code_challenge"); got != tc.codeChallenge {
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
}
if got := query.Get("code_challenge_method"); got != "S256" {
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
}
} else {
if got := query.Get("code_challenge"); got != "" {
t.Errorf("Expected no code_challenge, but got %q", got)
}
if got := query.Get("code_challenge_method"); got != "" {
t.Errorf("Expected no code_challenge_method, but got %q", got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
@@ -1861,6 +1937,211 @@ func TestBuildAuthURL(t *testing.T) {
}
}
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
func TestExchangeCodeForToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
enablePKCE bool
codeVerifier string
setupMock func(t *testing.T) *httptest.Server
}{
{
name: "With PKCE Enabled and Code Verifier",
enablePKCE: true,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Disabled but Code Verifier Provided",
enablePKCE: false,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Enabled but No Code Verifier",
enablePKCE: true,
codeVerifier: "",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupMock(t)
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
tOidc.enablePKCE = tc.enablePKCE
// Test exchangeCodeForToken
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
})
}
}
// TestHandleCallback_PreservesURLFragments tests that URL fragments (anchors) are preserved during the authentication callback process.
func TestHandleCallback_PreservesURLFragments(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create a new instance for this specific test
logger := NewLogger("info")
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
tOidc := &TraefikOidc{
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
sessionManager: sessionManager,
redirURLPath: "/callback",
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
exchangeCodeForTokenFunc: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
}
// Create a request with the callback URL
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-csrf-token", nil)
rr := httptest.NewRecorder()
// Create session with an incoming path that contains a URL fragment
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up the session with necessary values and an incoming path with a fragment
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
session.SetIncomingPath("/dashboard?param=value") // The fragment will be client-side only
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Call handleCallback
tOidc.handleCallback(rr, req, "http://example.com/callback")
// The response should be OK (200) since we're returning HTML, not a redirect
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
// Verify that the response is HTML and contains our JavaScript for preserving fragments
contentType := rr.Header().Get("Content-Type")
if !strings.Contains(contentType, "text/html") {
t.Errorf("Expected Content-Type to contain 'text/html', got %s", contentType)
}
// Verify the response contains the redirect path and JavaScript for preserving fragments
body := rr.Body.String()
if !strings.Contains(body, "/dashboard?param=value") {
t.Errorf("Response body doesn't contain the original redirect path")
}
if !strings.Contains(body, "window.location.hash") {
t.Errorf("Response doesn't contain JavaScript logic to preserve URL fragments")
}
if !strings.Contains(body, "redirectUrl.hash = window.location.hash") {
t.Errorf("Response doesn't contain logic to copy the fragment from current URL")
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
@@ -1868,7 +2149,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
// Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
rw := httptest.NewRecorder()
responseRecorder := httptest.NewRecorder()
// Get session
session, err := ts.sessionManager.GetSession(req)
@@ -1878,7 +2159,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
// Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
// Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath()
+33 -16
View File
@@ -7,16 +7,21 @@ import (
"time"
)
// MetadataCache provides thread-safe caching for OIDC provider metadata
type MetadataCache struct {
metadata *ProviderMetadata
expiresAt time.Time
mutex sync.RWMutex
metadata *ProviderMetadata
expiresAt time.Time
mutex sync.RWMutex
autoCleanupInterval time.Duration
stopCleanup chan struct{}
}
// NewMetadataCache creates a new metadata cache instance
func NewMetadataCache() *MetadataCache {
return &MetadataCache{}
c := &MetadataCache{
autoCleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
}
go c.startAutoCleanup()
return c
}
// Cleanup removes expired metadata from the cache.
@@ -29,11 +34,14 @@ func (c *MetadataCache) Cleanup() {
c.metadata = nil
}
}
func (c *MetadataCache) isCacheValid() bool {
return c.metadata != nil && time.Now().Before(c.expiresAt)
}
// GetMetadata retrieves the metadata from cache or fetches it if expired
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
c.mutex.RLock()
if c.metadata != nil && time.Now().Before(c.expiresAt) {
if c.isCacheValid() {
defer c.mutex.RUnlock()
return c.metadata, nil
}
@@ -43,7 +51,7 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
defer c.mutex.Unlock()
// Double-check after acquiring write lock
if c.metadata != nil && time.Now().Before(c.expiresAt) {
if c.isCacheValid() {
return c.metadata, nil
}
@@ -60,14 +68,23 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
c.metadata = metadata
// Calculate expiration time based on usage patterns
usageCount := 0 // This should be replaced with actual usage tracking logic
if usageCount < 10 {
c.expiresAt = time.Now().Add(30 * time.Minute)
} else if usageCount < 50 {
c.expiresAt = time.Now().Add(1 * time.Hour)
} else {
c.expiresAt = time.Now().Add(2 * time.Hour)
}
usageCount := 0 // This should be replaced with actual usage tracking logic
if usageCount < 10 {
c.expiresAt = time.Now().Add(30 * time.Minute)
} else if usageCount < 50 {
c.expiresAt = time.Now().Add(1 * time.Hour)
} else {
c.expiresAt = time.Now().Add(2 * time.Hour)
}
// End of GetMetadata
return metadata, nil
}
func (c *MetadataCache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
}
func (c *MetadataCache) Close() {
close(c.stopCleanup)
}
+119
View File
@@ -0,0 +1,119 @@
package traefikoidc
import (
"fmt"
"net/http"
"testing"
"time"
)
func TestIsCacheValid(t *testing.T) {
// Setup with a dummy ProviderMetadata.
pm := &ProviderMetadata{}
mc := &MetadataCache{
metadata: pm,
expiresAt: time.Now().Add(1 * time.Hour),
}
if !mc.isCacheValid() {
t.Errorf("Expected cache to be valid")
}
mc.expiresAt = time.Now().Add(-1 * time.Hour)
if mc.isCacheValid() {
t.Errorf("Expected cache to be invalid")
}
}
func TestCleanup(t *testing.T) {
pm := &ProviderMetadata{}
mc := &MetadataCache{
metadata: pm,
expiresAt: time.Now().Add(-1 * time.Hour),
}
mc.Cleanup()
if mc.metadata != nil {
t.Errorf("Expected metadata to be nil after cleanup")
}
}
func TestGetMetadata_Cached(t *testing.T) {
dummyData := &ProviderMetadata{}
// Construct MetadataCache manually to avoid interference from auto cleanup.
mc := &MetadataCache{
metadata: dummyData,
expiresAt: time.Now().Add(1 * time.Hour),
stopCleanup: make(chan struct{}),
autoCleanupInterval: 5 * time.Minute,
}
// Use NewLogger to create a logger that writes errors only.
logger := NewLogger("error")
result, err := mc.GetMetadata("http://example.com", http.DefaultClient, logger)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if result != dummyData {
t.Errorf("Expected cached metadata to be returned")
}
}
func TestMetadataCacheAutoCleanup(t *testing.T) {
mc := &MetadataCache{
autoCleanupInterval: 50 * time.Millisecond,
stopCleanup: make(chan struct{}),
}
// Start auto cleanup.
go mc.startAutoCleanup()
mc.mutex.Lock()
mc.metadata = &ProviderMetadata{}
mc.expiresAt = time.Now().Add(-50 * time.Millisecond)
mc.mutex.Unlock()
// Wait enough time for the auto cleanup to run.
time.Sleep(200 * time.Millisecond)
mc.Close()
mc.mutex.RLock()
defer mc.mutex.RUnlock()
if mc.metadata != nil {
t.Errorf("Expected metadata to be cleared by auto cleanup")
}
}
type errorRoundTripper struct {
err error
}
func (e errorRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, e.err
}
func TestGetMetadata_FetchError(t *testing.T) {
// Create an HTTP client that always returns an error.
errorClient := &http.Client{
Transport: errorRoundTripper{err: fmt.Errorf("fake fetch error")},
}
// Case 1: Cache is empty.
mc := &MetadataCache{
stopCleanup: make(chan struct{}),
}
logger := NewLogger("error")
metadata, err := mc.GetMetadata("http://example.com", errorClient, logger)
if err == nil {
t.Errorf("Expected error, got nil")
}
if metadata != nil {
t.Errorf("Expected nil metadata, got %v", metadata)
}
// Case 2: Cache has old metadata.
dummy := &ProviderMetadata{}
mc.metadata = dummy
mc.expiresAt = time.Now().Add(-1 * time.Minute)
logger2 := NewLogger("error")
metadata, err = mc.GetMetadata("http://example.com", errorClient, logger2)
if err != nil {
t.Errorf("Expected no error when cached metadata exists, got %v", err)
}
if metadata != dummy {
t.Errorf("Expected cached metadata to be returned")
}
}
+11
View File
@@ -576,6 +576,17 @@ func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetCodeVerifier retrieves the PKCE code verifier from the session.
func (sd *SessionData) GetCodeVerifier() string {
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
return codeVerifier
}
// SetCodeVerifier stores the PKCE code verifier in the session.
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
sd.mainSession.Values["code_verifier"] = codeVerifier
}
// GetEmail retrieves the authenticated user's email address from the session.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
+8 -1
View File
@@ -22,6 +22,11 @@ type Config struct {
// If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"`
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
// This enhances security but might not be supported by all OIDC providers
// Default: false
EnablePKCE bool `json:"enablePKCE"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback
CallbackURL string `json:"callbackURL"`
@@ -103,12 +108,14 @@ const (
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
// - EnablePKCE: false (PKCE is opt-in)
func CreateConfig() *Config {
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
ForceHTTPS: true, // Secure by default
EnablePKCE: false, // PKCE is opt-in
}
return c