mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Add support for PKCE (#31)
* Add PKCE support. * Add option to toggle PKCE checks feature. * GoFMT
This commit is contained in:
@@ -62,6 +62,7 @@ testData:
|
||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
|
||||
|
||||
# Configuration documentation
|
||||
configuration:
|
||||
@@ -230,3 +231,15 @@ configuration:
|
||||
|
||||
Example: https://accounts.google.com/logout
|
||||
required: false
|
||||
|
||||
enablePKCE:
|
||||
type: boolean
|
||||
description: |
|
||||
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
|
||||
PKCE adds an extra layer of security to protect against authorization code interception attacks.
|
||||
|
||||
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
|
||||
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
@@ -69,13 +69,14 @@ The middleware supports the following configuration options:
|
||||
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
||||
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
||||
| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -233,6 +234,30 @@ spec:
|
||||
- profile
|
||||
```
|
||||
|
||||
### With PKCE Enabled
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-pkce
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
enablePKCE: true # Enables PKCE for added security
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
```
|
||||
|
||||
### Keeping Secrets Secret in Kubernetes
|
||||
|
||||
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
|
||||
@@ -378,6 +403,17 @@ http:
|
||||
|
||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||
|
||||
### PKCE Support
|
||||
|
||||
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
|
||||
|
||||
PKCE is recommended when:
|
||||
- Your OIDC provider supports it (most modern providers do)
|
||||
- You need an additional layer of security for the authorization code flow
|
||||
- You're concerned about potential authorization code interception attacks
|
||||
|
||||
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
|
||||
|
||||
### Token Caching and Blacklisting
|
||||
|
||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||
|
||||
+49
-5
@@ -3,6 +3,7 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -27,6 +28,31 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier creates a cryptographically secure random string
|
||||
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
|
||||
// characters long, per the PKCE spec (RFC 7636).
|
||||
func generateCodeVerifier() (string, error) {
|
||||
// Using 32 bytes (256 bits) will produce a 43 character base64url string
|
||||
verifierBytes := make([]byte, 32)
|
||||
_, err := rand.Read(verifierBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not generate code verifier: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
|
||||
}
|
||||
|
||||
// deriveCodeChallenge creates a code challenge from a code verifier
|
||||
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
|
||||
func deriveCodeChallenge(codeVerifier string) string {
|
||||
// Calculate SHA-256 hash of the code verifier
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(codeVerifier))
|
||||
hash := hasher.Sum(nil)
|
||||
|
||||
// Base64url encode the hash to get the code challenge
|
||||
return base64.RawURLEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from the OIDC token endpoint.
|
||||
// It contains the various tokens and metadata returned after successful
|
||||
// code exchange or token refresh operations.
|
||||
@@ -54,7 +80,8 @@ type TokenResponse struct {
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||
// - codeOrToken: Either the authorization code or refresh token
|
||||
// - redirectURL: The callback URL for authorization code grant
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
// - codeVerifier: Optional PKCE code verifier for authorization code grant
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
@@ -64,6 +91,11 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
if grantType == "authorization_code" {
|
||||
data.Set("code", codeOrToken)
|
||||
data.Set("redirect_uri", redirectURL)
|
||||
|
||||
// Add code_verifier if PKCE is being used
|
||||
if codeVerifier != "" {
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
}
|
||||
} else if grantType == "refresh_token" {
|
||||
data.Set("refresh_token", codeOrToken)
|
||||
}
|
||||
@@ -112,7 +144,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
// This is used to refresh access tokens before they expire.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
@@ -190,7 +222,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
|
||||
// Get the code verifier from the session for PKCE flow
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
@@ -327,9 +362,18 @@ func (tc *TokenCache) Cleanup() {
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
|
||||
// The code verifier is only included in the token request if PKCE is enabled.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
|
||||
// Only include code verifier if PKCE is enabled
|
||||
effectiveCodeVerifier := ""
|
||||
if t.enablePKCE && codeVerifier != "" {
|
||||
effectiveCodeVerifier = codeVerifier
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
|
||||
@@ -83,6 +83,7 @@ type TraefikOidc struct {
|
||||
scopes []string
|
||||
limiter *rate.Limiter
|
||||
forceHTTPS bool
|
||||
enablePKCE bool
|
||||
scheme string
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
@@ -93,7 +94,7 @@ type TraefikOidc struct {
|
||||
allowedUserDomains map[string]struct{}
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
|
||||
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
|
||||
exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initComplete chan struct{}
|
||||
endSessionURL string
|
||||
@@ -279,7 +280,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
} else {
|
||||
httpClient = createDefaultHTTPClient()
|
||||
}
|
||||
|
||||
t := &TraefikOidc{
|
||||
next: next,
|
||||
name: name,
|
||||
@@ -302,6 +302,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
scopes: config.Scopes,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
@@ -310,9 +311,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
// Assign the initialized logger
|
||||
t.logger = logger
|
||||
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.extractClaimsFunc = extractClaims
|
||||
@@ -732,12 +732,32 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
return
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge if PKCE is enabled
|
||||
var codeVerifier, codeChallenge string
|
||||
if t.enablePKCE {
|
||||
var err error
|
||||
codeVerifier, err = generateCodeVerifier()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Derive code challenge from verifier
|
||||
codeChallenge = deriveCodeChallenge(codeVerifier)
|
||||
}
|
||||
|
||||
// Clear any existing session data to avoid stale state causing redirect loops
|
||||
session.Clear(req, rw)
|
||||
|
||||
// Set new session values
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
|
||||
// Only set code verifier if PKCE is enabled
|
||||
if t.enablePKCE {
|
||||
session.SetCodeVerifier(codeVerifier)
|
||||
}
|
||||
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
|
||||
// Save the session
|
||||
@@ -748,7 +768,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
}
|
||||
|
||||
// Build and redirect to authentication URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -760,14 +780,21 @@ func (t *TraefikOidc) verifyToken(token string) error {
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
|
||||
// buildAuthURL constructs the authentication URL
|
||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
// buildAuthURL constructs the authentication URL with optional PKCE support
|
||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", t.clientID)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("redirect_uri", redirectURL)
|
||||
params.Set("state", state)
|
||||
params.Set("nonce", nonce)
|
||||
|
||||
// Add PKCE parameters only if PKCE is enabled and we have a code challenge
|
||||
if t.enablePKCE && codeChallenge != "" {
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
if len(t.scopes) > 0 {
|
||||
params.Set("scope", strings.Join(t.scopes, " "))
|
||||
}
|
||||
|
||||
+215
-31
@@ -118,7 +118,7 @@ func (ts *TestSuite) Setup() {
|
||||
}
|
||||
|
||||
// Helper functions used by TraefikOidc
|
||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
|
||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -489,7 +489,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
||||
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(*SessionData)
|
||||
expectedStatus int
|
||||
@@ -497,7 +497,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Success",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -527,7 +527,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Exchange Code Error",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("exchange code error")
|
||||
},
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
@@ -539,7 +539,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Missing ID Token",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
@@ -551,7 +551,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Disallowed Email",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -572,7 +572,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Invalid State Parameter",
|
||||
queryParams: "?code=test-code&state=invalid-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -593,7 +593,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Nonce Mismatch",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -614,7 +614,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Missing Nonce in Claims",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -730,7 +730,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
||||
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(session *sessions.Session)
|
||||
expectedStatus int
|
||||
@@ -746,7 +746,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -770,7 +770,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -793,7 +793,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -817,7 +817,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -1664,6 +1664,17 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
// Helper function to compare string slices
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
|
||||
func TestExchangeTokensWithRedirects(t *testing.T) {
|
||||
@@ -1748,7 +1759,7 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
|
||||
tOidc.tokenURL = server.URL
|
||||
|
||||
// Test token exchange
|
||||
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
|
||||
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
@@ -1770,18 +1781,6 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
|
||||
func TestBuildAuthURL(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
@@ -1794,7 +1793,10 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
redirectURL string
|
||||
state string
|
||||
nonce string
|
||||
enablePKCE bool
|
||||
codeChallenge string
|
||||
expectedPrefix string
|
||||
checkPKCE bool
|
||||
}{
|
||||
{
|
||||
name: "Absolute Auth URL",
|
||||
@@ -1803,7 +1805,10 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
enablePKCE: false,
|
||||
codeChallenge: "",
|
||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||
checkPKCE: false,
|
||||
},
|
||||
{
|
||||
name: "Relative Auth URL",
|
||||
@@ -1812,7 +1817,10 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
enablePKCE: false,
|
||||
codeChallenge: "",
|
||||
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
||||
checkPKCE: false,
|
||||
},
|
||||
{
|
||||
name: "Relative Auth URL with Different Issuer",
|
||||
@@ -1821,7 +1829,46 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
enablePKCE: false,
|
||||
codeChallenge: "",
|
||||
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
||||
checkPKCE: false,
|
||||
},
|
||||
{
|
||||
name: "With PKCE Enabled",
|
||||
authURL: "https://auth.example.com/oauth/authorize",
|
||||
issuerURL: "https://auth.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
enablePKCE: true,
|
||||
codeChallenge: "test-code-challenge",
|
||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||
checkPKCE: true,
|
||||
},
|
||||
{
|
||||
name: "With PKCE Enabled but No Challenge",
|
||||
authURL: "https://auth.example.com/oauth/authorize",
|
||||
issuerURL: "https://auth.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
enablePKCE: true,
|
||||
codeChallenge: "",
|
||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||
checkPKCE: false,
|
||||
},
|
||||
{
|
||||
name: "With PKCE Disabled but Challenge Provided",
|
||||
authURL: "https://auth.example.com/oauth/authorize",
|
||||
issuerURL: "https://auth.example.com",
|
||||
redirectURL: "https://app.example.com/callback",
|
||||
state: "test-state",
|
||||
nonce: "test-nonce",
|
||||
enablePKCE: false,
|
||||
codeChallenge: "test-code-challenge",
|
||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||
checkPKCE: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1831,9 +1878,10 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
tOidc := ts.tOidc
|
||||
tOidc.authURL = tc.authURL
|
||||
tOidc.issuerURL = tc.issuerURL
|
||||
tOidc.enablePKCE = tc.enablePKCE
|
||||
|
||||
// Call buildAuthURL
|
||||
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
|
||||
// Call buildAuthURL with code challenge
|
||||
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
|
||||
|
||||
// Verify the URL starts with the expected prefix
|
||||
if !strings.HasPrefix(result, tc.expectedPrefix) {
|
||||
@@ -1861,6 +1909,23 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify PKCE parameters
|
||||
if tc.checkPKCE {
|
||||
if got := query.Get("code_challenge"); got != tc.codeChallenge {
|
||||
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
|
||||
}
|
||||
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
|
||||
}
|
||||
} else {
|
||||
if got := query.Get("code_challenge"); got != "" {
|
||||
t.Errorf("Expected no code_challenge, but got %q", got)
|
||||
}
|
||||
if got := query.Get("code_challenge_method"); got != "" {
|
||||
t.Errorf("Expected no code_challenge_method, but got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify scopes are present and correct
|
||||
if len(tOidc.scopes) > 0 {
|
||||
expectedScopes := strings.Join(tOidc.scopes, " ")
|
||||
@@ -1872,6 +1937,125 @@ func TestBuildAuthURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
|
||||
func TestExchangeCodeForToken(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
enablePKCE bool
|
||||
codeVerifier string
|
||||
setupMock func(t *testing.T) *httptest.Server
|
||||
}{
|
||||
{
|
||||
name: "With PKCE Enabled and Code Verifier",
|
||||
enablePKCE: true,
|
||||
codeVerifier: "test-code-verifier",
|
||||
setupMock: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("Failed to parse form: %v", err)
|
||||
}
|
||||
|
||||
// Verify code_verifier is included
|
||||
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
|
||||
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
|
||||
}
|
||||
|
||||
// Return successful token response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
IDToken: "test.id.token",
|
||||
AccessToken: "test-access-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "test-refresh-token",
|
||||
})
|
||||
}))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "With PKCE Disabled but Code Verifier Provided",
|
||||
enablePKCE: false,
|
||||
codeVerifier: "test-code-verifier",
|
||||
setupMock: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("Failed to parse form: %v", err)
|
||||
}
|
||||
|
||||
// Verify code_verifier is NOT included
|
||||
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
||||
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
||||
}
|
||||
|
||||
// Return successful token response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
IDToken: "test.id.token",
|
||||
AccessToken: "test-access-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "test-refresh-token",
|
||||
})
|
||||
}))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "With PKCE Enabled but No Code Verifier",
|
||||
enablePKCE: true,
|
||||
codeVerifier: "",
|
||||
setupMock: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("Failed to parse form: %v", err)
|
||||
}
|
||||
|
||||
// Verify code_verifier is NOT included
|
||||
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
||||
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
||||
}
|
||||
|
||||
// Return successful token response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
IDToken: "test.id.token",
|
||||
AccessToken: "test-access-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "test-refresh-token",
|
||||
})
|
||||
}))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := tc.setupMock(t)
|
||||
defer server.Close()
|
||||
|
||||
// Configure the test instance
|
||||
tOidc := ts.tOidc
|
||||
tOidc.tokenURL = server.URL
|
||||
tOidc.enablePKCE = tc.enablePKCE
|
||||
|
||||
// Test exchangeCodeForToken
|
||||
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if response == nil {
|
||||
t.Error("Expected token response but got nil")
|
||||
} else if response.IDToken != "test.id.token" {
|
||||
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
|
||||
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
@@ -1879,7 +2063,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||
|
||||
// Create a request with query parameters
|
||||
req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
// Get session
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
@@ -1889,7 +2073,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||
|
||||
// Call defaultInitiateAuthentication
|
||||
redirectURL := "http://example.com/callback"
|
||||
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
|
||||
|
||||
// Verify that the incoming path includes query parameters
|
||||
incomingPath := session.GetIncomingPath()
|
||||
|
||||
+11
@@ -576,6 +576,17 @@ func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetCodeVerifier retrieves the PKCE code verifier from the session.
|
||||
func (sd *SessionData) GetCodeVerifier() string {
|
||||
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
||||
return codeVerifier
|
||||
}
|
||||
|
||||
// SetCodeVerifier stores the PKCE code verifier in the session.
|
||||
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
sd.mainSession.Values["code_verifier"] = codeVerifier
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
|
||||
+8
-1
@@ -22,6 +22,11 @@ type Config struct {
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
|
||||
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
|
||||
// This enhances security but might not be supported by all OIDC providers
|
||||
// Default: false
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
|
||||
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
||||
// Example: /oauth2/callback
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
@@ -103,12 +108,14 @@ const (
|
||||
// - RateLimit: 100 requests per second
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
// - EnablePKCE: false (PKCE is opt-in)
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
}
|
||||
|
||||
return c
|
||||
|
||||
Reference in New Issue
Block a user