diff --git a/jwk.go b/jwk.go index 1cba7f5..b376401 100644 --- a/jwk.go +++ b/jwk.go @@ -53,10 +53,26 @@ type JWKSet struct { Keys []JWK `json:"keys"` } -// JWKCache provides thread-safe caching of JWKS using UniversalCache +// JWKCache provides thread-safe caching of JWKS using UniversalCache. +// +// inflightFetches deduplicates concurrent fetches for the same JWKS URL. +// It replaces a global sync.RWMutex that was previously held for the entire +// HTTP round-trip in GetJWKS: on a cold cache (cold pod, JWK rotation, brief +// network blip) every concurrent request piled up on that single Lock(), and +// under Yaegi each Lock acquisition costs 10-50ms of interpreter-dispatch +// overhead. The singleflight pattern keeps the cold-cache cost O(1) HTTP +// fetch regardless of how many requests are waiting. type JWKCache struct { - cache *UniversalCache - mutex sync.RWMutex + cache *UniversalCache + inflightFetches sync.Map // map[jwksURL string]*jwksFetch +} + +// jwksFetch represents an in-flight JWKS fetch. Done is closed when the fetch +// completes; jwks and err carry the result (one of them is set, never both). +type jwksFetch struct { + done chan struct{} + jwks *JWKSet + err error } // JWKCacheInterface defines the contract for JWK caching implementations. @@ -83,36 +99,58 @@ func NewJWKCache() *JWKCache { // request refetches from the upstream. JWK rotation is rare and a per-replica // HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing. func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { - // Check cache first + // Fast path: cache hit. if cachedValue, found := c.cache.GetLocal(jwksURL); found { if jwks, ok := cachedValue.(*JWKSet); ok { return jwks, nil } } - c.mutex.Lock() - defer c.mutex.Unlock() + // Singleflight: dedupe concurrent fetches per URL key. The first arrival + // performs the HTTP fetch; any later arrival for the same URL waits on + // its done channel and shares the result. No global lock is held during + // the fetch. + candidate := &jwksFetch{done: make(chan struct{})} + if existing, loaded := c.inflightFetches.LoadOrStore(jwksURL, candidate); loaded { + f, _ := existing.(*jwksFetch) + select { + case <-f.done: + return f.jwks, f.err + case <-ctx.Done(): + return nil, ctx.Err() + } + } - // Double-check after acquiring lock + // We're the leader. Make absolutely sure the result fields and the + // in-flight map entry are cleaned up before any waiter unblocks. + defer func() { + c.inflightFetches.Delete(jwksURL) + close(candidate.done) + }() + + // Re-check the cache in case a concurrent fetch completed between our + // initial miss and our LoadOrStore win. if cachedValue, found := c.cache.GetLocal(jwksURL); found { if jwks, ok := cachedValue.(*JWKSet); ok { + candidate.jwks = jwks return jwks, nil } } - // Fetch from URL jwks, err := fetchJWKS(ctx, jwksURL, httpClient) if err != nil { + candidate.err = err return nil, err } - if len(jwks.Keys) == 0 { - return nil, fmt.Errorf("JWKS response contains no keys") + candidate.err = fmt.Errorf("JWKS response contains no keys") + return nil, candidate.err } - // Cache for 1 hour + // Cache for 1 hour. _ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical + candidate.jwks = jwks return jwks, nil }