Files
claude-mnemonic/internal/worker/ratelimit.go
T

228 lines
6.3 KiB
Go

// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"net/http"
"sync"
"time"
)
// RateLimiter implements a token bucket rate limiter.
type RateLimiter struct {
rate float64 // tokens per second
burst int // maximum burst size
mu sync.Mutex // protects following fields
tokens float64 // current tokens
lastUpdate time.Time // last token update time
requests int64 // total requests
rejected int64 // rejected requests
}
// LastUpdateTime returns the last update time.
// Thread-safe - acquires the limiter's lock.
func (rl *RateLimiter) LastUpdateTime() time.Time {
rl.mu.Lock()
defer rl.mu.Unlock()
return rl.lastUpdate
}
// lastUpdateTimeUnlocked returns the last update time without locking.
// Caller must hold rl.mu.
func (rl *RateLimiter) lastUpdateTimeUnlocked() time.Time {
return rl.lastUpdate
}
// NewRateLimiter creates a new rate limiter.
// rate is the number of requests per second to allow.
// burst is the maximum burst of requests to allow.
func NewRateLimiter(rate float64, burst int) *RateLimiter {
return &RateLimiter{
rate: rate,
burst: burst,
tokens: float64(burst),
lastUpdate: time.Now(),
}
}
// Allow checks if a request should be allowed.
// Returns true if the request is allowed, false if rate limited.
func (rl *RateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.requests++
// Calculate tokens added since last update
now := time.Now()
elapsed := now.Sub(rl.lastUpdate).Seconds()
rl.tokens += elapsed * rl.rate
if rl.tokens > float64(rl.burst) {
rl.tokens = float64(rl.burst)
}
rl.lastUpdate = now
// Check if we have a token available
if rl.tokens >= 1 {
rl.tokens--
return true
}
rl.rejected++
return false
}
// Stats returns rate limiter statistics.
func (rl *RateLimiter) Stats() map[string]any {
rl.mu.Lock()
defer rl.mu.Unlock()
return map[string]any{
"rate": rl.rate,
"burst": rl.burst,
"current_tokens": rl.tokens,
"total_requests": rl.requests,
"rejected": rl.rejected,
"rejection_rate": float64(rl.rejected) / max(float64(rl.requests), 1),
}
}
// RateLimitMiddleware creates middleware that applies rate limiting.
// Uses a shared rate limiter for all requests.
func RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// PerClientRateLimiter implements per-client rate limiting.
type PerClientRateLimiter struct {
rate float64
burst int
clients map[string]*RateLimiter
mu sync.Mutex
// Cleanup settings
cleanupInterval time.Duration
maxIdleTime time.Duration
lastCleanup time.Time
}
// NewPerClientRateLimiter creates a new per-client rate limiter.
func NewPerClientRateLimiter(rate float64, burst int) *PerClientRateLimiter {
return &PerClientRateLimiter{
rate: rate,
burst: burst,
clients: make(map[string]*RateLimiter),
cleanupInterval: 5 * time.Minute,
maxIdleTime: 10 * time.Minute,
lastCleanup: time.Now(),
}
}
// getLimiter returns a rate limiter for the given client key.
func (pcrl *PerClientRateLimiter) getLimiter(key string) *RateLimiter {
pcrl.mu.Lock()
defer pcrl.mu.Unlock()
// Periodic cleanup of idle clients
if time.Since(pcrl.lastCleanup) > pcrl.cleanupInterval {
pcrl.cleanupLocked()
}
limiter, exists := pcrl.clients[key]
if !exists {
limiter = NewRateLimiter(pcrl.rate, pcrl.burst)
pcrl.clients[key] = limiter
}
return limiter
}
// cleanupLocked removes idle limiters. Must be called with lock held.
// Uses consistent lock ordering: always acquire limiter.mu while holding pcrl.mu.
// This is safe because the limiter.mu critical section is brief (just reading lastUpdate).
func (pcrl *PerClientRateLimiter) cleanupLocked() {
now := time.Now()
keysToDelete := make([]string, 0)
// Check each limiter while holding pcrl.mu
// We briefly acquire limiter.mu but the critical section is minimal
for key, limiter := range pcrl.clients {
limiter.mu.Lock()
lastUpdate := limiter.lastUpdateTimeUnlocked()
limiter.mu.Unlock()
if now.Sub(lastUpdate) > pcrl.maxIdleTime {
keysToDelete = append(keysToDelete, key)
}
}
// Delete collected keys
for _, key := range keysToDelete {
delete(pcrl.clients, key)
}
pcrl.lastCleanup = now
}
// Allow checks if a request from the given client should be allowed.
func (pcrl *PerClientRateLimiter) Allow(clientKey string) bool {
return pcrl.getLimiter(clientKey).Allow()
}
// Stats returns aggregate statistics.
// Uses two-phase approach to avoid nested lock acquisition.
func (pcrl *PerClientRateLimiter) Stats() map[string]any {
// Phase 1: Collect limiters under pcrl.mu
pcrl.mu.Lock()
rate := pcrl.rate
burst := pcrl.burst
activeClients := len(pcrl.clients)
limiters := make([]*RateLimiter, 0, activeClients)
for _, limiter := range pcrl.clients {
limiters = append(limiters, limiter)
}
pcrl.mu.Unlock()
// Phase 2: Collect stats from each limiter (only acquiring limiter.mu, not pcrl.mu)
var totalRequests, totalRejected int64
for _, limiter := range limiters {
limiter.mu.Lock()
totalRequests += limiter.requests
totalRejected += limiter.rejected
limiter.mu.Unlock()
}
return map[string]any{
"rate": rate,
"burst": burst,
"active_clients": activeClients,
"total_requests": totalRequests,
"total_rejected": totalRejected,
}
}
// PerClientRateLimitMiddleware creates middleware that applies per-client rate limiting.
// Uses X-Forwarded-For or RemoteAddr to identify clients.
func PerClientRateLimitMiddleware(limiter *PerClientRateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client identifier (prefer X-Real-IP from RealIP middleware)
clientKey := r.RemoteAddr
if xff := r.Header.Get("X-Real-IP"); xff != "" {
clientKey = xff
}
if !limiter.Allow(clientKey) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}