mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8c5df82dcf | |||
| aa96e9dbee | |||
| 1e33bb0a4d | |||
| bfd702a447 | |||
| 68c150eba4 | |||
| 9cbca4c4fb | |||
| 684a990f59 | |||
| 1b6c8616fd |
@@ -0,0 +1,15 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: lukaszraczylo
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
polar: # Replace with a single Polar username
|
||||
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||
thanks_dev: # Replace with a single thanks.dev username
|
||||
custom: https://monzo.me/lukaszraczylo
|
||||
@@ -23,6 +23,19 @@ testData:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
# Alternative: RFC 7523 private_key_jwt client authentication (Entra ID,
|
||||
# Okta, Auth0, Keycloak). Replaces clientSecret with a signed JWT assertion.
|
||||
# See README "Client authentication via private key JWT".
|
||||
# clientAuthMethod: private_key_jwt
|
||||
# clientAssertionKeyID: my-key-2026
|
||||
# clientAssertionAlg: RS256 # default; or PS256/384/512, ES256/384/512
|
||||
# # File path option:
|
||||
# clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
|
||||
# # Or inline PEM (PKCS#8 / PKCS#1 / SEC1):
|
||||
# clientAssertionPrivateKey: |
|
||||
# -----BEGIN PRIVATE KEY-----
|
||||
# MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDexampleexample
|
||||
# -----END PRIVATE KEY-----
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ More example configs in [`examples/`](examples/).
|
||||
|-----------|-------------|
|
||||
| `providerURL` | Issuer URL (used for OIDC discovery). |
|
||||
| `clientID` | OAuth 2.0 client ID. |
|
||||
| `clientSecret` | OAuth 2.0 client secret. Supports `urn:k8s:secret:ns:name:key`. |
|
||||
| `clientSecret` | OAuth 2.0 client secret. Supports `urn:k8s:secret:ns:name:key`. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`; optional with `private_key_jwt`. |
|
||||
| `sessionEncryptionKey` | Cookie encryption key, **min 32 bytes**. |
|
||||
| `callbackURL` | Callback path, e.g. `/oauth2/callback`. |
|
||||
|
||||
@@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. |
|
||||
| `sessionMaxAge` | `86400` | Session lifetime in seconds. |
|
||||
| `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. |
|
||||
| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. |
|
||||
| `rateLimit` | `100` | Requests/sec. Min `10`. |
|
||||
| `logLevel` | `info` | `debug`, `info`, `error`. |
|
||||
| `audience` | `clientID` | Custom access-token audience (Auth0 custom APIs). |
|
||||
@@ -132,6 +133,11 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `stripAuthCookies` | `false` | Strip OIDC cookies from backend hop (mitigates HTTP 431). |
|
||||
| `caCertPath` / `caCertPEM` | none | Trust an internal CA for the provider's TLS. |
|
||||
| `insecureSkipVerify` | `false` | **Local dev only.** Disables TLS verification, logs a security warning. |
|
||||
| `clientAuthMethod` | `client_secret_post` | Client auth method. Set `private_key_jwt` for RFC 7523 JWT assertions (Entra ID, Okta, Auth0, Keycloak). See [Client authentication via private key JWT](#client-authentication-via-private-key-jwt). |
|
||||
| `clientAssertionPrivateKey` | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. |
|
||||
| `clientAssertionKeyPath` | none | File path to PEM private key for `private_key_jwt`. |
|
||||
| `clientAssertionKeyID` | none | JWS `kid` header. Required when `clientAuthMethod=private_key_jwt`; must match the public key registered with the IdP. |
|
||||
| `clientAssertionAlg` | `RS256` | JWS alg for `private_key_jwt`. Supported: `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
|
||||
| `enableBackchannelLogout` / `backchannelLogoutURL` | `false` / none | OIDC Back-Channel Logout (server-to-server). |
|
||||
| `enableFrontchannelLogout` / `frontchannelLogoutURL` | `false` / none | OIDC Front-Channel Logout (iframe). |
|
||||
| `redis` | disabled | See [docs/REDIS.md](docs/REDIS.md). |
|
||||
@@ -165,6 +171,22 @@ Each instance must use a unique `cookiePrefix` **and** `sessionEncryptionKey`,
|
||||
otherwise a session minted by one instance can grant access through another.
|
||||
See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87).
|
||||
|
||||
### SSE and WebSocket endpoints
|
||||
|
||||
Browser clients cannot follow an OIDC `302` redirect on an SSE stream or a
|
||||
WebSocket upgrade. The middleware handles this automatically:
|
||||
|
||||
- **SSE** (`Accept: text/event-stream`) and **WebSocket** (`Upgrade: websocket`)
|
||||
requests skip the OIDC redirect.
|
||||
- They are **not** unauthenticated — a valid encrypted session cookie is
|
||||
required, otherwise the request is rejected. The session must already exist
|
||||
(i.e. the user logged in via a normal HTTP page first).
|
||||
- `X-Forwarded-User` is forwarded from the session.
|
||||
- Validation is cookie-only (no JWK fetch), so streaming keeps working during
|
||||
brief IdP outages.
|
||||
|
||||
No configuration needed — this is implicit behavior.
|
||||
|
||||
### HTTP 431 from backends
|
||||
|
||||
Either the ID token or the chunked OIDC cookies overflow your backend's header
|
||||
@@ -196,6 +218,44 @@ caCertPEM: |
|
||||
Both can be combined. An unparseable bundle fails the plugin at startup.
|
||||
See [#125](https://github.com/lukaszraczylo/traefikoidc/issues/125).
|
||||
|
||||
### Client authentication via private key JWT
|
||||
|
||||
Use when your IdP enforces short-lived secrets or pushes secretless client auth
|
||||
— Microsoft Entra ID / Azure AD, Okta, Auth0, Keycloak. Instead of sending a
|
||||
static `clientSecret`, the plugin signs a short-lived JWT and submits it as
|
||||
`client_assertion` per [RFC 7523](https://www.rfc-editor.org/rfc/rfc7523).
|
||||
|
||||
Minimal config:
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
|
||||
clientAssertionKeyID: my-key-2026
|
||||
# clientAssertionAlg: RS256 # default; or PS256/384/512, ES256/384/512
|
||||
```
|
||||
|
||||
Or inline:
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionPrivateKey: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
...
|
||||
-----END PRIVATE KEY-----
|
||||
clientAssertionKeyID: my-key-2026
|
||||
```
|
||||
|
||||
Accepted PEM forms: PKCS#8 (`PRIVATE KEY`), PKCS#1 (`RSA PRIVATE KEY`), SEC1
|
||||
(`EC PRIVATE KEY`). The assertion uses `iss=sub=clientID`, `aud=tokenURL`, 60s
|
||||
lifetime, random hex `jti` per request. Sent on `/token` (auth-code + refresh)
|
||||
and `/revoke`. The `kid` must match the public key registered with the IdP.
|
||||
|
||||
`clientSecret` becomes optional with `private_key_jwt`. Existing
|
||||
`client_secret_post` setups are unaffected. Keys are parsed once at startup —
|
||||
rotation requires a Traefik reload.
|
||||
|
||||
See [issue #135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
|
||||
|
||||
### Environment variable names containing `API`
|
||||
|
||||
Traefik reserves `TRAEFIK_API_*`. User vars whose name contains `API` (e.g.
|
||||
|
||||
+1
-1
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetUserIdentifier("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
|
||||
+30
-7
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
@@ -42,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
@@ -249,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
@@ -289,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens to prevent replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -360,9 +361,31 @@ func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
|
||||
return !strings.Contains(accept, "text/html")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
|
||||
// isRefreshTokenExpired checks whether the stored refresh token is likely
|
||||
// past its useful lifetime, using the cookie-side issued_at timestamp set by
|
||||
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
|
||||
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
|
||||
// MaxRefreshTokenAgeSeconds; 0 disables the check).
|
||||
//
|
||||
// The point of this check is to short-circuit the refresh path BEFORE the
|
||||
// thundering herd hits the IdP for a token the provider has almost certainly
|
||||
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
|
||||
// style polling clients from looping on invalid_grant after a long pause.
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
// This is a heuristic check - actual implementation would depend on
|
||||
// the specific provider and token metadata
|
||||
return false // Placeholder implementation
|
||||
if t == nil || session == nil {
|
||||
return false
|
||||
}
|
||||
if t.maxRefreshTokenAge <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
issuedAt := session.GetRefreshTokenIssuedAt()
|
||||
if issuedAt.IsZero() {
|
||||
// No timestamp recorded (legacy session pre-dating the issued_at
|
||||
// field). Don't force a re-auth - attempt refresh once and let the
|
||||
// IdP be the source of truth.
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(issuedAt) > t.maxRefreshTokenAge
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Pre-populate session with old data
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-access-token-with-many-characters")
|
||||
session.SetRefreshToken("old-refresh-token-with-many-characters")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Verify old data is cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
|
||||
// Verify new data is set
|
||||
s.Equal(csrfToken, session.GetCSRF())
|
||||
@@ -711,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
session, err := sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
session.mainSession.Values["redirect_count"] = 3
|
||||
|
||||
@@ -720,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
|
||||
// Session should be cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
s.Empty(session.GetIDToken())
|
||||
|
||||
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
|
||||
|
||||
@@ -113,6 +113,14 @@ func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
|
||||
// by the refresh path to coalesce grants across Traefik replicas via Redis.
|
||||
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// isSupportedClientAssertionAlg reports whether alg is a recognized JWS
|
||||
// algorithm for private_key_jwt (RFC 7523 §2.2).
|
||||
func isSupportedClientAssertionAlg(alg string) bool {
|
||||
switch alg {
|
||||
case "RS256", "RS384", "RS512",
|
||||
"PS256", "PS384", "PS512",
|
||||
"ES256", "ES384", "ES512":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ClientAssertionSigner builds and signs client_assertion JWTs (RFC 7523 §2.2).
|
||||
type ClientAssertionSigner struct {
|
||||
key crypto.PrivateKey
|
||||
alg string
|
||||
kid string
|
||||
// rand is the entropy source for jti generation and PSS/ECDSA signing.
|
||||
// Defaults to crypto/rand.Reader when nil.
|
||||
rand io.Reader
|
||||
// now returns the current time. Defaults to time.Now when nil.
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// NewClientAssertionSigner parses pemBytes as a private key, validates that
|
||||
// alg is consistent with the key type, and returns a ready-to-use signer.
|
||||
// kid is placed verbatim in the JWS header.
|
||||
//
|
||||
// PEM block types understood:
|
||||
// - "PRIVATE KEY" → PKCS#8 (tried first for all types)
|
||||
// - "RSA PRIVATE KEY" → PKCS#1
|
||||
// - "EC PRIVATE KEY" → SEC1
|
||||
func NewClientAssertionSigner(pemBytes []byte, alg, kid string) (*ClientAssertionSigner, error) {
|
||||
if !isSupportedClientAssertionAlg(alg) {
|
||||
return nil, fmt.Errorf("unsupported client assertion alg %q", alg)
|
||||
}
|
||||
if kid == "" {
|
||||
return nil, fmt.Errorf("kid must not be empty")
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM block found in private key material")
|
||||
}
|
||||
|
||||
var key crypto.PrivateKey
|
||||
var parseErr error
|
||||
|
||||
switch block.Type {
|
||||
case "PRIVATE KEY":
|
||||
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
case "RSA PRIVATE KEY":
|
||||
key, parseErr = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
key, parseErr = x509.ParseECPrivateKey(block.Bytes)
|
||||
default:
|
||||
// Best-effort fallback for unknown block types.
|
||||
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
}
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key (block type %q): %w", block.Type, parseErr)
|
||||
}
|
||||
|
||||
if err := validateAlgKeyMatch(alg, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ClientAssertionSigner{key: key, alg: alg, kid: kid}, nil
|
||||
}
|
||||
|
||||
// validateAlgKeyMatch returns an error when alg implies a key type that does
|
||||
// not match the actual key.
|
||||
func validateAlgKeyMatch(alg string, key crypto.PrivateKey) error {
|
||||
switch alg[0] {
|
||||
case 'R', 'P': // RS* or PS*
|
||||
if _, ok := key.(*rsa.PrivateKey); !ok {
|
||||
return fmt.Errorf("alg %q requires an RSA key, got %T", alg, key)
|
||||
}
|
||||
case 'E': // ES*
|
||||
if _, ok := key.(*ecdsa.PrivateKey); !ok {
|
||||
return fmt.Errorf("alg %q requires an EC key, got %T", alg, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign constructs and returns a signed client_assertion JWT.
|
||||
// audience is typically the token endpoint URL (RFC 7523 §3).
|
||||
// clientID is used as both iss and sub per RFC 7523 §2.2.
|
||||
func (s *ClientAssertionSigner) Sign(audience, clientID string) (string, error) {
|
||||
rander := s.rand
|
||||
if rander == nil {
|
||||
rander = rand.Reader
|
||||
}
|
||||
nowFn := s.now
|
||||
if nowFn == nil {
|
||||
nowFn = time.Now
|
||||
}
|
||||
|
||||
now := nowFn()
|
||||
|
||||
// 16 random bytes as lowercase hex for jti uniqueness.
|
||||
jtiBytes := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rander, jtiBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate jti: %w", err)
|
||||
}
|
||||
jti := hex.EncodeToString(jtiBytes)
|
||||
|
||||
header := map[string]string{
|
||||
"alg": s.alg,
|
||||
"typ": "JWT",
|
||||
"kid": s.kid,
|
||||
}
|
||||
hdrJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal JWT header: %w", err)
|
||||
}
|
||||
|
||||
claims := map[string]any{
|
||||
"iss": clientID,
|
||||
"sub": clientID,
|
||||
"aud": audience,
|
||||
"jti": jti,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(60 * time.Second).Unix(),
|
||||
}
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal JWT claims: %w", err)
|
||||
}
|
||||
|
||||
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signingInput := hdrB64 + "." + claimsB64
|
||||
|
||||
sig, err := s.sign(rander, []byte(signingInput))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil
|
||||
}
|
||||
|
||||
// sign computes raw signature bytes for signingInput per s.alg.
|
||||
// validateAlgKeyMatch in NewClientAssertionSigner guarantees the key type
|
||||
// matches s.alg, but the comma-ok asserts here keep errcheck happy and
|
||||
// surface internal misuse loudly instead of via panic.
|
||||
func (s *ClientAssertionSigner) sign(rander io.Reader, input []byte) ([]byte, error) {
|
||||
switch s.alg {
|
||||
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
|
||||
rsaKey, ok := s.key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal: alg %q requires *rsa.PrivateKey, got %T", s.alg, s.key)
|
||||
}
|
||||
hash := rsaHashForAlg(s.alg)
|
||||
digest := hashSum(hash, input)
|
||||
if s.alg[0] == 'R' {
|
||||
return signRSAPKCS1v15(rander, rsaKey, hash, digest)
|
||||
}
|
||||
return signRSAPSS(rander, rsaKey, hash, digest)
|
||||
case "ES256", "ES384", "ES512":
|
||||
ecKey, ok := s.key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal: alg %q requires *ecdsa.PrivateKey, got %T", s.alg, s.key)
|
||||
}
|
||||
hash := ecHashForAlg(s.alg)
|
||||
digest := hashSum(hash, input)
|
||||
return signECDSA(rander, ecKey, digest)
|
||||
}
|
||||
return nil, fmt.Errorf("unhandled alg %q", s.alg)
|
||||
}
|
||||
|
||||
func rsaHashForAlg(alg string) crypto.Hash {
|
||||
switch alg {
|
||||
case "RS256", "PS256":
|
||||
return crypto.SHA256
|
||||
case "RS384", "PS384":
|
||||
return crypto.SHA384
|
||||
case "RS512", "PS512":
|
||||
return crypto.SHA512
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func ecHashForAlg(alg string) crypto.Hash {
|
||||
switch alg {
|
||||
case "ES256":
|
||||
return crypto.SHA256
|
||||
case "ES384":
|
||||
return crypto.SHA384
|
||||
case "ES512":
|
||||
return crypto.SHA512
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func hashSum(h crypto.Hash, input []byte) []byte {
|
||||
switch h {
|
||||
case crypto.SHA256:
|
||||
sum := sha256.Sum256(input)
|
||||
return sum[:]
|
||||
case crypto.SHA384:
|
||||
sum := sha512.Sum384(input)
|
||||
return sum[:]
|
||||
case crypto.SHA512:
|
||||
sum := sha512.Sum512(input)
|
||||
return sum[:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func signRSAPKCS1v15(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
|
||||
sig, err := rsa.SignPKCS1v15(rander, key, hash, digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("RSA PKCS1v15 signing failed: %w", err)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
func signRSAPSS(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
|
||||
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hash}
|
||||
sig, err := rsa.SignPSS(rander, key, hash, digest, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("RSA PSS signing failed: %w", err)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// signECDSA produces the JWS raw r||s signature (RFC 7515 App. A.3).
|
||||
// Each scalar is zero-padded to (curve.BitSize+7)/8 bytes.
|
||||
func signECDSA(rander io.Reader, key *ecdsa.PrivateKey, digest []byte) ([]byte, error) {
|
||||
r, ss, err := ecdsa.Sign(rander, key, digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ECDSA signing failed: %w", err)
|
||||
}
|
||||
byteLen := (key.Curve.Params().BitSize + 7) / 8
|
||||
sig := make([]byte, 2*byteLen)
|
||||
padBigInt(sig[0:byteLen], r)
|
||||
padBigInt(sig[byteLen:], ss)
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// padBigInt writes n as a fixed-width big-endian integer into buf.
|
||||
func padBigInt(buf []byte, n *big.Int) {
|
||||
b := n.Bytes()
|
||||
copy(buf[len(buf)-len(b):], b)
|
||||
}
|
||||
|
||||
// buildClientAssertionSignerFromConfig loads key material and constructs a
|
||||
// ClientAssertionSigner. Called from NewWithContext when
|
||||
// ClientAuthMethod == "private_key_jwt".
|
||||
func buildClientAssertionSignerFromConfig(config *Config) (*ClientAssertionSigner, error) {
|
||||
var pemBytes []byte
|
||||
|
||||
if config.ClientAssertionPrivateKey != "" {
|
||||
pemBytes = []byte(config.ClientAssertionPrivateKey)
|
||||
} else {
|
||||
data, err := os.ReadFile(config.ClientAssertionKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read clientAssertionKeyPath %q: %w", config.ClientAssertionKeyPath, err)
|
||||
}
|
||||
pemBytes = data
|
||||
}
|
||||
|
||||
alg := config.ClientAssertionAlg
|
||||
if alg == "" {
|
||||
alg = "RS256"
|
||||
}
|
||||
|
||||
return NewClientAssertionSigner(pemBytes, alg, config.ClientAssertionKeyID)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("old-access-token")
|
||||
session.SetRefreshToken("old-refresh-token")
|
||||
session.SetIDToken("old-id-token")
|
||||
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Now perform selective clearing (as done in the fix)
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// Set initial session data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-token")
|
||||
session.SetCSRF("existing-csrf")
|
||||
|
||||
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
|
||||
// NEW BEHAVIOR: Selective clearing
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
|
||||
+158
-1
@@ -5,6 +5,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|
||||
## Table of Contents
|
||||
|
||||
- [Required Parameters](#required-parameters)
|
||||
- [Client Authentication](#client-authentication)
|
||||
- [Optional Parameters](#optional-parameters)
|
||||
- [Security Options](#security-options)
|
||||
- [Session Management](#session-management)
|
||||
@@ -22,7 +23,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|
||||
|-----------|------|-------------|---------|
|
||||
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`. Optional when `clientAuthMethod: private_key_jwt`. | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
|
||||
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
|
||||
|
||||
@@ -45,6 +46,129 @@ spec:
|
||||
|
||||
---
|
||||
|
||||
## Client Authentication
|
||||
|
||||
The middleware supports three client authentication methods at the token and
|
||||
revocation endpoints. The default is `client_secret_post` (current behavior);
|
||||
`private_key_jwt` is opt-in and backwards compatible.
|
||||
|
||||
| Method | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `client_secret_post` | yes | `client_id` + `client_secret` in the request body. |
|
||||
| `client_secret_basic` | no | RFC 6749 §2.3.1 — `client_id` + `client_secret` in the `Authorization: Basic` header (form-urlencoded then base64); not in the body. |
|
||||
| `private_key_jwt` | no | RFC 7523 §2.2 — plugin signs a short-lived JWT with a private key and sends it as `client_assertion`. |
|
||||
|
||||
Select via `clientAuthMethod`:
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
```
|
||||
|
||||
### client_secret_post
|
||||
|
||||
Default. The plugin sends `client_id` and `client_secret` as form parameters
|
||||
in the token / revocation request body. No additional configuration required.
|
||||
|
||||
### private_key_jwt
|
||||
|
||||
Asymmetric client authentication per
|
||||
[RFC 7523 §2.2](https://www.rfc-editor.org/rfc/rfc7523). Use this when your
|
||||
IdP enforces short secret TTLs, when policy mandates secretless clients, or
|
||||
when you want to avoid distributing a shared secret to the proxy.
|
||||
|
||||
For each token / revocation request the plugin builds a JWS with:
|
||||
|
||||
- `iss` = `sub` = `clientID`
|
||||
- `aud` = token endpoint URL
|
||||
- `iat` = now, `exp` = now + 60s
|
||||
- `jti` = random hex per request
|
||||
- `kid` header = `clientAssertionKeyID`
|
||||
|
||||
**Required fields:**
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `clientAuthMethod` | string | `client_secret_post` | Set to `private_key_jwt`. |
|
||||
| `clientAssertionPrivateKey` | string | none | Inline PEM private key. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8, PKCS#1, and SEC1 formats accepted. |
|
||||
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk. Mutually exclusive with `clientAssertionPrivateKey`. |
|
||||
| `clientAssertionKeyID` | string | none | `kid` header inserted in the JWS. Must match the public key registered with the IdP. |
|
||||
| `clientAssertionAlg` | string | `RS256` | One of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. |
|
||||
|
||||
When `clientAuthMethod: private_key_jwt`, `clientSecret` is optional.
|
||||
|
||||
**Example — inline PEM:**
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://idp.example.com
|
||||
clientID: my-client-id
|
||||
sessionEncryptionKey: your-32-byte-encryption-key-here
|
||||
callbackURL: /oauth2/callback
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyID: key-2026-01
|
||||
clientAssertionAlg: RS256
|
||||
clientAssertionPrivateKey: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7VJTUt9Us8cKj
|
||||
MZj4ev7QnMa1mYV3Kx1jRkH5YwXQ7N2J2j8K5pP6h0oZmXq1yQv4r8wZb3sH9D2k
|
||||
... (truncated) ...
|
||||
-----END PRIVATE KEY-----
|
||||
```
|
||||
|
||||
**Example — key on disk:**
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
|
||||
clientAssertionKeyID: key-2026-01
|
||||
clientAssertionAlg: RS256
|
||||
```
|
||||
|
||||
**Generating an RS256 key with OpenSSL:**
|
||||
|
||||
```bash
|
||||
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 \
|
||||
-out client-key.pem
|
||||
openssl rsa -in client-key.pem -pubout -out client-pub.pem
|
||||
```
|
||||
|
||||
Register `client-pub.pem` (or its JWK form) with your IdP under the same
|
||||
`kid` you set in `clientAssertionKeyID`.
|
||||
|
||||
**Notes:**
|
||||
|
||||
- The private key is parsed once at plugin startup. Key rotation requires a
|
||||
Traefik reload.
|
||||
- Assertion lifetime is fixed at 60 seconds.
|
||||
- A fresh random `jti` is generated per request.
|
||||
- The `aud` claim is the token endpoint URL (from discovery).
|
||||
- Tracking issue:
|
||||
[#135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
|
||||
|
||||
### client_secret_basic
|
||||
|
||||
Per [RFC 6749 §2.3.1][rfc6749-2-3-1], the plugin sends the client credentials
|
||||
in an `Authorization: Basic` header instead of the body. Both halves
|
||||
(`client_id`, `client_secret`) are form-urlencoded individually, joined with
|
||||
a colon, then base64-encoded. Use this when your IdP requires Basic auth at
|
||||
the token endpoint and rejects credentials in the body.
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: client_secret_basic
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
```
|
||||
|
||||
[rfc6749-2-3-1]: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
|
||||
|
||||
---
|
||||
|
||||
## Optional Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
@@ -59,6 +183,11 @@ spec:
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
|
||||
| `clientAuthMethod` | string | `client_secret_post` | Client authentication method at token/revocation endpoints. One of `client_secret_post`, `client_secret_basic`, `private_key_jwt`. See [Client Authentication](#client-authentication). |
|
||||
| `clientAssertionPrivateKey` | string | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8 / PKCS#1 / SEC1. |
|
||||
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk for `private_key_jwt`. Mutually exclusive with `clientAssertionPrivateKey`. |
|
||||
| `clientAssertionKeyID` | string | none | `kid` header for `private_key_jwt` assertions. Required when `clientAuthMethod: private_key_jwt`. |
|
||||
| `clientAssertionAlg` | string | `RS256` | Signing algorithm for `private_key_jwt`. One of `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
|
||||
|
||||
### TLS Termination at Load Balancer
|
||||
|
||||
@@ -70,6 +199,33 @@ overwrite it).
|
||||
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
|
||||
dev). Otherwise leave it at default.
|
||||
|
||||
### Streaming Endpoints (SSE and WebSocket)
|
||||
|
||||
The middleware automatically bypasses the OIDC redirect for two request kinds
|
||||
that browsers cannot follow a 302 on:
|
||||
|
||||
| Bypass | Triggered by |
|
||||
|--------|--------------|
|
||||
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
|
||||
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
|
||||
|
||||
These requests do **not** require any explicit configuration — they are
|
||||
handled implicitly. However, the bypass is **not** unauthenticated:
|
||||
|
||||
- A valid, encrypted session cookie is required. Requests without one are
|
||||
rejected (the connection cannot proceed to the backend).
|
||||
- The session cookie is sealed with `sessionEncryptionKey`, so the
|
||||
`authenticated` flag cannot be forged.
|
||||
- Validation is cookie-only — no JWK fetch / signature verification — so
|
||||
streaming endpoints keep working when the OIDC provider is briefly
|
||||
unavailable.
|
||||
- The user identifier from the session is forwarded as `X-Forwarded-User`
|
||||
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
|
||||
|
||||
For browser clients, the user must complete the normal OIDC flow on a
|
||||
regular HTTP page first; the resulting session cookie is then reused on the
|
||||
SSE / WebSocket connection.
|
||||
|
||||
---
|
||||
|
||||
## Security Options
|
||||
@@ -113,6 +269,7 @@ strictAudienceValidation: true
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
|
||||
+46
-3
@@ -642,7 +642,7 @@ spec:
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientSecret</code></td>
|
||||
<td class="py-2 px-3">OAuth 2.0 client secret</td>
|
||||
<td class="py-2 px-3">OAuth 2.0 client secret. Only required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is unset or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">sessionEncryptionKey</code></td>
|
||||
@@ -718,6 +718,11 @@ spec:
|
||||
<td class="py-2 px-3">86400</td>
|
||||
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
|
||||
<td class="py-2 px-3">21600</td>
|
||||
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
|
||||
<td class="py-2 px-3">_oidc_raczylo_</td>
|
||||
@@ -748,15 +753,48 @@ spec:
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Require RFC 7662 introspection for opaque tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">disableReplayDetection</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Disable JTI replay detection (for multi-replica without Redis)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code></td>
|
||||
<td class="py-2 px-3">client_secret_post</td>
|
||||
<td class="py-2 px-3">Selects how the plugin authenticates to the token endpoint. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">Inline PEM private key used to sign client assertions for <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyPath</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">Path to a PEM private key file. Alternative to <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyID</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">JWS <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">kid</code> header value. Required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionAlg</code></td>
|
||||
<td class="py-2 px-3">RS256</td>
|
||||
<td class="py-2 px-3">Signing algorithm for the client assertion. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES512</code>.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Private Key JWT (RFC 7523)</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Use this when your IdP (Entra ID, Okta, Auth0, Keycloak) pressures short-lived secrets, or when policy mandates secretless service-to-service authentication. The plugin signs a 60-second assertion with the configured private key and sends it as <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_assertion</code> instead of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret</code>. Public-key registration on the IdP replaces shared-secret rotation. See <a href="https://www.rfc-editor.org/rfc/rfc7523" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">RFC 7523</a> and <a href="https://github.com/lukaszraczylo/traefikoidc/issues/135" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">issue #135</a>.</p>
|
||||
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc-client.pem
|
||||
clientAssertionKeyID: my-client-key-2026
|
||||
# clientSecret no longer required</code></pre>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Google Workspace with Domain Restriction</h3>
|
||||
|
||||
@@ -858,7 +896,12 @@ spec:
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -101,6 +101,16 @@ http:
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Optional: switch to RFC 7523 private_key_jwt client auth
|
||||
# (Entra ID, Okta, Auth0, Keycloak). Replaces clientSecret with a
|
||||
# signed JWT assertion. See README for details and PEM formats.
|
||||
# ----------------------------------------------------------------
|
||||
# clientAuthMethod: "private_key_jwt"
|
||||
# clientAssertionKeyPath: "/etc/traefik/oidc/client-key.pem"
|
||||
# clientAssertionKeyID: "prod-key-2026"
|
||||
# clientAssertionAlg: "RS256" # or PS256/384/512, ES256/384/512
|
||||
|
||||
# Session Configuration
|
||||
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
|
||||
sessionMaxAge: 28800 # 8 hours
|
||||
|
||||
+37
-4
@@ -107,9 +107,12 @@ type TokenResponse struct {
|
||||
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant)
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
"grant_type": {grantType},
|
||||
}
|
||||
// client_id is sent in the body for every method except client_secret_basic,
|
||||
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
|
||||
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
|
||||
data.Set("client_id", t.clientID)
|
||||
}
|
||||
|
||||
if grantType == "authorization_code" {
|
||||
@@ -141,16 +144,33 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
}
|
||||
}
|
||||
|
||||
// Read tokenURL with RLock
|
||||
// Read tokenURL with RLock — needed as audience for private_key_jwt (RFC 7523 §3).
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
useBasicAuth := false
|
||||
if t.clientAssertion != nil {
|
||||
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign client assertion: %w", err)
|
||||
}
|
||||
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||
data.Set("client_assertion", assertion)
|
||||
} else if t.clientAuthMethod == "client_secret_basic" {
|
||||
useBasicAuth = true
|
||||
} else {
|
||||
data.Set("client_secret", t.clientSecret)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if useBasicAuth {
|
||||
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
@@ -423,6 +443,19 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// setOAuthBasicAuth sets the Authorization header per RFC 6749 §2.3.1: the
|
||||
// client_id and client_secret are form-urlencoded individually, joined with a
|
||||
// colon, then base64-encoded. This differs from http.Request.SetBasicAuth,
|
||||
// which skips the form-urlencode step — that matters for credentials with
|
||||
// reserved characters (`:`, `@`, `+`, `%`, etc.) where the wire format would
|
||||
// otherwise diverge from what the spec mandates.
|
||||
func setOAuthBasicAuth(req *http.Request, clientID, clientSecret string) {
|
||||
user := url.QueryEscape(clientID)
|
||||
pass := url.QueryEscape(clientSecret)
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
// deduplicateScopes removes duplicate scopes from a slice while preserving order.
|
||||
// This ensures that OAuth scope parameters don't contain duplicates which could
|
||||
// cause issues with some authorization servers.
|
||||
|
||||
Vendored
+3
@@ -24,6 +24,7 @@ type Config struct {
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
TLSServerName string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
@@ -34,6 +35,8 @@ type Config struct {
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
EnableMetrics bool
|
||||
EnableTLS bool
|
||||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+73
-26
@@ -20,6 +20,7 @@ type HybridBackend struct {
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
l1BackfillBuffer chan *l1BackfillItem
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
@@ -28,6 +29,7 @@ type HybridBackend struct {
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
l1BackfillDrops atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// l1BackfillItem represents a deferred write of an L2-resolved value back into
|
||||
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
|
||||
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
|
||||
type l1BackfillItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
|
||||
// L2 hits regardless of request rate.
|
||||
h.wg.Add(1)
|
||||
go h.l1BackfillWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
// Populate L1 cache with value from L2 (write-through on read).
|
||||
// Hand off to the bounded backfill worker instead of spawning a goroutine
|
||||
// per read - under burst that would mint thousands of goroutines.
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
close(h.l1BackfillBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
|
||||
// Drops on full buffer to keep the read path constant-time; the next L2 hit
|
||||
// for the same key simply re-queues it.
|
||||
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
|
||||
select {
|
||||
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
|
||||
default:
|
||||
h.l1BackfillDrops.Add(1)
|
||||
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// l1BackfillWorker drains the backfill queue serially. Single worker is
|
||||
// intentional - L1 writes are local and cheap, and serializing them keeps
|
||||
// goroutine count bounded under any read rate.
|
||||
func (h *HybridBackend) l1BackfillWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items best-effort then exit.
|
||||
for len(h.l1BackfillBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.l1BackfillBuffer:
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.l1BackfillBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", item.key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
|
||||
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
|
||||
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
|
||||
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 256,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
const burst = 1000
|
||||
|
||||
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
|
||||
// fresh L1 backfill enqueue.
|
||||
for i := 0; i < burst; i++ {
|
||||
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
|
||||
}
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Issue the burst as fast as possible; the backfill worker MUST be the
|
||||
// only goroutine doing L1 writes. Allow brief slack for the test runtime
|
||||
// scheduling but anything north of +20 means goroutine leakage.
|
||||
peak := baseline
|
||||
for i := 0; i < burst; i++ {
|
||||
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
if g := runtime.NumGoroutine(); g > peak {
|
||||
peak = g
|
||||
}
|
||||
}
|
||||
|
||||
delta := peak - baseline
|
||||
if delta > 20 {
|
||||
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
|
||||
delta, baseline, peak)
|
||||
}
|
||||
|
||||
// L1 must eventually catch up via the worker. Worker drains serially so
|
||||
// give it a generous window proportional to the burst size.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
var populated int
|
||||
for i := 0; i < burst; i++ {
|
||||
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
|
||||
populated++
|
||||
}
|
||||
}
|
||||
// Be lenient: drops are acceptable under buffer pressure, just want
|
||||
// most of the keys to make it.
|
||||
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
|
||||
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
|
||||
}
|
||||
|
||||
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
|
||||
// buffer is saturated. Drops must be counted, never block, never spawn a
|
||||
// goroutine.
|
||||
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
// Tiny buffer + slow primary writes via failSet so the worker stays
|
||||
// blocked enough to overflow the buffer.
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 4,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
// Stop the worker from draining: cancel the underlying context so the
|
||||
// worker bails out, leaving us with a cold buffer and the queue method
|
||||
// itself responsible for drop accounting.
|
||||
hybrid.cancel()
|
||||
// Wait for worker to exit so it can't drain.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
|
||||
}
|
||||
|
||||
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
|
||||
"expected some drops when buffer is saturated and worker is stopped")
|
||||
}
|
||||
Vendored
+3
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
TLSServerName: config.TLSServerName,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
EnableTLS: config.EnableTLS,
|
||||
TLSSkipVerify: config.TLSSkipVerify,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
|
||||
+25
-3
@@ -2,6 +2,7 @@ package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -31,6 +32,7 @@ type ConnectionPool struct {
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
TLSServerName string // SNI server name; defaults to host(Address) when empty
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
@@ -39,6 +41,8 @@ type PoolConfig struct {
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
|
||||
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
conn, err = p.createConnection(ctx)
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if p.config.EnableTLS {
|
||||
serverName := p.config.TLSServerName
|
||||
if serverName == "" {
|
||||
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
|
||||
serverName = host
|
||||
}
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
|
||||
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
} else {
|
||||
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// drainRESPRequest consumes a single RESP request (array or inline) from r and
|
||||
// returns true on success. Any read error returns false.
|
||||
func drainRESPRequest(r *bufio.Reader) bool {
|
||||
header, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(header, "*") {
|
||||
return true // inline command (single line) — already consumed
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
|
||||
if err != nil || n <= 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
// Each bulk: "$len\r\n<bytes>\r\n"
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
|
||||
// answer PING with +PONG. Returns the listener address and a self-signed cert.
|
||||
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
c, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer conn.Close()
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
if !drainRESPRequest(reader) {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("+PONG\r\n"))
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
close(stopCh)
|
||||
_ = listener.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
return listener.Addr().String(), der, stop
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
|
||||
// Regression test for issue #133 (enableTLS not propagated to client).
|
||||
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=false rejects a self-signed server cert.
|
||||
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "tls")
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
|
||||
// fails to handshake against a plain (non-TLS) listener.
|
||||
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
|
||||
plain, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer plain.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, acceptErr := plain.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: plain.Addr().String(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
|
||||
// when EnableTLS=false (default).
|
||||
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
|
||||
// the fix for issue #132: token refresh path hardcoded the "email" claim and
|
||||
// ignored the configured userIdentifierClaim. Keycloak users without an email
|
||||
// claim (using sub or another identifier) were being kicked out on refresh
|
||||
// even though their initial login worked.
|
||||
//
|
||||
// The callback path (auth_flow.go) already honored userIdentifierClaim with
|
||||
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
|
||||
// after PR #100 (commit a316a98).
|
||||
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
expectedIdentifier string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "sub claim configured, only sub present (Keycloak no-email case)",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-keycloak-12345",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user-uuid-keycloak-12345",
|
||||
},
|
||||
{
|
||||
name: "preferred_username configured, claim present",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"preferred_username": "alice",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "alice",
|
||||
},
|
||||
{
|
||||
name: "configured claim missing, falls back to sub",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "fallback-sub-id",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
},
|
||||
{
|
||||
name: "email default, email present (backward compatibility)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "email default, no email and no sub - refresh fails",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: false,
|
||||
expectedIdentifier: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager(
|
||||
"test-encryption-key-32-bytes-long!!",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
NewLogger("error"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("session manager: %v", err)
|
||||
}
|
||||
defer sessionManager.Shutdown()
|
||||
|
||||
capturedClaims := tt.claims
|
||||
tOidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
userIdentifierClaim: tt.userIdentifierClaim,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: &EnhancedMockTokenExchanger{
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
IDToken: "new-id-token-jwt",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
},
|
||||
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
|
||||
extractClaimsFunc: func(token string) (map[string]any, error) {
|
||||
return capturedClaims, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("get session: %v", err)
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
session.SetRefreshToken("initial-refresh-token")
|
||||
|
||||
refreshed := tOidc.refreshToken(rw, req, session)
|
||||
|
||||
if refreshed != tt.expectSuccess {
|
||||
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
|
||||
}
|
||||
|
||||
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
|
||||
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// signGraphStyleAccessToken builds a JWT in Microsoft's Graph proprietary
|
||||
// nonce-header form: bytes that get signed contain the SHA256 hash of the
|
||||
// nonce, while the wire token ships the original nonce. A standard JWS
|
||||
// verifier always rejects these with `crypto/rsa: verification error`, which
|
||||
// is why Microsoft documents Graph access tokens as opaque to client apps:
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
// "you can't validate tokens for Microsoft Graph according to these rules
|
||||
// due to their proprietary format"
|
||||
func signGraphStyleAccessToken(t *testing.T, key *rsa.PrivateKey, kid, originalNonce string, claims map[string]any) string {
|
||||
t.Helper()
|
||||
|
||||
wireHeader := map[string]any{
|
||||
"alg": "RS256",
|
||||
"kid": kid,
|
||||
"typ": "JWT",
|
||||
"nonce": originalNonce,
|
||||
}
|
||||
wireHeaderJSON, err := json.Marshal(wireHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashed := sha256.Sum256([]byte(originalNonce))
|
||||
signedHeader := map[string]any{
|
||||
"alg": "RS256",
|
||||
"kid": kid,
|
||||
"typ": "JWT",
|
||||
"nonce": fmt.Sprintf("%x", hashed),
|
||||
}
|
||||
signedHeaderJSON, err := json.Marshal(signedHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
wireHeaderB64 := base64.RawURLEncoding.EncodeToString(wireHeaderJSON)
|
||||
signedHeaderB64 := base64.RawURLEncoding.EncodeToString(signedHeaderJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
signedInput := signedHeaderB64 + "." + claimsB64
|
||||
hSign := sha256.Sum256([]byte(signedInput))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hSign[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
return wireHeaderB64 + "." + claimsB64 + "." + base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// newAzureFollowupOIDC produces a TraefikOidc instance wired for an Azure
|
||||
// AD tenant with a captured error log buffer. Used by the issue #134 followup
|
||||
// tests to assert log behavior during validateAzureTokens flows.
|
||||
func newAzureFollowupOIDC(t *testing.T, jwks *JWKSet) (*TraefikOidc, *bytes.Buffer) {
|
||||
t.Helper()
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
jwkCache: &MockJWKCache{JWKS: jwks},
|
||||
tokenCache: tokenCache,
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
oidc.tokenVerifier = oidc
|
||||
oidc.jwtVerifier = oidc
|
||||
require.True(t, oidc.isAzureProvider(), "fixture must be detected as Azure provider")
|
||||
return oidc, errBuf
|
||||
}
|
||||
|
||||
// authedSessionWithTokens returns a SessionData populated with the supplied
|
||||
// access and ID tokens, marked authenticated and recently created. The
|
||||
// SessionManager carries a real ChunkManager so that GetAccessToken /
|
||||
// GetIDToken / GetRefreshToken behave like the production code path.
|
||||
func authedSessionWithTokens(t *testing.T, accessToken, idToken string) *SessionData {
|
||||
t.Helper()
|
||||
|
||||
chunkLogger := NewLogger("error")
|
||||
chunkManager := NewChunkManager(chunkLogger)
|
||||
t.Cleanup(chunkManager.Shutdown)
|
||||
|
||||
sd := CreateMockSessionData()
|
||||
sd.manager = &SessionManager{
|
||||
sessionMaxAge: 24 * time.Hour,
|
||||
chunkManager: chunkManager,
|
||||
logger: chunkLogger,
|
||||
}
|
||||
|
||||
sd.mainSession = sessions.NewSession(nil, "main")
|
||||
sd.mainSession.Values["authenticated"] = true
|
||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||
|
||||
sd.accessSession = sessions.NewSession(nil, "access")
|
||||
sd.accessSession.Values["token"] = accessToken
|
||||
sd.accessSession.Values["compressed"] = false
|
||||
|
||||
sd.idTokenSession = sessions.NewSession(nil, "id")
|
||||
sd.idTokenSession.Values["token"] = idToken
|
||||
sd.idTokenSession.Values["compressed"] = false
|
||||
|
||||
sd.refreshSession = sessions.NewSession(nil, "refresh")
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = false
|
||||
|
||||
return sd
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_GraphAccessTokenReproducesUsersError sanity-checks
|
||||
// that our crafted Graph-style token reproduces the exact rsa error string
|
||||
// quoted on the issue thread (dada-engineer 2026-05-08, friek 2026-05-11).
|
||||
//
|
||||
// Sanity test: must always pass, regardless of the issue #134 followup fix.
|
||||
// It exists so a future contributor does not accidentally weaken the
|
||||
// reproducer and assume the followup fix is no longer needed.
|
||||
func TestIssue134_Followup_GraphAccessTokenReproducesUsersError(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-followup-kid"
|
||||
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
parsedJWT, err := parseJWT(graphToken)
|
||||
require.NoError(t, err)
|
||||
pubKey := &rsaKey.PublicKey
|
||||
alg, _ := parsedJWT.Header["alg"].(string)
|
||||
verifyErr := verifySignatureWithKey(graphToken, pubKey, alg)
|
||||
require.Error(t, verifyErr)
|
||||
assert.Contains(t, verifyErr.Error(), "crypto/rsa: verification error",
|
||||
"reproducer must emit the exact error string reported on issue #134")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken is the
|
||||
// failing-then-passing test for the followup fix.
|
||||
//
|
||||
// Symptom (before fix): validateAzureTokens calls verifyToken on every
|
||||
// JWT-shaped access token. For Microsoft Graph access tokens (the default
|
||||
// when no custom resource is registered), verification always fails with
|
||||
// `crypto/rsa: verification error`, generating two error log lines per
|
||||
// request:
|
||||
//
|
||||
// UNKNOWN token verification failed: signature verification failed:
|
||||
// crypto/rsa: verification error
|
||||
// DIAGNOSTIC: Signature verification failed for kid=<kid>, alg=RS256:
|
||||
// crypto/rsa: verification error
|
||||
//
|
||||
// Microsoft's own documentation tells client apps not to validate Graph
|
||||
// access tokens. The fix matches that guidance: when an Azure access token
|
||||
// carries Microsoft's proprietary `nonce` JWT header, treat it as opaque
|
||||
// (skip JWT verification, fall through to ID token validation).
|
||||
func TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-followup-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
now := time.Now()
|
||||
exp := now.Add(time.Hour).Unix()
|
||||
|
||||
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-azure-graph", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": exp,
|
||||
"iat": now.Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"appid": "test-client-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"email": "user@example.com",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "id-token-jti-followup",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
|
||||
output := errBuf.String()
|
||||
assert.NotContains(t, output, "crypto/rsa: verification error",
|
||||
"validateAzureTokens must not log rsa verification error for Graph-style access tokens; got: %q", output)
|
||||
assert.NotContains(t, output, "DIAGNOSTIC: Signature verification failed",
|
||||
"DIAGNOSTIC line must not fire for Graph-style access tokens; got: %q", output)
|
||||
assert.NotContains(t, output, "UNKNOWN token verification failed",
|
||||
"UNKNOWN classification log must not fire for Graph-style access tokens; got: %q", output)
|
||||
|
||||
assert.True(t, authenticated, "session must remain authenticated via the ID token fallback")
|
||||
assert.False(t, needsRefresh, "valid ID token must not signal a refresh need")
|
||||
assert.False(t, expired, "valid ID token must not be reported as expired")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection covers the
|
||||
// classifier added by the followup fix. Pure-function unit test for the
|
||||
// Microsoft proprietary marker we rely on (nonce in JWT header).
|
||||
func TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-detection-kid"
|
||||
standardToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
oidc, _ := newAzureFollowupOIDC(t, &JWKSet{})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantUnverified bool
|
||||
}{
|
||||
{name: "standard JWT without nonce header", token: standardToken, wantUnverified: false},
|
||||
{name: "Microsoft proprietary token (nonce in header)", token: graphToken, wantUnverified: true},
|
||||
{name: "garbage token treated as unverifiable", token: "not-a-jwt-at-all", wantUnverified: true},
|
||||
{name: "empty token treated as unverifiable", token: "", wantUnverified: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := oidc.isUnverifiableAzureAccessToken(tc.token)
|
||||
assert.Equal(t, tc.wantUnverified, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_StandardAzureAccessTokenStillVerifies guards against
|
||||
// regression in the happy path: an access token issued for our own clientID
|
||||
// (custom Azure-registered API) — no proprietary nonce header, signed normally
|
||||
// — must still flow through the standard verification path and authenticate.
|
||||
func TestIssue134_Followup_StandardAzureAccessTokenStillVerifies(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-standard-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
now := time.Now()
|
||||
exp := now.Add(time.Hour).Unix()
|
||||
|
||||
// Custom-resource access token: aud points to the app, no nonce header.
|
||||
accessToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "api.read",
|
||||
"jti": "standard-access-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"email": "user@example.com",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "standard-id-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, accessToken, idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
|
||||
assert.True(t, authenticated, "standard Azure access token must verify and authenticate")
|
||||
assert.False(t, needsRefresh)
|
||||
assert.False(t, expired)
|
||||
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error",
|
||||
"standard Azure token must not produce signature errors")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_GraphAccessTokenWithoutIDToken covers the edge where
|
||||
// the session has only a Graph access token (no ID token). The classifier must
|
||||
// preserve the existing "treat as opaque" semantics for backward compatibility:
|
||||
// authenticated=true even when there is no ID token to verify.
|
||||
func TestIssue134_Followup_GraphAccessTokenWithoutIDToken(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-no-idt-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-no-idt", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, "")
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
|
||||
assert.True(t, authenticated, "Graph token without ID token must remain authenticated (matches existing opaque-token semantics)")
|
||||
assert.False(t, needsRefresh)
|
||||
assert.False(t, expired)
|
||||
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification proves
|
||||
// the classifier is not a security regression. An attacker who forges a JWT
|
||||
// with a `nonce` JWT header (Microsoft's proprietary marker) but a payload
|
||||
// claiming `aud=our-clientID` should NOT gain authenticated status simply by
|
||||
// triggering the "treat as opaque" branch.
|
||||
//
|
||||
// This is the confused-deputy guardrail Microsoft warns about
|
||||
// (https://cwe.mitre.org/data/definitions/441.html): we treat the access token
|
||||
// as opaque, which means we DO NOT authorize from it — authorization comes
|
||||
// only from a separately verifiable ID token. An attacker without a valid ID
|
||||
// token must not be authenticated.
|
||||
func TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
attackerKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-attack-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
// Forged: attacker uses their OWN key, sets aud = our clientID, plants a
|
||||
// `nonce` header to trip the opaque-detection path.
|
||||
forgedAccessToken := signGraphStyleAccessToken(t, attackerKey, kid, "attacker-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "attacker",
|
||||
"scp": "admin",
|
||||
})
|
||||
|
||||
// Forged ID token signed with the attacker's key — must fail verification
|
||||
// against the tenant JWKS.
|
||||
forgedIDToken, err := createTestJWT(attackerKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Minute).Unix(),
|
||||
"nbf": time.Now().Add(-2 * time.Minute).Unix(),
|
||||
"sub": "attacker",
|
||||
"email": "attacker@evil.example",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "attacker-id-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, _ := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, forgedAccessToken, forgedIDToken)
|
||||
|
||||
authenticated, _, _ := oidc.validateAzureTokens(session)
|
||||
assert.False(t, authenticated,
|
||||
"attacker's forged tokens must not authenticate even when the access token has a nonce header — ID token verification rejects the wrong-key signature")
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError reproduces and
|
||||
// verifies the fix for issue #134.
|
||||
//
|
||||
// Symptom (before fix): with a Redis backend wired into UniversalCache,
|
||||
// caching the parsed *parsedJWKS triggered:
|
||||
//
|
||||
// json: cannot unmarshal number 2251513...
|
||||
// into Go value of type float64
|
||||
//
|
||||
// Root cause: under yaegi, json.Marshal of a struct exposes unexported
|
||||
// fields with an X-prefixed name. parsedJWKS{ keys map[string]crypto.PublicKey }
|
||||
// thus serialized the inner *rsa.PublicKey, whose modulus *big.Int marshals
|
||||
// as a JSON number hundreds of digits long. On read, json.Unmarshal into
|
||||
// interface{} parses numbers as float64, which cannot represent that range.
|
||||
// The user saw the error log on every request even though auth still worked
|
||||
// (fallback path rebuilt the keys in memory).
|
||||
//
|
||||
// Fix: route both *JWKSet and *parsedJWKS through SetLocal/GetLocal — the
|
||||
// distributed backend never sees them.
|
||||
func TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
|
||||
var fetchCount int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&fetchCount, 1)
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
infoBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(infoBuf, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
ctx := context.Background()
|
||||
|
||||
pub1, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "first GetPublicKey should succeed")
|
||||
require.NotNil(t, pub1)
|
||||
gotRSA, ok := pub1.(*rsa.PublicKey)
|
||||
require.True(t, ok, "returned key should be *rsa.PublicKey, got %T", pub1)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N), "modulus must survive intact")
|
||||
assert.Equal(t, rsaKey.E, gotRSA.E, "exponent must survive intact")
|
||||
|
||||
pub2, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "second GetPublicKey should succeed")
|
||||
require.True(t, samePublicKey(pub1, pub2), "second call must return the same parsed key (cache hit)")
|
||||
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&fetchCount),
|
||||
"upstream JWKS endpoint must be hit exactly once; second call must be served from local cache")
|
||||
|
||||
errOutput := errBuf.String()
|
||||
assert.NotContains(t, errOutput, "Failed to deserialize",
|
||||
"deserialize error must not appear with the fix in place; got: %s", errOutput)
|
||||
assert.NotContains(t, errOutput, "into Go value of type float64",
|
||||
"float64 unmarshal error must not appear; got: %s", errOutput)
|
||||
|
||||
parsedKey := server.URL + parsedKeysSuffix
|
||||
jwksKey := server.URL
|
||||
for _, k := range []string{cache.prefixKey(parsedKey), cache.prefixKey(jwksKey)} {
|
||||
fullKey := redisCfg.RedisPrefix + k
|
||||
assert.False(t, mr.Exists(fullKey),
|
||||
"key %q must not exist in Redis (local-only caching); got %v", fullKey, mr.Keys())
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue134_StalePoisonedRedisDataIgnored verifies that pre-existing bad
|
||||
// data left in Redis under a JWK :parsed key from a prior buggy version is
|
||||
// ignored: the local-only fix never reads that key, so no log spam, and the
|
||||
// fallback path returns a real *rsa.PublicKey.
|
||||
func TestIssue134_StalePoisonedRedisDataIgnored(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134stale:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pre-poison Redis with the kind of payload the old buggy path would have
|
||||
// produced (huge unquoted JSON number for the modulus). With the fix the
|
||||
// JWKCache must not even read this key.
|
||||
poisoned := []byte("\x01" + strings.Replace(
|
||||
`{"Xkeys":{"azure-test-kid":{"N":NUMBER,"E":65537}}}`,
|
||||
"NUMBER", rsaKey.N.String(), 1,
|
||||
))
|
||||
parsedRedisKey := redisCfg.RedisPrefix + "jwk:" + server.URL + parsedKeysSuffix
|
||||
require.NoError(t, mr.Set(parsedRedisKey, string(poisoned)))
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
pub, err := jwkCache.GetPublicKey(context.Background(), server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pub)
|
||||
gotRSA, ok := pub.(*rsa.PublicKey)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N))
|
||||
|
||||
assert.NotContains(t, errBuf.String(), "Failed to deserialize",
|
||||
"poisoned Redis entry must not be touched; got error log: %s", errBuf.String())
|
||||
}
|
||||
|
||||
// TestIssue134_SetLocalGetLocalSkipBackend verifies the new SetLocal/GetLocal
|
||||
// pair never reads or writes the configured backend.
|
||||
func TestIssue134_SetLocalGetLocalSkipBackend(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "local:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 10,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
type unsafeShape struct {
|
||||
hidden map[string]interface{}
|
||||
}
|
||||
val := &unsafeShape{hidden: map[string]interface{}{"k": 1}}
|
||||
|
||||
require.NoError(t, cache.SetLocal("local-key", val, 1*time.Hour))
|
||||
|
||||
got, found := cache.GetLocal("local-key")
|
||||
require.True(t, found)
|
||||
assert.Same(t, val, got, "GetLocal must return the exact pointer stored, no JSON round-trip")
|
||||
|
||||
for _, k := range mr.Keys() {
|
||||
assert.NotContains(t, k, "local-key",
|
||||
"SetLocal must not write to Redis; found key %q (all keys: %v)", k, mr.Keys())
|
||||
}
|
||||
|
||||
cache.mu.Lock()
|
||||
delete(cache.items, "local-key")
|
||||
cache.lruList.Init()
|
||||
cache.currentSize = 0
|
||||
cache.currentMemory = 0
|
||||
cache.mu.Unlock()
|
||||
|
||||
_, found = cache.GetLocal("local-key")
|
||||
assert.False(t, found, "GetLocal must not fall back to backend after local cache cleared")
|
||||
}
|
||||
|
||||
// big2bytes returns the big-endian byte slice for a positive int.
|
||||
func big2bytes(e int) []byte {
|
||||
if e <= 0 {
|
||||
return []byte{}
|
||||
}
|
||||
var buf []byte
|
||||
for e > 0 {
|
||||
buf = append([]byte{byte(e & 0xff)}, buf...)
|
||||
e >>= 8
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// samePublicKey reports whether two crypto.PublicKey instances represent the
|
||||
// same RSA key, used to confirm cache hits return identical reconstructed
|
||||
// keys.
|
||||
func samePublicKey(a, b interface{}) bool {
|
||||
ar, ok1 := a.(*rsa.PublicKey)
|
||||
br, ok2 := b.(*rsa.PublicKey)
|
||||
if !ok1 || !ok2 {
|
||||
return false
|
||||
}
|
||||
return ar.N.Cmp(br.N) == 0 && ar.E == br.E
|
||||
}
|
||||
@@ -0,0 +1,925 @@
|
||||
package traefikoidc
|
||||
|
||||
// issue135_regression_test.go — regression tests for RFC 7523 private_key_jwt
|
||||
// client authentication (issue #135).
|
||||
//
|
||||
// These tests guard:
|
||||
// - Correct JWT construction and cryptographic signature for all supported
|
||||
// algorithms (RS*/PS*/ES*).
|
||||
// - Proper validation of alg/key type combinations and empty-kid rejection.
|
||||
// - JTI uniqueness across concurrent calls.
|
||||
// - PEM variant tolerance (PKCS#8, PKCS#1, SEC1).
|
||||
// - Config.Validate() behavior for all private_key_jwt configuration paths.
|
||||
// - buildClientAssertionSignerFromConfig: inline PEM, file-backed PEM, default alg.
|
||||
// - Wire-up in exchangeTokens: assertion fields sent, client_secret absent.
|
||||
// - Wire-up in RevokeTokenWithProvider: assertion fields sent, audience = tokenURL.
|
||||
// - Back-compat: client_secret_post path unchanged when clientAssertion == nil.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ── A. Signer unit tests ──────────────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_SignerRSAFamily verifies that NewClientAssertionSigner + Sign
|
||||
// produces a well-formed, cryptographically valid JWT for every RSA-family
|
||||
// algorithm (RS256/RS384/RS512/PS256/PS384/PS512).
|
||||
func TestIssue135_SignerRSAFamily(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
cases := []struct {
|
||||
alg string
|
||||
hashFn func([]byte) []byte
|
||||
isPS bool
|
||||
hash crypto.Hash
|
||||
}{
|
||||
{"RS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, false, crypto.SHA256},
|
||||
{"RS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, false, crypto.SHA384},
|
||||
{"RS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, false, crypto.SHA512},
|
||||
{"PS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, true, crypto.SHA256},
|
||||
{"PS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, true, crypto.SHA384},
|
||||
{"PS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, true, crypto.SHA512},
|
||||
}
|
||||
|
||||
const (
|
||||
audience = "https://example.com/token"
|
||||
clientID = "client-abc"
|
||||
kid = "kid-1"
|
||||
)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.alg, func(t *testing.T) {
|
||||
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign(audience, clientID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3, "JWT must have three dot-separated parts")
|
||||
|
||||
// Decode and check header.
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, tc.alg, hdr["alg"])
|
||||
assert.Equal(t, "JWT", hdr["typ"])
|
||||
assert.Equal(t, kid, hdr["kid"])
|
||||
|
||||
// Decode and check claims.
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, clientID, clms["iss"])
|
||||
assert.Equal(t, clientID, clms["sub"])
|
||||
assert.Equal(t, audience, clms["aud"])
|
||||
|
||||
iat, ok := clms["iat"].(float64)
|
||||
require.True(t, ok, "iat must be numeric")
|
||||
exp, ok := clms["exp"].(float64)
|
||||
require.True(t, ok, "exp must be numeric")
|
||||
assert.InDelta(t, 60, exp-iat, 2, "exp-iat must equal ~60s")
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
assert.True(t, iat <= now+2 && iat >= now-5, "iat must be current time ±5s")
|
||||
|
||||
jti, ok := clms["jti"].(string)
|
||||
require.True(t, ok, "jti must be a string")
|
||||
assert.Len(t, jti, 32, "jti must be 32-char hex (16 bytes → hex)")
|
||||
|
||||
// Verify cryptographic signature.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := tc.hashFn([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
pub := &rsaKey.PublicKey
|
||||
if tc.isPS {
|
||||
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: tc.hash}
|
||||
assert.NoError(t, rsa.VerifyPSS(pub, tc.hash, digest, sigBytes, opts),
|
||||
"PSS signature verification failed for %s", tc.alg)
|
||||
} else {
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(pub, tc.hash, digest, sigBytes),
|
||||
"PKCS1v15 signature verification failed for %s", tc.alg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerECDSAFamily verifies correct JWT production for all
|
||||
// ECDSA algorithms (ES256/ES384/ES512) including that the signature is the
|
||||
// raw r||s encoding (not ASN.1 DER) and is verifiable with the matching key.
|
||||
func TestIssue135_SignerECDSAFamily(t *testing.T) {
|
||||
cases := []struct {
|
||||
alg string
|
||||
curve elliptic.Curve
|
||||
hashFn func([]byte) []byte
|
||||
hash crypto.Hash
|
||||
}{
|
||||
{"ES256", elliptic.P256(), func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, crypto.SHA256},
|
||||
{"ES384", elliptic.P384(), func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, crypto.SHA384},
|
||||
{"ES512", elliptic.P521(), func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, crypto.SHA512},
|
||||
}
|
||||
|
||||
const (
|
||||
audience = "https://idp.example.com/token"
|
||||
clientID = "ec-client"
|
||||
kid = "ec-kid"
|
||||
)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.alg, func(t *testing.T) {
|
||||
ecKey, err := ecdsa.GenerateKey(tc.curve, rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign(audience, clientID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
byteLen := (tc.curve.Params().BitSize + 7) / 8
|
||||
assert.Len(t, sigBytes, 2*byteLen,
|
||||
"ECDSA signature must be raw r||s (2×%d bytes for %s)", byteLen, tc.alg)
|
||||
|
||||
r := new(big.Int).SetBytes(sigBytes[:byteLen])
|
||||
s := new(big.Int).SetBytes(sigBytes[byteLen:])
|
||||
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := tc.hashFn([]byte(sigInput))
|
||||
|
||||
ok := ecdsa.Verify(&ecKey.PublicKey, digest, r, s)
|
||||
assert.True(t, ok, "ECDSA signature verification failed for %s", tc.alg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerRejectsAlgKeyMismatch verifies that the signer constructor
|
||||
// rejects type mismatches between key type and algorithm, unknown algorithms,
|
||||
// and an empty kid.
|
||||
func TestIssue135_SignerRejectsAlgKeyMismatch(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
rsaPEM := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
ecPEM := encodeECPKCS8(t, ecKey)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
pemBytes []byte
|
||||
alg string
|
||||
kid string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "RSA key with ES256",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "ES256",
|
||||
kid: "k1",
|
||||
wantErr: "EC key",
|
||||
},
|
||||
{
|
||||
name: "EC key with RS256",
|
||||
pemBytes: ecPEM,
|
||||
alg: "RS256",
|
||||
kid: "k1",
|
||||
wantErr: "RSA key",
|
||||
},
|
||||
{
|
||||
name: "unknown alg HS256",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "HS256",
|
||||
kid: "k1",
|
||||
wantErr: "unsupported",
|
||||
},
|
||||
{
|
||||
name: "empty kid",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "RS256",
|
||||
kid: "",
|
||||
wantErr: "kid must not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewClientAssertionSigner(tc.pemBytes, tc.alg, tc.kid)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.wantErr),
|
||||
"error should mention %q", tc.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerJTIUniqueness signs 50 assertions with the same signer
|
||||
// and asserts all jti values are distinct. Guards against broken entropy reuse.
|
||||
func TestIssue135_SignerJTIUniqueness(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "jti-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
seen := make(map[string]bool, 50)
|
||||
for i := range 50 {
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "client-x")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
jti, ok := clms["jti"].(string)
|
||||
require.True(t, ok)
|
||||
assert.False(t, seen[jti], "jti %q was reused at iteration %d", jti, i)
|
||||
seen[jti] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerPEMVariants confirms that all PEM block types understood
|
||||
// by NewClientAssertionSigner are parsed correctly: PKCS#8 ("PRIVATE KEY"),
|
||||
// PKCS#1 ("RSA PRIVATE KEY"), and SEC1 ("EC PRIVATE KEY").
|
||||
func TestIssue135_SignerPEMVariants(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("RSA PKCS8", func(t *testing.T) {
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
|
||||
require.NoError(t, err)
|
||||
assertValidRSAJWT(t, rsaKey, signer, "RS256")
|
||||
})
|
||||
|
||||
t.Run("RSA PKCS1", func(t *testing.T) {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: der})
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
|
||||
require.NoError(t, err)
|
||||
assertValidRSAJWT(t, rsaKey, signer, "RS256")
|
||||
})
|
||||
|
||||
t.Run("EC PKCS8", func(t *testing.T) {
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
|
||||
require.NoError(t, err)
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "cid")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
})
|
||||
|
||||
t.Run("EC SEC1", func(t *testing.T) {
|
||||
der, err := x509.MarshalECPrivateKey(ecKey)
|
||||
require.NoError(t, err)
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
|
||||
require.NoError(t, err)
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "cid")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
})
|
||||
}
|
||||
|
||||
// ── B. Config validation ──────────────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_ConfigValidation table-drives Config.Validate() for every
|
||||
// client-authentication-related validation branch.
|
||||
func TestIssue135_ConfigValidation(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
validPEM := string(encodeRSAPKCS8(t, rsaKey))
|
||||
|
||||
// baseConfig returns the minimum valid config, modified per test case.
|
||||
base := func() *Config {
|
||||
return &Config{
|
||||
ProviderURL: "https://idp.example.com",
|
||||
CallbackURL: "/cb",
|
||||
ClientID: "cid",
|
||||
ClientSecret: "secret",
|
||||
SessionEncryptionKey: "01234567890123456789012345678901", // 32 chars
|
||||
RateLimit: 100,
|
||||
}
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string // empty = expect nil error
|
||||
}{
|
||||
{
|
||||
name: "default empty method + secret ok",
|
||||
mutate: func(c *Config) { /* nothing extra */ },
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "explicit client_secret_post + secret ok",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "client_secret_post"
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt inline key + kid ok",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt no key at all",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "clientAssertionPrivateKey",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt both inline and path",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyPath = "/tmp/key.pem"
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "only one of",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt key but no kid",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
},
|
||||
wantErr: "clientAssertionKeyID",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt unsupported alg HS256",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
c.ClientAssertionAlg = "HS256"
|
||||
},
|
||||
wantErr: "is not supported",
|
||||
},
|
||||
{
|
||||
name: "unknown client auth method",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "weird"
|
||||
},
|
||||
wantErr: "is not supported",
|
||||
},
|
||||
{
|
||||
name: "client_secret_post with no secret",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "client_secret_post"
|
||||
c.ClientSecret = ""
|
||||
},
|
||||
wantErr: "clientSecret is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := base()
|
||||
tc.mutate(cfg)
|
||||
err := cfg.Validate()
|
||||
if tc.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.wantErr,
|
||||
"error must mention %q", tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_ConfigKeyPathLoadsFile verifies that buildClientAssertionSignerFromConfig
|
||||
// reads the PEM key from disk when ClientAssertionKeyPath is set.
|
||||
func TestIssue135_ConfigKeyPathLoadsFile(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
dir := t.TempDir()
|
||||
keyFile := dir + "/private.pem"
|
||||
require.NoError(t, os.WriteFile(keyFile, pemBytes, 0o600))
|
||||
|
||||
cfg := &Config{
|
||||
ClientAuthMethod: "private_key_jwt",
|
||||
ClientAssertionKeyPath: keyFile,
|
||||
ClientAssertionKeyID: "file-kid",
|
||||
ClientAssertionAlg: "RS256",
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err, "should load signer from key file")
|
||||
require.NotNil(t, signer)
|
||||
|
||||
// Confirm signer produces a valid JWT.
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "client-from-file")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3, "should produce a 3-part JWT")
|
||||
}
|
||||
|
||||
// ── C. Wire-up — exchangeTokens ───────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_AuthCodeExchangeUsesAssertion confirms that exchangeTokens sends
|
||||
// client_assertion + client_assertion_type instead of client_secret when a
|
||||
// ClientAssertionSigner is configured, and that the assertion JWT is valid.
|
||||
func TestIssue135_AuthCodeExchangeUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
var capturedBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body := make([]byte, r.ContentLength)
|
||||
_, _ = r.Body.Read(body)
|
||||
capturedBody = body
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Return a minimal token response so exchangeTokens doesn't error.
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at",
|
||||
IDToken: "it",
|
||||
RefreshToken: "rt",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "wire-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "wire-client",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: signer,
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err = oidc.exchangeTokens(context.Background(), "authorization_code", "code-x", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
form, err := url.ParseQuery(string(capturedBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
form.Get("client_assertion_type"), "client_assertion_type must be set")
|
||||
assertionJWT := form.Get("client_assertion")
|
||||
assert.NotEmpty(t, assertionJWT, "client_assertion must be present")
|
||||
assert.Empty(t, form.Get("client_secret"), "client_secret must not be sent when using assertion")
|
||||
assert.Equal(t, "wire-client", form.Get("client_id"))
|
||||
assert.Equal(t, "code-x", form.Get("code"))
|
||||
assert.Equal(t, "authorization_code", form.Get("grant_type"))
|
||||
|
||||
// Verify assertion JWT: header, claims, signature.
|
||||
parts := strings.Split(assertionJWT, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "RS256", hdr["alg"])
|
||||
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, "wire-client", clms["iss"])
|
||||
assert.Equal(t, "wire-client", clms["sub"])
|
||||
assert.Equal(t, server.URL, clms["aud"],
|
||||
"audience must be the tokenURL (RFC 7523 §3)")
|
||||
|
||||
// Verify signature with RSA public key.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
// TestIssue135_RefreshTokenUsesAssertion verifies that the refresh_token grant
|
||||
// type also sends client_assertion and the correct form fields.
|
||||
func TestIssue135_RefreshTokenUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "new-at",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rt-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "rt-client",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: signer,
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err = oidc.exchangeTokens(context.Background(), "refresh_token", "rt-y", "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "refresh_token", capturedForm.Get("grant_type"))
|
||||
assert.Equal(t, "rt-y", capturedForm.Get("refresh_token"))
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
capturedForm.Get("client_assertion_type"))
|
||||
assert.NotEmpty(t, capturedForm.Get("client_assertion"))
|
||||
assert.Empty(t, capturedForm.Get("client_secret"))
|
||||
}
|
||||
|
||||
// TestIssue135_BackcompatClientSecretPath confirms that exchangeTokens sends
|
||||
// client_secret and does NOT send client_assertion when clientAssertion is nil.
|
||||
func TestIssue135_BackcompatClientSecretPath(t *testing.T) {
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "legacy-client",
|
||||
clientSecret: "legacy-secret",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: nil, // back-compat path
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bc", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "legacy-secret", capturedForm.Get("client_secret"),
|
||||
"client_secret must be sent on the classic path")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion"),
|
||||
"client_assertion must NOT be present on the classic path")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion_type"),
|
||||
"client_assertion_type must NOT be present on the classic path")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicAuth verifies that when clientAuthMethod is
|
||||
// "client_secret_basic", exchangeTokens sends an HTTP Basic Authorization
|
||||
// header carrying url-encoded client_id:client_secret per RFC 6749 §2.3.1,
|
||||
// and that neither client_id nor client_secret appears in the form body.
|
||||
func TestIssue135_ClientSecretBasicAuth(t *testing.T) {
|
||||
var capturedAuth string
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at-basic", TokenType: "Bearer", ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "basic-client",
|
||||
clientSecret: "basic-secret",
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
tokenHTTPClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bb", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, strings.HasPrefix(capturedAuth, "Basic "),
|
||||
"Authorization header must start with 'Basic ', got %q", capturedAuth)
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err, "Authorization payload must be valid base64")
|
||||
user, pass, ok := strings.Cut(string(raw), ":")
|
||||
require.True(t, ok, "Authorization payload must contain a single ':' separator")
|
||||
assert.Equal(t, "basic-client", user, "client_id should round-trip through QueryEscape")
|
||||
assert.Equal(t, "basic-secret", pass, "client_secret should round-trip through QueryEscape")
|
||||
|
||||
assert.Empty(t, capturedForm.Get("client_id"),
|
||||
"client_id must NOT be in the body when using client_secret_basic")
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must NOT be in the body when using client_secret_basic")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion"),
|
||||
"client_assertion must NOT be present on the basic-auth path")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicURLEncodesReservedChars verifies that
|
||||
// credentials containing reserved characters (`:`, `+`, `/`, etc.) are
|
||||
// form-urlencoded before base64 per RFC 6749 §2.3.1, so the receiving
|
||||
// authorization server can decode them deterministically.
|
||||
func TestIssue135_ClientSecretBasicURLEncodesReservedChars(t *testing.T) {
|
||||
var capturedAuth string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{AccessToken: "at", TokenType: "Bearer", ExpiresIn: 3600})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
const (
|
||||
clientID = "weird:id+1"
|
||||
clientSecret = "p@ss/word=&" //nolint:gosec // test fixture
|
||||
)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
tokenHTTPClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "c", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err)
|
||||
|
||||
wantUser := url.QueryEscape(clientID)
|
||||
wantPass := url.QueryEscape(clientSecret)
|
||||
assert.Equal(t, wantUser+":"+wantPass, string(raw),
|
||||
"both halves must be form-urlencoded before the base64 step")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicRevocation verifies that the revocation path
|
||||
// honors client_secret_basic identically to the token path.
|
||||
func TestIssue135_ClientSecretBasicRevocation(t *testing.T) {
|
||||
var capturedAuth string
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "rev-basic",
|
||||
clientSecret: "rev-secret",
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
httpClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = "https://idp.example.com/token"
|
||||
oidc.revocationURL = server.URL
|
||||
|
||||
require.NoError(t, oidc.RevokeTokenWithProvider("opaque-tok", "access_token"))
|
||||
|
||||
require.True(t, strings.HasPrefix(capturedAuth, "Basic "), "got %q", capturedAuth)
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rev-basic:rev-secret", string(raw))
|
||||
|
||||
assert.Equal(t, "opaque-tok", capturedForm.Get("token"))
|
||||
assert.Equal(t, "access_token", capturedForm.Get("token_type_hint"))
|
||||
assert.Empty(t, capturedForm.Get("client_id"),
|
||||
"client_id must NOT be in body on Basic-auth revocation")
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must NOT be in body on Basic-auth revocation")
|
||||
}
|
||||
|
||||
// ── D. Wire-up — RevokeTokenWithProvider ────────────────────────────────────
|
||||
|
||||
// TestIssue135_RevocationUsesAssertion verifies that RevokeTokenWithProvider
|
||||
// sends client_assertion (not client_secret), and that the assertion's audience
|
||||
// is the tokenURL, not the revocationURL (per RFC 7523 §3).
|
||||
func TestIssue135_RevocationUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
const (
|
||||
tokenEndpoint = "https://idp.example.com/token" // audience for assertion
|
||||
clientIDVal = "revoke-client"
|
||||
)
|
||||
|
||||
var capturedForm url.Values
|
||||
// Revocation endpoint — deliberate separate URL to confirm audience != revocationURL.
|
||||
revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer revokeServer.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rev-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: clientIDVal,
|
||||
clientAssertion: signer,
|
||||
httpClient: revokeServer.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
// tokenURL drives assertion audience; revocationURL is where the POST goes.
|
||||
oidc.tokenURL = tokenEndpoint
|
||||
oidc.revocationURL = revokeServer.URL
|
||||
|
||||
err = oidc.RevokeTokenWithProvider("some-token", "refresh_token")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
capturedForm.Get("client_assertion_type"))
|
||||
assertionJWT := capturedForm.Get("client_assertion")
|
||||
assert.NotEmpty(t, assertionJWT)
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must not appear in revocation request with assertion")
|
||||
|
||||
// Verify the assertion audience is tokenURL (not revocationURL).
|
||||
parts := strings.Split(assertionJWT, ".")
|
||||
require.Len(t, parts, 3)
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, tokenEndpoint, clms["aud"],
|
||||
"assertion audience must be tokenURL, not revocationURL")
|
||||
|
||||
// Sanity-check cryptographic validity.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
// ── E. End-to-end via buildClientAssertionSignerFromConfig ───────────────────
|
||||
|
||||
// TestIssue135_BuildSignerFromInlineConfig confirms that the full config→signer
|
||||
// pipeline works for an ES256 key specified inline in the Config struct.
|
||||
func TestIssue135_BuildSignerFromInlineConfig(t *testing.T) {
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
|
||||
cfg := &Config{
|
||||
ClientAuthMethod: "private_key_jwt",
|
||||
ClientAssertionPrivateKey: string(pemBytes),
|
||||
ClientAssertionKeyID: "inline-ec-kid",
|
||||
ClientAssertionAlg: "ES256",
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "inline-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "ES256", hdr["alg"])
|
||||
assert.Equal(t, "inline-ec-kid", hdr["kid"])
|
||||
|
||||
// Verify the EC signature.
|
||||
byteLen := (elliptic.P256().Params().BitSize + 7) / 8
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
require.Len(t, sigBytes, 2*byteLen)
|
||||
|
||||
r := new(big.Int).SetBytes(sigBytes[:byteLen])
|
||||
s := new(big.Int).SetBytes(sigBytes[byteLen:])
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
assert.True(t, ecdsa.Verify(&ecKey.PublicKey, digest, r, s))
|
||||
}
|
||||
|
||||
// TestIssue135_BuildSignerDefaultsToRS256 verifies that an empty
|
||||
// ClientAssertionAlg defaults to RS256.
|
||||
func TestIssue135_BuildSignerDefaultsToRS256(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
cfg := &Config{
|
||||
ClientAssertionPrivateKey: string(pemBytes),
|
||||
ClientAssertionKeyID: "default-alg-kid",
|
||||
ClientAssertionAlg: "", // intentionally empty
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "default-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "RS256", hdr["alg"], "empty alg must default to RS256")
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
// genRSAKey generates an RSA key of the given bit size, failing the test on error.
|
||||
func genRSAKey(t *testing.T, bits int) *rsa.PrivateKey {
|
||||
t.Helper()
|
||||
k, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
require.NoError(t, err)
|
||||
return k
|
||||
}
|
||||
|
||||
// encodeRSAPKCS8 marshals an RSA key as PKCS#8 PEM ("PRIVATE KEY").
|
||||
func encodeRSAPKCS8(t *testing.T, key *rsa.PrivateKey) []byte {
|
||||
t.Helper()
|
||||
der, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
|
||||
}
|
||||
|
||||
// encodeECPKCS8 marshals an EC key as PKCS#8 PEM ("PRIVATE KEY").
|
||||
func encodeECPKCS8(t *testing.T, key *ecdsa.PrivateKey) []byte {
|
||||
t.Helper()
|
||||
der, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
|
||||
}
|
||||
|
||||
// decodeJSONPart base64url-decodes a JWT part and parses it as a JSON object.
|
||||
func decodeJSONPart(t *testing.T, b64url string) map[string]any {
|
||||
t.Helper()
|
||||
raw, err := base64.RawURLEncoding.DecodeString(b64url)
|
||||
require.NoError(t, err, "base64url decode of JWT part failed")
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(raw, &m), "JSON unmarshal of JWT part failed")
|
||||
return m
|
||||
}
|
||||
|
||||
// sha256SumBytes returns the SHA-256 digest of b as a byte slice.
|
||||
func sha256SumBytes(b []byte) []byte {
|
||||
h := sha256.Sum256(b)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// assertValidRSAJWT signs a JWT with signer and verifies the RS256 signature
|
||||
// against the given RSA public key. Used by PEM variant tests.
|
||||
func assertValidRSAJWT(t *testing.T, key *rsa.PrivateKey, signer *ClientAssertionSigner, alg string) {
|
||||
t.Helper()
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "pem-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, alg, hdr["alg"])
|
||||
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
@@ -76,9 +76,15 @@ func NewJWKCache() *JWKCache {
|
||||
}
|
||||
|
||||
// GetJWKS retrieves JWKS from cache or fetches from the remote URL if not cached.
|
||||
//
|
||||
// The entry is stored locally only via SetLocal/GetLocal. Going through a
|
||||
// distributed backend defeats the cache: JSON round-tripping turns *JWKSet
|
||||
// into map[string]interface{}, the type assertion below fails, and every
|
||||
// request refetches from the upstream. JWK rotation is rare and a per-replica
|
||||
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Check cache first
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -88,7 +94,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -105,7 +111,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
}
|
||||
|
||||
// Cache for 1 hour
|
||||
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -114,9 +120,17 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
|
||||
// stored alongside the raw JWKSet under a sibling cache key with the same
|
||||
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
|
||||
//
|
||||
// parsedJWKS is stored locally only (SetLocal/GetLocal). Its values are
|
||||
// crypto.PublicKey interfaces wrapping *rsa.PublicKey/*ecdsa.PublicKey,
|
||||
// which contain *big.Int that marshals to a hundreds-digit JSON number.
|
||||
// On a distributed backend round-trip, json.Unmarshal into interface{} would
|
||||
// try to fit that into float64 and fail with UnmarshalTypeError. Under yaegi
|
||||
// the unexported parsedJWKS.keys field is exposed via an X-prefixed name on
|
||||
// Marshal, leaking the modulus into the cached payload (issue #134).
|
||||
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
parsedKey := jwksURL + parsedKeysSuffix
|
||||
if v, found := c.cache.Get(parsedKey); found {
|
||||
if v, found := c.cache.GetLocal(parsedKey); found {
|
||||
if pj, ok := v.(*parsedJWKS); ok {
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
@@ -130,7 +144,7 @@ func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpCl
|
||||
}
|
||||
|
||||
pj := buildParsedJWKS(jwks)
|
||||
_ = c.cache.Set(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
_ = c.cache.SetLocal(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
|
||||
@@ -169,6 +169,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
introspectionCache: cacheManager.GetSharedIntrospectionCache(), // Cache for introspection results
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
clientAuthMethod: func() string {
|
||||
if config.ClientAuthMethod != "" {
|
||||
return config.ClientAuthMethod
|
||||
}
|
||||
return "client_secret_post"
|
||||
}(),
|
||||
audience: func() string {
|
||||
if config.Audience != "" {
|
||||
return config.Audience
|
||||
@@ -226,6 +232,13 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
maxRefreshTokenAge: func() time.Duration {
|
||||
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
|
||||
if config.MaxRefreshTokenAgeSeconds > 0 {
|
||||
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
@@ -242,6 +255,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
|
||||
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
|
||||
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
|
||||
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
@@ -260,6 +274,20 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
tokenResilienceConfig := DefaultTokenResilienceConfig()
|
||||
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
|
||||
|
||||
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
|
||||
// call, preventing the thundering herd that yields invalid_grant when the IdP
|
||||
// rotates refresh tokens (Zitadel/Authentik default).
|
||||
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
|
||||
|
||||
if config.ClientAuthMethod == "private_key_jwt" {
|
||||
signer, err := buildClientAssertionSignerFromConfig(config)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("failed to build client assertion signer: %w", err)
|
||||
}
|
||||
t.clientAssertion = signer
|
||||
}
|
||||
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
|
||||
+199
-47
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_EventStream tests the event-stream bypass functionality
|
||||
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
|
||||
// handshake must skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but it must STILL require an authenticated session, otherwise
|
||||
// any caller could reach the backend by setting Accept: text/event-stream.
|
||||
func TestServeHTTP_EventStream(t *testing.T) {
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
|
||||
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
|
||||
}
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
// Build an authenticated session and inject its cookies onto req.
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated SSE request to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
|
||||
// handshake bypasses the OIDC redirect (clients can't follow it) but the
|
||||
// session must already be authenticated, otherwise the backend is exposed
|
||||
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
|
||||
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected event-stream request to bypass OIDC")
|
||||
}
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
// Mixed-case + multi-token Connection header to exercise parsing.
|
||||
req.Header.Set("Connection", "keep-alive, Upgrade")
|
||||
req.Header.Set("Upgrade", "WebSocket")
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("ws-user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "ws-user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
|
||||
// Sanity: requests without Upgrade headers must NOT hit the WS
|
||||
// bypass branch (otherwise the new code path could short-circuit
|
||||
// normal authentication).
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("backend must not be called for unauthenticated plain HTTP")
|
||||
}))
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code == http.StatusOK {
|
||||
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
|
||||
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "successful authorization with email",
|
||||
setupSession: func() *MockSessionData {
|
||||
session := &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "no email triggers reauth",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "",
|
||||
userIdentifier: "",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "roles and groups authorization",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "unauthorized role/group returns 403",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "template headers processing",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "OPTIONS request with CORS",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
manager: &SessionManager{logger: NewLogger("debug")},
|
||||
}
|
||||
// Copy values from mock to concrete session
|
||||
concreteSession.SetEmail(session.email)
|
||||
concreteSession.SetUserIdentifier(session.userIdentifier)
|
||||
concreteSession.SetIDToken(session.idToken)
|
||||
concreteSession.SetAccessToken(session.accessToken)
|
||||
concreteSession.SetRefreshToken(session.refreshToken)
|
||||
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
|
||||
// MockSessionData is a test implementation of SessionData interface
|
||||
type MockSessionData struct {
|
||||
email string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
userIdentifier string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
}
|
||||
|
||||
func (m *MockSessionData) GetEmail() string { return m.email }
|
||||
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
|
||||
func (m *MockSessionData) GetIDToken() string { return m.idToken }
|
||||
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
|
||||
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
|
||||
func (m *MockSessionData) SetEmail(email string) { m.email = email }
|
||||
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
|
||||
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
|
||||
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
|
||||
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
|
||||
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -771,7 +923,7 @@ func TestStripAuthCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Now add OIDC session cookies (simulating what the browser would send)
|
||||
@@ -852,7 +1004,7 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -899,7 +1051,7 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only OIDC cookies
|
||||
@@ -950,7 +1102,7 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only non-OIDC cookies
|
||||
@@ -1013,7 +1165,7 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add cookies with the custom prefix (should be stripped)
|
||||
|
||||
+15
-15
@@ -580,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -603,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -660,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -678,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -706,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -741,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
@@ -772,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetIDToken(validToken) // Ensure ID token is also set
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
@@ -792,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -814,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -2179,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAccessToken(expiredToken)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
@@ -2756,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2782,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2809,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
@@ -2829,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2851,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
|
||||
+150
-78
@@ -14,21 +14,40 @@ import (
|
||||
)
|
||||
|
||||
// bypassReason describes why a request is being forwarded without OIDC auth.
|
||||
// It is only used for logging and to decide whether extra SSE-specific work
|
||||
// It is only used for logging and to decide whether extra side-effects
|
||||
// (propagating the user header from an existing session) should run.
|
||||
const (
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonWebSocket = "websocket"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
|
||||
// (RFC 6455). The middleware can only see the handshake; once Traefik
|
||||
// completes the upgrade it forwards frames directly, so we never re-process
|
||||
// per-frame traffic. We bypass auth on the handshake the same way we do for
|
||||
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
return false
|
||||
}
|
||||
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// shouldBypassAuth decides whether a request must skip OIDC authentication
|
||||
// entirely. It returns (true, reason) when either the request path matches a
|
||||
// configured excluded URL or the Accept header asks for a text/event-stream
|
||||
// response (SSE). The reason lets ServeHTTP apply any side-effects that are
|
||||
// unique to the bypass kind (e.g. propagating user headers for SSE).
|
||||
// configured excluded URL, the Accept header asks for a text/event-stream
|
||||
// response (SSE), or the request is a WebSocket upgrade handshake. The
|
||||
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
|
||||
// kind (e.g. propagating user headers).
|
||||
//
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded and SSE
|
||||
// traffic is never blocked by a slow/broken provider.
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
|
||||
// WebSocket traffic is never blocked by a slow/broken provider.
|
||||
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
return true, bypassReasonExcluded
|
||||
@@ -36,38 +55,55 @@ func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
|
||||
return true, bypassReasonSSE
|
||||
}
|
||||
if isWebSocketUpgrade(req) {
|
||||
return true, bypassReasonWebSocket
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// applySSEUserHeaders attempts to copy the authenticated user's identity from
|
||||
// an existing session onto the outgoing SSE request so downstream services
|
||||
// can still see who the user is. Failures are logged (not silenced) because
|
||||
// they indicate either a corrupt cookie or a misconfigured session manager
|
||||
// and are useful for debugging, but they never block the bypass itself.
|
||||
func (t *TraefikOidc) applySSEUserHeaders(req *http.Request) {
|
||||
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
|
||||
// requests and, on success, copies the authenticated user's identity onto
|
||||
// the outgoing request so downstream services can see who the user is.
|
||||
//
|
||||
// Returns true when the request carries a valid authenticated session and
|
||||
// the bypass should proceed. Returns false when no usable session is
|
||||
// present; callers must then reject the request (typically with 401) to
|
||||
// prevent unauthenticated traffic from reaching the backend just by setting
|
||||
// `Accept: text/event-stream` or sending a WebSocket upgrade.
|
||||
//
|
||||
// The check is cookie-only: the session cookie is sealed by our encryption
|
||||
// key, so the authenticated flag cannot be forged. We do NOT run full token
|
||||
// signature verification here so that SSE/WS keeps working when the OIDC
|
||||
// provider is briefly unavailable for JWK fetches.
|
||||
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
|
||||
if t.sessionManager == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
// Intentionally not fatal: SSE requests bypass auth, we just lose the
|
||||
// forwarded-user header for this request.
|
||||
t.logger.Debugf("SSE bypass: unable to load session for user header propagation: %v", err)
|
||||
return
|
||||
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
|
||||
return false
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
return
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("SSE bypass: forwarded user %s from session", email)
|
||||
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
}
|
||||
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
|
||||
return true
|
||||
}
|
||||
|
||||
// ServeHTTP implements the main middleware logic for processing HTTP requests.
|
||||
@@ -124,16 +160,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded URLs
|
||||
// and SSE requests must not block on provider init. For SSE we additionally
|
||||
// attempt to forward the user identity from an existing session (best
|
||||
// effort) so downstream handlers still see X-Forwarded-User.
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
// URLs, SSE and WebSocket upgrade requests must not block on provider
|
||||
// init. For SSE/WebSocket we ALSO require an authenticated session
|
||||
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
|
||||
// of in-flight streams can't follow an OIDC redirect, so forwarding
|
||||
// unauthenticated traffic would silently expose the backend.
|
||||
if bypass, reason := t.shouldBypassAuth(req); bypass {
|
||||
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
|
||||
if reason == bypassReasonSSE {
|
||||
t.applySSEUserHeaders(req)
|
||||
switch reason {
|
||||
case bypassReasonExcluded:
|
||||
// Operator-declared excluded URLs forward unconditionally.
|
||||
t.next.ServeHTTP(rw, req)
|
||||
case bypassReasonSSE, bypassReasonWebSocket:
|
||||
// Skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but still require an authenticated session.
|
||||
// Otherwise an unauthenticated client could hit the backend
|
||||
// just by setting Accept: text/event-stream or sending a
|
||||
// WebSocket upgrade.
|
||||
if !t.applyBypassUserHeaders(req, reason) {
|
||||
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
default:
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -237,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
@@ -309,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
userIdentifier = session.GetUserIdentifier()
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
@@ -359,9 +411,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
// Reset redirect count to prevent loops when session is invalid
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -374,7 +426,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
// Clear the session and redirect to login
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
@@ -386,31 +438,52 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
tokenForClaims := session.GetIDToken()
|
||||
if tokenForClaims == "" {
|
||||
tokenForClaims = session.GetAccessToken()
|
||||
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
// Reset redirect count to prevent loops when token is missing
|
||||
// Resolve ID-token claims at most once per request. SessionData caches
|
||||
// the parsed claims keyed on the raw ID token, so concurrent dashboard
|
||||
// panel requests on the same session don't repeatedly base64-decode and
|
||||
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
|
||||
// that hosts Traefik plugins). idClaims is reused below by the
|
||||
// header-templates branch.
|
||||
idToken := session.GetIDToken()
|
||||
var (
|
||||
idClaims map[string]interface{}
|
||||
idClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
}
|
||||
|
||||
// Choose which claims drive groups/roles extraction. Prefer the ID
|
||||
// token (cached) and fall back to the access token if there is no ID
|
||||
// token in the session — matching the prior behavior for opaque
|
||||
// ID-token providers.
|
||||
var (
|
||||
groupClaims map[string]interface{}
|
||||
groupClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
groupClaims, groupClaimsErr = idClaims, idClaimsErr
|
||||
} else if accessToken := session.GetAccessToken(); accessToken != "" {
|
||||
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
|
||||
} else if len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
var groups, roles []string
|
||||
|
||||
if groupClaimsErr == nil && groupClaims != nil {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize empty slices
|
||||
var groups, roles []string
|
||||
|
||||
if tokenForClaims != "" {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
// Reset redirect count to prevent loops when claim extraction fails
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
} else if err == nil {
|
||||
if err == nil {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
@@ -429,54 +502,53 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
if idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.headerTemplates) > 0 {
|
||||
// Reuse claims parsed earlier in this request if the ID token has not
|
||||
// changed. Saves an unnecessary JWT parse on every authenticated
|
||||
// request that uses headerTemplates.
|
||||
claims, err := session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
|
||||
if idClaimsErr != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr)
|
||||
} else {
|
||||
// idClaims may be nil when no ID token is present; templates
|
||||
// referencing .Claims.* will simply produce empty values, which
|
||||
// matches the prior behavior.
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": session.GetAccessToken(),
|
||||
"IDToken": session.GetIDToken(),
|
||||
"IDToken": idToken,
|
||||
"RefreshToken": session.GetRefreshToken(),
|
||||
"Claims": claims,
|
||||
"Claims": idClaims,
|
||||
}
|
||||
|
||||
for headerName, tmpl := range t.headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := tmpl.Execute(&buf, templateData); err != nil {
|
||||
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
|
||||
req.Header.Set(headerName, headerValue)
|
||||
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
}
|
||||
session.MarkDirty()
|
||||
t.logger.Debugf("Session marked dirty after templated header processing.")
|
||||
// NOTE: templates only mutate request headers (not session state),
|
||||
// so we deliberately do NOT MarkDirty / Save here. Previously every
|
||||
// authenticated request with header templates re-encrypted and
|
||||
// rewrote all session cookies, which was a measurable CPU and
|
||||
// Set-Cookie tax on dashboards that poll many panels per second.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,7 +587,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create authenticated session
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetIDToken("dummy-token")
|
||||
session.Save(req, httptest.NewRecorder())
|
||||
@@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create session with forbidden domain
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@forbidden.com")
|
||||
session.SetUserIdentifier("user@forbidden.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
@@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
// Create session with opaque token
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
@@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("") // No email
|
||||
session.SetUserIdentifier("") // No email
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("") // No ID token
|
||||
session.SetAccessToken("") // No access token
|
||||
|
||||
@@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
testEmail := "user@example.com"
|
||||
session.SetEmail(testEmail)
|
||||
session.SetUserIdentifier(testEmail)
|
||||
session.SetIDToken("dummy-id-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@@ -466,10 +466,23 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
|
||||
return refreshCoordinatorSessionID(token)
|
||||
}
|
||||
|
||||
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
|
||||
// for both deduplication and per-session attempt tracking. Using sha256 of the
|
||||
// raw token means each rotation produces a fresh sessionID with its own attempt
|
||||
// budget, which is what we want.
|
||||
func refreshCoordinatorSessionID(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
|
||||
// coordinated refresh result. It is wider than RefreshTimeout so a follower
|
||||
// always sees the leader's result instead of timing out independently.
|
||||
const refreshCoordinatorWaitTimeout = 35 * time.Second
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure by
|
||||
// consulting the global memory monitor. Returns true when pressure reaches
|
||||
// High or Critical, at which point we refuse new refresh operations to
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stubTokenExchanger lets us count how many upstream refresh-token grants
|
||||
// happen for a given refresh_token across concurrent middleware-level calls.
|
||||
type stubTokenExchanger struct {
|
||||
calls int32
|
||||
delay time.Duration
|
||||
resp *TokenResponse
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.delay > 0 {
|
||||
time.Sleep(s.delay)
|
||||
}
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
|
||||
// concurrent calls to coordinatedTokenRefresh with the same refresh token
|
||||
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
|
||||
//
|
||||
// Without the wireup this assertion fails (one upstream call per goroutine).
|
||||
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 100 * time.Millisecond,
|
||||
resp: &TokenResponse{
|
||||
AccessToken: "new_access",
|
||||
RefreshToken: "new_refresh",
|
||||
IDToken: "new_id",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const concurrency = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "new_access" {
|
||||
t.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
got := atomic.LoadInt32(&stub.calls)
|
||||
// Up to 2 is acceptable to absorb the documented timing slack in the
|
||||
// existing coordinator tests (e.g. operation just cleaned up before a
|
||||
// late goroutine reads the in-flight map). Anything beyond that means
|
||||
// coalescing is broken.
|
||||
if got > 2 {
|
||||
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
|
||||
// coordinator path so existing tests that build TraefikOidc literals stay
|
||||
// green.
|
||||
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenExchanger: stub,
|
||||
// refreshCoordinator deliberately nil
|
||||
}
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected exactly 1 upstream call, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
|
||||
// distinct refresh tokens are not falsely coalesced.
|
||||
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 20 * time.Millisecond,
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
cfg.DeduplicationCleanupDelay = 0
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const distinct = 8
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(distinct)
|
||||
for i := 0; i < distinct; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
|
||||
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// inMemoryCache is the smallest CacheInterface that satisfies the cross-
|
||||
// replica dedup contract: Set/Get with TTL. Used in place of the universal
|
||||
// cache singleton so these tests stay hermetic.
|
||||
type inMemoryCache struct {
|
||||
entries map[string]inMemoryCacheEntry
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type inMemoryCacheEntry struct {
|
||||
expiresAt time.Time
|
||||
value interface{}
|
||||
}
|
||||
|
||||
func newInMemoryCache() *inMemoryCache {
|
||||
return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Get(key string) (any, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
e, ok := c.entries[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(e.expiresAt) {
|
||||
delete(c.entries, key)
|
||||
return nil, false
|
||||
}
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.entries, key)
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) SetMaxSize(int) {}
|
||||
func (c *inMemoryCache) Cleanup() {}
|
||||
func (c *inMemoryCache) Close() {}
|
||||
func (c *inMemoryCache) Size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.entries)
|
||||
}
|
||||
func (c *inMemoryCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries = map[string]inMemoryCacheEntry{}
|
||||
}
|
||||
func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} }
|
||||
|
||||
// erroringTokenExchanger always errors - simulates an IdP rejection.
|
||||
type erroringTokenExchanger struct {
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&e.calls, 1)
|
||||
return nil, errors.New("invalid_grant")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil }
|
||||
|
||||
// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik
|
||||
// replica having just refreshed: the shared cache already has the result, so
|
||||
// this pod must reuse it without ever calling the IdP.
|
||||
func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "should_not_be_called"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
preExisting := &TokenResponse{
|
||||
AccessToken: "from_peer",
|
||||
RefreshToken: "rotated_by_peer",
|
||||
IDToken: "id_from_peer",
|
||||
}
|
||||
rt := "shared_refresh_token"
|
||||
cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "from_peer" {
|
||||
t.Fatalf("expected peer-provided response, got %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 0 {
|
||||
t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a
|
||||
// cache miss the leader stores its result for peers to find within the TTL.
|
||||
func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "fresh_grant"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
rt := "fresh_refresh_token"
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected 1 upstream call, got %d", got)
|
||||
}
|
||||
|
||||
v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt)))
|
||||
if !ok {
|
||||
t.Fatal("expected refresh result to be cached after upstream success")
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("cached value malformed: %+v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the
|
||||
// dedup cache when the IdP rejects the grant. Peers must run their own
|
||||
// refresh; they cannot inherit an error.
|
||||
func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) {
|
||||
failing := &erroringTokenExchanger{}
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: failing,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil {
|
||||
t.Fatal("expected an error from the failing exchanger")
|
||||
}
|
||||
if cache.Size() != 0 {
|
||||
t.Fatalf("error result must not be cached, size=%d", cache.Size())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt
|
||||
// reads from. We can't reuse sessionPool.Get() here because that requires a
|
||||
// fully initialized SessionManager - overkill for this unit-level check.
|
||||
func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData {
|
||||
t.Helper()
|
||||
rs := sessions.NewSession(nil, "refresh")
|
||||
if !issuedAt.IsZero() {
|
||||
rs.Values["issued_at"] = issuedAt.Unix()
|
||||
}
|
||||
return &SessionData{
|
||||
refreshSession: rs,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 0}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false within max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour))
|
||||
if !tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=true beyond max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_NilGuards(t *testing.T) {
|
||||
var tr *TraefikOidc
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil receiver must not panic and must return false")
|
||||
}
|
||||
tr = &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil session must return false")
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
|
||||
// Simulate successful Azure authentication
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Azure may use opaque access tokens
|
||||
session.SetAccessToken("opaque-azure-access-token")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
|
||||
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
|
||||
assert.Equal(t, "user@example.com", session2.GetEmail())
|
||||
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
|
||||
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
|
||||
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
|
||||
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
|
||||
|
||||
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
|
||||
// Set up the attacker's session with malicious data
|
||||
attackerSession.SetAuthenticated(true)
|
||||
attackerSession.SetEmail("attacker@evil.com")
|
||||
attackerSession.SetUserIdentifier("attacker@evil.com")
|
||||
attackerSession.SetIDToken(ValidIDToken)
|
||||
attackerSession.SetAccessToken(ValidAccessToken)
|
||||
|
||||
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get the email from the session
|
||||
email := session.GetEmail()
|
||||
email := session.GetUserIdentifier()
|
||||
w.Header().Set("X-User-Email", email)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
+26
-26
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
|
||||
A string `json:"a,omitempty"`
|
||||
R string `json:"r,omitempty"`
|
||||
I string `json:"i,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
Ui string `json:"ui,omitempty"`
|
||||
Cs string `json:"cs,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
Cv string `json:"cv,omitempty"`
|
||||
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
|
||||
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
|
||||
// All other mainSession.Values keys are stored in the X (extra) field.
|
||||
var knownSessionKeys = map[string]bool{
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"email": true,
|
||||
"authenticated": true,
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"user_identifier": true,
|
||||
"authenticated": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
@@ -1134,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
|
||||
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
|
||||
|
||||
// Populate legacy session values from combined payload
|
||||
sessionData.mainSession.Values["email"] = payload.E
|
||||
sessionData.mainSession.Values["user_identifier"] = payload.Ui
|
||||
sessionData.mainSession.Values["authenticated"] = payload.Au
|
||||
sessionData.mainSession.Values["csrf"] = payload.Cs
|
||||
sessionData.mainSession.Values["nonce"] = payload.N
|
||||
@@ -1278,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
|
||||
A: sd.getAccessTokenUnsafe(),
|
||||
R: sd.getRefreshTokenUnsafe(),
|
||||
I: sd.getIDTokenUnsafe(),
|
||||
E: sd.getEmailUnsafe(),
|
||||
Ui: sd.getUserIdentifierUnsafe(),
|
||||
Au: sd.getAuthenticatedUnsafe(),
|
||||
Cs: sd.getCSRFUnsafe(),
|
||||
N: sd.getNonceUnsafe(),
|
||||
@@ -2469,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address.
|
||||
// The email is extracted from ID token claims and used for
|
||||
// authorization decisions and header injection.
|
||||
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
|
||||
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
|
||||
// upn, preferred_username, etc.). The value is used for authorization
|
||||
// decisions and header injection.
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
// - The user identifier string, or an empty string if not set.
|
||||
func (sd *SessionData) GetUserIdentifier() string {
|
||||
sd.sessionMutex.RLock()
|
||||
defer sd.sessionMutex.RUnlock()
|
||||
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// SetEmail stores the authenticated user's email address.
|
||||
// The email is typically extracted from the 'email' claim in the ID token.
|
||||
// SetUserIdentifier stores the authenticated user's identifier value.
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
|
||||
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
currentVal, _ := sd.mainSession.Values["email"].(string)
|
||||
if currentVal != email {
|
||||
sd.mainSession.Values["email"] = email
|
||||
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
if currentVal != userIdentifier {
|
||||
sd.mainSession.Values["user_identifier"] = userIdentifier
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
@@ -2626,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
|
||||
return result.Token
|
||||
}
|
||||
|
||||
// getEmailUnsafe retrieves the email without acquiring locks.
|
||||
func (sd *SessionData) getEmailUnsafe() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
|
||||
func (sd *SessionData) getUserIdentifierUnsafe() string {
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
|
||||
|
||||
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
|
||||
s.False(session.IsDirty())
|
||||
}
|
||||
|
||||
// TestSessionData_SetEmail tests email setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
|
||||
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
session, err := s.sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Set email
|
||||
session.SetEmail("test@example.com")
|
||||
s.Equal("test@example.com", session.GetEmail())
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.Equal("test@example.com", session.GetUserIdentifier())
|
||||
s.True(session.IsDirty())
|
||||
}
|
||||
|
||||
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
|
||||
// Set some data
|
||||
err = session.SetAuthenticated(true)
|
||||
s.Require().NoError(err)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetCSRF("csrf-token")
|
||||
|
||||
// Clear session
|
||||
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Modify session
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.True(session.IsDirty())
|
||||
|
||||
// Save session
|
||||
|
||||
+6
-6
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Set up initial session state (what user has when first logging in)
|
||||
session1.SetAuthenticated(true)
|
||||
session1.SetEmail(originalUserData["email"].(string))
|
||||
session1.SetUserIdentifier(originalUserData["email"].(string))
|
||||
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
|
||||
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
|
||||
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
|
||||
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
// Simulate what happens when middleware detects expired tokens
|
||||
// It should preserve session state while attempting token refresh
|
||||
originalAuth := session2.GetAuthenticated()
|
||||
originalEmail := session2.GetEmail()
|
||||
originalEmail := session2.GetUserIdentifier()
|
||||
|
||||
// Reconstruct user data from individual stored keys
|
||||
originalUserDataStored := make(map[string]interface{})
|
||||
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Verify all session data is still intact after token refresh
|
||||
postRefreshAuth := session2.GetAuthenticated()
|
||||
postRefreshEmail := session2.GetEmail()
|
||||
postRefreshEmail := session2.GetUserIdentifier()
|
||||
userDataPresent := true
|
||||
for k := range originalUserData {
|
||||
if session2.mainSession.Values["user_data_"+k] == nil {
|
||||
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with specific creation time
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
|
||||
|
||||
// Create tokens with specific expiry
|
||||
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with data that should be preserved or removed
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("cleanup@example.com")
|
||||
session.SetUserIdentifier("cleanup@example.com")
|
||||
|
||||
session.mainSession.Values["user_data"] = "Test User|user-123"
|
||||
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
|
||||
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
if scenario.shouldCleanup {
|
||||
if sessionTooOld {
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
for key := range session.mainSession.Values {
|
||||
|
||||
+71
-2
@@ -55,6 +55,15 @@ type Config struct {
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
|
||||
// a stored refresh token. Once the token has been in the session longer
|
||||
// than this, requests treat it as expired up-front - returning 401 to
|
||||
// AJAX callers and triggering full re-auth on navigations - instead of
|
||||
// hammering the IdP with grants that will only fail with invalid_grant.
|
||||
// IdPs do not expose RT TTL on the wire, so this is intentionally a
|
||||
// conservative heuristic; tune to match your provider configuration.
|
||||
// Default 21600 (6h). Set to 0 to disable the check.
|
||||
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
|
||||
SessionMaxAge int `json:"sessionMaxAge"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
@@ -84,6 +93,38 @@ type Config struct {
|
||||
// providers. Enabling this in production is a security hole — prefer
|
||||
// CACertPath/CACertPEM. Emits a loud warning at startup.
|
||||
InsecureSkipVerify bool `json:"insecureSkipVerify,omitempty"`
|
||||
|
||||
// ClientAuthMethod selects the OAuth 2.0 client authentication method used
|
||||
// at the token / revocation / introspection endpoints. Supported values:
|
||||
//
|
||||
// - "client_secret_post" (default, current behavior): clientSecret is
|
||||
// sent in the request body alongside client_id.
|
||||
// - "private_key_jwt" (RFC 7523 §2.2): the plugin signs a short-lived JWT
|
||||
// assertion with a configured private key and sends it as
|
||||
// client_assertion. Use this when your IdP enforces short-lived secrets
|
||||
// or mandates secretless client auth (Entra ID, Okta, Auth0, Keycloak).
|
||||
//
|
||||
// When set to "private_key_jwt", clientSecret may be left empty and one of
|
||||
// clientAssertionPrivateKey / clientAssertionKeyPath must be configured.
|
||||
ClientAuthMethod string `json:"clientAuthMethod,omitempty"`
|
||||
|
||||
// ClientAssertionPrivateKey is an inline PEM-encoded private key used to
|
||||
// sign client_assertion JWTs. Mutually exclusive with
|
||||
// ClientAssertionKeyPath. Supports PKCS#8, PKCS#1 (RSA), and SEC1 (EC).
|
||||
ClientAssertionPrivateKey string `json:"clientAssertionPrivateKey,omitempty"`
|
||||
|
||||
// ClientAssertionKeyPath is a filesystem path to a PEM-encoded private key,
|
||||
// equivalent to ClientAssertionPrivateKey but loaded from disk.
|
||||
ClientAssertionKeyPath string `json:"clientAssertionKeyPath,omitempty"`
|
||||
|
||||
// ClientAssertionKeyID is the JWK key id (kid) advertised in the JWS
|
||||
// header. Required when using private_key_jwt so the IdP can locate the
|
||||
// matching public key registered for the client.
|
||||
ClientAssertionKeyID string `json:"clientAssertionKeyID,omitempty"`
|
||||
|
||||
// ClientAssertionAlg is the JWS signing algorithm. Defaults to RS256.
|
||||
// Supported: RS256/384/512, PS256/384/512, ES256/384/512.
|
||||
ClientAssertionAlg string `json:"clientAssertionAlg,omitempty"`
|
||||
}
|
||||
|
||||
// loadCACertPool assembles an x509.CertPool from CACertPath and CACertPEM.
|
||||
@@ -247,6 +288,7 @@ func CreateConfig() *Config {
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
OverrideScopes: false, // Default to appending scopes, not overriding
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
MaxRefreshTokenAgeSeconds: 21600, // 6h - conservative heuristic, see field doc
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
|
||||
}
|
||||
@@ -313,8 +355,30 @@ func (c *Config) Validate() error {
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("clientID is required")
|
||||
}
|
||||
if c.ClientSecret == "" {
|
||||
return fmt.Errorf("clientSecret is required")
|
||||
authMethod := c.ClientAuthMethod
|
||||
if authMethod == "" {
|
||||
authMethod = "client_secret_post"
|
||||
}
|
||||
switch authMethod {
|
||||
case "client_secret_post", "client_secret_basic":
|
||||
if c.ClientSecret == "" {
|
||||
return fmt.Errorf("clientSecret is required when clientAuthMethod is %q", authMethod)
|
||||
}
|
||||
case "private_key_jwt":
|
||||
if c.ClientAssertionPrivateKey == "" && c.ClientAssertionKeyPath == "" {
|
||||
return fmt.Errorf("clientAssertionPrivateKey or clientAssertionKeyPath is required when clientAuthMethod is private_key_jwt")
|
||||
}
|
||||
if c.ClientAssertionPrivateKey != "" && c.ClientAssertionKeyPath != "" {
|
||||
return fmt.Errorf("only one of clientAssertionPrivateKey or clientAssertionKeyPath may be set")
|
||||
}
|
||||
if c.ClientAssertionKeyID == "" {
|
||||
return fmt.Errorf("clientAssertionKeyID is required when clientAuthMethod is private_key_jwt")
|
||||
}
|
||||
if c.ClientAssertionAlg != "" && !isSupportedClientAssertionAlg(c.ClientAssertionAlg) {
|
||||
return fmt.Errorf("clientAssertionAlg %q is not supported (use RS256/384/512, PS256/384/512, or ES256/384/512)", c.ClientAssertionAlg)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("clientAuthMethod %q is not supported", authMethod)
|
||||
}
|
||||
|
||||
// Validate session encryption key
|
||||
@@ -370,6 +434,11 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate refresh-token max-age heuristic
|
||||
if c.MaxRefreshTokenAgeSeconds < 0 {
|
||||
return fmt.Errorf("maxRefreshTokenAgeSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate audience if specified
|
||||
if c.Audience != "" {
|
||||
// Validate audience format - should be a valid identifier or URL
|
||||
|
||||
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(tf.fixtures.UserEmail)
|
||||
session.SetUserIdentifier(tf.fixtures.UserEmail)
|
||||
session.SetAccessToken(tf.fixtures.AccessToken)
|
||||
session.SetRefreshToken(tf.fixtures.RefreshToken)
|
||||
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
|
||||
|
||||
+213
-29
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -46,6 +47,17 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Hot-path fast-return: a previously-verified token has already passed
|
||||
// signature, claims, and replay checks. Skipping the parseJWT cost here
|
||||
// matters under bursty traffic (e.g. 10+ concurrent panel requests on
|
||||
// every Grafana dashboard refresh) where the same token is validated
|
||||
// dozens of times per second by validateStandardTokens.
|
||||
if t.tokenCache != nil {
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
parsedJWT, parseErr := parseJWT(token)
|
||||
if parseErr != nil {
|
||||
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
|
||||
@@ -63,12 +75,6 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check token cache FIRST - if token is already verified and cached, return immediately
|
||||
// This prevents false positives when multiple goroutines validate the same token concurrently
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only check JTI blacklist for tokens that aren't already in the cache
|
||||
// This is for FIRST-TIME validation to detect replay attacks
|
||||
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
|
||||
@@ -335,7 +341,17 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
|
||||
if err := verifySignatureWithKey(token, pubKey, alg); err != nil {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
|
||||
// Microsoft Graph access tokens carry a `nonce` JWT header and are
|
||||
// signed in a proprietary form Microsoft documents as unverifiable
|
||||
// by client applications. They reach this path only when the
|
||||
// per-provider classifier (validateAzureTokens) didn't catch them,
|
||||
// so log at debug to keep the error stream actionable while still
|
||||
// surfacing the cause for diagnostics.
|
||||
if _, isMSProprietary := jwt.Header["nonce"]; isMSProprietary {
|
||||
t.safeLogDebugf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s (Microsoft proprietary nonce header — token is opaque to clients): %v", kid, alg, err)
|
||||
} else {
|
||||
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
@@ -416,7 +432,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
}
|
||||
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
||||
@@ -428,7 +444,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetRefreshToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens as well to prevent any replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -470,12 +486,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
|
||||
return false
|
||||
}
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
|
||||
return false
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
|
||||
// Get token expiry information for logging
|
||||
var expiryTime time.Time
|
||||
@@ -501,7 +523,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -518,6 +540,91 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// coordinatedTokenRefresh routes a refresh-token grant through the
|
||||
// RefreshCoordinator so that concurrent requests sharing the same refresh
|
||||
// token coalesce into a single upstream call. This prevents the thundering
|
||||
// herd that yields invalid_grant when the IdP rotates refresh tokens.
|
||||
//
|
||||
// Falls back to a direct call when the coordinator is nil, which only
|
||||
// happens in tests that build TraefikOidc literals without going through
|
||||
// NewWithContext.
|
||||
func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) {
|
||||
if t.refreshCoordinator == nil {
|
||||
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
parentCtx := context.Background()
|
||||
if req != nil {
|
||||
parentCtx = req.Context()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout)
|
||||
defer cancel()
|
||||
|
||||
sessionID := refreshCoordinatorSessionID(refreshToken)
|
||||
|
||||
return t.refreshCoordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
func() (*TokenResponse, error) {
|
||||
// Cross-replica dedup. The in-process coordinator already
|
||||
// collapses concurrent grants on this pod; this Redis-backed
|
||||
// short-TTL cache covers the (rare) case of a failover or
|
||||
// load-balancer reroute mid-refresh, where two pods would
|
||||
// otherwise both POST the same refresh_token to the IdP.
|
||||
if cached, ok := t.lookupCachedRefreshResult(sessionID); ok {
|
||||
return cached, nil
|
||||
}
|
||||
resp, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
if err == nil && resp != nil {
|
||||
t.cacheRefreshResult(sessionID, resp)
|
||||
}
|
||||
return resp, err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// lookupCachedRefreshResult returns a previously-stored TokenResponse for the
|
||||
// given refresh-token hash, if one exists and is still within its short TTL.
|
||||
// The cache wraps the universal cache, which is Redis-backed in production -
|
||||
// so a "hit" here means another Traefik replica refreshed this same token
|
||||
// within the last few seconds.
|
||||
func (t *TraefikOidc) lookupCachedRefreshResult(sessionID string) (*TokenResponse, bool) {
|
||||
if t.refreshResultCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := t.refreshResultCache.Get(refreshResultCacheKey(sessionID))
|
||||
if !ok || v == nil {
|
||||
return nil, false
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); ok && tr != nil {
|
||||
return tr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cacheRefreshResult stores the new TokenResponse under the refresh-token
|
||||
// hash for a short window. TTL is intentionally tight: the rotated refresh
|
||||
// token cannot be re-presented to the IdP, and any peer waiting longer than
|
||||
// this window has almost certainly given up via its own coordinator timeout.
|
||||
func (t *TraefikOidc) cacheRefreshResult(sessionID string, resp *TokenResponse) {
|
||||
if t.refreshResultCache == nil || resp == nil {
|
||||
return
|
||||
}
|
||||
t.refreshResultCache.Set(refreshResultCacheKey(sessionID), resp, refreshResultCacheTTL)
|
||||
}
|
||||
|
||||
// refreshResultCacheKey namespaces refresh-result entries inside the shared
|
||||
// cache namespace.
|
||||
func refreshResultCacheKey(sessionID string) string {
|
||||
return "rt-result:" + sessionID
|
||||
}
|
||||
|
||||
// refreshResultCacheTTL bounds how long a peer can lean on the dedup cache.
|
||||
// Long enough for a sibling replica to observe the result, short enough that
|
||||
// a stale entry never re-supplies a token after the IdP has already moved on.
|
||||
const refreshResultCacheTTL = 5 * time.Second
|
||||
|
||||
// RevokeToken revokes a token locally by adding it to the blacklist cache.
|
||||
// It removes the token from the verification cache and adds both the token
|
||||
// and its JTI (if present) to the blacklist to prevent future use.
|
||||
@@ -563,11 +670,33 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
}
|
||||
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, revocationURL)
|
||||
|
||||
// Read tokenURL with RLock — used as audience for private_key_jwt (RFC 7523 §3).
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
data := url.Values{
|
||||
"token": {token},
|
||||
"token_type_hint": {tokenType},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
}
|
||||
// client_id is sent in the body for every method except client_secret_basic,
|
||||
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
|
||||
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
|
||||
data.Set("client_id", t.clientID)
|
||||
}
|
||||
|
||||
useBasicAuth := false
|
||||
if t.clientAssertion != nil {
|
||||
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign client assertion: %w", err)
|
||||
}
|
||||
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||
data.Set("client_assertion", assertion)
|
||||
} else if t.clientAuthMethod == "client_secret_basic" {
|
||||
useBasicAuth = true
|
||||
} else {
|
||||
data.Set("client_secret", t.clientSecret)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", revocationURL, strings.NewReader(data.Encode()))
|
||||
@@ -577,6 +706,9 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if useBasicAuth {
|
||||
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
|
||||
}
|
||||
|
||||
// Send the request with circuit breaker protection if available
|
||||
var resp *http.Response
|
||||
@@ -663,6 +795,27 @@ func (t *TraefikOidc) isGoogleProvider() bool {
|
||||
return strings.Contains(issuerURL, "google") || strings.Contains(issuerURL, "accounts.google.com")
|
||||
}
|
||||
|
||||
// isUnverifiableAzureAccessToken reports whether a JWT-shaped access token
|
||||
// matches the Microsoft proprietary format that client applications must not
|
||||
// validate. Microsoft injects a `nonce` value into the JWT header, signs over
|
||||
// the SHA256 hash of that nonce, and ships the original nonce on the wire,
|
||||
// guaranteeing that any standard JWS verifier rejects the signature. This is
|
||||
// the documented mechanism that keeps access tokens opaque to non-resource
|
||||
// holders (Microsoft Graph, Azure Management API).
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
//
|
||||
// Returns true on parse failure as well — a token we cannot parse should not
|
||||
// be passed through the verification path that emits ERROR logs.
|
||||
func (t *TraefikOidc) isUnverifiableAzureAccessToken(token string) bool {
|
||||
parsed, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
_, hasProprietaryNonce := parsed.Header["nonce"]
|
||||
return hasProprietaryNonce
|
||||
}
|
||||
|
||||
// isAzureProvider detects if the configured OIDC provider is Azure AD.
|
||||
// It checks the issuer URL for Microsoft Azure AD domains.
|
||||
// Returns:
|
||||
@@ -705,6 +858,31 @@ func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, boo
|
||||
|
||||
if accessToken != "" {
|
||||
if strings.Count(accessToken, ".") == 2 {
|
||||
// Microsoft documents that client apps cannot validate access
|
||||
// tokens issued for Microsoft-owned APIs (Graph, Azure Mgmt) due
|
||||
// to their proprietary signing format (nonce in JWT header is
|
||||
// the marker — signed bytes hash the nonce, wire bytes ship the
|
||||
// raw value, so rsa verification always fails). Treat such
|
||||
// tokens as opaque, matching Microsoft's guidance and avoiding
|
||||
// per-request signature-error log spam (issue #134 followup).
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
// "you can't validate tokens for Microsoft Graph according to
|
||||
// these rules due to their proprietary format"
|
||||
if t.isUnverifiableAzureAccessToken(accessToken) {
|
||||
t.logger.Debug("Azure access token is Microsoft-proprietary (Graph/Mgmt) — treating as opaque per Microsoft guidance")
|
||||
if idToken != "" {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
t.logger.Debugf("Azure: ID token validation failed while access token was opaque: %v", err)
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
if idToken != "" {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
@@ -1103,9 +1281,14 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
sessionManager := t.sessionManager
|
||||
logger := t.logger
|
||||
|
||||
// Only use the fast cleanup interval when actually running under `go test`.
|
||||
// runtime.Compiler == "yaegi" makes isTestMode() return true in production
|
||||
// (Traefik interprets the plugin via yaegi), which would otherwise pin this
|
||||
// ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and
|
||||
// jwkCache.Cleanup both being no-ops there.
|
||||
cleanupInterval := 1 * time.Minute
|
||||
if isTestMode() {
|
||||
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
|
||||
if isTestMode() && runtime.Compiler != "yaegi" {
|
||||
cleanupInterval = 50 * time.Millisecond
|
||||
}
|
||||
|
||||
// Create cleanup function
|
||||
@@ -1147,25 +1330,27 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts group and role information from token claims.
|
||||
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
|
||||
// Parameters:
|
||||
// - idToken: The ID token containing claims to extract.
|
||||
// It parses the configured group/role claims from the supplied ID token.
|
||||
//
|
||||
// Returns:
|
||||
// - groups: Array of group names from the 'groups' claim.
|
||||
// - roles: Array of role names from the 'roles' claim.
|
||||
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
|
||||
// but not arrays of strings.
|
||||
// Most callers should prefer extractGroupsAndRolesFromClaims when claims have
|
||||
// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims),
|
||||
// to avoid re-parsing the JWT.
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
|
||||
}
|
||||
return t.extractGroupsAndRolesFromClaims(claims)
|
||||
}
|
||||
|
||||
// extractGroupsAndRolesFromClaims extracts group and role information from
|
||||
// already-parsed claims. Hot path: callers that have a cached claims map (such
|
||||
// as SessionData.GetIDTokenClaims) should use this to skip a redundant
|
||||
// base64+JSON decode of the JWT on every authenticated request.
|
||||
func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) {
|
||||
var groups []string
|
||||
var roles []string
|
||||
|
||||
// Extract groups using configurable claim name (defaults to "groups")
|
||||
if groupsClaim, exists := claims[t.groupClaimName]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
@@ -1181,7 +1366,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles using configurable claim name (defaults to "roles")
|
||||
if rolesClaim, exists := claims[t.roleClaimName]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
|
||||
@@ -95,6 +95,7 @@ type TraefikOidc struct {
|
||||
cancelFunc context.CancelFunc
|
||||
errorRecoveryManager *ErrorRecoveryManager
|
||||
tokenResilienceManager *TokenResilienceManager
|
||||
refreshCoordinator *RefreshCoordinator
|
||||
goroutineWG *sync.WaitGroup
|
||||
dcrConfig *DynamicClientRegistrationConfig
|
||||
dynamicClientRegistrar *DynamicClientRegistrar
|
||||
@@ -118,17 +119,21 @@ type TraefikOidc struct {
|
||||
audience string
|
||||
clientID string
|
||||
clientSecret string
|
||||
clientAuthMethod string
|
||||
clientAssertion *ClientAssertionSigner
|
||||
registrationURL string
|
||||
backchannelLogoutPath string
|
||||
frontchannelLogoutPath string
|
||||
scopesSupported []string
|
||||
scopes []string
|
||||
refreshGracePeriod time.Duration
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
refreshResultCache CacheInterface
|
||||
minimalHeaders bool
|
||||
stripAuthCookies bool
|
||||
enableBackchannelLogout bool
|
||||
|
||||
@@ -252,6 +252,25 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
|
||||
}
|
||||
}
|
||||
|
||||
return c.setLocal(key, value, ttl)
|
||||
}
|
||||
|
||||
// SetLocal stores a value only in the in-memory LRU, bypassing any
|
||||
// distributed backend. Use for values that don't survive JSON round-tripping
|
||||
// — interfaces holding concrete crypto keys, *big.Int, or types whose
|
||||
// unexported fields yaegi exposes under an X prefix on Marshal. Each replica
|
||||
// caches independently; correctness must not depend on cross-replica
|
||||
// coherence for these keys.
|
||||
func (c *UniversalCache) SetLocal(key string, value interface{}, ttl time.Duration) error {
|
||||
if ttl == 0 {
|
||||
ttl = c.config.DefaultTTL
|
||||
}
|
||||
return c.setLocal(key, value, ttl)
|
||||
}
|
||||
|
||||
// setLocal performs the in-memory portion of a write. ttl must already be
|
||||
// resolved against DefaultTTL by the caller.
|
||||
func (c *UniversalCache) setLocal(key string, value interface{}, ttl time.Duration) error {
|
||||
size := c.estimateSize(value)
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -343,6 +362,19 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
return c.getLocal(key)
|
||||
}
|
||||
|
||||
// GetLocal retrieves a value only from the in-memory LRU, never querying the
|
||||
// distributed backend. Pair with SetLocal for values that aren't safe to
|
||||
// serialize (see SetLocal docstring).
|
||||
func (c *UniversalCache) GetLocal(key string) (interface{}, bool) {
|
||||
return c.getLocal(key)
|
||||
}
|
||||
|
||||
// getLocal returns the in-memory entry for key honoring expiry, grace
|
||||
// periods, and the RLock fast path used by token/JWK/session caches.
|
||||
func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
|
||||
// Fast read path for caches whose eviction is dominated by TTL rather than
|
||||
// access-recency (token, JWK, session). Holding only an RLock here lets all
|
||||
// concurrent readers verify cached tokens in parallel — under yaegi the
|
||||
|
||||
@@ -23,6 +23,7 @@ type UniversalCacheManager struct {
|
||||
metadataCache *UniversalCache
|
||||
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
|
||||
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
|
||||
refreshResultCache *UniversalCache // Short-lived cross-replica refresh-result dedup (paired with RefreshCoordinator)
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
@@ -181,6 +182,18 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// Refresh-result cache: short-lived store keyed by sha256(refreshToken).
|
||||
// In Redis-backed mode this gives cross-replica dedup of refresh grants;
|
||||
// in memory-only mode it's effectively redundant with RefreshCoordinator
|
||||
// but safe and cheap to keep.
|
||||
manager.refreshResultCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
}
|
||||
|
||||
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
|
||||
@@ -197,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
RedisPrefix: redisConfig.KeyPrefix,
|
||||
PoolSize: redisConfig.PoolSize,
|
||||
EnableMetrics: true,
|
||||
EnableTLS: redisConfig.EnableTLS,
|
||||
TLSSkipVerify: redisConfig.TLSSkipVerify,
|
||||
}
|
||||
|
||||
// Use concrete type to avoid Yaegi reflection issues with interface assignment
|
||||
@@ -387,6 +402,21 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
createBackend("session_invalidation"),
|
||||
)
|
||||
|
||||
// Refresh-result cache - shared via Redis so concurrent refreshes across
|
||||
// Traefik replicas can dedup their grants. The 5s TTL is long enough for
|
||||
// peers to observe a recent refresh and short enough that a stale entry
|
||||
// can't be replayed against a now-rotated refresh token.
|
||||
manager.refreshResultCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("refresh_result"),
|
||||
)
|
||||
|
||||
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
|
||||
}
|
||||
|
||||
@@ -436,6 +466,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
|
||||
m.tokenTypeCache,
|
||||
m.dcrCredentialsCache,
|
||||
m.sessionInvalidationCache,
|
||||
m.refreshResultCache,
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -498,6 +529,14 @@ func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
|
||||
return m.sessionInvalidationCache
|
||||
}
|
||||
|
||||
// GetRefreshResultCache returns the short-lived refresh-result cache used to
|
||||
// coalesce refresh-token grants across Traefik replicas.
|
||||
func (m *UniversalCacheManager) GetRefreshResultCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.refreshResultCache
|
||||
}
|
||||
|
||||
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
|
||||
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
@@ -520,7 +559,7 @@ func (m *UniversalCacheManager) Close() error {
|
||||
|
||||
// Close all caches first (they won't close the shared backend)
|
||||
for _, cache := range []*UniversalCache{
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache,
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, m.refreshResultCache,
|
||||
} {
|
||||
if cache != nil {
|
||||
_ = cache.Close() // Safe to ignore: best effort cache cleanup
|
||||
|
||||
@@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error {
|
||||
t.safeLogDebug("metadataRefreshStopChan closed")
|
||||
}
|
||||
|
||||
if t.refreshCoordinator != nil {
|
||||
t.refreshCoordinator.Shutdown()
|
||||
t.safeLogDebug("refreshCoordinator shut down")
|
||||
}
|
||||
|
||||
if t.goroutineWG != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestVerifyToken_CacheHitSkipsParse proves the hot-path optimization: when a
|
||||
// token is in the cache, VerifyToken returns nil without calling parseJWT.
|
||||
// We construct a token that PASSES the cheap format checks (3 segments, len
|
||||
// >= 10) but whose body is unparseable JSON. With the cache hit hoisted ahead
|
||||
// of parseJWT, the function returns nil. Without the hoist, parseJWT would
|
||||
// fail with "failed to parse JWT for blacklist check".
|
||||
func TestVerifyToken_CacheHitSkipsParse(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenCache: NewTokenCache(),
|
||||
// limiter intentionally absent; if we reached the rate-limit check
|
||||
// the test would NPE - this is a stronger assertion that we exit
|
||||
// before that point.
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
tr.tokenVerifier = tr
|
||||
|
||||
// Three segments separated by '.', body is junk after base64-decode + JSON.
|
||||
// Pre-fix this fails parseJWT; post-fix it returns nil because the cache
|
||||
// short-circuits.
|
||||
junkToken := "header.bm90LWpzb24.signature" // base64(not-json) in the middle
|
||||
tr.tokenCache.Set(junkToken, map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
"sub": "test",
|
||||
}, time.Hour)
|
||||
|
||||
if err := tr.VerifyToken(junkToken); err != nil {
|
||||
t.Fatalf("expected cache-hit fast path to return nil, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyToken_CacheMissStillParses ensures we did not skip too aggressively
|
||||
// - on a cache miss, the function must still parse and reach the rate-limit
|
||||
// check. We assert by passing a syntactically valid token whose signature
|
||||
// won't verify, expecting an error from later in the pipeline.
|
||||
func TestVerifyToken_CacheMissStillParses(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
// no tokenBlacklist, no jwkCache - the function will fail somewhere
|
||||
// after parseJWT. We just need a non-nil error to confirm we did
|
||||
// progress past the cache check.
|
||||
}
|
||||
tr.tokenVerifier = tr
|
||||
|
||||
// Real JWT structure but unsigned/unverifiable.
|
||||
rawToken := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"
|
||||
|
||||
if err := tr.VerifyToken(rawToken); err == nil {
|
||||
t.Fatal("expected an error past parseJWT for an unsigned token, got nil")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user