mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Add support for PKCE (#31)
* Add PKCE support. * Add option to toggle PKCE checks feature. * GoFMT
This commit is contained in:
@@ -62,6 +62,7 @@ testData:
|
|||||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||||
|
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
|
||||||
|
|
||||||
# Configuration documentation
|
# Configuration documentation
|
||||||
configuration:
|
configuration:
|
||||||
@@ -230,3 +231,15 @@ configuration:
|
|||||||
|
|
||||||
Example: https://accounts.google.com/logout
|
Example: https://accounts.google.com/logout
|
||||||
required: false
|
required: false
|
||||||
|
|
||||||
|
enablePKCE:
|
||||||
|
type: boolean
|
||||||
|
description: |
|
||||||
|
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
|
||||||
|
PKCE adds an extra layer of security to protect against authorization code interception attacks.
|
||||||
|
|
||||||
|
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
|
||||||
|
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
|
||||||
|
|
||||||
|
Default: false
|
||||||
|
required: false
|
||||||
|
|||||||
@@ -69,13 +69,14 @@ The middleware supports the following configuration options:
|
|||||||
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
||||||
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
|
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
|
||||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||||
| `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
||||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||||
|
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||||
|
|
||||||
## Usage Examples
|
## Usage Examples
|
||||||
|
|
||||||
@@ -233,6 +234,30 @@ spec:
|
|||||||
- profile
|
- profile
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### With PKCE Enabled
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
apiVersion: traefik.io/v1alpha1
|
||||||
|
kind: Middleware
|
||||||
|
metadata:
|
||||||
|
name: oidc-with-pkce
|
||||||
|
namespace: traefik
|
||||||
|
spec:
|
||||||
|
plugin:
|
||||||
|
traefikoidc:
|
||||||
|
providerURL: https://accounts.google.com
|
||||||
|
clientID: 1234567890.apps.googleusercontent.com
|
||||||
|
clientSecret: your-client-secret
|
||||||
|
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||||
|
callbackURL: /oauth2/callback
|
||||||
|
logoutURL: /oauth2/logout
|
||||||
|
enablePKCE: true # Enables PKCE for added security
|
||||||
|
scopes:
|
||||||
|
- openid
|
||||||
|
- email
|
||||||
|
- profile
|
||||||
|
```
|
||||||
|
|
||||||
### Keeping Secrets Secret in Kubernetes
|
### Keeping Secrets Secret in Kubernetes
|
||||||
|
|
||||||
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
|
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
|
||||||
@@ -378,6 +403,17 @@ http:
|
|||||||
|
|
||||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||||
|
|
||||||
|
### PKCE Support
|
||||||
|
|
||||||
|
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
|
||||||
|
|
||||||
|
PKCE is recommended when:
|
||||||
|
- Your OIDC provider supports it (most modern providers do)
|
||||||
|
- You need an additional layer of security for the authorization code flow
|
||||||
|
- You're concerned about potential authorization code interception attacks
|
||||||
|
|
||||||
|
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
|
||||||
|
|
||||||
### Token Caching and Blacklisting
|
### Token Caching and Blacklisting
|
||||||
|
|
||||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||||
|
|||||||
+49
-5
@@ -3,6 +3,7 @@ package traefikoidc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -27,6 +28,31 @@ func generateNonce() (string, error) {
|
|||||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateCodeVerifier creates a cryptographically secure random string
|
||||||
|
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
|
||||||
|
// characters long, per the PKCE spec (RFC 7636).
|
||||||
|
func generateCodeVerifier() (string, error) {
|
||||||
|
// Using 32 bytes (256 bits) will produce a 43 character base64url string
|
||||||
|
verifierBytes := make([]byte, 32)
|
||||||
|
_, err := rand.Read(verifierBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("could not generate code verifier: %w", err)
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deriveCodeChallenge creates a code challenge from a code verifier
|
||||||
|
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
|
||||||
|
func deriveCodeChallenge(codeVerifier string) string {
|
||||||
|
// Calculate SHA-256 hash of the code verifier
|
||||||
|
hasher := sha256.New()
|
||||||
|
hasher.Write([]byte(codeVerifier))
|
||||||
|
hash := hasher.Sum(nil)
|
||||||
|
|
||||||
|
// Base64url encode the hash to get the code challenge
|
||||||
|
return base64.RawURLEncoding.EncodeToString(hash)
|
||||||
|
}
|
||||||
|
|
||||||
// TokenResponse represents the response from the OIDC token endpoint.
|
// TokenResponse represents the response from the OIDC token endpoint.
|
||||||
// It contains the various tokens and metadata returned after successful
|
// It contains the various tokens and metadata returned after successful
|
||||||
// code exchange or token refresh operations.
|
// code exchange or token refresh operations.
|
||||||
@@ -54,7 +80,8 @@ type TokenResponse struct {
|
|||||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||||
// - codeOrToken: Either the authorization code or refresh token
|
// - codeOrToken: Either the authorization code or refresh token
|
||||||
// - redirectURL: The callback URL for authorization code grant
|
// - redirectURL: The callback URL for authorization code grant
|
||||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
// - codeVerifier: Optional PKCE code verifier for authorization code grant
|
||||||
|
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {grantType},
|
"grant_type": {grantType},
|
||||||
"client_id": {t.clientID},
|
"client_id": {t.clientID},
|
||||||
@@ -64,6 +91,11 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
|||||||
if grantType == "authorization_code" {
|
if grantType == "authorization_code" {
|
||||||
data.Set("code", codeOrToken)
|
data.Set("code", codeOrToken)
|
||||||
data.Set("redirect_uri", redirectURL)
|
data.Set("redirect_uri", redirectURL)
|
||||||
|
|
||||||
|
// Add code_verifier if PKCE is being used
|
||||||
|
if codeVerifier != "" {
|
||||||
|
data.Set("code_verifier", codeVerifier)
|
||||||
|
}
|
||||||
} else if grantType == "refresh_token" {
|
} else if grantType == "refresh_token" {
|
||||||
data.Set("refresh_token", codeOrToken)
|
data.Set("refresh_token", codeOrToken)
|
||||||
}
|
}
|
||||||
@@ -112,7 +144,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
|||||||
// This is used to refresh access tokens before they expire.
|
// This is used to refresh access tokens before they expire.
|
||||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -190,7 +222,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
|
// Get the code verifier from the session for PKCE flow
|
||||||
|
codeVerifier := session.GetCodeVerifier()
|
||||||
|
|
||||||
|
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||||
@@ -327,9 +362,18 @@ func (tc *TokenCache) Cleanup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
|
||||||
|
// The code verifier is only included in the token request if PKCE is enabled.
|
||||||
|
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
|
||||||
|
// Only include code verifier if PKCE is enabled
|
||||||
|
effectiveCodeVerifier := ""
|
||||||
|
if t.enablePKCE && codeVerifier != "" {
|
||||||
|
effectiveCodeVerifier = codeVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL, effectiveCodeVerifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ type TraefikOidc struct {
|
|||||||
scopes []string
|
scopes []string
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
forceHTTPS bool
|
forceHTTPS bool
|
||||||
|
enablePKCE bool
|
||||||
scheme string
|
scheme string
|
||||||
tokenCache *TokenCache
|
tokenCache *TokenCache
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
@@ -93,7 +94,7 @@ type TraefikOidc struct {
|
|||||||
allowedUserDomains map[string]struct{}
|
allowedUserDomains map[string]struct{}
|
||||||
allowedRolesAndGroups map[string]struct{}
|
allowedRolesAndGroups map[string]struct{}
|
||||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
|
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
|
||||||
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
|
exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||||
initComplete chan struct{}
|
initComplete chan struct{}
|
||||||
endSessionURL string
|
endSessionURL string
|
||||||
@@ -279,7 +280,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
|||||||
} else {
|
} else {
|
||||||
httpClient = createDefaultHTTPClient()
|
httpClient = createDefaultHTTPClient()
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &TraefikOidc{
|
t := &TraefikOidc{
|
||||||
next: next,
|
next: next,
|
||||||
name: name,
|
name: name,
|
||||||
@@ -302,6 +302,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
|||||||
clientID: config.ClientID,
|
clientID: config.ClientID,
|
||||||
clientSecret: config.ClientSecret,
|
clientSecret: config.ClientSecret,
|
||||||
forceHTTPS: config.ForceHTTPS,
|
forceHTTPS: config.ForceHTTPS,
|
||||||
|
enablePKCE: config.EnablePKCE,
|
||||||
scopes: config.Scopes,
|
scopes: config.Scopes,
|
||||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||||
tokenCache: NewTokenCache(),
|
tokenCache: NewTokenCache(),
|
||||||
@@ -310,9 +311,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
|||||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||||
initComplete: make(chan struct{}),
|
initComplete: make(chan struct{}),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
// Assign the initialized logger
|
|
||||||
t.logger = logger
|
|
||||||
|
|
||||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||||
t.extractClaimsFunc = extractClaims
|
t.extractClaimsFunc = extractClaims
|
||||||
@@ -732,12 +732,32 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate PKCE code verifier and challenge if PKCE is enabled
|
||||||
|
var codeVerifier, codeChallenge string
|
||||||
|
if t.enablePKCE {
|
||||||
|
var err error
|
||||||
|
codeVerifier, err = generateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive code challenge from verifier
|
||||||
|
codeChallenge = deriveCodeChallenge(codeVerifier)
|
||||||
|
}
|
||||||
|
|
||||||
// Clear any existing session data to avoid stale state causing redirect loops
|
// Clear any existing session data to avoid stale state causing redirect loops
|
||||||
session.Clear(req, rw)
|
session.Clear(req, rw)
|
||||||
|
|
||||||
// Set new session values
|
// Set new session values
|
||||||
session.SetCSRF(csrfToken)
|
session.SetCSRF(csrfToken)
|
||||||
session.SetNonce(nonce)
|
session.SetNonce(nonce)
|
||||||
|
|
||||||
|
// Only set code verifier if PKCE is enabled
|
||||||
|
if t.enablePKCE {
|
||||||
|
session.SetCodeVerifier(codeVerifier)
|
||||||
|
}
|
||||||
|
|
||||||
session.SetIncomingPath(req.URL.RequestURI())
|
session.SetIncomingPath(req.URL.RequestURI())
|
||||||
|
|
||||||
// Save the session
|
// Save the session
|
||||||
@@ -748,7 +768,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build and redirect to authentication URL
|
// Build and redirect to authentication URL
|
||||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -760,14 +780,21 @@ func (t *TraefikOidc) verifyToken(token string) error {
|
|||||||
return t.tokenVerifier.VerifyToken(token)
|
return t.tokenVerifier.VerifyToken(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildAuthURL constructs the authentication URL
|
// buildAuthURL constructs the authentication URL with optional PKCE support
|
||||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Set("client_id", t.clientID)
|
params.Set("client_id", t.clientID)
|
||||||
params.Set("response_type", "code")
|
params.Set("response_type", "code")
|
||||||
params.Set("redirect_uri", redirectURL)
|
params.Set("redirect_uri", redirectURL)
|
||||||
params.Set("state", state)
|
params.Set("state", state)
|
||||||
params.Set("nonce", nonce)
|
params.Set("nonce", nonce)
|
||||||
|
|
||||||
|
// Add PKCE parameters only if PKCE is enabled and we have a code challenge
|
||||||
|
if t.enablePKCE && codeChallenge != "" {
|
||||||
|
params.Set("code_challenge", codeChallenge)
|
||||||
|
params.Set("code_challenge_method", "S256")
|
||||||
|
}
|
||||||
|
|
||||||
if len(t.scopes) > 0 {
|
if len(t.scopes) > 0 {
|
||||||
params.Set("scope", strings.Join(t.scopes, " "))
|
params.Set("scope", strings.Join(t.scopes, " "))
|
||||||
}
|
}
|
||||||
|
|||||||
+215
-31
@@ -118,7 +118,7 @@ func (ts *TestSuite) Setup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions used by TraefikOidc
|
// Helper functions used by TraefikOidc
|
||||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
|
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
RefreshToken: "test-refresh-token",
|
RefreshToken: "test-refresh-token",
|
||||||
@@ -489,7 +489,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
queryParams string
|
queryParams string
|
||||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||||
sessionSetupFunc func(*SessionData)
|
sessionSetupFunc func(*SessionData)
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
@@ -497,7 +497,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Success",
|
name: "Success",
|
||||||
queryParams: "?code=test-code&state=test-csrf-token",
|
queryParams: "?code=test-code&state=test-csrf-token",
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
RefreshToken: "test-refresh-token",
|
RefreshToken: "test-refresh-token",
|
||||||
@@ -527,7 +527,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Exchange Code Error",
|
name: "Exchange Code Error",
|
||||||
queryParams: "?code=test-code&state=test-csrf-token",
|
queryParams: "?code=test-code&state=test-csrf-token",
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return nil, fmt.Errorf("exchange code error")
|
return nil, fmt.Errorf("exchange code error")
|
||||||
},
|
},
|
||||||
sessionSetupFunc: func(session *SessionData) {
|
sessionSetupFunc: func(session *SessionData) {
|
||||||
@@ -539,7 +539,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Missing ID Token",
|
name: "Missing ID Token",
|
||||||
queryParams: "?code=test-code&state=test-csrf-token",
|
queryParams: "?code=test-code&state=test-csrf-token",
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return &TokenResponse{}, nil
|
return &TokenResponse{}, nil
|
||||||
},
|
},
|
||||||
sessionSetupFunc: func(session *SessionData) {
|
sessionSetupFunc: func(session *SessionData) {
|
||||||
@@ -551,7 +551,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Disallowed Email",
|
name: "Disallowed Email",
|
||||||
queryParams: "?code=test-code&state=test-csrf-token",
|
queryParams: "?code=test-code&state=test-csrf-token",
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
RefreshToken: "test-refresh-token",
|
RefreshToken: "test-refresh-token",
|
||||||
@@ -572,7 +572,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Invalid State Parameter",
|
name: "Invalid State Parameter",
|
||||||
queryParams: "?code=test-code&state=invalid-csrf-token",
|
queryParams: "?code=test-code&state=invalid-csrf-token",
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
RefreshToken: "test-refresh-token",
|
RefreshToken: "test-refresh-token",
|
||||||
@@ -593,7 +593,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Nonce Mismatch",
|
name: "Nonce Mismatch",
|
||||||
queryParams: "?code=test-code&state=test-csrf-token",
|
queryParams: "?code=test-code&state=test-csrf-token",
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
RefreshToken: "test-refresh-token",
|
RefreshToken: "test-refresh-token",
|
||||||
@@ -614,7 +614,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Missing Nonce in Claims",
|
name: "Missing Nonce in Claims",
|
||||||
queryParams: "?code=test-code&state=test-csrf-token",
|
queryParams: "?code=test-code&state=test-csrf-token",
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
RefreshToken: "test-refresh-token",
|
RefreshToken: "test-refresh-token",
|
||||||
@@ -730,7 +730,7 @@ func TestOIDCHandler(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
queryParams string
|
queryParams string
|
||||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||||
sessionSetupFunc func(session *sessions.Session)
|
sessionSetupFunc func(session *sessions.Session)
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
@@ -746,7 +746,7 @@ func TestOIDCHandler(t *testing.T) {
|
|||||||
session.Values["csrf"] = "test-csrf-token"
|
session.Values["csrf"] = "test-csrf-token"
|
||||||
session.Values["nonce"] = "test-nonce"
|
session.Values["nonce"] = "test-nonce"
|
||||||
},
|
},
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
// Simulate token exchange
|
// Simulate token exchange
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
@@ -770,7 +770,7 @@ func TestOIDCHandler(t *testing.T) {
|
|||||||
session.Values["csrf"] = "test-csrf-token"
|
session.Values["csrf"] = "test-csrf-token"
|
||||||
session.Values["nonce"] = "test-nonce"
|
session.Values["nonce"] = "test-nonce"
|
||||||
},
|
},
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
// Simulate token exchange
|
// Simulate token exchange
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
@@ -793,7 +793,7 @@ func TestOIDCHandler(t *testing.T) {
|
|||||||
session.Values["csrf"] = "test-csrf-token"
|
session.Values["csrf"] = "test-csrf-token"
|
||||||
session.Values["nonce"] = "test-nonce"
|
session.Values["nonce"] = "test-nonce"
|
||||||
},
|
},
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
// Simulate token exchange
|
// Simulate token exchange
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
@@ -817,7 +817,7 @@ func TestOIDCHandler(t *testing.T) {
|
|||||||
session.Values["csrf"] = "test-csrf-token"
|
session.Values["csrf"] = "test-csrf-token"
|
||||||
session.Values["nonce"] = "test-nonce"
|
session.Values["nonce"] = "test-nonce"
|
||||||
},
|
},
|
||||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||||
// Simulate token exchange
|
// Simulate token exchange
|
||||||
return &TokenResponse{
|
return &TokenResponse{
|
||||||
IDToken: ts.token,
|
IDToken: ts.token,
|
||||||
@@ -1664,6 +1664,17 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to compare string slices
|
// Helper function to compare string slices
|
||||||
|
func stringSliceEqual(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
|
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
|
||||||
func TestExchangeTokensWithRedirects(t *testing.T) {
|
func TestExchangeTokensWithRedirects(t *testing.T) {
|
||||||
@@ -1748,7 +1759,7 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
|
|||||||
tOidc.tokenURL = server.URL
|
tOidc.tokenURL = server.URL
|
||||||
|
|
||||||
// Test token exchange
|
// Test token exchange
|
||||||
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback")
|
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
|
||||||
|
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -1770,18 +1781,6 @@ func TestExchangeTokensWithRedirects(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func stringSliceEqual(a, b []string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range a {
|
|
||||||
if a[i] != b[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
|
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
|
||||||
func TestBuildAuthURL(t *testing.T) {
|
func TestBuildAuthURL(t *testing.T) {
|
||||||
ts := &TestSuite{t: t}
|
ts := &TestSuite{t: t}
|
||||||
@@ -1794,7 +1793,10 @@ func TestBuildAuthURL(t *testing.T) {
|
|||||||
redirectURL string
|
redirectURL string
|
||||||
state string
|
state string
|
||||||
nonce string
|
nonce string
|
||||||
|
enablePKCE bool
|
||||||
|
codeChallenge string
|
||||||
expectedPrefix string
|
expectedPrefix string
|
||||||
|
checkPKCE bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Absolute Auth URL",
|
name: "Absolute Auth URL",
|
||||||
@@ -1803,7 +1805,10 @@ func TestBuildAuthURL(t *testing.T) {
|
|||||||
redirectURL: "https://app.example.com/callback",
|
redirectURL: "https://app.example.com/callback",
|
||||||
state: "test-state",
|
state: "test-state",
|
||||||
nonce: "test-nonce",
|
nonce: "test-nonce",
|
||||||
|
enablePKCE: false,
|
||||||
|
codeChallenge: "",
|
||||||
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||||
|
checkPKCE: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Relative Auth URL",
|
name: "Relative Auth URL",
|
||||||
@@ -1812,7 +1817,10 @@ func TestBuildAuthURL(t *testing.T) {
|
|||||||
redirectURL: "https://app.example.com/callback",
|
redirectURL: "https://app.example.com/callback",
|
||||||
state: "test-state",
|
state: "test-state",
|
||||||
nonce: "test-nonce",
|
nonce: "test-nonce",
|
||||||
|
enablePKCE: false,
|
||||||
|
codeChallenge: "",
|
||||||
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
||||||
|
checkPKCE: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Relative Auth URL with Different Issuer",
|
name: "Relative Auth URL with Different Issuer",
|
||||||
@@ -1821,7 +1829,46 @@ func TestBuildAuthURL(t *testing.T) {
|
|||||||
redirectURL: "https://app.example.com/callback",
|
redirectURL: "https://app.example.com/callback",
|
||||||
state: "test-state",
|
state: "test-state",
|
||||||
nonce: "test-nonce",
|
nonce: "test-nonce",
|
||||||
|
enablePKCE: false,
|
||||||
|
codeChallenge: "",
|
||||||
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
||||||
|
checkPKCE: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With PKCE Enabled",
|
||||||
|
authURL: "https://auth.example.com/oauth/authorize",
|
||||||
|
issuerURL: "https://auth.example.com",
|
||||||
|
redirectURL: "https://app.example.com/callback",
|
||||||
|
state: "test-state",
|
||||||
|
nonce: "test-nonce",
|
||||||
|
enablePKCE: true,
|
||||||
|
codeChallenge: "test-code-challenge",
|
||||||
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||||
|
checkPKCE: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With PKCE Enabled but No Challenge",
|
||||||
|
authURL: "https://auth.example.com/oauth/authorize",
|
||||||
|
issuerURL: "https://auth.example.com",
|
||||||
|
redirectURL: "https://app.example.com/callback",
|
||||||
|
state: "test-state",
|
||||||
|
nonce: "test-nonce",
|
||||||
|
enablePKCE: true,
|
||||||
|
codeChallenge: "",
|
||||||
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||||
|
checkPKCE: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With PKCE Disabled but Challenge Provided",
|
||||||
|
authURL: "https://auth.example.com/oauth/authorize",
|
||||||
|
issuerURL: "https://auth.example.com",
|
||||||
|
redirectURL: "https://app.example.com/callback",
|
||||||
|
state: "test-state",
|
||||||
|
nonce: "test-nonce",
|
||||||
|
enablePKCE: false,
|
||||||
|
codeChallenge: "test-code-challenge",
|
||||||
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
||||||
|
checkPKCE: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1831,9 +1878,10 @@ func TestBuildAuthURL(t *testing.T) {
|
|||||||
tOidc := ts.tOidc
|
tOidc := ts.tOidc
|
||||||
tOidc.authURL = tc.authURL
|
tOidc.authURL = tc.authURL
|
||||||
tOidc.issuerURL = tc.issuerURL
|
tOidc.issuerURL = tc.issuerURL
|
||||||
|
tOidc.enablePKCE = tc.enablePKCE
|
||||||
|
|
||||||
// Call buildAuthURL
|
// Call buildAuthURL with code challenge
|
||||||
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce)
|
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
|
||||||
|
|
||||||
// Verify the URL starts with the expected prefix
|
// Verify the URL starts with the expected prefix
|
||||||
if !strings.HasPrefix(result, tc.expectedPrefix) {
|
if !strings.HasPrefix(result, tc.expectedPrefix) {
|
||||||
@@ -1861,6 +1909,23 @@ func TestBuildAuthURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify PKCE parameters
|
||||||
|
if tc.checkPKCE {
|
||||||
|
if got := query.Get("code_challenge"); got != tc.codeChallenge {
|
||||||
|
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||||
|
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if got := query.Get("code_challenge"); got != "" {
|
||||||
|
t.Errorf("Expected no code_challenge, but got %q", got)
|
||||||
|
}
|
||||||
|
if got := query.Get("code_challenge_method"); got != "" {
|
||||||
|
t.Errorf("Expected no code_challenge_method, but got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Verify scopes are present and correct
|
// Verify scopes are present and correct
|
||||||
if len(tOidc.scopes) > 0 {
|
if len(tOidc.scopes) > 0 {
|
||||||
expectedScopes := strings.Join(tOidc.scopes, " ")
|
expectedScopes := strings.Join(tOidc.scopes, " ")
|
||||||
@@ -1872,6 +1937,125 @@ func TestBuildAuthURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
|
||||||
|
func TestExchangeCodeForToken(t *testing.T) {
|
||||||
|
ts := &TestSuite{t: t}
|
||||||
|
ts.Setup()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
enablePKCE bool
|
||||||
|
codeVerifier string
|
||||||
|
setupMock func(t *testing.T) *httptest.Server
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "With PKCE Enabled and Code Verifier",
|
||||||
|
enablePKCE: true,
|
||||||
|
codeVerifier: "test-code-verifier",
|
||||||
|
setupMock: func(t *testing.T) *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("Failed to parse form: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify code_verifier is included
|
||||||
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
|
||||||
|
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return successful token response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(TokenResponse{
|
||||||
|
IDToken: "test.id.token",
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
RefreshToken: "test-refresh-token",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With PKCE Disabled but Code Verifier Provided",
|
||||||
|
enablePKCE: false,
|
||||||
|
codeVerifier: "test-code-verifier",
|
||||||
|
setupMock: func(t *testing.T) *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("Failed to parse form: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify code_verifier is NOT included
|
||||||
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
||||||
|
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return successful token response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(TokenResponse{
|
||||||
|
IDToken: "test.id.token",
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
RefreshToken: "test-refresh-token",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With PKCE Enabled but No Code Verifier",
|
||||||
|
enablePKCE: true,
|
||||||
|
codeVerifier: "",
|
||||||
|
setupMock: func(t *testing.T) *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("Failed to parse form: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify code_verifier is NOT included
|
||||||
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
||||||
|
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return successful token response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(TokenResponse{
|
||||||
|
IDToken: "test.id.token",
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
RefreshToken: "test-refresh-token",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
server := tc.setupMock(t)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Configure the test instance
|
||||||
|
tOidc := ts.tOidc
|
||||||
|
tOidc.tokenURL = server.URL
|
||||||
|
tOidc.enablePKCE = tc.enablePKCE
|
||||||
|
|
||||||
|
// Test exchangeCodeForToken
|
||||||
|
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if response == nil {
|
||||||
|
t.Error("Expected token response but got nil")
|
||||||
|
} else if response.IDToken != "test.id.token" {
|
||||||
|
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
|
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
|
||||||
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||||
ts := &TestSuite{t: t}
|
ts := &TestSuite{t: t}
|
||||||
@@ -1879,7 +2063,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
|||||||
|
|
||||||
// Create a request with query parameters
|
// Create a request with query parameters
|
||||||
req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil)
|
req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil)
|
||||||
rw := httptest.NewRecorder()
|
responseRecorder := httptest.NewRecorder()
|
||||||
|
|
||||||
// Get session
|
// Get session
|
||||||
session, err := ts.sessionManager.GetSession(req)
|
session, err := ts.sessionManager.GetSession(req)
|
||||||
@@ -1889,7 +2073,7 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
|||||||
|
|
||||||
// Call defaultInitiateAuthentication
|
// Call defaultInitiateAuthentication
|
||||||
redirectURL := "http://example.com/callback"
|
redirectURL := "http://example.com/callback"
|
||||||
ts.tOidc.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
|
||||||
|
|
||||||
// Verify that the incoming path includes query parameters
|
// Verify that the incoming path includes query parameters
|
||||||
incomingPath := session.GetIncomingPath()
|
incomingPath := session.GetIncomingPath()
|
||||||
|
|||||||
+11
@@ -576,6 +576,17 @@ func (sd *SessionData) SetNonce(nonce string) {
|
|||||||
sd.mainSession.Values["nonce"] = nonce
|
sd.mainSession.Values["nonce"] = nonce
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCodeVerifier retrieves the PKCE code verifier from the session.
|
||||||
|
func (sd *SessionData) GetCodeVerifier() string {
|
||||||
|
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
||||||
|
return codeVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCodeVerifier stores the PKCE code verifier in the session.
|
||||||
|
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||||
|
sd.mainSession.Values["code_verifier"] = codeVerifier
|
||||||
|
}
|
||||||
|
|
||||||
// GetEmail retrieves the authenticated user's email address from the session.
|
// GetEmail retrieves the authenticated user's email address from the session.
|
||||||
func (sd *SessionData) GetEmail() string {
|
func (sd *SessionData) GetEmail() string {
|
||||||
email, _ := sd.mainSession.Values["email"].(string)
|
email, _ := sd.mainSession.Values["email"].(string)
|
||||||
|
|||||||
+8
-1
@@ -22,6 +22,11 @@ type Config struct {
|
|||||||
// If not provided, it will be discovered from provider metadata
|
// If not provided, it will be discovered from provider metadata
|
||||||
RevocationURL string `json:"revocationURL"`
|
RevocationURL string `json:"revocationURL"`
|
||||||
|
|
||||||
|
// EnablePKCE enables Proof Key for Code Exchange (PKCE) for the authorization code flow (optional)
|
||||||
|
// This enhances security but might not be supported by all OIDC providers
|
||||||
|
// Default: false
|
||||||
|
EnablePKCE bool `json:"enablePKCE"`
|
||||||
|
|
||||||
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
||||||
// Example: /oauth2/callback
|
// Example: /oauth2/callback
|
||||||
CallbackURL string `json:"callbackURL"`
|
CallbackURL string `json:"callbackURL"`
|
||||||
@@ -103,12 +108,14 @@ const (
|
|||||||
// - RateLimit: 100 requests per second
|
// - RateLimit: 100 requests per second
|
||||||
// - PostLogoutRedirectURI: "/"
|
// - PostLogoutRedirectURI: "/"
|
||||||
// - ForceHTTPS: true (for security)
|
// - ForceHTTPS: true (for security)
|
||||||
|
// - EnablePKCE: false (PKCE is opt-in)
|
||||||
func CreateConfig() *Config {
|
func CreateConfig() *Config {
|
||||||
c := &Config{
|
c := &Config{
|
||||||
Scopes: []string{"openid", "profile", "email"},
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
LogLevel: DefaultLogLevel,
|
LogLevel: DefaultLogLevel,
|
||||||
RateLimit: DefaultRateLimit,
|
RateLimit: DefaultRateLimit,
|
||||||
ForceHTTPS: true, // Secure by default
|
ForceHTTPS: true, // Secure by default
|
||||||
|
EnablePKCE: false, // PKCE is opt-in
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|||||||
Reference in New Issue
Block a user