Add support for PKCE (#31)

* Add PKCE support.
* Add option to toggle PKCE checks feature.
* GoFMT
This commit is contained in:
2025-03-18 01:09:14 +00:00
committed by GitHub
parent 4ce2815123
commit 4322407129
7 changed files with 373 additions and 51 deletions
+13
View File
@@ -62,6 +62,7 @@ testData:
# Advanced parameters (usually discovered automatically from provider metadata) # Advanced parameters (usually discovered automatically from provider metadata)
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
# Configuration documentation # Configuration documentation
configuration: configuration:
@@ -230,3 +231,15 @@ configuration:
Example: https://accounts.google.com/logout Example: https://accounts.google.com/logout
required: false required: false
enablePKCE:
type: boolean
description: |
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
PKCE adds an extra layer of security to protect against authorization code interception attacks.
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
Default: false
required: false
+43 -7
View File
@@ -69,13 +69,14 @@ The middleware supports the following configuration options:
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` | | `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` | | `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` | | `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` | | | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` | | | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` | | | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` | | | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` | | | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` | | | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` | | | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
## Usage Examples ## Usage Examples
@@ -233,6 +234,30 @@ spec:
- profile - profile
``` ```
### With PKCE Enabled
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-pkce
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
enablePKCE: true # Enables PKCE for added security
scopes:
- openid
- email
- profile
```
### Keeping Secrets Secret in Kubernetes ### Keeping Secrets Secret in Kubernetes
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values: For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
@@ -378,6 +403,17 @@ http:
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret. The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
### PKCE Support
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
PKCE is recommended when:
- Your OIDC provider supports it (most modern providers do)
- You need an additional layer of security for the authorization code flow
- You're concerned about potential authorization code interception attacks
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
### Token Caching and Blacklisting ### Token Caching and Blacklisting
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens. The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
+49 -5
View File
@@ -3,6 +3,7 @@ package traefikoidc
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -27,6 +28,31 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil return base64.URLEncoding.EncodeToString(nonceBytes), nil
} }
// generateCodeVerifier creates a cryptographically secure random string
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
// characters long, per the PKCE spec (RFC 7636).
func generateCodeVerifier() (string, error) {
// Using 32 bytes (256 bits) will produce a 43 character base64url string
verifierBytes := make([]byte, 32)
_, err := rand.Read(verifierBytes)
if err != nil {
return "", fmt.Errorf("could not generate code verifier: %w", err)
}
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
}
// deriveCodeChallenge creates a code challenge from a code verifier
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
func deriveCodeChallenge(codeVerifier string) string {
// Calculate SHA-256 hash of the code verifier
hasher := sha256.New()
hasher.Write([]byte(codeVerifier))
hash := hasher.Sum(nil)
// Base64url encode the hash to get the code challenge
return base64.RawURLEncoding.EncodeToString(hash)
}
// TokenResponse represents the response from the OIDC token endpoint. // TokenResponse represents the response from the OIDC token endpoint.
// It contains the various tokens and metadata returned after successful // It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations. // code exchange or token refresh operations.
@@ -54,7 +80,8 @@ type TokenResponse struct {
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token") // - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token // - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant // - redirectURL: The callback URL for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) { // - codeVerifier: Optional PKCE code verifier for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{ data := url.Values{
"grant_type": {grantType}, "grant_type": {grantType},
"client_id": {t.clientID}, "client_id": {t.clientID},
@@ -64,6 +91,11 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
if grantType == "authorization_code" { if grantType == "authorization_code" {
data.Set("code", codeOrToken) data.Set("code", codeOrToken)
data.Set("redirect_uri", redirectURL) data.Set("redirect_uri", redirectURL)
// Add code_verifier if PKCE is being used
if codeVerifier != "" {
data.Set("code_verifier", codeVerifier)
}
} else if grantType == "refresh_token" { } else if grantType == "refresh_token" {
data.Set("refresh_token", codeOrToken) data.Set("refresh_token", codeOrToken)
} }
@@ -112,7 +144,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
// This is used to refresh access tokens before they expire. // This is used to refresh access tokens before they expire.
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background() ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "") tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err) return nil, fmt.Errorf("failed to refresh token: %w", err)
} }
@@ -190,7 +222,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return return
} }
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL) // Get the code verifier from the session for PKCE flow
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier)
if err != nil { if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err) t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError) http.Error(rw, "Authentication failed", http.StatusInternalServerError)
@@ -327,9 +362,18 @@ func (tc *TokenCache) Cleanup() {
} }
// exchangeCodeForToken exchanges an authorization code for tokens. // exchangeCodeForToken exchanges an authorization code for tokens.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) { // It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
// The code verifier is only included in the token request if PKCE is enabled.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
ctx := context.Background() ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
// Only include code verifier if PKCE is enabled
effectiveCodeVerifier := ""
if t.enablePKCE && codeVerifier != "" {
effectiveCodeVerifier = codeVerifier
}
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err) return nil, fmt.Errorf("failed to exchange code for token: %w", err)
} }
+34 -7
View File
@@ -83,6 +83,7 @@ type TraefikOidc struct {
scopes []string scopes []string
limiter *rate.Limiter limiter *rate.Limiter
forceHTTPS bool forceHTTPS bool
enablePKCE bool
scheme string scheme string
tokenCache *TokenCache tokenCache *TokenCache
httpClient *http.Client httpClient *http.Client
@@ -93,7 +94,7 @@ type TraefikOidc struct {
allowedUserDomains map[string]struct{} allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{} allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error) exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{} initComplete chan struct{}
endSessionURL string endSessionURL string
@@ -279,7 +280,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
} else { } else {
httpClient = createDefaultHTTPClient() httpClient = createDefaultHTTPClient()
} }
t := &TraefikOidc{ t := &TraefikOidc{
next: next, next: next,
name: name, name: name,
@@ -302,6 +302,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
clientID: config.ClientID, clientID: config.ClientID,
clientSecret: config.ClientSecret, clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS, forceHTTPS: config.ForceHTTPS,
enablePKCE: config.EnablePKCE,
scopes: config.Scopes, scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(), tokenCache: NewTokenCache(),
@@ -310,9 +311,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
allowedUserDomains: createStringMap(config.AllowedUserDomains), allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups), allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}), initComplete: make(chan struct{}),
logger: logger,
} }
// Assign the initialized logger
t.logger = logger
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims t.extractClaimsFunc = extractClaims
@@ -732,12 +732,32 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return return
} }
// Generate PKCE code verifier and challenge if PKCE is enabled
var codeVerifier, codeChallenge string
if t.enablePKCE {
var err error
codeVerifier, err = generateCodeVerifier()
if err != nil {
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
// Derive code challenge from verifier
codeChallenge = deriveCodeChallenge(codeVerifier)
}
// Clear any existing session data to avoid stale state causing redirect loops // Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw) session.Clear(req, rw)
// Set new session values // Set new session values
session.SetCSRF(csrfToken) session.SetCSRF(csrfToken)
session.SetNonce(nonce) session.SetNonce(nonce)
// Only set code verifier if PKCE is enabled
if t.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(req.URL.RequestURI()) session.SetIncomingPath(req.URL.RequestURI())
// Save the session // Save the session
@@ -748,7 +768,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
} }
// Build and redirect to authentication URL // Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce) authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
http.Redirect(rw, req, authURL, http.StatusFound) http.Redirect(rw, req, authURL, http.StatusFound)
} }
@@ -760,14 +780,21 @@ func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token) return t.tokenVerifier.VerifyToken(token)
} }
// buildAuthURL constructs the authentication URL // buildAuthURL constructs the authentication URL with optional PKCE support
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{} params := url.Values{}
params.Set("client_id", t.clientID) params.Set("client_id", t.clientID)
params.Set("response_type", "code") params.Set("response_type", "code")
params.Set("redirect_uri", redirectURL) params.Set("redirect_uri", redirectURL)
params.Set("state", state) params.Set("state", state)
params.Set("nonce", nonce) params.Set("nonce", nonce)
// Add PKCE parameters only if PKCE is enabled and we have a code challenge
if t.enablePKCE && codeChallenge != "" {
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}
if len(t.scopes) > 0 { if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " ")) params.Set("scope", strings.Join(t.scopes, " "))
} }
+215 -31
View File
@@ -118,7 +118,7 @@ func (ts *TestSuite) Setup() {
} }
// Helper functions used by TraefikOidc // Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) { func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
RefreshToken: "test-refresh-token", RefreshToken: "test-refresh-token",
@@ -489,7 +489,7 @@ func TestHandleCallback(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
queryParams string queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error) exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(*SessionData) sessionSetupFunc func(*SessionData)
expectedStatus int expectedStatus int
@@ -497,7 +497,7 @@ func TestHandleCallback(t *testing.T) {
{ {
name: "Success", name: "Success",
queryParams: "?code=test-code&state=test-csrf-token", queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
RefreshToken: "test-refresh-token", RefreshToken: "test-refresh-token",
@@ -527,7 +527,7 @@ func TestHandleCallback(t *testing.T) {
{ {
name: "Exchange Code Error", name: "Exchange Code Error",
queryParams: "?code=test-code&state=test-csrf-token", queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error") return nil, fmt.Errorf("exchange code error")
}, },
sessionSetupFunc: func(session *SessionData) { sessionSetupFunc: func(session *SessionData) {
@@ -539,7 +539,7 @@ func TestHandleCallback(t *testing.T) {
{ {
name: "Missing ID Token", name: "Missing ID Token",
queryParams: "?code=test-code&state=test-csrf-token", queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{}, nil return &TokenResponse{}, nil
}, },
sessionSetupFunc: func(session *SessionData) { sessionSetupFunc: func(session *SessionData) {
@@ -551,7 +551,7 @@ func TestHandleCallback(t *testing.T) {
{ {
name: "Disallowed Email", name: "Disallowed Email",
queryParams: "?code=test-code&state=test-csrf-token", queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
RefreshToken: "test-refresh-token", RefreshToken: "test-refresh-token",
@@ -572,7 +572,7 @@ func TestHandleCallback(t *testing.T) {
{ {
name: "Invalid State Parameter", name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token", queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
RefreshToken: "test-refresh-token", RefreshToken: "test-refresh-token",
@@ -593,7 +593,7 @@ func TestHandleCallback(t *testing.T) {
{ {
name: "Nonce Mismatch", name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token", queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
RefreshToken: "test-refresh-token", RefreshToken: "test-refresh-token",
@@ -614,7 +614,7 @@ func TestHandleCallback(t *testing.T) {
{ {
name: "Missing Nonce in Claims", name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token", queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
RefreshToken: "test-refresh-token", RefreshToken: "test-refresh-token",
@@ -730,7 +730,7 @@ func TestOIDCHandler(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
queryParams string queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error) exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session) sessionSetupFunc func(session *sessions.Session)
expectedStatus int expectedStatus int
@@ -746,7 +746,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token" session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce" session.Values["nonce"] = "test-nonce"
}, },
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange // Simulate token exchange
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
@@ -770,7 +770,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token" session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce" session.Values["nonce"] = "test-nonce"
}, },
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange // Simulate token exchange
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
@@ -793,7 +793,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token" session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce" session.Values["nonce"] = "test-nonce"
}, },
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange // Simulate token exchange
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
@@ -817,7 +817,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token" session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce" session.Values["nonce"] = "test-nonce"
}, },
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange // Simulate token exchange
return &TokenResponse{ return &TokenResponse{
IDToken: ts.token, IDToken: ts.token,
@@ -1664,6 +1664,17 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
} }
// Helper function to compare string slices // Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestExchangeTokensWithRedirects tests the token exchange process with redirects // TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) { func TestExchangeTokensWithRedirects(t *testing.T) {
@@ -1748,7 +1759,7 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
tOidc.tokenURL = server.URL tOidc.tokenURL = server.URL
// Test token exchange // Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback") response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
if tc.expectError { if tc.expectError {
if err == nil { if err == nil {
@@ -1770,18 +1781,6 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
} }
} }
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios // TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) { func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t} ts := &TestSuite{t: t}
@@ -1794,7 +1793,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL string redirectURL string
state string state string
nonce string nonce string
enablePKCE bool
codeChallenge string
expectedPrefix string expectedPrefix string
checkPKCE bool
}{ }{
{ {
name: "Absolute Auth URL", name: "Absolute Auth URL",
@@ -1803,7 +1805,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback", redirectURL: "https://app.example.com/callback",
state: "test-state", state: "test-state",
nonce: "test-nonce", nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?", expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
}, },
{ {
name: "Relative Auth URL", name: "Relative Auth URL",
@@ -1812,7 +1817,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback", redirectURL: "https://app.example.com/callback",
state: "test-state", state: "test-state",
nonce: "test-nonce", nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://logto.example.com/oidc/auth?", expectedPrefix: "https://logto.example.com/oidc/auth?",
checkPKCE: false,
}, },
{ {
name: "Relative Auth URL with Different Issuer", name: "Relative Auth URL with Different Issuer",
@@ -1821,7 +1829,46 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback", redirectURL: "https://app.example.com/callback",
state: "test-state", state: "test-state",
nonce: "test-nonce", nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com:8443/sign-in?", expectedPrefix: "https://auth.example.com:8443/sign-in?",
checkPKCE: false,
},
{
name: "With PKCE Enabled",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: true,
},
{
name: "With PKCE Enabled but No Challenge",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
{
name: "With PKCE Disabled but Challenge Provided",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
}, },
} }
@@ -1831,9 +1878,10 @@ func TestBuildAuthURL(t *testing.T) {
tOidc := ts.tOidc tOidc := ts.tOidc
tOidc.authURL = tc.authURL tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL tOidc.issuerURL = tc.issuerURL
tOidc.enablePKCE = tc.enablePKCE
// Call buildAuthURL // Call buildAuthURL with code challenge
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce) result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
// Verify the URL starts with the expected prefix // Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) { if !strings.HasPrefix(result, tc.expectedPrefix) {
@@ -1861,6 +1909,23 @@ func TestBuildAuthURL(t *testing.T) {
} }
} }
// Verify PKCE parameters
if tc.checkPKCE {
if got := query.Get("code_challenge"); got != tc.codeChallenge {
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
}
if got := query.Get("code_challenge_method"); got != "S256" {
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
}
} else {
if got := query.Get("code_challenge"); got != "" {
t.Errorf("Expected no code_challenge, but got %q", got)
}
if got := query.Get("code_challenge_method"); got != "" {
t.Errorf("Expected no code_challenge_method, but got %q", got)
}
}
// Verify scopes are present and correct // Verify scopes are present and correct
if len(tOidc.scopes) > 0 { if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ") expectedScopes := strings.Join(tOidc.scopes, " ")
@@ -1872,6 +1937,125 @@ func TestBuildAuthURL(t *testing.T) {
} }
} }
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
func TestExchangeCodeForToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
enablePKCE bool
codeVerifier string
setupMock func(t *testing.T) *httptest.Server
}{
{
name: "With PKCE Enabled and Code Verifier",
enablePKCE: true,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Disabled but Code Verifier Provided",
enablePKCE: false,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Enabled but No Code Verifier",
enablePKCE: true,
codeVerifier: "",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupMock(t)
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
tOidc.enablePKCE = tc.enablePKCE
// Test exchangeCodeForToken
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path. // TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) { func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t} ts := &TestSuite{t: t}
@@ -1879,7 +2063,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
// Create a request with query parameters // Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil) req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
rw := httptest.NewRecorder() responseRecorder := httptest.NewRecorder()
// Get session // Get session
session, err := ts.sessionManager.GetSession(req) session, err := ts.sessionManager.GetSession(req)
@@ -1889,7 +2073,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
// Call defaultInitiateAuthentication // Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback" redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL) ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
// Verify that the incoming path includes query parameters // Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath() incomingPath := session.GetIncomingPath()
+11
View File
@@ -576,6 +576,17 @@ func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce sd.mainSession.Values["nonce"] = nonce
} }
// GetCodeVerifier retrieves the PKCE code verifier from the session.
func (sd *SessionData) GetCodeVerifier() string {
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
return codeVerifier
}
// SetCodeVerifier stores the PKCE code verifier in the session.
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
sd.mainSession.Values["code_verifier"] = codeVerifier
}
// GetEmail retrieves the authenticated user's email address from the session. // GetEmail retrieves the authenticated user's email address from the session.
func (sd *SessionData) GetEmail() string { func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string) email, _ := sd.mainSession.Values["email"].(string)
+8 -1
View File
@@ -22,6 +22,11 @@ type Config struct {
// If not provided, it will be discovered from provider metadata // If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"` RevocationURL string `json:"revocationURL"`
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
// This enhances security but might not be supported by all OIDC providers
// Default: false
EnablePKCE bool `json:"enablePKCE"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required) // CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback // Example: /oauth2/callback
CallbackURL string `json:"callbackURL"` CallbackURL string `json:"callbackURL"`
@@ -103,12 +108,14 @@ const (
// - RateLimit: 100 requests per second // - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/" // - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security) // - ForceHTTPS: true (for security)
// - EnablePKCE: false (PKCE is opt-in)
func CreateConfig() *Config { func CreateConfig() *Config {
c := &Config{ c := &Config{
Scopes: []string{"openid", "profile", "email"}, Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel, LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit, RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default ForceHTTPS: true, // Secure by default
EnablePKCE: false, // PKCE is opt-in
} }
return c return c