package traefikoidc import ( "bytes" "context" "crypto" "crypto/ecdsa" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" "fmt" "math/big" "strings" "sync" "time" "github.com/lukaszraczylo/traefikoidc/internal/pool" ) // Replay attack protection cache using sharded design for reduced lock contention. // This cache tracks JWT IDs (jti claims) to prevent token reuse attacks. // Under high load (500+ req/sec), the sharded design reduces contention significantly. var ( // replayCacheMu protects access to the replay cache instance (only used for initialization) replayCacheMu sync.RWMutex // replayCache stores JWT IDs with expiration to prevent replay attacks (legacy interface) replayCache CacheInterface // shardedReplayCache is the new high-performance sharded cache for replay detection shardedReplayCache *ShardedCache // replayCacheOnce ensures the replay cache is initialized only once replayCacheOnce sync.Once // replayCacheCleanupWG waits for cleanup goroutine to finish replayCacheCleanupWG sync.WaitGroup // replayCacheCancel cancels the cleanup context replayCacheCancel context.CancelFunc // replayCacheCleanupMu protects cleanup operations replayCacheCleanupMu sync.Mutex ) // initReplayCache initializes the JWT replay protection cache with bounded size. // Uses a sharded cache design with 64 shards for reduced lock contention under high load. // The cache is bounded to 10,000 entries to prevent unbounded memory growth. // This function uses sync.Once to ensure thread-safe single initialization. func initReplayCache() { replayCacheOnce.Do(func() { // Hold mutex during initialization to synchronize with cleanup goroutine replayCacheMu.Lock() defer replayCacheMu.Unlock() // Create sharded cache with 64 shards for reduced contention // Under 500 req/sec, this reduces lock contention by ~64x compared to single mutex shardedReplayCache = NewShardedCache(64, 10000) // Also initialize legacy cache for backward compatibility replayCache = NewCache() replayCache.SetMaxSize(10000) }) } // cleanupReplayCache performs graceful shutdown of the replay cache system. // It cancels the cleanup context, waits for background goroutines to finish, // and properly closes the cache to ensure proper cleanup during shutdown. func cleanupReplayCache() { replayCacheCleanupMu.Lock() shouldWait := replayCacheCancel != nil if replayCacheCancel != nil { replayCacheCancel() replayCacheCancel = nil } replayCacheCleanupMu.Unlock() // Only wait if there was a cleanup routine running if shouldWait { replayCacheCleanupWG.Wait() } replayCacheMu.Lock() defer replayCacheMu.Unlock() // Clear sharded cache if shardedReplayCache != nil { shardedReplayCache.Clear() shardedReplayCache = nil } // Clear legacy cache if replayCache != nil { replayCache.Close() replayCache = nil } replayCacheOnce = sync.Once{} } // getReplayCacheStats returns statistics about the replay cache state. // Returns: // - size: Current number of entries in the cache // - maxSize: Maximum allowed entries (10,000) func getReplayCacheStats() (size int, maxSize int) { // Use sharded cache if available (no mutex needed due to internal sharding) if shardedReplayCache != nil { return shardedReplayCache.Size(), 10000 } // Fall back to legacy cache replayCacheMu.RLock() defer replayCacheMu.RUnlock() if replayCache == nil { return 0, 10000 } return 0, 10000 } // startReplayCacheCleanup starts a background goroutine for periodic cache maintenance. // The goroutine runs every 5 minutes to clean expired entries and log cache statistics. // Uses the global task registry with circuit breaker pattern to prevent duplicate tasks. // Parameters: // - ctx: Parent context for cancellation // - logger: Logger for debug output (can be nil) func startReplayCacheCleanup(ctx context.Context, logger *Logger) { registry := GetGlobalTaskRegistry() // Define the cleanup task function cleanupFunc := func() { // Use mutex to safely access cache pointers - this prevents race with initReplayCache replayCacheMu.RLock() shardedCache := shardedReplayCache legacyCache := replayCache replayCacheMu.RUnlock() // Only proceed if caches have been initialized if shardedCache == nil && legacyCache == nil { return } size, maxSize := getReplayCacheStats() if logger != nil { logger.Debugf("Replay cache stats: size=%d, maxSize=%d", size, maxSize) } // Clean up sharded cache if shardedCache != nil { shardedCache.Cleanup() } // Also clean up legacy cache for backward compatibility if legacyCache != nil { legacyCache.Cleanup() } } // Create or get singleton cleanup task task, err := registry.CreateSingletonTask( "replay-cache-cleanup", 5*time.Minute, cleanupFunc, logger, &replayCacheCleanupWG, ) if err != nil { if logger != nil { logger.Debugf("Replay cache cleanup task already exists or circuit breaker limit reached: %v (this is expected with multiple instances)", err) } return } // Start the task task.Start() if logger != nil { logger.Debug("Started replay cache cleanup task with circuit breaker protection") } } // ClockSkewToleranceFuture defines the maximum allowable clock skew for future time validation. // Tokens are considered valid for an additional 2 minutes past their expiration time. var ClockSkewToleranceFuture = 2 * time.Minute // ClockSkewTolerancePast defines the maximum allowable clock skew for past time validation. // Tokens are considered valid if issued up to 10 seconds in the future. var ClockSkewTolerancePast = 10 * time.Second // ClockSkewTolerance is an alias for ClockSkewToleranceFuture for backward compatibility. var ClockSkewTolerance = ClockSkewToleranceFuture // JWT represents a parsed JSON Web Token with its constituent parts. // It provides a structured representation of JWT components // for validation and processing within the OIDC middleware. type JWT struct { // Header contains the JWT header claims (alg, typ, kid, etc.) Header map[string]interface{} // Claims contains the JWT payload claims (iss, sub, aud, exp, etc.) Claims map[string]interface{} // Token is the original JWT token string Token string // Signature contains the decoded JWT signature bytes Signature []byte } // parseJWT parses a JWT token string into its constituent parts. // It decodes the base64url-encoded header, claims, and signature components // and unmarshals the JSON data into structured maps. Uses memory pools // for efficient memory allocation during parsing. // Parameters: // - tokenString: The JWT token string to parse // // Returns: // - *JWT: Parsed JWT structure with header, claims, and signature // - An error if the token format is invalid or decoding/unmarshaling fails func parseJWT(tokenString string) (*JWT, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) } pm := pool.Get() jwtBuf := pm.GetJWTBuffer() defer pm.PutJWTBuffer(jwtBuf) jwt := &JWT{ Token: tokenString, } headerLen := base64.RawURLEncoding.DecodedLen(len(parts[0])) if headerLen > cap(jwtBuf.Header) { jwtBuf.Header = make([]byte, headerLen) } else { jwtBuf.Header = jwtBuf.Header[:headerLen] } n, err := base64.RawURLEncoding.Decode(jwtBuf.Header, []byte(parts[0])) if err != nil { return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err) } headerBytes := jwtBuf.Header[:n] decoder := pm.GetJSONDecoder(bytes.NewReader(headerBytes)) defer pm.PutJSONDecoder(decoder) if err := decoder.Decode(&jwt.Header); err != nil { return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err) } if jwt.Header == nil { return nil, fmt.Errorf("invalid JWT format: header is nil after unmarshaling") } claimsLen := base64.RawURLEncoding.DecodedLen(len(parts[1])) if claimsLen > cap(jwtBuf.Payload) { jwtBuf.Payload = make([]byte, claimsLen) } else { jwtBuf.Payload = jwtBuf.Payload[:claimsLen] } n, err = base64.RawURLEncoding.Decode(jwtBuf.Payload, []byte(parts[1])) if err != nil { return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err) } claimsBytes := jwtBuf.Payload[:n] decoder2 := pm.GetJSONDecoder(bytes.NewReader(claimsBytes)) defer pm.PutJSONDecoder(decoder2) if err := decoder2.Decode(&jwt.Claims); err != nil { return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err) } if jwt.Claims == nil { return nil, fmt.Errorf("invalid JWT format: claims is nil after unmarshaling") } sigLen := base64.RawURLEncoding.DecodedLen(len(parts[2])) if sigLen > cap(jwtBuf.Signature) { jwtBuf.Signature = make([]byte, sigLen) } else { jwtBuf.Signature = jwtBuf.Signature[:sigLen] } n, err = base64.RawURLEncoding.Decode(jwtBuf.Signature, []byte(parts[2])) if err != nil { return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err) } // Reuse the signature buffer if it's large enough, otherwise allocate if cap(jwtBuf.Signature) >= n { jwt.Signature = jwtBuf.Signature[:n:n] // Use slice trick to prevent aliasing } else { jwt.Signature = make([]byte, n) copy(jwt.Signature, jwtBuf.Signature[:n]) } return jwt, nil } // Verify performs comprehensive JWT token validation according to OIDC specifications. // It validates the token signature algorithm, issuer, audience, expiration, issued-at time, // not-before time (if present), and prevents replay attacks using JTI claims. // Parameters: // - issuerURL: Expected issuer URL to validate against // - expectedAudience: Expected audience to validate against (can be clientID or custom audience) // - skipReplayCheck: Optional parameter to skip replay attack protection // // Returns: // - An error describing the first validation failure encountered func (j *JWT) Verify(issuerURL, expectedAudience string, skipReplayCheck ...bool) error { alg, ok := j.Header["alg"].(string) if !ok { return fmt.Errorf("missing 'alg' header") } supportedAlgs := map[string]bool{ "RS256": true, "RS384": true, "RS512": true, "PS256": true, "PS384": true, "PS512": true, "ES256": true, "ES384": true, "ES512": true, } if !supportedAlgs[alg] { return fmt.Errorf("unsupported algorithm: %s", alg) } claims := j.Claims iss, ok := claims["iss"].(string) if !ok { return fmt.Errorf("missing 'iss' claim") } if err := verifyIssuer(iss, issuerURL); err != nil { return err } aud, ok := claims["aud"] if !ok { return fmt.Errorf("missing 'aud' claim") } if err := verifyAudience(aud, expectedAudience); err != nil { return err } exp, ok := claims["exp"].(float64) if !ok { return fmt.Errorf("missing or invalid 'exp' claim") } if err := verifyExpiration(exp); err != nil { return err } iat, ok := claims["iat"].(float64) if !ok { return fmt.Errorf("missing or invalid 'iat' claim") } if err := verifyIssuedAt(iat); err != nil { return err } if nbf, ok := claims["nbf"].(float64); ok { if err := verifyNotBefore(nbf); err != nil { return err } } shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0] jtiValue, jtiOk := claims["jti"].(string) if jtiOk && !shouldSkipReplay && jtiValue != "" { initReplayCache() // Use sharded cache for replay detection - no global mutex needed // This reduces lock contention by ~64x under high load if shardedReplayCache != nil { if shardedReplayCache.Exists(jtiValue) { return fmt.Errorf("token replay detected (jti: %s)", jtiValue) } expFloat, ok := claims["exp"].(float64) var expTime time.Time if ok { expTime = time.Unix(int64(expFloat), 0) } else { expTime = time.Now().Add(10 * time.Minute) } duration := time.Until(expTime) if duration > 0 { shardedReplayCache.Set(jtiValue, true, duration) } } else { // Fall back to legacy cache with mutex (should rarely happen) replayCacheMu.RLock() _, exists := replayCache.Get(jtiValue) replayCacheMu.RUnlock() if exists { return fmt.Errorf("token replay detected (jti: %s)", jtiValue) } expFloat, ok := claims["exp"].(float64) var expTime time.Time if ok { expTime = time.Unix(int64(expFloat), 0) } else { expTime = time.Now().Add(10 * time.Minute) } duration := time.Until(expTime) if duration > 0 { replayCacheMu.Lock() if replayCache != nil { replayCache.Set(jtiValue, true, duration) } replayCacheMu.Unlock() } } } sub, ok := claims["sub"].(string) if !ok || sub == "" { return fmt.Errorf("missing or empty 'sub' claim") } return nil } // verifyAudience validates the JWT audience claim against the expected client ID. // The audience claim can be either a single string or an array of strings. // Parameters: // - tokenAudience: The audience claim from the JWT (string or []interface{}) // - expectedAudience: The expected audience value (typically the OAuth client ID) // // Returns: // - An error if the claim type is invalid or the expected audience is not present func verifyAudience(tokenAudience interface{}, expectedAudience string) error { switch aud := tokenAudience.(type) { case string: if aud != expectedAudience { return fmt.Errorf("invalid audience") } case []interface{}: found := false for _, v := range aud { if str, ok := v.(string); ok && str == expectedAudience { found = true break } } if !found { return fmt.Errorf("invalid audience") } default: return fmt.Errorf("invalid 'aud' claim type") } return nil } // verifyIssuer validates the JWT issuer claim against the expected issuer URL. // Parameters: // - tokenIssuer: The issuer claim from the JWT // - expectedIssuer: The expected issuer URL from OIDC configuration // // Returns: // - An error if the issuers do not match func verifyIssuer(tokenIssuer, expectedIssuer string) error { if tokenIssuer != expectedIssuer { return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer) } return nil } // verifyTimeConstraint validates time-based JWT claims with clock skew tolerance. // It handles both future constraints (exp) and past constraints (iat, nbf). // Parameters: // - unixTime: The Unix timestamp from the JWT claim // - claimName: Name of the claim being validated (for error messages) // - future: If true, validates against future tolerance; if false, against past tolerance // // Returns: // - An error describing the failure (e.g., "token has expired", "token used before issued") func verifyTimeConstraint(unixTime float64, claimName string, future bool) error { claimTime := time.Unix(int64(unixTime), 0) now := time.Now() var err error if future { allowedExpiry := claimTime.Add(ClockSkewToleranceFuture) if now.After(allowedExpiry) { err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC()) } } else { allowedStart := claimTime.Add(-ClockSkewTolerancePast) if now.Before(allowedStart) { reason := "not yet valid" if claimName == "iat" { reason = "used before issued" } err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC()) } } return err } // verifyExpiration validates the JWT expiration time (exp claim) with clock skew tolerance. // It calls verifyTimeConstraint with future=true. func verifyExpiration(expiration float64) error { return verifyTimeConstraint(expiration, "exp", true) } // verifyIssuedAt validates the JWT issued-at time (iat claim) with clock skew tolerance. // It calls verifyTimeConstraint with future=false. func verifyIssuedAt(issuedAt float64) error { return verifyTimeConstraint(issuedAt, "iat", false) } // verifyNotBefore validates the JWT not-before time (nbf claim) with clock skew tolerance. // It calls verifyTimeConstraint with future=false. func verifyNotBefore(notBefore float64) error { return verifyTimeConstraint(notBefore, "nbf", false) } // verifySignature verifies the JWT signature using the provided public key. // Supports RSA (RS256/384/512, PS256/384/512) and ECDSA (ES256/384/512) algorithms. // Parameters: // - tokenString: The complete JWT token string // - publicKeyPEM: The public key in PEM format // - alg: The signing algorithm specified in the JWT header // // Returns: // - An error if the key parsing fails, the algorithm is unsupported, // or the signature verification fails func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { parts := strings.Split(tokenString, ".") if len(parts) != 3 { return fmt.Errorf("invalid token format") } signedContent := parts[0] + "." + parts[1] signature, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return fmt.Errorf("failed to decode signature: %w", err) } block, _ := pem.Decode(publicKeyPEM) if block == nil { return fmt.Errorf("failed to parse PEM block containing the public key") } pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return fmt.Errorf("failed to parse public key: %w", err) } var hashFunc crypto.Hash switch alg { case "RS256", "PS256", "ES256": hashFunc = crypto.SHA256 case "RS384", "PS384", "ES384": hashFunc = crypto.SHA384 case "RS512", "PS512", "ES512": hashFunc = crypto.SHA512 default: return fmt.Errorf("unsupported algorithm: %s", alg) } h := hashFunc.New() h.Write([]byte(signedContent)) hashed := h.Sum(nil) switch pubKey := pubKey.(type) { case *rsa.PublicKey: if strings.HasPrefix(alg, "RS") { return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature) } else if strings.HasPrefix(alg, "PS") { return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil) } else { return fmt.Errorf("unexpected key type for algorithm %s", alg) } case *ecdsa.PublicKey: if strings.HasPrefix(alg, "ES") { var r, s big.Int sigLen := len(signature) if sigLen%2 != 0 { return fmt.Errorf("invalid ECDSA signature length") } r.SetBytes(signature[:sigLen/2]) s.SetBytes(signature[sigLen/2:]) if ecdsa.Verify(pubKey, hashed, &r, &s) { return nil } else { return fmt.Errorf("invalid ECDSA signature") } } else { return fmt.Errorf("unexpected key type for algorithm %s", alg) } default: return fmt.Errorf("unsupported public key type: %T", pubKey) } }