diff --git a/enhanced_mocks_test.go b/enhanced_mocks_test.go index 874bb89..4e636f1 100644 --- a/enhanced_mocks_test.go +++ b/enhanced_mocks_test.go @@ -2,6 +2,8 @@ package traefikoidc import ( "context" + "crypto" + "fmt" "net/http" "sync" "sync/atomic" @@ -40,6 +42,31 @@ func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, http return m.JWKS, m.Err } +func (m *EnhancedMockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) { + jwks, err := m.GetJWKS(ctx, jwksURL, httpClient) + if err != nil { + return nil, err + } + if jwks == nil { + return nil, fmt.Errorf("JWKS is nil") + } + for i := range jwks.Keys { + k := &jwks.Keys[i] + if k.Kid != kid { + continue + } + switch k.Kty { + case "RSA": + return k.ToRSAPublicKey() + case "EC": + return k.ToECDSAPublicKey() + default: + return nil, fmt.Errorf("unsupported key type: %s", k.Kty) + } + } + return nil, fmt.Errorf("no matching public key found for kid: %s", kid) +} + func (m *EnhancedMockJWKCache) Cleanup() { atomic.AddInt32(&m.CleanupCalls, 1) m.mu.Lock() diff --git a/jwk.go b/jwk.go index e6c07bc..ab381a9 100644 --- a/jwk.go +++ b/jwk.go @@ -2,6 +2,7 @@ package traefikoidc import ( "context" + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" @@ -18,6 +19,18 @@ import ( "time" ) +// parsedKeysSuffix marks the parallel UniversalCache entry that stores +// pre-parsed public keys for a given JWKS URL. +const parsedKeysSuffix = ":parsed" + +// parsedJWKS holds keys decoded from a JWKSet, indexed by kid. Storing the +// already-parsed crypto.PublicKey avoids re-running the DER/PEM round trip +// on every JWT verification — a costly operation under the yaegi interpreter +// that hosts Traefik plugins. +type parsedJWKS struct { + keys map[string]crypto.PublicKey +} + // JWK represents a JSON Web Key as defined in RFC 7517. // It can represent different key types including RSA, EC, and symmetric keys. type JWK struct { @@ -49,6 +62,7 @@ type JWKCache struct { // JWKCacheInterface defines the contract for JWK caching implementations. type JWKCacheInterface interface { GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) + GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) Cleanup() Close() } @@ -96,6 +110,62 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http return jwks, nil } +// GetPublicKey returns the parsed public key for a given kid, fetching and +// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is +// stored alongside the raw JWKSet under a sibling cache key with the same +// 1-hour TTL, so both invalidate together when the upstream JWKS rotates. +func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) { + parsedKey := jwksURL + parsedKeysSuffix + if v, found := c.cache.Get(parsedKey); found { + if pj, ok := v.(*parsedJWKS); ok { + if k, ok := pj.keys[kid]; ok { + return k, nil + } + } + } + + jwks, err := c.GetJWKS(ctx, jwksURL, httpClient) + if err != nil { + return nil, err + } + + pj := buildParsedJWKS(jwks) + _ = c.cache.Set(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical + + if k, ok := pj.keys[kid]; ok { + return k, nil + } + return nil, fmt.Errorf("no matching public key found for kid: %s", kid) +} + +// buildParsedJWKS pre-parses every JWK in the set into the matching +// crypto.PublicKey, indexed by kid. Errors on individual keys are skipped so +// a single bad key does not block the rest of the keyset. +func buildParsedJWKS(jwks *JWKSet) *parsedJWKS { + out := make(map[string]crypto.PublicKey, len(jwks.Keys)) + for i := range jwks.Keys { + k := &jwks.Keys[i] + if k.Kid == "" { + continue + } + var pub crypto.PublicKey + var err error + switch k.Kty { + case "RSA": + pub, err = k.ToRSAPublicKey() + case "EC": + pub, err = k.ToECDSAPublicKey() + default: + continue + } + if err != nil { + continue + } + out[k.Kid] = pub + } + return &parsedJWKS{keys: out} +} + // Cleanup is a no-op as cleanup is handled by UniversalCache func (c *JWKCache) Cleanup() { // Handled internally by UniversalCache diff --git a/jwt.go b/jwt.go index 64a343c..9cf597f 100644 --- a/jwt.go +++ b/jwt.go @@ -528,6 +528,21 @@ func verifyNotBefore(notBefore float64) error { // - An error if the key parsing fails, the algorithm is unsupported, // or the signature verification fails func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { + block, _ := pem.Decode(publicKeyPEM) + if block == nil { + return fmt.Errorf("failed to parse PEM block containing the public key") + } + pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse public key: %w", err) + } + return verifySignatureWithKey(tokenString, pubKey, alg) +} + +// verifySignatureWithKey verifies a JWT signature using an already-parsed +// public key, skipping the PEM-encode/decode round trip that verifySignature +// performs. This is the hot path used by VerifyJWTSignatureAndClaims. +func verifySignatureWithKey(tokenString string, pubKey crypto.PublicKey, alg string) error { parts := strings.Split(tokenString, ".") if len(parts) != 3 { return fmt.Errorf("invalid token format") @@ -537,14 +552,6 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error if err != nil { return fmt.Errorf("failed to decode signature: %w", err) } - block, _ := pem.Decode(publicKeyPEM) - if block == nil { - return fmt.Errorf("failed to parse PEM block containing the public key") - } - pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return fmt.Errorf("failed to parse public key: %w", err) - } var hashFunc crypto.Hash switch alg { case "RS256", "PS256", "ES256": diff --git a/logout_test.go b/logout_test.go index 4f9b359..54597d6 100644 --- a/logout_test.go +++ b/logout_test.go @@ -2,6 +2,7 @@ package traefikoidc import ( "context" + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -639,6 +640,26 @@ func (m *mockJWKCacheForLogout) GetJWKS(ctx context.Context, jwksURL string, htt }, nil } +func (m *mockJWKCacheForLogout) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) { + jwks, err := m.GetJWKS(ctx, jwksURL, httpClient) + if err != nil { + return nil, err + } + for i := range jwks.Keys { + k := &jwks.Keys[i] + if k.Kid != kid { + continue + } + switch k.Kty { + case "RSA": + return k.ToRSAPublicKey() + case "EC": + return k.ToECDSAPublicKey() + } + } + return nil, fmt.Errorf("no matching public key found for kid: %s", kid) +} + func (m *mockJWKCacheForLogout) Clear() {} func (m *mockJWKCacheForLogout) Cleanup() {} func (m *mockJWKCacheForLogout) Close() {} @@ -755,6 +776,22 @@ func (s *staticJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient return s.jwks, nil } +func (s *staticJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) { + for i := range s.jwks.Keys { + k := &s.jwks.Keys[i] + if k.Kid != kid { + continue + } + switch k.Kty { + case "RSA": + return k.ToRSAPublicKey() + case "EC": + return k.ToECDSAPublicKey() + } + } + return nil, fmt.Errorf("no matching public key found for kid: %s", kid) +} + func (s *staticJWKCache) Clear() {} func (s *staticJWKCache) Cleanup() {} func (s *staticJWKCache) Close() {} diff --git a/main_test.go b/main_test.go index d9286c4..8abff2d 100644 --- a/main_test.go +++ b/main_test.go @@ -208,6 +208,32 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient * return m.JWKS, m.Err } +func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) { + m.mu.RLock() + defer m.mu.RUnlock() + if m.Err != nil { + return nil, m.Err + } + if m.JWKS == nil { + return nil, fmt.Errorf("JWKS is nil") + } + for i := range m.JWKS.Keys { + k := &m.JWKS.Keys[i] + if k.Kid != kid { + continue + } + switch k.Kty { + case "RSA": + return k.ToRSAPublicKey() + case "EC": + return k.ToECDSAPublicKey() + default: + return nil, fmt.Errorf("unsupported key type: %s", k.Kty) + } + } + return nil, fmt.Errorf("no matching public key found for kid: %s", kid) +} + func (m *MockJWKCache) Cleanup() { // Mock cleanup is a no-op - we don't want to destroy the mock JWKS data // Real cleanup is for expired entries, not resetting all data diff --git a/token_manager.go b/token_manager.go index 4f81875..01e208b 100644 --- a/token_manager.go +++ b/token_manager.go @@ -315,15 +315,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error jwksURL := t.jwksURL t.metadataMu.RUnlock() - jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient) - if err != nil { - return fmt.Errorf("failed to get JWKS: %w", err) - } - - if !t.suppressDiagnosticLogs && jwks != nil { - t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), jwksURL) - } - kid, ok := jwt.Header["kid"].(string) if !ok { return fmt.Errorf("missing key ID in token header") @@ -337,38 +328,12 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error t.safeLogDebugf("DIAGNOSTIC: Looking for kid=%s, alg=%s in JWKS", kid, alg) } - if jwks == nil { - return fmt.Errorf("JWKS is nil, cannot verify token") - } - - // Find the matching key in JWKS - var matchingKey *JWK - availableKids := make([]string, 0, len(jwks.Keys)) - for _, key := range jwks.Keys { - availableKids = append(availableKids, key.Kid) - if key.Kid == kid { - matchingKey = &key - break - } - } - - if matchingKey == nil { - if !t.suppressDiagnosticLogs { - t.safeLogErrorf("DIAGNOSTIC: No matching key found for kid=%s. Available kids: %v", kid, availableKids) - } - return fmt.Errorf("no matching public key found for kid: %s", kid) - } - - if !t.suppressDiagnosticLogs { - t.safeLogDebugf("DIAGNOSTIC: Found matching key for kid=%s, key type: %s", kid, matchingKey.Kty) - } - - publicKeyPEM, err := jwkToPEM(matchingKey) + pubKey, err := t.jwkCache.GetPublicKey(context.Background(), jwksURL, kid, t.httpClient) if err != nil { - return fmt.Errorf("failed to convert JWK to PEM: %w", err) + return fmt.Errorf("failed to get public key: %w", err) } - if err := verifySignature(token, publicKeyPEM, alg); err != nil { + if err := verifySignatureWithKey(token, pubKey, alg); err != nil { if !t.suppressDiagnosticLogs { t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err) } diff --git a/universal_cache.go b/universal_cache.go index 6a75a05..562280b 100644 --- a/universal_cache.go +++ b/universal_cache.go @@ -343,6 +343,31 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) { } } + // Fast read path for caches whose eviction is dominated by TTL rather than + // access-recency (token, JWK, session). Holding only an RLock here lets all + // concurrent readers verify cached tokens in parallel — under yaegi the + // previous unconditional Lock serialized every JWT verify on a single + // mutex and pinned a CPU under load. + switch c.config.Type { + case CacheTypeToken, CacheTypeJWK, CacheTypeSession: + c.mu.RLock() + item, exists := c.items[key] + if !exists { + c.mu.RUnlock() + atomic.AddInt64(&c.misses, 1) + return nil, false + } + if !time.Now().After(item.ExpiresAt) { + value := item.Value + c.mu.RUnlock() + atomic.AddInt64(&c.hits, 1) + return value, true + } + c.mu.RUnlock() + // Expired — fall through to the write-locked slow path below to + // remove the entry under exclusive access. + } + c.mu.Lock() defer c.mu.Unlock()