diff --git a/.traefik.yml b/.traefik.yml index b0119c1..b80c2a5 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -62,6 +62,7 @@ testData: # Advanced parameters (usually discovered automatically from provider metadata) revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint + enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security # Configuration documentation configuration: @@ -230,3 +231,15 @@ configuration: Example: https://accounts.google.com/logout required: false + + enablePKCE: + type: boolean + description: | + Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow. + PKCE adds an extra layer of security to protect against authorization code interception attacks. + + Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it. + If enabled, the middleware will generate and use a code verifier/challenge pair during authentication. + + Default: false + required: false diff --git a/README.md b/README.md index 7678eb5..1e0175b 100644 --- a/README.md +++ b/README.md @@ -69,13 +69,14 @@ The middleware supports the following configuration options: | `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` | | `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` | | `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` | -| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` | -| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` | -| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` | -| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` | -| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` | -| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` | -| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` | +| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` | +| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` | +| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` | +| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` | +| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` | +| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` | +| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` | +| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` | ## Usage Examples @@ -233,6 +234,30 @@ spec: - profile ``` +### With PKCE Enabled + +```yaml +apiVersion: traefik.io/v1alpha1 +kind: Middleware +metadata: + name: oidc-with-pkce + namespace: traefik +spec: + plugin: + traefikoidc: + providerURL: https://accounts.google.com + clientID: 1234567890.apps.googleusercontent.com + clientSecret: your-client-secret + sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long + callbackURL: /oauth2/callback + logoutURL: /oauth2/logout + enablePKCE: true # Enables PKCE for added security + scopes: + - openid + - email + - profile +``` + ### Keeping Secrets Secret in Kubernetes For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values: @@ -378,6 +403,17 @@ http: The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret. +### PKCE Support + +The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process. + +PKCE is recommended when: +- Your OIDC provider supports it (most modern providers do) +- You need an additional layer of security for the authorization code flow +- You're concerned about potential authorization code interception attacks + +Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature. + ### Token Caching and Blacklisting The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens. diff --git a/helpers.go b/helpers.go index e18ffe2..6d27528 100644 --- a/helpers.go +++ b/helpers.go @@ -3,6 +3,7 @@ package traefikoidc import ( "context" "crypto/rand" + "crypto/sha256" "encoding/base64" "encoding/json" "fmt" @@ -27,6 +28,31 @@ func generateNonce() (string, error) { return base64.URLEncoding.EncodeToString(nonceBytes), nil } +// generateCodeVerifier creates a cryptographically secure random string +// for use as a PKCE code verifier. The code verifier must be between 43 and 128 +// characters long, per the PKCE spec (RFC 7636). +func generateCodeVerifier() (string, error) { + // Using 32 bytes (256 bits) will produce a 43 character base64url string + verifierBytes := make([]byte, 32) + _, err := rand.Read(verifierBytes) + if err != nil { + return "", fmt.Errorf("could not generate code verifier: %w", err) + } + return base64.RawURLEncoding.EncodeToString(verifierBytes), nil +} + +// deriveCodeChallenge creates a code challenge from a code verifier +// using the SHA-256 method as specified in the PKCE standard (RFC 7636). +func deriveCodeChallenge(codeVerifier string) string { + // Calculate SHA-256 hash of the code verifier + hasher := sha256.New() + hasher.Write([]byte(codeVerifier)) + hash := hasher.Sum(nil) + + // Base64url encode the hash to get the code challenge + return base64.RawURLEncoding.EncodeToString(hash) +} + // TokenResponse represents the response from the OIDC token endpoint. // It contains the various tokens and metadata returned after successful // code exchange or token refresh operations. @@ -54,7 +80,8 @@ type TokenResponse struct { // - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token") // - codeOrToken: Either the authorization code or refresh token // - redirectURL: The callback URL for authorization code grant -func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) { +// - codeVerifier: Optional PKCE code verifier for authorization code grant +func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) { data := url.Values{ "grant_type": {grantType}, "client_id": {t.clientID}, @@ -64,6 +91,11 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken if grantType == "authorization_code" { data.Set("code", codeOrToken) data.Set("redirect_uri", redirectURL) + + // Add code_verifier if PKCE is being used + if codeVerifier != "" { + data.Set("code_verifier", codeVerifier) + } } else if grantType == "refresh_token" { data.Set("refresh_token", codeOrToken) } @@ -112,7 +144,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken // This is used to refresh access tokens before they expire. func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { ctx := context.Background() - tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "") + tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "") if err != nil { return nil, fmt.Errorf("failed to refresh token: %w", err) } @@ -190,7 +222,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL) + // Get the code verifier from the session for PKCE flow + codeVerifier := session.GetCodeVerifier() + + tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier) if err != nil { t.logger.Errorf("Failed to exchange code for token: %v", err) http.Error(rw, "Authentication failed", http.StatusInternalServerError) @@ -327,9 +362,18 @@ func (tc *TokenCache) Cleanup() { } // exchangeCodeForToken exchanges an authorization code for tokens. -func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) { +// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration. +// The code verifier is only included in the token request if PKCE is enabled. +func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { ctx := context.Background() - tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL) + + // Only include code verifier if PKCE is enabled + effectiveCodeVerifier := "" + if t.enablePKCE && codeVerifier != "" { + effectiveCodeVerifier = codeVerifier + } + + tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier) if err != nil { return nil, fmt.Errorf("failed to exchange code for token: %w", err) } diff --git a/main.go b/main.go index 7c7cdcc..526d182 100644 --- a/main.go +++ b/main.go @@ -83,6 +83,7 @@ type TraefikOidc struct { scopes []string limiter *rate.Limiter forceHTTPS bool + enablePKCE bool scheme string tokenCache *TokenCache httpClient *http.Client @@ -93,7 +94,7 @@ type TraefikOidc struct { allowedUserDomains map[string]struct{} allowedRolesAndGroups map[string]struct{} initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) - exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error) + exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) initComplete chan struct{} endSessionURL string @@ -279,7 +280,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } else { httpClient = createDefaultHTTPClient() } - t := &TraefikOidc{ next: next, name: name, @@ -302,6 +302,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h clientID: config.ClientID, clientSecret: config.ClientSecret, forceHTTPS: config.ForceHTTPS, + enablePKCE: config.EnablePKCE, scopes: config.Scopes, limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), tokenCache: NewTokenCache(), @@ -310,9 +311,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h allowedUserDomains: createStringMap(config.AllowedUserDomains), allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups), initComplete: make(chan struct{}), + logger: logger, } - // Assign the initialized logger - t.logger = logger t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.extractClaimsFunc = extractClaims @@ -732,12 +732,32 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req return } + // Generate PKCE code verifier and challenge if PKCE is enabled + var codeVerifier, codeChallenge string + if t.enablePKCE { + var err error + codeVerifier, err = generateCodeVerifier() + if err != nil { + http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError) + return + } + + // Derive code challenge from verifier + codeChallenge = deriveCodeChallenge(codeVerifier) + } + // Clear any existing session data to avoid stale state causing redirect loops session.Clear(req, rw) // Set new session values session.SetCSRF(csrfToken) session.SetNonce(nonce) + + // Only set code verifier if PKCE is enabled + if t.enablePKCE { + session.SetCodeVerifier(codeVerifier) + } + session.SetIncomingPath(req.URL.RequestURI()) // Save the session @@ -748,7 +768,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req } // Build and redirect to authentication URL - authURL := t.buildAuthURL(redirectURL, csrfToken, nonce) + authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge) http.Redirect(rw, req, authURL, http.StatusFound) } @@ -760,14 +780,21 @@ func (t *TraefikOidc) verifyToken(token string) error { return t.tokenVerifier.VerifyToken(token) } -// buildAuthURL constructs the authentication URL -func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { +// buildAuthURL constructs the authentication URL with optional PKCE support +func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string { params := url.Values{} params.Set("client_id", t.clientID) params.Set("response_type", "code") params.Set("redirect_uri", redirectURL) params.Set("state", state) params.Set("nonce", nonce) + + // Add PKCE parameters only if PKCE is enabled and we have a code challenge + if t.enablePKCE && codeChallenge != "" { + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + } + if len(t.scopes) > 0 { params.Set("scope", strings.Join(t.scopes, " ")) } diff --git a/main_test.go b/main_test.go index 8f847a5..4535798 100644 --- a/main_test.go +++ b/main_test.go @@ -118,7 +118,7 @@ func (ts *TestSuite) Setup() { } // Helper functions used by TraefikOidc -func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) { +func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", @@ -489,7 +489,7 @@ func TestHandleCallback(t *testing.T) { tests := []struct { name string queryParams string - exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error) + exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) sessionSetupFunc func(*SessionData) expectedStatus int @@ -497,7 +497,7 @@ func TestHandleCallback(t *testing.T) { { name: "Success", queryParams: "?code=test-code&state=test-csrf-token", - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", @@ -527,7 +527,7 @@ func TestHandleCallback(t *testing.T) { { name: "Exchange Code Error", queryParams: "?code=test-code&state=test-csrf-token", - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return nil, fmt.Errorf("exchange code error") }, sessionSetupFunc: func(session *SessionData) { @@ -539,7 +539,7 @@ func TestHandleCallback(t *testing.T) { { name: "Missing ID Token", queryParams: "?code=test-code&state=test-csrf-token", - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{}, nil }, sessionSetupFunc: func(session *SessionData) { @@ -551,7 +551,7 @@ func TestHandleCallback(t *testing.T) { { name: "Disallowed Email", queryParams: "?code=test-code&state=test-csrf-token", - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", @@ -572,7 +572,7 @@ func TestHandleCallback(t *testing.T) { { name: "Invalid State Parameter", queryParams: "?code=test-code&state=invalid-csrf-token", - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", @@ -593,7 +593,7 @@ func TestHandleCallback(t *testing.T) { { name: "Nonce Mismatch", queryParams: "?code=test-code&state=test-csrf-token", - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", @@ -614,7 +614,7 @@ func TestHandleCallback(t *testing.T) { { name: "Missing Nonce in Claims", queryParams: "?code=test-code&state=test-csrf-token", - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", @@ -730,7 +730,7 @@ func TestOIDCHandler(t *testing.T) { tests := []struct { name string queryParams string - exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error) + exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) sessionSetupFunc func(session *sessions.Session) expectedStatus int @@ -746,7 +746,7 @@ func TestOIDCHandler(t *testing.T) { session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, @@ -770,7 +770,7 @@ func TestOIDCHandler(t *testing.T) { session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, @@ -793,7 +793,7 @@ func TestOIDCHandler(t *testing.T) { session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, @@ -817,7 +817,7 @@ func TestOIDCHandler(t *testing.T) { session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, - exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) { + exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, @@ -1664,6 +1664,17 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { } // Helper function to compare string slices +func stringSliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} // TestExchangeTokensWithRedirects tests the token exchange process with redirects func TestExchangeTokensWithRedirects(t *testing.T) { @@ -1748,7 +1759,7 @@ func TestExchangeTokensWithRedirects(t *testing.T) { tOidc.tokenURL = server.URL // Test token exchange - response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback") + response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier") if tc.expectError { if err == nil { @@ -1770,18 +1781,6 @@ func TestExchangeTokensWithRedirects(t *testing.T) { } } -func stringSliceEqual(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} - // TestBuildAuthURL tests the buildAuthURL function with various URL scenarios func TestBuildAuthURL(t *testing.T) { ts := &TestSuite{t: t} @@ -1794,7 +1793,10 @@ func TestBuildAuthURL(t *testing.T) { redirectURL string state string nonce string + enablePKCE bool + codeChallenge string expectedPrefix string + checkPKCE bool }{ { name: "Absolute Auth URL", @@ -1803,7 +1805,10 @@ func TestBuildAuthURL(t *testing.T) { redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", + enablePKCE: false, + codeChallenge: "", expectedPrefix: "https://auth.example.com/oauth/authorize?", + checkPKCE: false, }, { name: "Relative Auth URL", @@ -1812,7 +1817,10 @@ func TestBuildAuthURL(t *testing.T) { redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", + enablePKCE: false, + codeChallenge: "", expectedPrefix: "https://logto.example.com/oidc/auth?", + checkPKCE: false, }, { name: "Relative Auth URL with Different Issuer", @@ -1821,7 +1829,46 @@ func TestBuildAuthURL(t *testing.T) { redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", + enablePKCE: false, + codeChallenge: "", expectedPrefix: "https://auth.example.com:8443/sign-in?", + checkPKCE: false, + }, + { + name: "With PKCE Enabled", + authURL: "https://auth.example.com/oauth/authorize", + issuerURL: "https://auth.example.com", + redirectURL: "https://app.example.com/callback", + state: "test-state", + nonce: "test-nonce", + enablePKCE: true, + codeChallenge: "test-code-challenge", + expectedPrefix: "https://auth.example.com/oauth/authorize?", + checkPKCE: true, + }, + { + name: "With PKCE Enabled but No Challenge", + authURL: "https://auth.example.com/oauth/authorize", + issuerURL: "https://auth.example.com", + redirectURL: "https://app.example.com/callback", + state: "test-state", + nonce: "test-nonce", + enablePKCE: true, + codeChallenge: "", + expectedPrefix: "https://auth.example.com/oauth/authorize?", + checkPKCE: false, + }, + { + name: "With PKCE Disabled but Challenge Provided", + authURL: "https://auth.example.com/oauth/authorize", + issuerURL: "https://auth.example.com", + redirectURL: "https://app.example.com/callback", + state: "test-state", + nonce: "test-nonce", + enablePKCE: false, + codeChallenge: "test-code-challenge", + expectedPrefix: "https://auth.example.com/oauth/authorize?", + checkPKCE: false, }, } @@ -1831,9 +1878,10 @@ func TestBuildAuthURL(t *testing.T) { tOidc := ts.tOidc tOidc.authURL = tc.authURL tOidc.issuerURL = tc.issuerURL + tOidc.enablePKCE = tc.enablePKCE - // Call buildAuthURL - result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce) + // Call buildAuthURL with code challenge + result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge) // Verify the URL starts with the expected prefix if !strings.HasPrefix(result, tc.expectedPrefix) { @@ -1861,6 +1909,23 @@ func TestBuildAuthURL(t *testing.T) { } } + // Verify PKCE parameters + if tc.checkPKCE { + if got := query.Get("code_challenge"); got != tc.codeChallenge { + t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got) + } + if got := query.Get("code_challenge_method"); got != "S256" { + t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got) + } + } else { + if got := query.Get("code_challenge"); got != "" { + t.Errorf("Expected no code_challenge, but got %q", got) + } + if got := query.Get("code_challenge_method"); got != "" { + t.Errorf("Expected no code_challenge_method, but got %q", got) + } + } + // Verify scopes are present and correct if len(tOidc.scopes) > 0 { expectedScopes := strings.Join(tOidc.scopes, " ") @@ -1872,6 +1937,125 @@ func TestBuildAuthURL(t *testing.T) { } } +// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support +func TestExchangeCodeForToken(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + tests := []struct { + name string + enablePKCE bool + codeVerifier string + setupMock func(t *testing.T) *httptest.Server + }{ + { + name: "With PKCE Enabled and Code Verifier", + enablePKCE: true, + codeVerifier: "test-code-verifier", + setupMock: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("Failed to parse form: %v", err) + } + + // Verify code_verifier is included + if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" { + t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier) + } + + // Return successful token response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TokenResponse{ + IDToken: "test.id.token", + AccessToken: "test-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "test-refresh-token", + }) + })) + }, + }, + { + name: "With PKCE Disabled but Code Verifier Provided", + enablePKCE: false, + codeVerifier: "test-code-verifier", + setupMock: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("Failed to parse form: %v", err) + } + + // Verify code_verifier is NOT included + if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" { + t.Errorf("Expected no code_verifier, got %s", codeVerifier) + } + + // Return successful token response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TokenResponse{ + IDToken: "test.id.token", + AccessToken: "test-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "test-refresh-token", + }) + })) + }, + }, + { + name: "With PKCE Enabled but No Code Verifier", + enablePKCE: true, + codeVerifier: "", + setupMock: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("Failed to parse form: %v", err) + } + + // Verify code_verifier is NOT included + if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" { + t.Errorf("Expected no code_verifier, got %s", codeVerifier) + } + + // Return successful token response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(TokenResponse{ + IDToken: "test.id.token", + AccessToken: "test-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "test-refresh-token", + }) + })) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := tc.setupMock(t) + defer server.Close() + + // Configure the test instance + tOidc := ts.tOidc + tOidc.tokenURL = server.URL + tOidc.enablePKCE = tc.enablePKCE + + // Test exchangeCodeForToken + response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if response == nil { + t.Error("Expected token response but got nil") + } else if response.IDToken != "test.id.token" { + t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken) + } + }) + } +} + // TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path. func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) { ts := &TestSuite{t: t} @@ -1879,7 +2063,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) { // Create a request with query parameters req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil) - rw := httptest.NewRecorder() + responseRecorder := httptest.NewRecorder() // Get session session, err := ts.sessionManager.GetSession(req) @@ -1889,7 +2073,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) { // Call defaultInitiateAuthentication redirectURL := "http://example.com/callback" - ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL) + ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL) // Verify that the incoming path includes query parameters incomingPath := session.GetIncomingPath() diff --git a/session.go b/session.go index 3877cdb..7e96428 100644 --- a/session.go +++ b/session.go @@ -576,6 +576,17 @@ func (sd *SessionData) SetNonce(nonce string) { sd.mainSession.Values["nonce"] = nonce } +// GetCodeVerifier retrieves the PKCE code verifier from the session. +func (sd *SessionData) GetCodeVerifier() string { + codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string) + return codeVerifier +} + +// SetCodeVerifier stores the PKCE code verifier in the session. +func (sd *SessionData) SetCodeVerifier(codeVerifier string) { + sd.mainSession.Values["code_verifier"] = codeVerifier +} + // GetEmail retrieves the authenticated user's email address from the session. func (sd *SessionData) GetEmail() string { email, _ := sd.mainSession.Values["email"].(string) diff --git a/settings.go b/settings.go index 2d5380a..d16c8ee 100644 --- a/settings.go +++ b/settings.go @@ -22,6 +22,11 @@ type Config struct { // If not provided, it will be discovered from provider metadata RevocationURL string `json:"revocationURL"` + // EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional) + // This enhances security but might not be supported by all OIDC providers + // Default: false + EnablePKCE bool `json:"enablePKCE"` + // CallbackURL is the path where the OIDC provider will redirect after authentication (required) // Example: /oauth2/callback CallbackURL string `json:"callbackURL"` @@ -103,12 +108,14 @@ const ( // - RateLimit: 100 requests per second // - PostLogoutRedirectURI: "/" // - ForceHTTPS: true (for security) +// - EnablePKCE: false (PKCE is opt-in) func CreateConfig() *Config { c := &Config{ Scopes: []string{"openid", "profile", "email"}, LogLevel: DefaultLogLevel, RateLimit: DefaultRateLimit, - ForceHTTPS: true, // Secure by default + ForceHTTPS: true, // Secure by default + EnablePKCE: false, // PKCE is opt-in } return c