Make things 'betterer' across the board

This commit is contained in:
2026-01-11 00:53:44 +00:00
parent 7ab4b07cf2
commit 548b27702e
47 changed files with 12535 additions and 1784 deletions
+377 -36
View File
@@ -4,11 +4,15 @@ package sdk
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
@@ -20,8 +24,178 @@ import (
"github.com/rs/zerolog/log"
)
// CircuitBreaker implements a simple circuit breaker pattern for CLI calls.
type CircuitBreaker struct {
failures int64 // Current failure count
lastFailure int64 // Unix timestamp of last failure
threshold int64 // Number of failures before opening
resetTimeout int64 // Seconds to wait before trying again
state int32 // 0=closed, 1=open, 2=half-open
}
const (
circuitClosed int32 = 0
circuitOpen int32 = 1
circuitHalfOpen int32 = 2
)
// NewCircuitBreaker creates a new circuit breaker.
func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker {
return &CircuitBreaker{
threshold: threshold,
resetTimeout: resetTimeout,
}
}
// Allow checks if a request should be allowed through.
func (cb *CircuitBreaker) Allow() bool {
state := atomic.LoadInt32(&cb.state)
if state == circuitClosed {
return true
}
if state == circuitOpen {
// Check if reset timeout has passed
lastFail := atomic.LoadInt64(&cb.lastFailure)
if time.Now().Unix()-lastFail > cb.resetTimeout {
// Transition to half-open
atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen)
return true
}
return false
}
// Half-open: allow one request through
return true
}
// RecordSuccess records a successful call.
func (cb *CircuitBreaker) RecordSuccess() {
atomic.StoreInt64(&cb.failures, 0)
atomic.StoreInt32(&cb.state, circuitClosed)
}
// RecordFailure records a failed call.
func (cb *CircuitBreaker) RecordFailure() {
failures := atomic.AddInt64(&cb.failures, 1)
atomic.StoreInt64(&cb.lastFailure, time.Now().Unix())
if failures >= cb.threshold {
atomic.StoreInt32(&cb.state, circuitOpen)
log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled")
}
}
// State returns the current state as a string.
func (cb *CircuitBreaker) State() string {
switch atomic.LoadInt32(&cb.state) {
case circuitOpen:
return "open"
case circuitHalfOpen:
return "half-open"
default:
return "closed"
}
}
// CircuitBreakerMetrics contains metrics about the circuit breaker state.
type CircuitBreakerMetrics struct {
State string `json:"state"`
Failures int64 `json:"failures"`
Threshold int64 `json:"threshold"`
ResetTimeoutSecs int64 `json:"reset_timeout_secs"`
LastFailureUnix int64 `json:"last_failure_unix,omitempty"`
SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"`
}
// Metrics returns the current metrics of the circuit breaker.
func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics {
failures := atomic.LoadInt64(&cb.failures)
lastFail := atomic.LoadInt64(&cb.lastFailure)
state := cb.State()
metrics := CircuitBreakerMetrics{
State: state,
Failures: failures,
Threshold: cb.threshold,
ResetTimeoutSecs: cb.resetTimeout,
}
if lastFail > 0 {
metrics.LastFailureUnix = lastFail
if state == "open" {
remaining := cb.resetTimeout - (time.Now().Unix() - lastFail)
if remaining > 0 {
metrics.SecondsUntilReset = remaining
}
}
}
return metrics
}
// RequestDeduplicator tracks recent requests to prevent duplicates.
type RequestDeduplicator struct {
seen map[string]int64 // hash -> timestamp
mu sync.RWMutex
ttlSecs int64
maxSize int
}
// NewRequestDeduplicator creates a new deduplicator.
func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator {
return &RequestDeduplicator{
seen: make(map[string]int64),
ttlSecs: ttlSecs,
maxSize: maxSize,
}
}
// IsDuplicate checks if a request hash was seen recently.
func (d *RequestDeduplicator) IsDuplicate(hash string) bool {
now := time.Now().Unix()
d.mu.RLock()
ts, exists := d.seen[hash]
d.mu.RUnlock()
if exists && now-ts < d.ttlSecs {
return true
}
return false
}
// Record marks a request hash as seen.
func (d *RequestDeduplicator) Record(hash string) {
now := time.Now().Unix()
d.mu.Lock()
defer d.mu.Unlock()
// Evict old entries if at capacity
if len(d.seen) >= d.maxSize {
threshold := now - d.ttlSecs
for k, ts := range d.seen {
if ts < threshold {
delete(d.seen, k)
}
}
}
d.seen[hash] = now
}
// hashRequest creates a hash of a request for deduplication.
func hashRequest(toolName, input, output string) string {
h := sha256.New()
h.Write([]byte(toolName))
h.Write([]byte(input))
h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output
return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient
}
// BroadcastFunc is a callback for broadcasting events to SSE clients.
type BroadcastFunc func(event map[string]interface{})
type BroadcastFunc func(event map[string]any)
// SyncObservationFunc is a callback for syncing observations to vector DB.
type SyncObservationFunc func(obs *models.Observation)
@@ -29,6 +203,10 @@ type SyncObservationFunc func(obs *models.Observation)
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
type SyncSummaryFunc func(summary *models.SessionSummary)
// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations.
// This prevents unbounded goroutine spawning during high-volume observation ingestion.
const MaxVectorSyncWorkers = 8
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
type Processor struct {
claudePath string
@@ -40,6 +218,14 @@ type Processor struct {
syncSummaryFunc SyncSummaryFunc
// Semaphore to limit concurrent Claude CLI calls (prevents API overload)
sem chan struct{}
// Circuit breaker for CLI failures
circuitBreaker *CircuitBreaker
// Request deduplicator to prevent duplicate processing
deduplicator *RequestDeduplicator
// Bounded worker pool for vector sync operations
vectorSyncChan chan *models.Observation
vectorSyncWg sync.WaitGroup
vectorSyncDone chan struct{}
}
// SetBroadcastFunc sets the broadcast callback for SSE events.
@@ -58,7 +244,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
}
// broadcast sends an event via the broadcast callback if set.
func (p *Processor) broadcast(event map[string]interface{}) {
func (p *Processor) broadcast(event map[string]any) {
if p.broadcastFunc != nil {
p.broadcastFunc(event)
}
@@ -94,9 +280,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su
observationStore: observationStore,
summaryStore: summaryStore,
sem: make(chan struct{}, MaxConcurrentCLICalls),
circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s
deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel
vectorSyncDone: make(chan struct{}),
}, nil
}
// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations.
// Call this after setting the sync function via SetSyncObservationFunc.
func (p *Processor) StartVectorSyncWorkers() {
for i := 0; i < MaxVectorSyncWorkers; i++ {
p.vectorSyncWg.Add(1)
go p.vectorSyncWorker()
}
log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started")
}
// StopVectorSyncWorkers gracefully stops the worker pool.
func (p *Processor) StopVectorSyncWorkers() {
close(p.vectorSyncDone)
p.vectorSyncWg.Wait()
log.Info().Msg("Vector sync worker pool stopped")
}
// vectorSyncWorker is a worker goroutine that processes vector sync requests.
func (p *Processor) vectorSyncWorker() {
defer p.vectorSyncWg.Done()
for {
select {
case <-p.vectorSyncDone:
// Drain remaining items before exiting
for {
select {
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
default:
return
}
}
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
}
}
}
// CircuitBreakerState returns the current state of the circuit breaker.
func (p *Processor) CircuitBreakerState() string {
return p.circuitBreaker.State()
}
// CircuitBreakerMetrics returns detailed metrics about the circuit breaker.
func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics {
return p.circuitBreaker.Metrics()
}
// IsAvailable checks if the Claude CLI is available for processing.
func (p *Processor) IsAvailable() bool {
_, err := os.Stat(p.claudePath)
@@ -104,7 +346,7 @@ func (p *Processor) IsAvailable() bool {
}
// ProcessObservation processes a single tool observation and extracts insights.
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error {
// Skip certain tools that aren't worth processing
if shouldSkipTool(toolName) {
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
@@ -121,11 +363,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
return nil
}
// Check for duplicate request within TTL window
reqHash := hashRequest(toolName, inputStr, outputStr)
if p.deduplicator.IsDuplicate(reqHash) {
log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)")
return nil
}
// Check circuit breaker before making CLI call
if !p.circuitBreaker.Allow() {
log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call")
return fmt.Errorf("circuit breaker open")
}
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
// Note: Removed the "file already has observations" check
// Each tool execution can produce unique insights even for the same file
// Similarity-based deduplication will handle true duplicates
// Record this request to prevent duplicates
p.deduplicator.Record(reqHash)
// Build the prompt
exec := ToolExecution{
@@ -147,9 +401,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
// Call Claude Code CLI
response, err := p.callClaudeCLI(ctx, prompt)
if err != nil {
p.circuitBreaker.RecordFailure()
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
return err
}
p.circuitBreaker.RecordSuccess()
// Parse observations from response
observations := ParseObservations(response, sdkSessionID)
@@ -200,16 +456,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
Int("trackedFiles", len(obs.FileMtimes)).
Msg("Observation stored")
// Sync to vector DB if callback is set
if p.syncObservationFunc != nil {
// Sync to vector DB via bounded worker pool (non-blocking to reduce latency)
if p.syncObservationFunc != nil && p.vectorSyncChan != nil {
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
fullObs.ID = id
fullObs.CreatedAtEpoch = createdAtEpoch
p.syncObservationFunc(fullObs)
// Non-blocking send to worker pool - drops if channel is full
select {
case p.vectorSyncChan <- fullObs:
// Sent to worker pool
default:
// Channel full, fall back to direct sync in goroutine (bounded by channel buffer)
log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine")
go func(obsToSync *models.Observation) {
p.syncObservationFunc(obsToSync)
}(fullObs)
}
}
// Broadcast new observation event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "observation",
"action": "created",
"id": id,
@@ -311,7 +577,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
}
// Broadcast new summary event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "summary",
"action": "created",
"id": id,
@@ -321,8 +587,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
return nil
}
// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI.
// This prevents resource exhaustion from extremely large prompts.
const MaxPromptSize = 100 * 1024 // 100KB
// sanitizePrompt removes null bytes and control characters from a prompt.
// Keeps newlines, tabs, and carriage returns as they're valid in prompts.
func sanitizePrompt(s string) string {
return strings.Map(func(r rune) rune {
// Keep printable ASCII, extended Unicode, and common whitespace
if r >= 32 || r == '\n' || r == '\t' || r == '\r' {
return r
}
// Remove null bytes and other control characters
return -1
}, s)
}
// callClaudeCLI calls the Claude Code CLI with the given prompt.
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
// Validate and sanitize prompt
if len(prompt) > MaxPromptSize {
return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize)
}
prompt = sanitizePrompt(prompt)
// Build the full prompt with system instructions
fullPrompt := systemPrompt + "\n\n" + prompt
@@ -419,8 +708,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
return true
}
// Skip if output indicates an error or empty result
// Pre-compute lowercase strings once to avoid repeated allocations
lowerOutput := strings.ToLower(outputStr)
lowerInput := strings.ToLower(inputStr)
// Skip if output indicates an error or empty result
trivialOutputs := []string{
"no matches found",
"file not found",
@@ -444,13 +736,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
// Skip reading config files that rarely contain project-specific insights
boringFiles := []string{
"package-lock.json", "yarn.lock", "pnpm-lock.yaml",
"go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock",
"go.sum", "cargo.lock", "gemfile.lock", "poetry.lock",
".gitignore", ".dockerignore", ".eslintignore",
"tsconfig.json", "jsconfig.json", "vite.config",
"tailwind.config", "postcss.config",
}
for _, boring := range boringFiles {
if strings.Contains(inputStr, boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -462,14 +754,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
case "Bash":
// Skip simple status commands
// Skip simple status commands (use pre-computed lowerInput)
boringCommands := []string{
"git status", "git diff", "git log", "git branch",
"ls ", "pwd", "echo ", "cat ", "which ", "type ",
"npm list", "npm outdated", "npm audit",
}
for _, boring := range boringCommands {
if strings.Contains(strings.ToLower(inputStr), boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -479,7 +771,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
// toJSONString converts an interface to a JSON string.
func toJSONString(v interface{}) string {
func toJSONString(v any) string {
if v == nil {
return ""
}
@@ -495,36 +787,85 @@ func toJSONString(v interface{}) string {
// captureFileMtimes captures current modification times for tracked files.
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
// For large file lists (>10 files), uses parallel stat calls for better performance.
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
mtimes := make(map[string]int64)
// Combine all unique file paths
allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified))
for _, path := range filesRead {
allPaths[path] = struct{}{}
}
for _, path := range filesModified {
allPaths[path] = struct{}{}
}
// Helper to get mtime for a file path
getMtime := func(path string) (int64, bool) {
// Resolve relative paths against cwd
// For small lists, use sequential processing (goroutine overhead not worth it)
if len(allPaths) <= 10 {
return captureFileMtimesSequential(allPaths, cwd)
}
// For larger lists, parallelize with bounded concurrency
return captureFileMtimesParallel(allPaths, cwd)
}
// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists).
func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 {
mtimes := make(map[string]int64, len(paths))
for path := range paths {
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
}
info, err := os.Stat(absPath)
if err != nil {
return 0, false
}
return info.ModTime().UnixMilli(), true
}
// Capture mtimes for all read files
for _, path := range filesRead {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
if err == nil {
mtimes[path] = info.ModTime().UnixMilli()
}
}
// Capture mtimes for all modified files
for _, path := range filesModified {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
}
return mtimes
}
// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency.
func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 {
type mtimeResult struct {
path string
mtime int64
}
results := make(chan mtimeResult, len(paths))
sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls
var wg sync.WaitGroup
for path := range paths {
wg.Add(1)
go func(p string) {
defer wg.Done()
sem <- struct{}{} // Acquire
defer func() { <-sem }() // Release
absPath := p
if !filepath.IsAbs(p) && cwd != "" {
absPath = filepath.Join(cwd, p)
}
info, err := os.Stat(absPath)
if err == nil {
results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()}
}
}(path)
}
// Close results channel when all goroutines complete
go func() {
wg.Wait()
close(results)
}()
// Collect results
mtimes := make(map[string]int64, len(paths))
for res := range results {
mtimes[res.path] = res.mtime
}
return mtimes