Add support for PKCE (#31)

* Add PKCE support.
* Add option to toggle PKCE checks feature.
* GoFMT
This commit is contained in:
2025-03-18 01:09:14 +00:00
committed by GitHub
parent 4ce2815123
commit 4322407129
7 changed files with 373 additions and 51 deletions
+13
View File
@@ -62,6 +62,7 @@ testData:
# Advanced parameters (usually discovered automatically from provider metadata)
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
# Configuration documentation
configuration:
@@ -230,3 +231,15 @@ configuration:
Example: https://accounts.google.com/logout
required: false
enablePKCE:
type: boolean
description: |
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
PKCE adds an extra layer of security to protect against authorization code interception attacks.
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
Default: false
required: false
+43 -7
View File
@@ -69,13 +69,14 @@ The middleware supports the following configuration options:
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
## Usage Examples
@@ -233,6 +234,30 @@ spec:
- profile
```
### With PKCE Enabled
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-pkce
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
enablePKCE: true # Enables PKCE for added security
scopes:
- openid
- email
- profile
```
### Keeping Secrets Secret in Kubernetes
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
@@ -378,6 +403,17 @@ http:
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
### PKCE Support
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
PKCE is recommended when:
- Your OIDC provider supports it (most modern providers do)
- You need an additional layer of security for the authorization code flow
- You're concerned about potential authorization code interception attacks
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
### Token Caching and Blacklisting
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
+49 -5
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
@@ -27,6 +28,31 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// generateCodeVerifier creates a cryptographically secure random string
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
// characters long, per the PKCE spec (RFC 7636).
func generateCodeVerifier() (string, error) {
// Using 32 bytes (256 bits) will produce a 43 character base64url string
verifierBytes := make([]byte, 32)
_, err := rand.Read(verifierBytes)
if err != nil {
return "", fmt.Errorf("could not generate code verifier: %w", err)
}
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
}
// deriveCodeChallenge creates a code challenge from a code verifier
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
func deriveCodeChallenge(codeVerifier string) string {
// Calculate SHA-256 hash of the code verifier
hasher := sha256.New()
hasher.Write([]byte(codeVerifier))
hash := hasher.Sum(nil)
// Base64url encode the hash to get the code challenge
return base64.RawURLEncoding.EncodeToString(hash)
}
// TokenResponse represents the response from the OIDC token endpoint.
// It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations.
@@ -54,7 +80,8 @@ type TokenResponse struct {
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
// - codeVerifier: Optional PKCE code verifier for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
@@ -64,6 +91,11 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
if grantType == "authorization_code" {
data.Set("code", codeOrToken)
data.Set("redirect_uri", redirectURL)
// Add code_verifier if PKCE is being used
if codeVerifier != "" {
data.Set("code_verifier", codeVerifier)
}
} else if grantType == "refresh_token" {
data.Set("refresh_token", codeOrToken)
}
@@ -112,7 +144,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
// This is used to refresh access tokens before they expire.
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
@@ -190,7 +222,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
// Get the code verifier from the session for PKCE flow
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
@@ -327,9 +362,18 @@ func (tc *TokenCache) Cleanup() {
}
// exchangeCodeForToken exchanges an authorization code for tokens.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
// The code verifier is only included in the token request if PKCE is enabled.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
// Only include code verifier if PKCE is enabled
effectiveCodeVerifier := ""
if t.enablePKCE && codeVerifier != "" {
effectiveCodeVerifier = codeVerifier
}
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
+34 -7
View File
@@ -83,6 +83,7 @@ type TraefikOidc struct {
scopes []string
limiter *rate.Limiter
forceHTTPS bool
enablePKCE bool
scheme string
tokenCache *TokenCache
httpClient *http.Client
@@ -93,7 +94,7 @@ type TraefikOidc struct {
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
@@ -279,7 +280,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
} else {
httpClient = createDefaultHTTPClient()
}
t := &TraefikOidc{
next: next,
name: name,
@@ -302,6 +302,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
enablePKCE: config.EnablePKCE,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
@@ -310,9 +311,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
logger: logger,
}
// Assign the initialized logger
t.logger = logger
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
@@ -732,12 +732,32 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Generate PKCE code verifier and challenge if PKCE is enabled
var codeVerifier, codeChallenge string
if t.enablePKCE {
var err error
codeVerifier, err = generateCodeVerifier()
if err != nil {
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
// Derive code challenge from verifier
codeChallenge = deriveCodeChallenge(codeVerifier)
}
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
// Only set code verifier if PKCE is enabled
if t.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(req.URL.RequestURI())
// Save the session
@@ -748,7 +768,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
}
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -760,14 +780,21 @@ func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
// buildAuthURL constructs the authentication URL with optional PKCE support
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
params.Set("response_type", "code")
params.Set("redirect_uri", redirectURL)
params.Set("state", state)
params.Set("nonce", nonce)
// Add PKCE parameters only if PKCE is enabled and we have a code challenge
if t.enablePKCE && codeChallenge != "" {
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
}
+215 -31
View File
@@ -118,7 +118,7 @@ func (ts *TestSuite) Setup() {
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -489,7 +489,7 @@ func TestHandleCallback(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(*SessionData)
expectedStatus int
@@ -497,7 +497,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Success",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -527,7 +527,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Exchange Code Error",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *SessionData) {
@@ -539,7 +539,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Missing ID Token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *SessionData) {
@@ -551,7 +551,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Disallowed Email",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -572,7 +572,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -593,7 +593,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -614,7 +614,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -730,7 +730,7 @@ func TestOIDCHandler(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
@@ -746,7 +746,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -770,7 +770,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -793,7 +793,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -817,7 +817,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -1664,6 +1664,17 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
func TestExchangeTokensWithRedirects(t *testing.T) {
@@ -1748,7 +1759,7 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
tOidc.tokenURL = server.URL
// Test token exchange
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
if tc.expectError {
if err == nil {
@@ -1770,18 +1781,6 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
}
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
func TestBuildAuthURL(t *testing.T) {
ts := &TestSuite{t: t}
@@ -1794,7 +1793,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL string
state string
nonce string
enablePKCE bool
codeChallenge string
expectedPrefix string
checkPKCE bool
}{
{
name: "Absolute Auth URL",
@@ -1803,7 +1805,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
{
name: "Relative Auth URL",
@@ -1812,7 +1817,10 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://logto.example.com/oidc/auth?",
checkPKCE: false,
},
{
name: "Relative Auth URL with Different Issuer",
@@ -1821,7 +1829,46 @@ func TestBuildAuthURL(t *testing.T) {
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "",
expectedPrefix: "https://auth.example.com:8443/sign-in?",
checkPKCE: false,
},
{
name: "With PKCE Enabled",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: true,
},
{
name: "With PKCE Enabled but No Challenge",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: true,
codeChallenge: "",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
{
name: "With PKCE Disabled but Challenge Provided",
authURL: "https://auth.example.com/oauth/authorize",
issuerURL: "https://auth.example.com",
redirectURL: "https://app.example.com/callback",
state: "test-state",
nonce: "test-nonce",
enablePKCE: false,
codeChallenge: "test-code-challenge",
expectedPrefix: "https://auth.example.com/oauth/authorize?",
checkPKCE: false,
},
}
@@ -1831,9 +1878,10 @@ func TestBuildAuthURL(t *testing.T) {
tOidc := ts.tOidc
tOidc.authURL = tc.authURL
tOidc.issuerURL = tc.issuerURL
tOidc.enablePKCE = tc.enablePKCE
// Call buildAuthURL
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
// Call buildAuthURL with code challenge
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
// Verify the URL starts with the expected prefix
if !strings.HasPrefix(result, tc.expectedPrefix) {
@@ -1861,6 +1909,23 @@ func TestBuildAuthURL(t *testing.T) {
}
}
// Verify PKCE parameters
if tc.checkPKCE {
if got := query.Get("code_challenge"); got != tc.codeChallenge {
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
}
if got := query.Get("code_challenge_method"); got != "S256" {
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
}
} else {
if got := query.Get("code_challenge"); got != "" {
t.Errorf("Expected no code_challenge, but got %q", got)
}
if got := query.Get("code_challenge_method"); got != "" {
t.Errorf("Expected no code_challenge_method, but got %q", got)
}
}
// Verify scopes are present and correct
if len(tOidc.scopes) > 0 {
expectedScopes := strings.Join(tOidc.scopes, " ")
@@ -1872,6 +1937,125 @@ func TestBuildAuthURL(t *testing.T) {
}
}
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
func TestExchangeCodeForToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
enablePKCE bool
codeVerifier string
setupMock func(t *testing.T) *httptest.Server
}{
{
name: "With PKCE Enabled and Code Verifier",
enablePKCE: true,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Disabled but Code Verifier Provided",
enablePKCE: false,
codeVerifier: "test-code-verifier",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
{
name: "With PKCE Enabled but No Code Verifier",
enablePKCE: true,
codeVerifier: "",
setupMock: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify code_verifier is NOT included
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
}
// Return successful token response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
IDToken: "test.id.token",
AccessToken: "test-access-token",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "test-refresh-token",
})
}))
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupMock(t)
defer server.Close()
// Configure the test instance
tOidc := ts.tOidc
tOidc.tokenURL = server.URL
tOidc.enablePKCE = tc.enablePKCE
// Test exchangeCodeForToken
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if response == nil {
t.Error("Expected token response but got nil")
} else if response.IDToken != "test.id.token" {
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
}
})
}
}
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
ts := &TestSuite{t: t}
@@ -1879,7 +2063,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
// Create a request with query parameters
req := httptest.NewRequest("GET", "/protected/resource?param1=value1&param2=value2", nil)
rw := httptest.NewRecorder()
responseRecorder := httptest.NewRecorder()
// Get session
session, err := ts.sessionManager.GetSession(req)
@@ -1889,7 +2073,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
// Call defaultInitiateAuthentication
redirectURL := "http://example.com/callback"
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
// Verify that the incoming path includes query parameters
incomingPath := session.GetIncomingPath()
+11
View File
@@ -576,6 +576,17 @@ func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetCodeVerifier retrieves the PKCE code verifier from the session.
func (sd *SessionData) GetCodeVerifier() string {
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
return codeVerifier
}
// SetCodeVerifier stores the PKCE code verifier in the session.
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
sd.mainSession.Values["code_verifier"] = codeVerifier
}
// GetEmail retrieves the authenticated user's email address from the session.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
+8 -1
View File
@@ -22,6 +22,11 @@ type Config struct {
// If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"`
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
// This enhances security but might not be supported by all OIDC providers
// Default: false
EnablePKCE bool `json:"enablePKCE"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback
CallbackURL string `json:"callbackURL"`
@@ -103,12 +108,14 @@ const (
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
// - EnablePKCE: false (PKCE is opt-in)
func CreateConfig() *Config {
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
ForceHTTPS: true, // Secure by default
EnablePKCE: false, // PKCE is opt-in
}
return c