mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8c5df82dcf | |||
| aa96e9dbee | |||
| 1e33bb0a4d | |||
| bfd702a447 | |||
| 68c150eba4 | |||
| 9cbca4c4fb |
@@ -0,0 +1,15 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: lukaszraczylo
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
polar: # Replace with a single Polar username
|
||||
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||
thanks_dev: # Replace with a single thanks.dev username
|
||||
custom: https://monzo.me/lukaszraczylo
|
||||
@@ -23,6 +23,19 @@ testData:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: 1234567890.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
# Alternative: RFC 7523 private_key_jwt client authentication (Entra ID,
|
||||
# Okta, Auth0, Keycloak). Replaces clientSecret with a signed JWT assertion.
|
||||
# See README "Client authentication via private key JWT".
|
||||
# clientAuthMethod: private_key_jwt
|
||||
# clientAssertionKeyID: my-key-2026
|
||||
# clientAssertionAlg: RS256 # default; or PS256/384/512, ES256/384/512
|
||||
# # File path option:
|
||||
# clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
|
||||
# # Or inline PEM (PKCS#8 / PKCS#1 / SEC1):
|
||||
# clientAssertionPrivateKey: |
|
||||
# -----BEGIN PRIVATE KEY-----
|
||||
# MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDexampleexample
|
||||
# -----END PRIVATE KEY-----
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ More example configs in [`examples/`](examples/).
|
||||
|-----------|-------------|
|
||||
| `providerURL` | Issuer URL (used for OIDC discovery). |
|
||||
| `clientID` | OAuth 2.0 client ID. |
|
||||
| `clientSecret` | OAuth 2.0 client secret. Supports `urn:k8s:secret:ns:name:key`. |
|
||||
| `clientSecret` | OAuth 2.0 client secret. Supports `urn:k8s:secret:ns:name:key`. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`; optional with `private_key_jwt`. |
|
||||
| `sessionEncryptionKey` | Cookie encryption key, **min 32 bytes**. |
|
||||
| `callbackURL` | Callback path, e.g. `/oauth2/callback`. |
|
||||
|
||||
@@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. |
|
||||
| `sessionMaxAge` | `86400` | Session lifetime in seconds. |
|
||||
| `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. |
|
||||
| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. |
|
||||
| `rateLimit` | `100` | Requests/sec. Min `10`. |
|
||||
| `logLevel` | `info` | `debug`, `info`, `error`. |
|
||||
| `audience` | `clientID` | Custom access-token audience (Auth0 custom APIs). |
|
||||
@@ -132,6 +133,11 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `stripAuthCookies` | `false` | Strip OIDC cookies from backend hop (mitigates HTTP 431). |
|
||||
| `caCertPath` / `caCertPEM` | none | Trust an internal CA for the provider's TLS. |
|
||||
| `insecureSkipVerify` | `false` | **Local dev only.** Disables TLS verification, logs a security warning. |
|
||||
| `clientAuthMethod` | `client_secret_post` | Client auth method. Set `private_key_jwt` for RFC 7523 JWT assertions (Entra ID, Okta, Auth0, Keycloak). See [Client authentication via private key JWT](#client-authentication-via-private-key-jwt). |
|
||||
| `clientAssertionPrivateKey` | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. |
|
||||
| `clientAssertionKeyPath` | none | File path to PEM private key for `private_key_jwt`. |
|
||||
| `clientAssertionKeyID` | none | JWS `kid` header. Required when `clientAuthMethod=private_key_jwt`; must match the public key registered with the IdP. |
|
||||
| `clientAssertionAlg` | `RS256` | JWS alg for `private_key_jwt`. Supported: `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
|
||||
| `enableBackchannelLogout` / `backchannelLogoutURL` | `false` / none | OIDC Back-Channel Logout (server-to-server). |
|
||||
| `enableFrontchannelLogout` / `frontchannelLogoutURL` | `false` / none | OIDC Front-Channel Logout (iframe). |
|
||||
| `redis` | disabled | See [docs/REDIS.md](docs/REDIS.md). |
|
||||
@@ -165,6 +171,22 @@ Each instance must use a unique `cookiePrefix` **and** `sessionEncryptionKey`,
|
||||
otherwise a session minted by one instance can grant access through another.
|
||||
See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87).
|
||||
|
||||
### SSE and WebSocket endpoints
|
||||
|
||||
Browser clients cannot follow an OIDC `302` redirect on an SSE stream or a
|
||||
WebSocket upgrade. The middleware handles this automatically:
|
||||
|
||||
- **SSE** (`Accept: text/event-stream`) and **WebSocket** (`Upgrade: websocket`)
|
||||
requests skip the OIDC redirect.
|
||||
- They are **not** unauthenticated — a valid encrypted session cookie is
|
||||
required, otherwise the request is rejected. The session must already exist
|
||||
(i.e. the user logged in via a normal HTTP page first).
|
||||
- `X-Forwarded-User` is forwarded from the session.
|
||||
- Validation is cookie-only (no JWK fetch), so streaming keeps working during
|
||||
brief IdP outages.
|
||||
|
||||
No configuration needed — this is implicit behavior.
|
||||
|
||||
### HTTP 431 from backends
|
||||
|
||||
Either the ID token or the chunked OIDC cookies overflow your backend's header
|
||||
@@ -196,6 +218,44 @@ caCertPEM: |
|
||||
Both can be combined. An unparseable bundle fails the plugin at startup.
|
||||
See [#125](https://github.com/lukaszraczylo/traefikoidc/issues/125).
|
||||
|
||||
### Client authentication via private key JWT
|
||||
|
||||
Use when your IdP enforces short-lived secrets or pushes secretless client auth
|
||||
— Microsoft Entra ID / Azure AD, Okta, Auth0, Keycloak. Instead of sending a
|
||||
static `clientSecret`, the plugin signs a short-lived JWT and submits it as
|
||||
`client_assertion` per [RFC 7523](https://www.rfc-editor.org/rfc/rfc7523).
|
||||
|
||||
Minimal config:
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
|
||||
clientAssertionKeyID: my-key-2026
|
||||
# clientAssertionAlg: RS256 # default; or PS256/384/512, ES256/384/512
|
||||
```
|
||||
|
||||
Or inline:
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionPrivateKey: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
...
|
||||
-----END PRIVATE KEY-----
|
||||
clientAssertionKeyID: my-key-2026
|
||||
```
|
||||
|
||||
Accepted PEM forms: PKCS#8 (`PRIVATE KEY`), PKCS#1 (`RSA PRIVATE KEY`), SEC1
|
||||
(`EC PRIVATE KEY`). The assertion uses `iss=sub=clientID`, `aud=tokenURL`, 60s
|
||||
lifetime, random hex `jti` per request. Sent on `/token` (auth-code + refresh)
|
||||
and `/revoke`. The `kid` must match the public key registered with the IdP.
|
||||
|
||||
`clientSecret` becomes optional with `private_key_jwt`. Existing
|
||||
`client_secret_post` setups are unaffected. Keys are parsed once at startup —
|
||||
rotation requires a Traefik reload.
|
||||
|
||||
See [issue #135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
|
||||
|
||||
### Environment variable names containing `API`
|
||||
|
||||
Traefik reserves `TRAEFIK_API_*`. User vars whose name contains `API` (e.g.
|
||||
|
||||
+1
-1
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetUserIdentifier("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
|
||||
+3
-3
@@ -43,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
@@ -250,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
@@ -290,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens to prevent replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
|
||||
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Pre-populate session with old data
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-access-token-with-many-characters")
|
||||
session.SetRefreshToken("old-refresh-token-with-many-characters")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Verify old data is cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
|
||||
// Verify new data is set
|
||||
s.Equal(csrfToken, session.GetCSRF())
|
||||
@@ -711,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
session, err := sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
session.mainSession.Values["redirect_count"] = 3
|
||||
|
||||
@@ -720,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
|
||||
// Session should be cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
s.Empty(session.GetIDToken())
|
||||
|
||||
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// isSupportedClientAssertionAlg reports whether alg is a recognized JWS
|
||||
// algorithm for private_key_jwt (RFC 7523 §2.2).
|
||||
func isSupportedClientAssertionAlg(alg string) bool {
|
||||
switch alg {
|
||||
case "RS256", "RS384", "RS512",
|
||||
"PS256", "PS384", "PS512",
|
||||
"ES256", "ES384", "ES512":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ClientAssertionSigner builds and signs client_assertion JWTs (RFC 7523 §2.2).
|
||||
type ClientAssertionSigner struct {
|
||||
key crypto.PrivateKey
|
||||
alg string
|
||||
kid string
|
||||
// rand is the entropy source for jti generation and PSS/ECDSA signing.
|
||||
// Defaults to crypto/rand.Reader when nil.
|
||||
rand io.Reader
|
||||
// now returns the current time. Defaults to time.Now when nil.
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// NewClientAssertionSigner parses pemBytes as a private key, validates that
|
||||
// alg is consistent with the key type, and returns a ready-to-use signer.
|
||||
// kid is placed verbatim in the JWS header.
|
||||
//
|
||||
// PEM block types understood:
|
||||
// - "PRIVATE KEY" → PKCS#8 (tried first for all types)
|
||||
// - "RSA PRIVATE KEY" → PKCS#1
|
||||
// - "EC PRIVATE KEY" → SEC1
|
||||
func NewClientAssertionSigner(pemBytes []byte, alg, kid string) (*ClientAssertionSigner, error) {
|
||||
if !isSupportedClientAssertionAlg(alg) {
|
||||
return nil, fmt.Errorf("unsupported client assertion alg %q", alg)
|
||||
}
|
||||
if kid == "" {
|
||||
return nil, fmt.Errorf("kid must not be empty")
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM block found in private key material")
|
||||
}
|
||||
|
||||
var key crypto.PrivateKey
|
||||
var parseErr error
|
||||
|
||||
switch block.Type {
|
||||
case "PRIVATE KEY":
|
||||
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
case "RSA PRIVATE KEY":
|
||||
key, parseErr = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
key, parseErr = x509.ParseECPrivateKey(block.Bytes)
|
||||
default:
|
||||
// Best-effort fallback for unknown block types.
|
||||
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
}
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key (block type %q): %w", block.Type, parseErr)
|
||||
}
|
||||
|
||||
if err := validateAlgKeyMatch(alg, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ClientAssertionSigner{key: key, alg: alg, kid: kid}, nil
|
||||
}
|
||||
|
||||
// validateAlgKeyMatch returns an error when alg implies a key type that does
|
||||
// not match the actual key.
|
||||
func validateAlgKeyMatch(alg string, key crypto.PrivateKey) error {
|
||||
switch alg[0] {
|
||||
case 'R', 'P': // RS* or PS*
|
||||
if _, ok := key.(*rsa.PrivateKey); !ok {
|
||||
return fmt.Errorf("alg %q requires an RSA key, got %T", alg, key)
|
||||
}
|
||||
case 'E': // ES*
|
||||
if _, ok := key.(*ecdsa.PrivateKey); !ok {
|
||||
return fmt.Errorf("alg %q requires an EC key, got %T", alg, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign constructs and returns a signed client_assertion JWT.
|
||||
// audience is typically the token endpoint URL (RFC 7523 §3).
|
||||
// clientID is used as both iss and sub per RFC 7523 §2.2.
|
||||
func (s *ClientAssertionSigner) Sign(audience, clientID string) (string, error) {
|
||||
rander := s.rand
|
||||
if rander == nil {
|
||||
rander = rand.Reader
|
||||
}
|
||||
nowFn := s.now
|
||||
if nowFn == nil {
|
||||
nowFn = time.Now
|
||||
}
|
||||
|
||||
now := nowFn()
|
||||
|
||||
// 16 random bytes as lowercase hex for jti uniqueness.
|
||||
jtiBytes := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rander, jtiBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate jti: %w", err)
|
||||
}
|
||||
jti := hex.EncodeToString(jtiBytes)
|
||||
|
||||
header := map[string]string{
|
||||
"alg": s.alg,
|
||||
"typ": "JWT",
|
||||
"kid": s.kid,
|
||||
}
|
||||
hdrJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal JWT header: %w", err)
|
||||
}
|
||||
|
||||
claims := map[string]any{
|
||||
"iss": clientID,
|
||||
"sub": clientID,
|
||||
"aud": audience,
|
||||
"jti": jti,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(60 * time.Second).Unix(),
|
||||
}
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal JWT claims: %w", err)
|
||||
}
|
||||
|
||||
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signingInput := hdrB64 + "." + claimsB64
|
||||
|
||||
sig, err := s.sign(rander, []byte(signingInput))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil
|
||||
}
|
||||
|
||||
// sign computes raw signature bytes for signingInput per s.alg.
|
||||
// validateAlgKeyMatch in NewClientAssertionSigner guarantees the key type
|
||||
// matches s.alg, but the comma-ok asserts here keep errcheck happy and
|
||||
// surface internal misuse loudly instead of via panic.
|
||||
func (s *ClientAssertionSigner) sign(rander io.Reader, input []byte) ([]byte, error) {
|
||||
switch s.alg {
|
||||
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
|
||||
rsaKey, ok := s.key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal: alg %q requires *rsa.PrivateKey, got %T", s.alg, s.key)
|
||||
}
|
||||
hash := rsaHashForAlg(s.alg)
|
||||
digest := hashSum(hash, input)
|
||||
if s.alg[0] == 'R' {
|
||||
return signRSAPKCS1v15(rander, rsaKey, hash, digest)
|
||||
}
|
||||
return signRSAPSS(rander, rsaKey, hash, digest)
|
||||
case "ES256", "ES384", "ES512":
|
||||
ecKey, ok := s.key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal: alg %q requires *ecdsa.PrivateKey, got %T", s.alg, s.key)
|
||||
}
|
||||
hash := ecHashForAlg(s.alg)
|
||||
digest := hashSum(hash, input)
|
||||
return signECDSA(rander, ecKey, digest)
|
||||
}
|
||||
return nil, fmt.Errorf("unhandled alg %q", s.alg)
|
||||
}
|
||||
|
||||
func rsaHashForAlg(alg string) crypto.Hash {
|
||||
switch alg {
|
||||
case "RS256", "PS256":
|
||||
return crypto.SHA256
|
||||
case "RS384", "PS384":
|
||||
return crypto.SHA384
|
||||
case "RS512", "PS512":
|
||||
return crypto.SHA512
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func ecHashForAlg(alg string) crypto.Hash {
|
||||
switch alg {
|
||||
case "ES256":
|
||||
return crypto.SHA256
|
||||
case "ES384":
|
||||
return crypto.SHA384
|
||||
case "ES512":
|
||||
return crypto.SHA512
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func hashSum(h crypto.Hash, input []byte) []byte {
|
||||
switch h {
|
||||
case crypto.SHA256:
|
||||
sum := sha256.Sum256(input)
|
||||
return sum[:]
|
||||
case crypto.SHA384:
|
||||
sum := sha512.Sum384(input)
|
||||
return sum[:]
|
||||
case crypto.SHA512:
|
||||
sum := sha512.Sum512(input)
|
||||
return sum[:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func signRSAPKCS1v15(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
|
||||
sig, err := rsa.SignPKCS1v15(rander, key, hash, digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("RSA PKCS1v15 signing failed: %w", err)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
func signRSAPSS(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
|
||||
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hash}
|
||||
sig, err := rsa.SignPSS(rander, key, hash, digest, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("RSA PSS signing failed: %w", err)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// signECDSA produces the JWS raw r||s signature (RFC 7515 App. A.3).
|
||||
// Each scalar is zero-padded to (curve.BitSize+7)/8 bytes.
|
||||
func signECDSA(rander io.Reader, key *ecdsa.PrivateKey, digest []byte) ([]byte, error) {
|
||||
r, ss, err := ecdsa.Sign(rander, key, digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ECDSA signing failed: %w", err)
|
||||
}
|
||||
byteLen := (key.Curve.Params().BitSize + 7) / 8
|
||||
sig := make([]byte, 2*byteLen)
|
||||
padBigInt(sig[0:byteLen], r)
|
||||
padBigInt(sig[byteLen:], ss)
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// padBigInt writes n as a fixed-width big-endian integer into buf.
|
||||
func padBigInt(buf []byte, n *big.Int) {
|
||||
b := n.Bytes()
|
||||
copy(buf[len(buf)-len(b):], b)
|
||||
}
|
||||
|
||||
// buildClientAssertionSignerFromConfig loads key material and constructs a
|
||||
// ClientAssertionSigner. Called from NewWithContext when
|
||||
// ClientAuthMethod == "private_key_jwt".
|
||||
func buildClientAssertionSignerFromConfig(config *Config) (*ClientAssertionSigner, error) {
|
||||
var pemBytes []byte
|
||||
|
||||
if config.ClientAssertionPrivateKey != "" {
|
||||
pemBytes = []byte(config.ClientAssertionPrivateKey)
|
||||
} else {
|
||||
data, err := os.ReadFile(config.ClientAssertionKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read clientAssertionKeyPath %q: %w", config.ClientAssertionKeyPath, err)
|
||||
}
|
||||
pemBytes = data
|
||||
}
|
||||
|
||||
alg := config.ClientAssertionAlg
|
||||
if alg == "" {
|
||||
alg = "RS256"
|
||||
}
|
||||
|
||||
return NewClientAssertionSigner(pemBytes, alg, config.ClientAssertionKeyID)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("old-access-token")
|
||||
session.SetRefreshToken("old-refresh-token")
|
||||
session.SetIDToken("old-id-token")
|
||||
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Now perform selective clearing (as done in the fix)
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// Set initial session data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-token")
|
||||
session.SetCSRF("existing-csrf")
|
||||
|
||||
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
|
||||
// NEW BEHAVIOR: Selective clearing
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
|
||||
+158
-1
@@ -5,6 +5,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|
||||
## Table of Contents
|
||||
|
||||
- [Required Parameters](#required-parameters)
|
||||
- [Client Authentication](#client-authentication)
|
||||
- [Optional Parameters](#optional-parameters)
|
||||
- [Security Options](#security-options)
|
||||
- [Session Management](#session-management)
|
||||
@@ -22,7 +23,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|
||||
|-----------|------|-------------|---------|
|
||||
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`. Optional when `clientAuthMethod: private_key_jwt`. | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
|
||||
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
|
||||
|
||||
@@ -45,6 +46,129 @@ spec:
|
||||
|
||||
---
|
||||
|
||||
## Client Authentication
|
||||
|
||||
The middleware supports three client authentication methods at the token and
|
||||
revocation endpoints. The default is `client_secret_post` (current behavior);
|
||||
`private_key_jwt` is opt-in and backwards compatible.
|
||||
|
||||
| Method | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `client_secret_post` | yes | `client_id` + `client_secret` in the request body. |
|
||||
| `client_secret_basic` | no | RFC 6749 §2.3.1 — `client_id` + `client_secret` in the `Authorization: Basic` header (form-urlencoded then base64); not in the body. |
|
||||
| `private_key_jwt` | no | RFC 7523 §2.2 — plugin signs a short-lived JWT with a private key and sends it as `client_assertion`. |
|
||||
|
||||
Select via `clientAuthMethod`:
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
```
|
||||
|
||||
### client_secret_post
|
||||
|
||||
Default. The plugin sends `client_id` and `client_secret` as form parameters
|
||||
in the token / revocation request body. No additional configuration required.
|
||||
|
||||
### private_key_jwt
|
||||
|
||||
Asymmetric client authentication per
|
||||
[RFC 7523 §2.2](https://www.rfc-editor.org/rfc/rfc7523). Use this when your
|
||||
IdP enforces short secret TTLs, when policy mandates secretless clients, or
|
||||
when you want to avoid distributing a shared secret to the proxy.
|
||||
|
||||
For each token / revocation request the plugin builds a JWS with:
|
||||
|
||||
- `iss` = `sub` = `clientID`
|
||||
- `aud` = token endpoint URL
|
||||
- `iat` = now, `exp` = now + 60s
|
||||
- `jti` = random hex per request
|
||||
- `kid` header = `clientAssertionKeyID`
|
||||
|
||||
**Required fields:**
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `clientAuthMethod` | string | `client_secret_post` | Set to `private_key_jwt`. |
|
||||
| `clientAssertionPrivateKey` | string | none | Inline PEM private key. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8, PKCS#1, and SEC1 formats accepted. |
|
||||
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk. Mutually exclusive with `clientAssertionPrivateKey`. |
|
||||
| `clientAssertionKeyID` | string | none | `kid` header inserted in the JWS. Must match the public key registered with the IdP. |
|
||||
| `clientAssertionAlg` | string | `RS256` | One of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. |
|
||||
|
||||
When `clientAuthMethod: private_key_jwt`, `clientSecret` is optional.
|
||||
|
||||
**Example — inline PEM:**
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://idp.example.com
|
||||
clientID: my-client-id
|
||||
sessionEncryptionKey: your-32-byte-encryption-key-here
|
||||
callbackURL: /oauth2/callback
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyID: key-2026-01
|
||||
clientAssertionAlg: RS256
|
||||
clientAssertionPrivateKey: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7VJTUt9Us8cKj
|
||||
MZj4ev7QnMa1mYV3Kx1jRkH5YwXQ7N2J2j8K5pP6h0oZmXq1yQv4r8wZb3sH9D2k
|
||||
... (truncated) ...
|
||||
-----END PRIVATE KEY-----
|
||||
```
|
||||
|
||||
**Example — key on disk:**
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
|
||||
clientAssertionKeyID: key-2026-01
|
||||
clientAssertionAlg: RS256
|
||||
```
|
||||
|
||||
**Generating an RS256 key with OpenSSL:**
|
||||
|
||||
```bash
|
||||
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 \
|
||||
-out client-key.pem
|
||||
openssl rsa -in client-key.pem -pubout -out client-pub.pem
|
||||
```
|
||||
|
||||
Register `client-pub.pem` (or its JWK form) with your IdP under the same
|
||||
`kid` you set in `clientAssertionKeyID`.
|
||||
|
||||
**Notes:**
|
||||
|
||||
- The private key is parsed once at plugin startup. Key rotation requires a
|
||||
Traefik reload.
|
||||
- Assertion lifetime is fixed at 60 seconds.
|
||||
- A fresh random `jti` is generated per request.
|
||||
- The `aud` claim is the token endpoint URL (from discovery).
|
||||
- Tracking issue:
|
||||
[#135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
|
||||
|
||||
### client_secret_basic
|
||||
|
||||
Per [RFC 6749 §2.3.1][rfc6749-2-3-1], the plugin sends the client credentials
|
||||
in an `Authorization: Basic` header instead of the body. Both halves
|
||||
(`client_id`, `client_secret`) are form-urlencoded individually, joined with
|
||||
a colon, then base64-encoded. Use this when your IdP requires Basic auth at
|
||||
the token endpoint and rejects credentials in the body.
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: client_secret_basic
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
```
|
||||
|
||||
[rfc6749-2-3-1]: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
|
||||
|
||||
---
|
||||
|
||||
## Optional Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
@@ -59,6 +183,11 @@ spec:
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
|
||||
| `clientAuthMethod` | string | `client_secret_post` | Client authentication method at token/revocation endpoints. One of `client_secret_post`, `client_secret_basic`, `private_key_jwt`. See [Client Authentication](#client-authentication). |
|
||||
| `clientAssertionPrivateKey` | string | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8 / PKCS#1 / SEC1. |
|
||||
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk for `private_key_jwt`. Mutually exclusive with `clientAssertionPrivateKey`. |
|
||||
| `clientAssertionKeyID` | string | none | `kid` header for `private_key_jwt` assertions. Required when `clientAuthMethod: private_key_jwt`. |
|
||||
| `clientAssertionAlg` | string | `RS256` | Signing algorithm for `private_key_jwt`. One of `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
|
||||
|
||||
### TLS Termination at Load Balancer
|
||||
|
||||
@@ -70,6 +199,33 @@ overwrite it).
|
||||
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
|
||||
dev). Otherwise leave it at default.
|
||||
|
||||
### Streaming Endpoints (SSE and WebSocket)
|
||||
|
||||
The middleware automatically bypasses the OIDC redirect for two request kinds
|
||||
that browsers cannot follow a 302 on:
|
||||
|
||||
| Bypass | Triggered by |
|
||||
|--------|--------------|
|
||||
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
|
||||
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
|
||||
|
||||
These requests do **not** require any explicit configuration — they are
|
||||
handled implicitly. However, the bypass is **not** unauthenticated:
|
||||
|
||||
- A valid, encrypted session cookie is required. Requests without one are
|
||||
rejected (the connection cannot proceed to the backend).
|
||||
- The session cookie is sealed with `sessionEncryptionKey`, so the
|
||||
`authenticated` flag cannot be forged.
|
||||
- Validation is cookie-only — no JWK fetch / signature verification — so
|
||||
streaming endpoints keep working when the OIDC provider is briefly
|
||||
unavailable.
|
||||
- The user identifier from the session is forwarded as `X-Forwarded-User`
|
||||
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
|
||||
|
||||
For browser clients, the user must complete the normal OIDC flow on a
|
||||
regular HTTP page first; the resulting session cookie is then reused on the
|
||||
SSE / WebSocket connection.
|
||||
|
||||
---
|
||||
|
||||
## Security Options
|
||||
@@ -113,6 +269,7 @@ strictAudienceValidation: true
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
|
||||
+46
-3
@@ -642,7 +642,7 @@ spec:
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientSecret</code></td>
|
||||
<td class="py-2 px-3">OAuth 2.0 client secret</td>
|
||||
<td class="py-2 px-3">OAuth 2.0 client secret. Only required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is unset or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">sessionEncryptionKey</code></td>
|
||||
@@ -718,6 +718,11 @@ spec:
|
||||
<td class="py-2 px-3">86400</td>
|
||||
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
|
||||
<td class="py-2 px-3">21600</td>
|
||||
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
|
||||
<td class="py-2 px-3">_oidc_raczylo_</td>
|
||||
@@ -748,15 +753,48 @@ spec:
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Require RFC 7662 introspection for opaque tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">disableReplayDetection</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Disable JTI replay detection (for multi-replica without Redis)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code></td>
|
||||
<td class="py-2 px-3">client_secret_post</td>
|
||||
<td class="py-2 px-3">Selects how the plugin authenticates to the token endpoint. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">Inline PEM private key used to sign client assertions for <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyPath</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">Path to a PEM private key file. Alternative to <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyID</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">JWS <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">kid</code> header value. Required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionAlg</code></td>
|
||||
<td class="py-2 px-3">RS256</td>
|
||||
<td class="py-2 px-3">Signing algorithm for the client assertion. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES512</code>.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Private Key JWT (RFC 7523)</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Use this when your IdP (Entra ID, Okta, Auth0, Keycloak) pressures short-lived secrets, or when policy mandates secretless service-to-service authentication. The plugin signs a 60-second assertion with the configured private key and sends it as <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_assertion</code> instead of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret</code>. Public-key registration on the IdP replaces shared-secret rotation. See <a href="https://www.rfc-editor.org/rfc/rfc7523" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">RFC 7523</a> and <a href="https://github.com/lukaszraczylo/traefikoidc/issues/135" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">issue #135</a>.</p>
|
||||
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc-client.pem
|
||||
clientAssertionKeyID: my-client-key-2026
|
||||
# clientSecret no longer required</code></pre>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Google Workspace with Domain Restriction</h3>
|
||||
|
||||
@@ -858,7 +896,12 @@ spec:
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -101,6 +101,16 @@ http:
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Optional: switch to RFC 7523 private_key_jwt client auth
|
||||
# (Entra ID, Okta, Auth0, Keycloak). Replaces clientSecret with a
|
||||
# signed JWT assertion. See README for details and PEM formats.
|
||||
# ----------------------------------------------------------------
|
||||
# clientAuthMethod: "private_key_jwt"
|
||||
# clientAssertionKeyPath: "/etc/traefik/oidc/client-key.pem"
|
||||
# clientAssertionKeyID: "prod-key-2026"
|
||||
# clientAssertionAlg: "RS256" # or PS256/384/512, ES256/384/512
|
||||
|
||||
# Session Configuration
|
||||
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
|
||||
sessionMaxAge: 28800 # 8 hours
|
||||
|
||||
+37
-4
@@ -107,9 +107,12 @@ type TokenResponse struct {
|
||||
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant)
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
"grant_type": {grantType},
|
||||
}
|
||||
// client_id is sent in the body for every method except client_secret_basic,
|
||||
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
|
||||
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
|
||||
data.Set("client_id", t.clientID)
|
||||
}
|
||||
|
||||
if grantType == "authorization_code" {
|
||||
@@ -141,16 +144,33 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
}
|
||||
}
|
||||
|
||||
// Read tokenURL with RLock
|
||||
// Read tokenURL with RLock — needed as audience for private_key_jwt (RFC 7523 §3).
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
useBasicAuth := false
|
||||
if t.clientAssertion != nil {
|
||||
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign client assertion: %w", err)
|
||||
}
|
||||
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||
data.Set("client_assertion", assertion)
|
||||
} else if t.clientAuthMethod == "client_secret_basic" {
|
||||
useBasicAuth = true
|
||||
} else {
|
||||
data.Set("client_secret", t.clientSecret)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if useBasicAuth {
|
||||
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
@@ -423,6 +443,19 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// setOAuthBasicAuth sets the Authorization header per RFC 6749 §2.3.1: the
|
||||
// client_id and client_secret are form-urlencoded individually, joined with a
|
||||
// colon, then base64-encoded. This differs from http.Request.SetBasicAuth,
|
||||
// which skips the form-urlencode step — that matters for credentials with
|
||||
// reserved characters (`:`, `@`, `+`, `%`, etc.) where the wire format would
|
||||
// otherwise diverge from what the spec mandates.
|
||||
func setOAuthBasicAuth(req *http.Request, clientID, clientSecret string) {
|
||||
user := url.QueryEscape(clientID)
|
||||
pass := url.QueryEscape(clientSecret)
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
// deduplicateScopes removes duplicate scopes from a slice while preserving order.
|
||||
// This ensures that OAuth scope parameters don't contain duplicates which could
|
||||
// cause issues with some authorization servers.
|
||||
|
||||
Vendored
+3
@@ -24,6 +24,7 @@ type Config struct {
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
TLSServerName string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
@@ -34,6 +35,8 @@ type Config struct {
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
EnableMetrics bool
|
||||
EnableTLS bool
|
||||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+3
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
TLSServerName: config.TLSServerName,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
EnableTLS: config.EnableTLS,
|
||||
TLSSkipVerify: config.TLSSkipVerify,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
|
||||
+25
-3
@@ -2,6 +2,7 @@ package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -31,6 +32,7 @@ type ConnectionPool struct {
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
TLSServerName string // SNI server name; defaults to host(Address) when empty
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
@@ -39,6 +41,8 @@ type PoolConfig struct {
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
|
||||
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
conn, err = p.createConnection(ctx)
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if p.config.EnableTLS {
|
||||
serverName := p.config.TLSServerName
|
||||
if serverName == "" {
|
||||
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
|
||||
serverName = host
|
||||
}
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
|
||||
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
} else {
|
||||
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// drainRESPRequest consumes a single RESP request (array or inline) from r and
|
||||
// returns true on success. Any read error returns false.
|
||||
func drainRESPRequest(r *bufio.Reader) bool {
|
||||
header, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(header, "*") {
|
||||
return true // inline command (single line) — already consumed
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
|
||||
if err != nil || n <= 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
// Each bulk: "$len\r\n<bytes>\r\n"
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
|
||||
// answer PING with +PONG. Returns the listener address and a self-signed cert.
|
||||
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
c, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer conn.Close()
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
if !drainRESPRequest(reader) {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("+PONG\r\n"))
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
close(stopCh)
|
||||
_ = listener.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
return listener.Addr().String(), der, stop
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
|
||||
// Regression test for issue #133 (enableTLS not propagated to client).
|
||||
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=false rejects a self-signed server cert.
|
||||
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "tls")
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
|
||||
// fails to handshake against a plain (non-TLS) listener.
|
||||
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
|
||||
plain, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer plain.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, acceptErr := plain.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: plain.Addr().String(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
|
||||
// when EnableTLS=false (default).
|
||||
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
|
||||
// the fix for issue #132: token refresh path hardcoded the "email" claim and
|
||||
// ignored the configured userIdentifierClaim. Keycloak users without an email
|
||||
// claim (using sub or another identifier) were being kicked out on refresh
|
||||
// even though their initial login worked.
|
||||
//
|
||||
// The callback path (auth_flow.go) already honored userIdentifierClaim with
|
||||
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
|
||||
// after PR #100 (commit a316a98).
|
||||
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
expectedIdentifier string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "sub claim configured, only sub present (Keycloak no-email case)",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-keycloak-12345",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user-uuid-keycloak-12345",
|
||||
},
|
||||
{
|
||||
name: "preferred_username configured, claim present",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"preferred_username": "alice",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "alice",
|
||||
},
|
||||
{
|
||||
name: "configured claim missing, falls back to sub",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "fallback-sub-id",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
},
|
||||
{
|
||||
name: "email default, email present (backward compatibility)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "email default, no email and no sub - refresh fails",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: false,
|
||||
expectedIdentifier: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager(
|
||||
"test-encryption-key-32-bytes-long!!",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
NewLogger("error"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("session manager: %v", err)
|
||||
}
|
||||
defer sessionManager.Shutdown()
|
||||
|
||||
capturedClaims := tt.claims
|
||||
tOidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
userIdentifierClaim: tt.userIdentifierClaim,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: &EnhancedMockTokenExchanger{
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
IDToken: "new-id-token-jwt",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
},
|
||||
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
|
||||
extractClaimsFunc: func(token string) (map[string]any, error) {
|
||||
return capturedClaims, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("get session: %v", err)
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
session.SetRefreshToken("initial-refresh-token")
|
||||
|
||||
refreshed := tOidc.refreshToken(rw, req, session)
|
||||
|
||||
if refreshed != tt.expectSuccess {
|
||||
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
|
||||
}
|
||||
|
||||
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
|
||||
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// signGraphStyleAccessToken builds a JWT in Microsoft's Graph proprietary
|
||||
// nonce-header form: bytes that get signed contain the SHA256 hash of the
|
||||
// nonce, while the wire token ships the original nonce. A standard JWS
|
||||
// verifier always rejects these with `crypto/rsa: verification error`, which
|
||||
// is why Microsoft documents Graph access tokens as opaque to client apps:
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
// "you can't validate tokens for Microsoft Graph according to these rules
|
||||
// due to their proprietary format"
|
||||
func signGraphStyleAccessToken(t *testing.T, key *rsa.PrivateKey, kid, originalNonce string, claims map[string]any) string {
|
||||
t.Helper()
|
||||
|
||||
wireHeader := map[string]any{
|
||||
"alg": "RS256",
|
||||
"kid": kid,
|
||||
"typ": "JWT",
|
||||
"nonce": originalNonce,
|
||||
}
|
||||
wireHeaderJSON, err := json.Marshal(wireHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashed := sha256.Sum256([]byte(originalNonce))
|
||||
signedHeader := map[string]any{
|
||||
"alg": "RS256",
|
||||
"kid": kid,
|
||||
"typ": "JWT",
|
||||
"nonce": fmt.Sprintf("%x", hashed),
|
||||
}
|
||||
signedHeaderJSON, err := json.Marshal(signedHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
wireHeaderB64 := base64.RawURLEncoding.EncodeToString(wireHeaderJSON)
|
||||
signedHeaderB64 := base64.RawURLEncoding.EncodeToString(signedHeaderJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
signedInput := signedHeaderB64 + "." + claimsB64
|
||||
hSign := sha256.Sum256([]byte(signedInput))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hSign[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
return wireHeaderB64 + "." + claimsB64 + "." + base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// newAzureFollowupOIDC produces a TraefikOidc instance wired for an Azure
|
||||
// AD tenant with a captured error log buffer. Used by the issue #134 followup
|
||||
// tests to assert log behavior during validateAzureTokens flows.
|
||||
func newAzureFollowupOIDC(t *testing.T, jwks *JWKSet) (*TraefikOidc, *bytes.Buffer) {
|
||||
t.Helper()
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
jwkCache: &MockJWKCache{JWKS: jwks},
|
||||
tokenCache: tokenCache,
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
oidc.tokenVerifier = oidc
|
||||
oidc.jwtVerifier = oidc
|
||||
require.True(t, oidc.isAzureProvider(), "fixture must be detected as Azure provider")
|
||||
return oidc, errBuf
|
||||
}
|
||||
|
||||
// authedSessionWithTokens returns a SessionData populated with the supplied
|
||||
// access and ID tokens, marked authenticated and recently created. The
|
||||
// SessionManager carries a real ChunkManager so that GetAccessToken /
|
||||
// GetIDToken / GetRefreshToken behave like the production code path.
|
||||
func authedSessionWithTokens(t *testing.T, accessToken, idToken string) *SessionData {
|
||||
t.Helper()
|
||||
|
||||
chunkLogger := NewLogger("error")
|
||||
chunkManager := NewChunkManager(chunkLogger)
|
||||
t.Cleanup(chunkManager.Shutdown)
|
||||
|
||||
sd := CreateMockSessionData()
|
||||
sd.manager = &SessionManager{
|
||||
sessionMaxAge: 24 * time.Hour,
|
||||
chunkManager: chunkManager,
|
||||
logger: chunkLogger,
|
||||
}
|
||||
|
||||
sd.mainSession = sessions.NewSession(nil, "main")
|
||||
sd.mainSession.Values["authenticated"] = true
|
||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||
|
||||
sd.accessSession = sessions.NewSession(nil, "access")
|
||||
sd.accessSession.Values["token"] = accessToken
|
||||
sd.accessSession.Values["compressed"] = false
|
||||
|
||||
sd.idTokenSession = sessions.NewSession(nil, "id")
|
||||
sd.idTokenSession.Values["token"] = idToken
|
||||
sd.idTokenSession.Values["compressed"] = false
|
||||
|
||||
sd.refreshSession = sessions.NewSession(nil, "refresh")
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = false
|
||||
|
||||
return sd
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_GraphAccessTokenReproducesUsersError sanity-checks
|
||||
// that our crafted Graph-style token reproduces the exact rsa error string
|
||||
// quoted on the issue thread (dada-engineer 2026-05-08, friek 2026-05-11).
|
||||
//
|
||||
// Sanity test: must always pass, regardless of the issue #134 followup fix.
|
||||
// It exists so a future contributor does not accidentally weaken the
|
||||
// reproducer and assume the followup fix is no longer needed.
|
||||
func TestIssue134_Followup_GraphAccessTokenReproducesUsersError(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-followup-kid"
|
||||
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
parsedJWT, err := parseJWT(graphToken)
|
||||
require.NoError(t, err)
|
||||
pubKey := &rsaKey.PublicKey
|
||||
alg, _ := parsedJWT.Header["alg"].(string)
|
||||
verifyErr := verifySignatureWithKey(graphToken, pubKey, alg)
|
||||
require.Error(t, verifyErr)
|
||||
assert.Contains(t, verifyErr.Error(), "crypto/rsa: verification error",
|
||||
"reproducer must emit the exact error string reported on issue #134")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken is the
|
||||
// failing-then-passing test for the followup fix.
|
||||
//
|
||||
// Symptom (before fix): validateAzureTokens calls verifyToken on every
|
||||
// JWT-shaped access token. For Microsoft Graph access tokens (the default
|
||||
// when no custom resource is registered), verification always fails with
|
||||
// `crypto/rsa: verification error`, generating two error log lines per
|
||||
// request:
|
||||
//
|
||||
// UNKNOWN token verification failed: signature verification failed:
|
||||
// crypto/rsa: verification error
|
||||
// DIAGNOSTIC: Signature verification failed for kid=<kid>, alg=RS256:
|
||||
// crypto/rsa: verification error
|
||||
//
|
||||
// Microsoft's own documentation tells client apps not to validate Graph
|
||||
// access tokens. The fix matches that guidance: when an Azure access token
|
||||
// carries Microsoft's proprietary `nonce` JWT header, treat it as opaque
|
||||
// (skip JWT verification, fall through to ID token validation).
|
||||
func TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-followup-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
now := time.Now()
|
||||
exp := now.Add(time.Hour).Unix()
|
||||
|
||||
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-azure-graph", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": exp,
|
||||
"iat": now.Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"appid": "test-client-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"email": "user@example.com",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "id-token-jti-followup",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
|
||||
output := errBuf.String()
|
||||
assert.NotContains(t, output, "crypto/rsa: verification error",
|
||||
"validateAzureTokens must not log rsa verification error for Graph-style access tokens; got: %q", output)
|
||||
assert.NotContains(t, output, "DIAGNOSTIC: Signature verification failed",
|
||||
"DIAGNOSTIC line must not fire for Graph-style access tokens; got: %q", output)
|
||||
assert.NotContains(t, output, "UNKNOWN token verification failed",
|
||||
"UNKNOWN classification log must not fire for Graph-style access tokens; got: %q", output)
|
||||
|
||||
assert.True(t, authenticated, "session must remain authenticated via the ID token fallback")
|
||||
assert.False(t, needsRefresh, "valid ID token must not signal a refresh need")
|
||||
assert.False(t, expired, "valid ID token must not be reported as expired")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection covers the
|
||||
// classifier added by the followup fix. Pure-function unit test for the
|
||||
// Microsoft proprietary marker we rely on (nonce in JWT header).
|
||||
func TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-detection-kid"
|
||||
standardToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
oidc, _ := newAzureFollowupOIDC(t, &JWKSet{})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantUnverified bool
|
||||
}{
|
||||
{name: "standard JWT without nonce header", token: standardToken, wantUnverified: false},
|
||||
{name: "Microsoft proprietary token (nonce in header)", token: graphToken, wantUnverified: true},
|
||||
{name: "garbage token treated as unverifiable", token: "not-a-jwt-at-all", wantUnverified: true},
|
||||
{name: "empty token treated as unverifiable", token: "", wantUnverified: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := oidc.isUnverifiableAzureAccessToken(tc.token)
|
||||
assert.Equal(t, tc.wantUnverified, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_StandardAzureAccessTokenStillVerifies guards against
|
||||
// regression in the happy path: an access token issued for our own clientID
|
||||
// (custom Azure-registered API) — no proprietary nonce header, signed normally
|
||||
// — must still flow through the standard verification path and authenticate.
|
||||
func TestIssue134_Followup_StandardAzureAccessTokenStillVerifies(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-standard-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
now := time.Now()
|
||||
exp := now.Add(time.Hour).Unix()
|
||||
|
||||
// Custom-resource access token: aud points to the app, no nonce header.
|
||||
accessToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "api.read",
|
||||
"jti": "standard-access-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"email": "user@example.com",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "standard-id-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, accessToken, idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
|
||||
assert.True(t, authenticated, "standard Azure access token must verify and authenticate")
|
||||
assert.False(t, needsRefresh)
|
||||
assert.False(t, expired)
|
||||
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error",
|
||||
"standard Azure token must not produce signature errors")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_GraphAccessTokenWithoutIDToken covers the edge where
|
||||
// the session has only a Graph access token (no ID token). The classifier must
|
||||
// preserve the existing "treat as opaque" semantics for backward compatibility:
|
||||
// authenticated=true even when there is no ID token to verify.
|
||||
func TestIssue134_Followup_GraphAccessTokenWithoutIDToken(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-no-idt-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-no-idt", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, "")
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
|
||||
assert.True(t, authenticated, "Graph token without ID token must remain authenticated (matches existing opaque-token semantics)")
|
||||
assert.False(t, needsRefresh)
|
||||
assert.False(t, expired)
|
||||
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification proves
|
||||
// the classifier is not a security regression. An attacker who forges a JWT
|
||||
// with a `nonce` JWT header (Microsoft's proprietary marker) but a payload
|
||||
// claiming `aud=our-clientID` should NOT gain authenticated status simply by
|
||||
// triggering the "treat as opaque" branch.
|
||||
//
|
||||
// This is the confused-deputy guardrail Microsoft warns about
|
||||
// (https://cwe.mitre.org/data/definitions/441.html): we treat the access token
|
||||
// as opaque, which means we DO NOT authorize from it — authorization comes
|
||||
// only from a separately verifiable ID token. An attacker without a valid ID
|
||||
// token must not be authenticated.
|
||||
func TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
attackerKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-attack-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
// Forged: attacker uses their OWN key, sets aud = our clientID, plants a
|
||||
// `nonce` header to trip the opaque-detection path.
|
||||
forgedAccessToken := signGraphStyleAccessToken(t, attackerKey, kid, "attacker-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "attacker",
|
||||
"scp": "admin",
|
||||
})
|
||||
|
||||
// Forged ID token signed with the attacker's key — must fail verification
|
||||
// against the tenant JWKS.
|
||||
forgedIDToken, err := createTestJWT(attackerKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Minute).Unix(),
|
||||
"nbf": time.Now().Add(-2 * time.Minute).Unix(),
|
||||
"sub": "attacker",
|
||||
"email": "attacker@evil.example",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "attacker-id-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, _ := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, forgedAccessToken, forgedIDToken)
|
||||
|
||||
authenticated, _, _ := oidc.validateAzureTokens(session)
|
||||
assert.False(t, authenticated,
|
||||
"attacker's forged tokens must not authenticate even when the access token has a nonce header — ID token verification rejects the wrong-key signature")
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError reproduces and
|
||||
// verifies the fix for issue #134.
|
||||
//
|
||||
// Symptom (before fix): with a Redis backend wired into UniversalCache,
|
||||
// caching the parsed *parsedJWKS triggered:
|
||||
//
|
||||
// json: cannot unmarshal number 2251513...
|
||||
// into Go value of type float64
|
||||
//
|
||||
// Root cause: under yaegi, json.Marshal of a struct exposes unexported
|
||||
// fields with an X-prefixed name. parsedJWKS{ keys map[string]crypto.PublicKey }
|
||||
// thus serialized the inner *rsa.PublicKey, whose modulus *big.Int marshals
|
||||
// as a JSON number hundreds of digits long. On read, json.Unmarshal into
|
||||
// interface{} parses numbers as float64, which cannot represent that range.
|
||||
// The user saw the error log on every request even though auth still worked
|
||||
// (fallback path rebuilt the keys in memory).
|
||||
//
|
||||
// Fix: route both *JWKSet and *parsedJWKS through SetLocal/GetLocal — the
|
||||
// distributed backend never sees them.
|
||||
func TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
|
||||
var fetchCount int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&fetchCount, 1)
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
infoBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(infoBuf, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
ctx := context.Background()
|
||||
|
||||
pub1, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "first GetPublicKey should succeed")
|
||||
require.NotNil(t, pub1)
|
||||
gotRSA, ok := pub1.(*rsa.PublicKey)
|
||||
require.True(t, ok, "returned key should be *rsa.PublicKey, got %T", pub1)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N), "modulus must survive intact")
|
||||
assert.Equal(t, rsaKey.E, gotRSA.E, "exponent must survive intact")
|
||||
|
||||
pub2, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "second GetPublicKey should succeed")
|
||||
require.True(t, samePublicKey(pub1, pub2), "second call must return the same parsed key (cache hit)")
|
||||
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&fetchCount),
|
||||
"upstream JWKS endpoint must be hit exactly once; second call must be served from local cache")
|
||||
|
||||
errOutput := errBuf.String()
|
||||
assert.NotContains(t, errOutput, "Failed to deserialize",
|
||||
"deserialize error must not appear with the fix in place; got: %s", errOutput)
|
||||
assert.NotContains(t, errOutput, "into Go value of type float64",
|
||||
"float64 unmarshal error must not appear; got: %s", errOutput)
|
||||
|
||||
parsedKey := server.URL + parsedKeysSuffix
|
||||
jwksKey := server.URL
|
||||
for _, k := range []string{cache.prefixKey(parsedKey), cache.prefixKey(jwksKey)} {
|
||||
fullKey := redisCfg.RedisPrefix + k
|
||||
assert.False(t, mr.Exists(fullKey),
|
||||
"key %q must not exist in Redis (local-only caching); got %v", fullKey, mr.Keys())
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue134_StalePoisonedRedisDataIgnored verifies that pre-existing bad
|
||||
// data left in Redis under a JWK :parsed key from a prior buggy version is
|
||||
// ignored: the local-only fix never reads that key, so no log spam, and the
|
||||
// fallback path returns a real *rsa.PublicKey.
|
||||
func TestIssue134_StalePoisonedRedisDataIgnored(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134stale:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pre-poison Redis with the kind of payload the old buggy path would have
|
||||
// produced (huge unquoted JSON number for the modulus). With the fix the
|
||||
// JWKCache must not even read this key.
|
||||
poisoned := []byte("\x01" + strings.Replace(
|
||||
`{"Xkeys":{"azure-test-kid":{"N":NUMBER,"E":65537}}}`,
|
||||
"NUMBER", rsaKey.N.String(), 1,
|
||||
))
|
||||
parsedRedisKey := redisCfg.RedisPrefix + "jwk:" + server.URL + parsedKeysSuffix
|
||||
require.NoError(t, mr.Set(parsedRedisKey, string(poisoned)))
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
pub, err := jwkCache.GetPublicKey(context.Background(), server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pub)
|
||||
gotRSA, ok := pub.(*rsa.PublicKey)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N))
|
||||
|
||||
assert.NotContains(t, errBuf.String(), "Failed to deserialize",
|
||||
"poisoned Redis entry must not be touched; got error log: %s", errBuf.String())
|
||||
}
|
||||
|
||||
// TestIssue134_SetLocalGetLocalSkipBackend verifies the new SetLocal/GetLocal
|
||||
// pair never reads or writes the configured backend.
|
||||
func TestIssue134_SetLocalGetLocalSkipBackend(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "local:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 10,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
type unsafeShape struct {
|
||||
hidden map[string]interface{}
|
||||
}
|
||||
val := &unsafeShape{hidden: map[string]interface{}{"k": 1}}
|
||||
|
||||
require.NoError(t, cache.SetLocal("local-key", val, 1*time.Hour))
|
||||
|
||||
got, found := cache.GetLocal("local-key")
|
||||
require.True(t, found)
|
||||
assert.Same(t, val, got, "GetLocal must return the exact pointer stored, no JSON round-trip")
|
||||
|
||||
for _, k := range mr.Keys() {
|
||||
assert.NotContains(t, k, "local-key",
|
||||
"SetLocal must not write to Redis; found key %q (all keys: %v)", k, mr.Keys())
|
||||
}
|
||||
|
||||
cache.mu.Lock()
|
||||
delete(cache.items, "local-key")
|
||||
cache.lruList.Init()
|
||||
cache.currentSize = 0
|
||||
cache.currentMemory = 0
|
||||
cache.mu.Unlock()
|
||||
|
||||
_, found = cache.GetLocal("local-key")
|
||||
assert.False(t, found, "GetLocal must not fall back to backend after local cache cleared")
|
||||
}
|
||||
|
||||
// big2bytes returns the big-endian byte slice for a positive int.
|
||||
func big2bytes(e int) []byte {
|
||||
if e <= 0 {
|
||||
return []byte{}
|
||||
}
|
||||
var buf []byte
|
||||
for e > 0 {
|
||||
buf = append([]byte{byte(e & 0xff)}, buf...)
|
||||
e >>= 8
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// samePublicKey reports whether two crypto.PublicKey instances represent the
|
||||
// same RSA key, used to confirm cache hits return identical reconstructed
|
||||
// keys.
|
||||
func samePublicKey(a, b interface{}) bool {
|
||||
ar, ok1 := a.(*rsa.PublicKey)
|
||||
br, ok2 := b.(*rsa.PublicKey)
|
||||
if !ok1 || !ok2 {
|
||||
return false
|
||||
}
|
||||
return ar.N.Cmp(br.N) == 0 && ar.E == br.E
|
||||
}
|
||||
@@ -0,0 +1,925 @@
|
||||
package traefikoidc
|
||||
|
||||
// issue135_regression_test.go — regression tests for RFC 7523 private_key_jwt
|
||||
// client authentication (issue #135).
|
||||
//
|
||||
// These tests guard:
|
||||
// - Correct JWT construction and cryptographic signature for all supported
|
||||
// algorithms (RS*/PS*/ES*).
|
||||
// - Proper validation of alg/key type combinations and empty-kid rejection.
|
||||
// - JTI uniqueness across concurrent calls.
|
||||
// - PEM variant tolerance (PKCS#8, PKCS#1, SEC1).
|
||||
// - Config.Validate() behavior for all private_key_jwt configuration paths.
|
||||
// - buildClientAssertionSignerFromConfig: inline PEM, file-backed PEM, default alg.
|
||||
// - Wire-up in exchangeTokens: assertion fields sent, client_secret absent.
|
||||
// - Wire-up in RevokeTokenWithProvider: assertion fields sent, audience = tokenURL.
|
||||
// - Back-compat: client_secret_post path unchanged when clientAssertion == nil.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ── A. Signer unit tests ──────────────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_SignerRSAFamily verifies that NewClientAssertionSigner + Sign
|
||||
// produces a well-formed, cryptographically valid JWT for every RSA-family
|
||||
// algorithm (RS256/RS384/RS512/PS256/PS384/PS512).
|
||||
func TestIssue135_SignerRSAFamily(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
cases := []struct {
|
||||
alg string
|
||||
hashFn func([]byte) []byte
|
||||
isPS bool
|
||||
hash crypto.Hash
|
||||
}{
|
||||
{"RS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, false, crypto.SHA256},
|
||||
{"RS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, false, crypto.SHA384},
|
||||
{"RS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, false, crypto.SHA512},
|
||||
{"PS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, true, crypto.SHA256},
|
||||
{"PS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, true, crypto.SHA384},
|
||||
{"PS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, true, crypto.SHA512},
|
||||
}
|
||||
|
||||
const (
|
||||
audience = "https://example.com/token"
|
||||
clientID = "client-abc"
|
||||
kid = "kid-1"
|
||||
)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.alg, func(t *testing.T) {
|
||||
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign(audience, clientID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3, "JWT must have three dot-separated parts")
|
||||
|
||||
// Decode and check header.
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, tc.alg, hdr["alg"])
|
||||
assert.Equal(t, "JWT", hdr["typ"])
|
||||
assert.Equal(t, kid, hdr["kid"])
|
||||
|
||||
// Decode and check claims.
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, clientID, clms["iss"])
|
||||
assert.Equal(t, clientID, clms["sub"])
|
||||
assert.Equal(t, audience, clms["aud"])
|
||||
|
||||
iat, ok := clms["iat"].(float64)
|
||||
require.True(t, ok, "iat must be numeric")
|
||||
exp, ok := clms["exp"].(float64)
|
||||
require.True(t, ok, "exp must be numeric")
|
||||
assert.InDelta(t, 60, exp-iat, 2, "exp-iat must equal ~60s")
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
assert.True(t, iat <= now+2 && iat >= now-5, "iat must be current time ±5s")
|
||||
|
||||
jti, ok := clms["jti"].(string)
|
||||
require.True(t, ok, "jti must be a string")
|
||||
assert.Len(t, jti, 32, "jti must be 32-char hex (16 bytes → hex)")
|
||||
|
||||
// Verify cryptographic signature.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := tc.hashFn([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
pub := &rsaKey.PublicKey
|
||||
if tc.isPS {
|
||||
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: tc.hash}
|
||||
assert.NoError(t, rsa.VerifyPSS(pub, tc.hash, digest, sigBytes, opts),
|
||||
"PSS signature verification failed for %s", tc.alg)
|
||||
} else {
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(pub, tc.hash, digest, sigBytes),
|
||||
"PKCS1v15 signature verification failed for %s", tc.alg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerECDSAFamily verifies correct JWT production for all
|
||||
// ECDSA algorithms (ES256/ES384/ES512) including that the signature is the
|
||||
// raw r||s encoding (not ASN.1 DER) and is verifiable with the matching key.
|
||||
func TestIssue135_SignerECDSAFamily(t *testing.T) {
|
||||
cases := []struct {
|
||||
alg string
|
||||
curve elliptic.Curve
|
||||
hashFn func([]byte) []byte
|
||||
hash crypto.Hash
|
||||
}{
|
||||
{"ES256", elliptic.P256(), func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, crypto.SHA256},
|
||||
{"ES384", elliptic.P384(), func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, crypto.SHA384},
|
||||
{"ES512", elliptic.P521(), func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, crypto.SHA512},
|
||||
}
|
||||
|
||||
const (
|
||||
audience = "https://idp.example.com/token"
|
||||
clientID = "ec-client"
|
||||
kid = "ec-kid"
|
||||
)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.alg, func(t *testing.T) {
|
||||
ecKey, err := ecdsa.GenerateKey(tc.curve, rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign(audience, clientID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
byteLen := (tc.curve.Params().BitSize + 7) / 8
|
||||
assert.Len(t, sigBytes, 2*byteLen,
|
||||
"ECDSA signature must be raw r||s (2×%d bytes for %s)", byteLen, tc.alg)
|
||||
|
||||
r := new(big.Int).SetBytes(sigBytes[:byteLen])
|
||||
s := new(big.Int).SetBytes(sigBytes[byteLen:])
|
||||
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := tc.hashFn([]byte(sigInput))
|
||||
|
||||
ok := ecdsa.Verify(&ecKey.PublicKey, digest, r, s)
|
||||
assert.True(t, ok, "ECDSA signature verification failed for %s", tc.alg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerRejectsAlgKeyMismatch verifies that the signer constructor
|
||||
// rejects type mismatches between key type and algorithm, unknown algorithms,
|
||||
// and an empty kid.
|
||||
func TestIssue135_SignerRejectsAlgKeyMismatch(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
rsaPEM := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
ecPEM := encodeECPKCS8(t, ecKey)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
pemBytes []byte
|
||||
alg string
|
||||
kid string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "RSA key with ES256",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "ES256",
|
||||
kid: "k1",
|
||||
wantErr: "EC key",
|
||||
},
|
||||
{
|
||||
name: "EC key with RS256",
|
||||
pemBytes: ecPEM,
|
||||
alg: "RS256",
|
||||
kid: "k1",
|
||||
wantErr: "RSA key",
|
||||
},
|
||||
{
|
||||
name: "unknown alg HS256",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "HS256",
|
||||
kid: "k1",
|
||||
wantErr: "unsupported",
|
||||
},
|
||||
{
|
||||
name: "empty kid",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "RS256",
|
||||
kid: "",
|
||||
wantErr: "kid must not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewClientAssertionSigner(tc.pemBytes, tc.alg, tc.kid)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.wantErr),
|
||||
"error should mention %q", tc.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerJTIUniqueness signs 50 assertions with the same signer
|
||||
// and asserts all jti values are distinct. Guards against broken entropy reuse.
|
||||
func TestIssue135_SignerJTIUniqueness(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "jti-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
seen := make(map[string]bool, 50)
|
||||
for i := range 50 {
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "client-x")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
jti, ok := clms["jti"].(string)
|
||||
require.True(t, ok)
|
||||
assert.False(t, seen[jti], "jti %q was reused at iteration %d", jti, i)
|
||||
seen[jti] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerPEMVariants confirms that all PEM block types understood
|
||||
// by NewClientAssertionSigner are parsed correctly: PKCS#8 ("PRIVATE KEY"),
|
||||
// PKCS#1 ("RSA PRIVATE KEY"), and SEC1 ("EC PRIVATE KEY").
|
||||
func TestIssue135_SignerPEMVariants(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("RSA PKCS8", func(t *testing.T) {
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
|
||||
require.NoError(t, err)
|
||||
assertValidRSAJWT(t, rsaKey, signer, "RS256")
|
||||
})
|
||||
|
||||
t.Run("RSA PKCS1", func(t *testing.T) {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: der})
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
|
||||
require.NoError(t, err)
|
||||
assertValidRSAJWT(t, rsaKey, signer, "RS256")
|
||||
})
|
||||
|
||||
t.Run("EC PKCS8", func(t *testing.T) {
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
|
||||
require.NoError(t, err)
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "cid")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
})
|
||||
|
||||
t.Run("EC SEC1", func(t *testing.T) {
|
||||
der, err := x509.MarshalECPrivateKey(ecKey)
|
||||
require.NoError(t, err)
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
|
||||
require.NoError(t, err)
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "cid")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
})
|
||||
}
|
||||
|
||||
// ── B. Config validation ──────────────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_ConfigValidation table-drives Config.Validate() for every
|
||||
// client-authentication-related validation branch.
|
||||
func TestIssue135_ConfigValidation(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
validPEM := string(encodeRSAPKCS8(t, rsaKey))
|
||||
|
||||
// baseConfig returns the minimum valid config, modified per test case.
|
||||
base := func() *Config {
|
||||
return &Config{
|
||||
ProviderURL: "https://idp.example.com",
|
||||
CallbackURL: "/cb",
|
||||
ClientID: "cid",
|
||||
ClientSecret: "secret",
|
||||
SessionEncryptionKey: "01234567890123456789012345678901", // 32 chars
|
||||
RateLimit: 100,
|
||||
}
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string // empty = expect nil error
|
||||
}{
|
||||
{
|
||||
name: "default empty method + secret ok",
|
||||
mutate: func(c *Config) { /* nothing extra */ },
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "explicit client_secret_post + secret ok",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "client_secret_post"
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt inline key + kid ok",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt no key at all",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "clientAssertionPrivateKey",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt both inline and path",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyPath = "/tmp/key.pem"
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "only one of",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt key but no kid",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
},
|
||||
wantErr: "clientAssertionKeyID",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt unsupported alg HS256",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
c.ClientAssertionAlg = "HS256"
|
||||
},
|
||||
wantErr: "is not supported",
|
||||
},
|
||||
{
|
||||
name: "unknown client auth method",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "weird"
|
||||
},
|
||||
wantErr: "is not supported",
|
||||
},
|
||||
{
|
||||
name: "client_secret_post with no secret",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "client_secret_post"
|
||||
c.ClientSecret = ""
|
||||
},
|
||||
wantErr: "clientSecret is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := base()
|
||||
tc.mutate(cfg)
|
||||
err := cfg.Validate()
|
||||
if tc.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.wantErr,
|
||||
"error must mention %q", tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_ConfigKeyPathLoadsFile verifies that buildClientAssertionSignerFromConfig
|
||||
// reads the PEM key from disk when ClientAssertionKeyPath is set.
|
||||
func TestIssue135_ConfigKeyPathLoadsFile(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
dir := t.TempDir()
|
||||
keyFile := dir + "/private.pem"
|
||||
require.NoError(t, os.WriteFile(keyFile, pemBytes, 0o600))
|
||||
|
||||
cfg := &Config{
|
||||
ClientAuthMethod: "private_key_jwt",
|
||||
ClientAssertionKeyPath: keyFile,
|
||||
ClientAssertionKeyID: "file-kid",
|
||||
ClientAssertionAlg: "RS256",
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err, "should load signer from key file")
|
||||
require.NotNil(t, signer)
|
||||
|
||||
// Confirm signer produces a valid JWT.
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "client-from-file")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3, "should produce a 3-part JWT")
|
||||
}
|
||||
|
||||
// ── C. Wire-up — exchangeTokens ───────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_AuthCodeExchangeUsesAssertion confirms that exchangeTokens sends
|
||||
// client_assertion + client_assertion_type instead of client_secret when a
|
||||
// ClientAssertionSigner is configured, and that the assertion JWT is valid.
|
||||
func TestIssue135_AuthCodeExchangeUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
var capturedBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body := make([]byte, r.ContentLength)
|
||||
_, _ = r.Body.Read(body)
|
||||
capturedBody = body
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Return a minimal token response so exchangeTokens doesn't error.
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at",
|
||||
IDToken: "it",
|
||||
RefreshToken: "rt",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "wire-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "wire-client",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: signer,
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err = oidc.exchangeTokens(context.Background(), "authorization_code", "code-x", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
form, err := url.ParseQuery(string(capturedBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
form.Get("client_assertion_type"), "client_assertion_type must be set")
|
||||
assertionJWT := form.Get("client_assertion")
|
||||
assert.NotEmpty(t, assertionJWT, "client_assertion must be present")
|
||||
assert.Empty(t, form.Get("client_secret"), "client_secret must not be sent when using assertion")
|
||||
assert.Equal(t, "wire-client", form.Get("client_id"))
|
||||
assert.Equal(t, "code-x", form.Get("code"))
|
||||
assert.Equal(t, "authorization_code", form.Get("grant_type"))
|
||||
|
||||
// Verify assertion JWT: header, claims, signature.
|
||||
parts := strings.Split(assertionJWT, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "RS256", hdr["alg"])
|
||||
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, "wire-client", clms["iss"])
|
||||
assert.Equal(t, "wire-client", clms["sub"])
|
||||
assert.Equal(t, server.URL, clms["aud"],
|
||||
"audience must be the tokenURL (RFC 7523 §3)")
|
||||
|
||||
// Verify signature with RSA public key.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
// TestIssue135_RefreshTokenUsesAssertion verifies that the refresh_token grant
|
||||
// type also sends client_assertion and the correct form fields.
|
||||
func TestIssue135_RefreshTokenUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "new-at",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rt-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "rt-client",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: signer,
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err = oidc.exchangeTokens(context.Background(), "refresh_token", "rt-y", "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "refresh_token", capturedForm.Get("grant_type"))
|
||||
assert.Equal(t, "rt-y", capturedForm.Get("refresh_token"))
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
capturedForm.Get("client_assertion_type"))
|
||||
assert.NotEmpty(t, capturedForm.Get("client_assertion"))
|
||||
assert.Empty(t, capturedForm.Get("client_secret"))
|
||||
}
|
||||
|
||||
// TestIssue135_BackcompatClientSecretPath confirms that exchangeTokens sends
|
||||
// client_secret and does NOT send client_assertion when clientAssertion is nil.
|
||||
func TestIssue135_BackcompatClientSecretPath(t *testing.T) {
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "legacy-client",
|
||||
clientSecret: "legacy-secret",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: nil, // back-compat path
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bc", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "legacy-secret", capturedForm.Get("client_secret"),
|
||||
"client_secret must be sent on the classic path")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion"),
|
||||
"client_assertion must NOT be present on the classic path")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion_type"),
|
||||
"client_assertion_type must NOT be present on the classic path")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicAuth verifies that when clientAuthMethod is
|
||||
// "client_secret_basic", exchangeTokens sends an HTTP Basic Authorization
|
||||
// header carrying url-encoded client_id:client_secret per RFC 6749 §2.3.1,
|
||||
// and that neither client_id nor client_secret appears in the form body.
|
||||
func TestIssue135_ClientSecretBasicAuth(t *testing.T) {
|
||||
var capturedAuth string
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at-basic", TokenType: "Bearer", ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "basic-client",
|
||||
clientSecret: "basic-secret",
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
tokenHTTPClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bb", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, strings.HasPrefix(capturedAuth, "Basic "),
|
||||
"Authorization header must start with 'Basic ', got %q", capturedAuth)
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err, "Authorization payload must be valid base64")
|
||||
user, pass, ok := strings.Cut(string(raw), ":")
|
||||
require.True(t, ok, "Authorization payload must contain a single ':' separator")
|
||||
assert.Equal(t, "basic-client", user, "client_id should round-trip through QueryEscape")
|
||||
assert.Equal(t, "basic-secret", pass, "client_secret should round-trip through QueryEscape")
|
||||
|
||||
assert.Empty(t, capturedForm.Get("client_id"),
|
||||
"client_id must NOT be in the body when using client_secret_basic")
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must NOT be in the body when using client_secret_basic")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion"),
|
||||
"client_assertion must NOT be present on the basic-auth path")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicURLEncodesReservedChars verifies that
|
||||
// credentials containing reserved characters (`:`, `+`, `/`, etc.) are
|
||||
// form-urlencoded before base64 per RFC 6749 §2.3.1, so the receiving
|
||||
// authorization server can decode them deterministically.
|
||||
func TestIssue135_ClientSecretBasicURLEncodesReservedChars(t *testing.T) {
|
||||
var capturedAuth string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{AccessToken: "at", TokenType: "Bearer", ExpiresIn: 3600})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
const (
|
||||
clientID = "weird:id+1"
|
||||
clientSecret = "p@ss/word=&" //nolint:gosec // test fixture
|
||||
)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
tokenHTTPClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "c", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err)
|
||||
|
||||
wantUser := url.QueryEscape(clientID)
|
||||
wantPass := url.QueryEscape(clientSecret)
|
||||
assert.Equal(t, wantUser+":"+wantPass, string(raw),
|
||||
"both halves must be form-urlencoded before the base64 step")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicRevocation verifies that the revocation path
|
||||
// honors client_secret_basic identically to the token path.
|
||||
func TestIssue135_ClientSecretBasicRevocation(t *testing.T) {
|
||||
var capturedAuth string
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "rev-basic",
|
||||
clientSecret: "rev-secret",
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
httpClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = "https://idp.example.com/token"
|
||||
oidc.revocationURL = server.URL
|
||||
|
||||
require.NoError(t, oidc.RevokeTokenWithProvider("opaque-tok", "access_token"))
|
||||
|
||||
require.True(t, strings.HasPrefix(capturedAuth, "Basic "), "got %q", capturedAuth)
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rev-basic:rev-secret", string(raw))
|
||||
|
||||
assert.Equal(t, "opaque-tok", capturedForm.Get("token"))
|
||||
assert.Equal(t, "access_token", capturedForm.Get("token_type_hint"))
|
||||
assert.Empty(t, capturedForm.Get("client_id"),
|
||||
"client_id must NOT be in body on Basic-auth revocation")
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must NOT be in body on Basic-auth revocation")
|
||||
}
|
||||
|
||||
// ── D. Wire-up — RevokeTokenWithProvider ────────────────────────────────────
|
||||
|
||||
// TestIssue135_RevocationUsesAssertion verifies that RevokeTokenWithProvider
|
||||
// sends client_assertion (not client_secret), and that the assertion's audience
|
||||
// is the tokenURL, not the revocationURL (per RFC 7523 §3).
|
||||
func TestIssue135_RevocationUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
const (
|
||||
tokenEndpoint = "https://idp.example.com/token" // audience for assertion
|
||||
clientIDVal = "revoke-client"
|
||||
)
|
||||
|
||||
var capturedForm url.Values
|
||||
// Revocation endpoint — deliberate separate URL to confirm audience != revocationURL.
|
||||
revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer revokeServer.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rev-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: clientIDVal,
|
||||
clientAssertion: signer,
|
||||
httpClient: revokeServer.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
// tokenURL drives assertion audience; revocationURL is where the POST goes.
|
||||
oidc.tokenURL = tokenEndpoint
|
||||
oidc.revocationURL = revokeServer.URL
|
||||
|
||||
err = oidc.RevokeTokenWithProvider("some-token", "refresh_token")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
capturedForm.Get("client_assertion_type"))
|
||||
assertionJWT := capturedForm.Get("client_assertion")
|
||||
assert.NotEmpty(t, assertionJWT)
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must not appear in revocation request with assertion")
|
||||
|
||||
// Verify the assertion audience is tokenURL (not revocationURL).
|
||||
parts := strings.Split(assertionJWT, ".")
|
||||
require.Len(t, parts, 3)
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, tokenEndpoint, clms["aud"],
|
||||
"assertion audience must be tokenURL, not revocationURL")
|
||||
|
||||
// Sanity-check cryptographic validity.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
// ── E. End-to-end via buildClientAssertionSignerFromConfig ───────────────────
|
||||
|
||||
// TestIssue135_BuildSignerFromInlineConfig confirms that the full config→signer
|
||||
// pipeline works for an ES256 key specified inline in the Config struct.
|
||||
func TestIssue135_BuildSignerFromInlineConfig(t *testing.T) {
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
|
||||
cfg := &Config{
|
||||
ClientAuthMethod: "private_key_jwt",
|
||||
ClientAssertionPrivateKey: string(pemBytes),
|
||||
ClientAssertionKeyID: "inline-ec-kid",
|
||||
ClientAssertionAlg: "ES256",
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "inline-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "ES256", hdr["alg"])
|
||||
assert.Equal(t, "inline-ec-kid", hdr["kid"])
|
||||
|
||||
// Verify the EC signature.
|
||||
byteLen := (elliptic.P256().Params().BitSize + 7) / 8
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
require.Len(t, sigBytes, 2*byteLen)
|
||||
|
||||
r := new(big.Int).SetBytes(sigBytes[:byteLen])
|
||||
s := new(big.Int).SetBytes(sigBytes[byteLen:])
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
assert.True(t, ecdsa.Verify(&ecKey.PublicKey, digest, r, s))
|
||||
}
|
||||
|
||||
// TestIssue135_BuildSignerDefaultsToRS256 verifies that an empty
|
||||
// ClientAssertionAlg defaults to RS256.
|
||||
func TestIssue135_BuildSignerDefaultsToRS256(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
cfg := &Config{
|
||||
ClientAssertionPrivateKey: string(pemBytes),
|
||||
ClientAssertionKeyID: "default-alg-kid",
|
||||
ClientAssertionAlg: "", // intentionally empty
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "default-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "RS256", hdr["alg"], "empty alg must default to RS256")
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
// genRSAKey generates an RSA key of the given bit size, failing the test on error.
|
||||
func genRSAKey(t *testing.T, bits int) *rsa.PrivateKey {
|
||||
t.Helper()
|
||||
k, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
require.NoError(t, err)
|
||||
return k
|
||||
}
|
||||
|
||||
// encodeRSAPKCS8 marshals an RSA key as PKCS#8 PEM ("PRIVATE KEY").
|
||||
func encodeRSAPKCS8(t *testing.T, key *rsa.PrivateKey) []byte {
|
||||
t.Helper()
|
||||
der, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
|
||||
}
|
||||
|
||||
// encodeECPKCS8 marshals an EC key as PKCS#8 PEM ("PRIVATE KEY").
|
||||
func encodeECPKCS8(t *testing.T, key *ecdsa.PrivateKey) []byte {
|
||||
t.Helper()
|
||||
der, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
|
||||
}
|
||||
|
||||
// decodeJSONPart base64url-decodes a JWT part and parses it as a JSON object.
|
||||
func decodeJSONPart(t *testing.T, b64url string) map[string]any {
|
||||
t.Helper()
|
||||
raw, err := base64.RawURLEncoding.DecodeString(b64url)
|
||||
require.NoError(t, err, "base64url decode of JWT part failed")
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(raw, &m), "JSON unmarshal of JWT part failed")
|
||||
return m
|
||||
}
|
||||
|
||||
// sha256SumBytes returns the SHA-256 digest of b as a byte slice.
|
||||
func sha256SumBytes(b []byte) []byte {
|
||||
h := sha256.Sum256(b)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// assertValidRSAJWT signs a JWT with signer and verifies the RS256 signature
|
||||
// against the given RSA public key. Used by PEM variant tests.
|
||||
func assertValidRSAJWT(t *testing.T, key *rsa.PrivateKey, signer *ClientAssertionSigner, alg string) {
|
||||
t.Helper()
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "pem-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, alg, hdr["alg"])
|
||||
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
@@ -76,9 +76,15 @@ func NewJWKCache() *JWKCache {
|
||||
}
|
||||
|
||||
// GetJWKS retrieves JWKS from cache or fetches from the remote URL if not cached.
|
||||
//
|
||||
// The entry is stored locally only via SetLocal/GetLocal. Going through a
|
||||
// distributed backend defeats the cache: JSON round-tripping turns *JWKSet
|
||||
// into map[string]interface{}, the type assertion below fails, and every
|
||||
// request refetches from the upstream. JWK rotation is rare and a per-replica
|
||||
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Check cache first
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -88,7 +94,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -105,7 +111,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
}
|
||||
|
||||
// Cache for 1 hour
|
||||
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -114,9 +120,17 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
|
||||
// stored alongside the raw JWKSet under a sibling cache key with the same
|
||||
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
|
||||
//
|
||||
// parsedJWKS is stored locally only (SetLocal/GetLocal). Its values are
|
||||
// crypto.PublicKey interfaces wrapping *rsa.PublicKey/*ecdsa.PublicKey,
|
||||
// which contain *big.Int that marshals to a hundreds-digit JSON number.
|
||||
// On a distributed backend round-trip, json.Unmarshal into interface{} would
|
||||
// try to fit that into float64 and fail with UnmarshalTypeError. Under yaegi
|
||||
// the unexported parsedJWKS.keys field is exposed via an X-prefixed name on
|
||||
// Marshal, leaking the modulus into the cached payload (issue #134).
|
||||
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
parsedKey := jwksURL + parsedKeysSuffix
|
||||
if v, found := c.cache.Get(parsedKey); found {
|
||||
if v, found := c.cache.GetLocal(parsedKey); found {
|
||||
if pj, ok := v.(*parsedJWKS); ok {
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
@@ -130,7 +144,7 @@ func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpCl
|
||||
}
|
||||
|
||||
pj := buildParsedJWKS(jwks)
|
||||
_ = c.cache.Set(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
_ = c.cache.SetLocal(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
|
||||
@@ -169,6 +169,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
introspectionCache: cacheManager.GetSharedIntrospectionCache(), // Cache for introspection results
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
clientAuthMethod: func() string {
|
||||
if config.ClientAuthMethod != "" {
|
||||
return config.ClientAuthMethod
|
||||
}
|
||||
return "client_secret_post"
|
||||
}(),
|
||||
audience: func() string {
|
||||
if config.Audience != "" {
|
||||
return config.Audience
|
||||
@@ -273,6 +279,15 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
// rotates refresh tokens (Zitadel/Authentik default).
|
||||
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
|
||||
|
||||
if config.ClientAuthMethod == "private_key_jwt" {
|
||||
signer, err := buildClientAssertionSignerFromConfig(config)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("failed to build client assertion signer: %w", err)
|
||||
}
|
||||
t.clientAssertion = signer
|
||||
}
|
||||
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
|
||||
+28
-28
@@ -138,7 +138,7 @@ func TestServeHTTP_EventStream(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
@@ -221,7 +221,7 @@ func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetEmail("ws-user@example.com")
|
||||
session.SetUserIdentifier("ws-user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
@@ -408,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "successful authorization with email",
|
||||
setupSession: func() *MockSessionData {
|
||||
session := &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -440,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "no email triggers reauth",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "",
|
||||
userIdentifier: "",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -461,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "roles and groups authorization",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -494,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "unauthorized role/group returns 403",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -521,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "template headers processing",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -553,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "OPTIONS request with CORS",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -604,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
manager: &SessionManager{logger: NewLogger("debug")},
|
||||
}
|
||||
// Copy values from mock to concrete session
|
||||
concreteSession.SetEmail(session.email)
|
||||
concreteSession.SetUserIdentifier(session.userIdentifier)
|
||||
concreteSession.SetIDToken(session.idToken)
|
||||
concreteSession.SetAccessToken(session.accessToken)
|
||||
concreteSession.SetRefreshToken(session.refreshToken)
|
||||
@@ -654,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
|
||||
// MockSessionData is a test implementation of SessionData interface
|
||||
type MockSessionData struct {
|
||||
email string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
userIdentifier string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
}
|
||||
|
||||
func (m *MockSessionData) GetEmail() string { return m.email }
|
||||
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
|
||||
func (m *MockSessionData) GetIDToken() string { return m.idToken }
|
||||
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
|
||||
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
|
||||
func (m *MockSessionData) SetEmail(email string) { m.email = email }
|
||||
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
|
||||
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
|
||||
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
|
||||
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
|
||||
@@ -762,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
@@ -837,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -923,7 +923,7 @@ func TestStripAuthCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Now add OIDC session cookies (simulating what the browser would send)
|
||||
@@ -1004,7 +1004,7 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -1051,7 +1051,7 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only OIDC cookies
|
||||
@@ -1102,7 +1102,7 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only non-OIDC cookies
|
||||
@@ -1165,7 +1165,7 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add cookies with the custom prefix (should be stripped)
|
||||
|
||||
+15
-15
@@ -580,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -603,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -660,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -678,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -706,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -741,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
@@ -772,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetIDToken(validToken) // Ensure ID token is also set
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
@@ -792,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -814,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -2179,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAccessToken(expiredToken)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
@@ -2756,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2782,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2809,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
@@ -2829,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2851,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
|
||||
+15
-15
@@ -92,17 +92,17 @@ func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) b
|
||||
return false
|
||||
}
|
||||
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
}
|
||||
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, email)
|
||||
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -289,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
@@ -361,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
userIdentifier = session.GetUserIdentifier()
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
@@ -411,9 +411,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
// Reset redirect count to prevent loops when session is invalid
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -426,7 +426,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
// Clear the session and redirect to login
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
@@ -502,19 +502,19 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
if idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
@@ -587,7 +587,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create authenticated session
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetIDToken("dummy-token")
|
||||
session.Save(req, httptest.NewRecorder())
|
||||
@@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create session with forbidden domain
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@forbidden.com")
|
||||
session.SetUserIdentifier("user@forbidden.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
@@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
// Create session with opaque token
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
@@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("") // No email
|
||||
session.SetUserIdentifier("") // No email
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("") // No ID token
|
||||
session.SetAccessToken("") // No access token
|
||||
|
||||
@@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
testEmail := "user@example.com"
|
||||
session.SetEmail(testEmail)
|
||||
session.SetUserIdentifier(testEmail)
|
||||
session.SetIDToken("dummy-id-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
|
||||
// Simulate successful Azure authentication
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Azure may use opaque access tokens
|
||||
session.SetAccessToken("opaque-azure-access-token")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
|
||||
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
|
||||
assert.Equal(t, "user@example.com", session2.GetEmail())
|
||||
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
|
||||
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
|
||||
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
|
||||
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
|
||||
|
||||
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
|
||||
// Set up the attacker's session with malicious data
|
||||
attackerSession.SetAuthenticated(true)
|
||||
attackerSession.SetEmail("attacker@evil.com")
|
||||
attackerSession.SetUserIdentifier("attacker@evil.com")
|
||||
attackerSession.SetIDToken(ValidIDToken)
|
||||
attackerSession.SetAccessToken(ValidAccessToken)
|
||||
|
||||
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get the email from the session
|
||||
email := session.GetEmail()
|
||||
email := session.GetUserIdentifier()
|
||||
w.Header().Set("X-User-Email", email)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
+26
-26
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
|
||||
A string `json:"a,omitempty"`
|
||||
R string `json:"r,omitempty"`
|
||||
I string `json:"i,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
Ui string `json:"ui,omitempty"`
|
||||
Cs string `json:"cs,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
Cv string `json:"cv,omitempty"`
|
||||
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
|
||||
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
|
||||
// All other mainSession.Values keys are stored in the X (extra) field.
|
||||
var knownSessionKeys = map[string]bool{
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"email": true,
|
||||
"authenticated": true,
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"user_identifier": true,
|
||||
"authenticated": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
@@ -1134,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
|
||||
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
|
||||
|
||||
// Populate legacy session values from combined payload
|
||||
sessionData.mainSession.Values["email"] = payload.E
|
||||
sessionData.mainSession.Values["user_identifier"] = payload.Ui
|
||||
sessionData.mainSession.Values["authenticated"] = payload.Au
|
||||
sessionData.mainSession.Values["csrf"] = payload.Cs
|
||||
sessionData.mainSession.Values["nonce"] = payload.N
|
||||
@@ -1278,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
|
||||
A: sd.getAccessTokenUnsafe(),
|
||||
R: sd.getRefreshTokenUnsafe(),
|
||||
I: sd.getIDTokenUnsafe(),
|
||||
E: sd.getEmailUnsafe(),
|
||||
Ui: sd.getUserIdentifierUnsafe(),
|
||||
Au: sd.getAuthenticatedUnsafe(),
|
||||
Cs: sd.getCSRFUnsafe(),
|
||||
N: sd.getNonceUnsafe(),
|
||||
@@ -2469,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address.
|
||||
// The email is extracted from ID token claims and used for
|
||||
// authorization decisions and header injection.
|
||||
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
|
||||
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
|
||||
// upn, preferred_username, etc.). The value is used for authorization
|
||||
// decisions and header injection.
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
// - The user identifier string, or an empty string if not set.
|
||||
func (sd *SessionData) GetUserIdentifier() string {
|
||||
sd.sessionMutex.RLock()
|
||||
defer sd.sessionMutex.RUnlock()
|
||||
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// SetEmail stores the authenticated user's email address.
|
||||
// The email is typically extracted from the 'email' claim in the ID token.
|
||||
// SetUserIdentifier stores the authenticated user's identifier value.
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
|
||||
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
currentVal, _ := sd.mainSession.Values["email"].(string)
|
||||
if currentVal != email {
|
||||
sd.mainSession.Values["email"] = email
|
||||
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
if currentVal != userIdentifier {
|
||||
sd.mainSession.Values["user_identifier"] = userIdentifier
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
@@ -2626,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
|
||||
return result.Token
|
||||
}
|
||||
|
||||
// getEmailUnsafe retrieves the email without acquiring locks.
|
||||
func (sd *SessionData) getEmailUnsafe() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
|
||||
func (sd *SessionData) getUserIdentifierUnsafe() string {
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
|
||||
|
||||
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
|
||||
s.False(session.IsDirty())
|
||||
}
|
||||
|
||||
// TestSessionData_SetEmail tests email setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
|
||||
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
session, err := s.sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Set email
|
||||
session.SetEmail("test@example.com")
|
||||
s.Equal("test@example.com", session.GetEmail())
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.Equal("test@example.com", session.GetUserIdentifier())
|
||||
s.True(session.IsDirty())
|
||||
}
|
||||
|
||||
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
|
||||
// Set some data
|
||||
err = session.SetAuthenticated(true)
|
||||
s.Require().NoError(err)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetCSRF("csrf-token")
|
||||
|
||||
// Clear session
|
||||
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Modify session
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.True(session.IsDirty())
|
||||
|
||||
// Save session
|
||||
|
||||
+6
-6
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Set up initial session state (what user has when first logging in)
|
||||
session1.SetAuthenticated(true)
|
||||
session1.SetEmail(originalUserData["email"].(string))
|
||||
session1.SetUserIdentifier(originalUserData["email"].(string))
|
||||
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
|
||||
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
|
||||
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
|
||||
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
// Simulate what happens when middleware detects expired tokens
|
||||
// It should preserve session state while attempting token refresh
|
||||
originalAuth := session2.GetAuthenticated()
|
||||
originalEmail := session2.GetEmail()
|
||||
originalEmail := session2.GetUserIdentifier()
|
||||
|
||||
// Reconstruct user data from individual stored keys
|
||||
originalUserDataStored := make(map[string]interface{})
|
||||
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Verify all session data is still intact after token refresh
|
||||
postRefreshAuth := session2.GetAuthenticated()
|
||||
postRefreshEmail := session2.GetEmail()
|
||||
postRefreshEmail := session2.GetUserIdentifier()
|
||||
userDataPresent := true
|
||||
for k := range originalUserData {
|
||||
if session2.mainSession.Values["user_data_"+k] == nil {
|
||||
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with specific creation time
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
|
||||
|
||||
// Create tokens with specific expiry
|
||||
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with data that should be preserved or removed
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("cleanup@example.com")
|
||||
session.SetUserIdentifier("cleanup@example.com")
|
||||
|
||||
session.mainSession.Values["user_data"] = "Test User|user-123"
|
||||
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
|
||||
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
if scenario.shouldCleanup {
|
||||
if sessionTooOld {
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
for key := range session.mainSession.Values {
|
||||
|
||||
+56
-2
@@ -93,6 +93,38 @@ type Config struct {
|
||||
// providers. Enabling this in production is a security hole — prefer
|
||||
// CACertPath/CACertPEM. Emits a loud warning at startup.
|
||||
InsecureSkipVerify bool `json:"insecureSkipVerify,omitempty"`
|
||||
|
||||
// ClientAuthMethod selects the OAuth 2.0 client authentication method used
|
||||
// at the token / revocation / introspection endpoints. Supported values:
|
||||
//
|
||||
// - "client_secret_post" (default, current behavior): clientSecret is
|
||||
// sent in the request body alongside client_id.
|
||||
// - "private_key_jwt" (RFC 7523 §2.2): the plugin signs a short-lived JWT
|
||||
// assertion with a configured private key and sends it as
|
||||
// client_assertion. Use this when your IdP enforces short-lived secrets
|
||||
// or mandates secretless client auth (Entra ID, Okta, Auth0, Keycloak).
|
||||
//
|
||||
// When set to "private_key_jwt", clientSecret may be left empty and one of
|
||||
// clientAssertionPrivateKey / clientAssertionKeyPath must be configured.
|
||||
ClientAuthMethod string `json:"clientAuthMethod,omitempty"`
|
||||
|
||||
// ClientAssertionPrivateKey is an inline PEM-encoded private key used to
|
||||
// sign client_assertion JWTs. Mutually exclusive with
|
||||
// ClientAssertionKeyPath. Supports PKCS#8, PKCS#1 (RSA), and SEC1 (EC).
|
||||
ClientAssertionPrivateKey string `json:"clientAssertionPrivateKey,omitempty"`
|
||||
|
||||
// ClientAssertionKeyPath is a filesystem path to a PEM-encoded private key,
|
||||
// equivalent to ClientAssertionPrivateKey but loaded from disk.
|
||||
ClientAssertionKeyPath string `json:"clientAssertionKeyPath,omitempty"`
|
||||
|
||||
// ClientAssertionKeyID is the JWK key id (kid) advertised in the JWS
|
||||
// header. Required when using private_key_jwt so the IdP can locate the
|
||||
// matching public key registered for the client.
|
||||
ClientAssertionKeyID string `json:"clientAssertionKeyID,omitempty"`
|
||||
|
||||
// ClientAssertionAlg is the JWS signing algorithm. Defaults to RS256.
|
||||
// Supported: RS256/384/512, PS256/384/512, ES256/384/512.
|
||||
ClientAssertionAlg string `json:"clientAssertionAlg,omitempty"`
|
||||
}
|
||||
|
||||
// loadCACertPool assembles an x509.CertPool from CACertPath and CACertPEM.
|
||||
@@ -323,8 +355,30 @@ func (c *Config) Validate() error {
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("clientID is required")
|
||||
}
|
||||
if c.ClientSecret == "" {
|
||||
return fmt.Errorf("clientSecret is required")
|
||||
authMethod := c.ClientAuthMethod
|
||||
if authMethod == "" {
|
||||
authMethod = "client_secret_post"
|
||||
}
|
||||
switch authMethod {
|
||||
case "client_secret_post", "client_secret_basic":
|
||||
if c.ClientSecret == "" {
|
||||
return fmt.Errorf("clientSecret is required when clientAuthMethod is %q", authMethod)
|
||||
}
|
||||
case "private_key_jwt":
|
||||
if c.ClientAssertionPrivateKey == "" && c.ClientAssertionKeyPath == "" {
|
||||
return fmt.Errorf("clientAssertionPrivateKey or clientAssertionKeyPath is required when clientAuthMethod is private_key_jwt")
|
||||
}
|
||||
if c.ClientAssertionPrivateKey != "" && c.ClientAssertionKeyPath != "" {
|
||||
return fmt.Errorf("only one of clientAssertionPrivateKey or clientAssertionKeyPath may be set")
|
||||
}
|
||||
if c.ClientAssertionKeyID == "" {
|
||||
return fmt.Errorf("clientAssertionKeyID is required when clientAuthMethod is private_key_jwt")
|
||||
}
|
||||
if c.ClientAssertionAlg != "" && !isSupportedClientAssertionAlg(c.ClientAssertionAlg) {
|
||||
return fmt.Errorf("clientAssertionAlg %q is not supported (use RS256/384/512, PS256/384/512, or ES256/384/512)", c.ClientAssertionAlg)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("clientAuthMethod %q is not supported", authMethod)
|
||||
}
|
||||
|
||||
// Validate session encryption key
|
||||
|
||||
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(tf.fixtures.UserEmail)
|
||||
session.SetUserIdentifier(tf.fixtures.UserEmail)
|
||||
session.SetAccessToken(tf.fixtures.AccessToken)
|
||||
session.SetRefreshToken(tf.fixtures.RefreshToken)
|
||||
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
|
||||
|
||||
+97
-10
@@ -341,7 +341,17 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
|
||||
if err := verifySignatureWithKey(token, pubKey, alg); err != nil {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
|
||||
// Microsoft Graph access tokens carry a `nonce` JWT header and are
|
||||
// signed in a proprietary form Microsoft documents as unverifiable
|
||||
// by client applications. They reach this path only when the
|
||||
// per-provider classifier (validateAzureTokens) didn't catch them,
|
||||
// so log at debug to keep the error stream actionable while still
|
||||
// surfacing the cause for diagnostics.
|
||||
if _, isMSProprietary := jwt.Header["nonce"]; isMSProprietary {
|
||||
t.safeLogDebugf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s (Microsoft proprietary nonce header — token is opaque to clients): %v", kid, alg, err)
|
||||
} else {
|
||||
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
@@ -434,7 +444,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetRefreshToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens as well to prevent any replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -476,12 +486,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
|
||||
return false
|
||||
}
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
|
||||
return false
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
|
||||
// Get token expiry information for logging
|
||||
var expiryTime time.Time
|
||||
@@ -507,7 +523,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -654,11 +670,33 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
}
|
||||
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, revocationURL)
|
||||
|
||||
// Read tokenURL with RLock — used as audience for private_key_jwt (RFC 7523 §3).
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
data := url.Values{
|
||||
"token": {token},
|
||||
"token_type_hint": {tokenType},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
}
|
||||
// client_id is sent in the body for every method except client_secret_basic,
|
||||
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
|
||||
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
|
||||
data.Set("client_id", t.clientID)
|
||||
}
|
||||
|
||||
useBasicAuth := false
|
||||
if t.clientAssertion != nil {
|
||||
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign client assertion: %w", err)
|
||||
}
|
||||
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||
data.Set("client_assertion", assertion)
|
||||
} else if t.clientAuthMethod == "client_secret_basic" {
|
||||
useBasicAuth = true
|
||||
} else {
|
||||
data.Set("client_secret", t.clientSecret)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", revocationURL, strings.NewReader(data.Encode()))
|
||||
@@ -668,6 +706,9 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if useBasicAuth {
|
||||
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
|
||||
}
|
||||
|
||||
// Send the request with circuit breaker protection if available
|
||||
var resp *http.Response
|
||||
@@ -754,6 +795,27 @@ func (t *TraefikOidc) isGoogleProvider() bool {
|
||||
return strings.Contains(issuerURL, "google") || strings.Contains(issuerURL, "accounts.google.com")
|
||||
}
|
||||
|
||||
// isUnverifiableAzureAccessToken reports whether a JWT-shaped access token
|
||||
// matches the Microsoft proprietary format that client applications must not
|
||||
// validate. Microsoft injects a `nonce` value into the JWT header, signs over
|
||||
// the SHA256 hash of that nonce, and ships the original nonce on the wire,
|
||||
// guaranteeing that any standard JWS verifier rejects the signature. This is
|
||||
// the documented mechanism that keeps access tokens opaque to non-resource
|
||||
// holders (Microsoft Graph, Azure Management API).
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
//
|
||||
// Returns true on parse failure as well — a token we cannot parse should not
|
||||
// be passed through the verification path that emits ERROR logs.
|
||||
func (t *TraefikOidc) isUnverifiableAzureAccessToken(token string) bool {
|
||||
parsed, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
_, hasProprietaryNonce := parsed.Header["nonce"]
|
||||
return hasProprietaryNonce
|
||||
}
|
||||
|
||||
// isAzureProvider detects if the configured OIDC provider is Azure AD.
|
||||
// It checks the issuer URL for Microsoft Azure AD domains.
|
||||
// Returns:
|
||||
@@ -796,6 +858,31 @@ func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, boo
|
||||
|
||||
if accessToken != "" {
|
||||
if strings.Count(accessToken, ".") == 2 {
|
||||
// Microsoft documents that client apps cannot validate access
|
||||
// tokens issued for Microsoft-owned APIs (Graph, Azure Mgmt) due
|
||||
// to their proprietary signing format (nonce in JWT header is
|
||||
// the marker — signed bytes hash the nonce, wire bytes ship the
|
||||
// raw value, so rsa verification always fails). Treat such
|
||||
// tokens as opaque, matching Microsoft's guidance and avoiding
|
||||
// per-request signature-error log spam (issue #134 followup).
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
// "you can't validate tokens for Microsoft Graph according to
|
||||
// these rules due to their proprietary format"
|
||||
if t.isUnverifiableAzureAccessToken(accessToken) {
|
||||
t.logger.Debug("Azure access token is Microsoft-proprietary (Graph/Mgmt) — treating as opaque per Microsoft guidance")
|
||||
if idToken != "" {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
t.logger.Debugf("Azure: ID token validation failed while access token was opaque: %v", err)
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
if idToken != "" {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
|
||||
@@ -119,6 +119,8 @@ type TraefikOidc struct {
|
||||
audience string
|
||||
clientID string
|
||||
clientSecret string
|
||||
clientAuthMethod string
|
||||
clientAssertion *ClientAssertionSigner
|
||||
registrationURL string
|
||||
backchannelLogoutPath string
|
||||
frontchannelLogoutPath string
|
||||
|
||||
@@ -252,6 +252,25 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
|
||||
}
|
||||
}
|
||||
|
||||
return c.setLocal(key, value, ttl)
|
||||
}
|
||||
|
||||
// SetLocal stores a value only in the in-memory LRU, bypassing any
|
||||
// distributed backend. Use for values that don't survive JSON round-tripping
|
||||
// — interfaces holding concrete crypto keys, *big.Int, or types whose
|
||||
// unexported fields yaegi exposes under an X prefix on Marshal. Each replica
|
||||
// caches independently; correctness must not depend on cross-replica
|
||||
// coherence for these keys.
|
||||
func (c *UniversalCache) SetLocal(key string, value interface{}, ttl time.Duration) error {
|
||||
if ttl == 0 {
|
||||
ttl = c.config.DefaultTTL
|
||||
}
|
||||
return c.setLocal(key, value, ttl)
|
||||
}
|
||||
|
||||
// setLocal performs the in-memory portion of a write. ttl must already be
|
||||
// resolved against DefaultTTL by the caller.
|
||||
func (c *UniversalCache) setLocal(key string, value interface{}, ttl time.Duration) error {
|
||||
size := c.estimateSize(value)
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -343,6 +362,19 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
return c.getLocal(key)
|
||||
}
|
||||
|
||||
// GetLocal retrieves a value only from the in-memory LRU, never querying the
|
||||
// distributed backend. Pair with SetLocal for values that aren't safe to
|
||||
// serialize (see SetLocal docstring).
|
||||
func (c *UniversalCache) GetLocal(key string) (interface{}, bool) {
|
||||
return c.getLocal(key)
|
||||
}
|
||||
|
||||
// getLocal returns the in-memory entry for key honoring expiry, grace
|
||||
// periods, and the RLock fast path used by token/JWK/session caches.
|
||||
func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
|
||||
// Fast read path for caches whose eviction is dominated by TTL rather than
|
||||
// access-recency (token, JWK, session). Holding only an RLock here lets all
|
||||
// concurrent readers verify cached tokens in parallel — under yaegi the
|
||||
|
||||
@@ -210,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
RedisPrefix: redisConfig.KeyPrefix,
|
||||
PoolSize: redisConfig.PoolSize,
|
||||
EnableMetrics: true,
|
||||
EnableTLS: redisConfig.EnableTLS,
|
||||
TLSSkipVerify: redisConfig.TLSSkipVerify,
|
||||
}
|
||||
|
||||
// Use concrete type to avoid Yaegi reflection issues with interface assignment
|
||||
|
||||
Reference in New Issue
Block a user