Make things 'betterer' across the board (#23)

* Make things 'betterer' across the board

* fix: reorganize struct fields and config parameters for consistency

- [x] Reorder Config struct fields alphabetically and by related functionality
- [x] Reorganize Observation model fields with archival fields grouped together
- [x] Reorder ObservationStore fields to group related members
- [x] Reorder Store struct fields with health check caching grouped
- [x] Reorganize HealthInfo and PoolMetrics struct field order
- [x] Reorder maintenance Service struct fields logically
- [x] Reorganize MCP server handler parameter structs alphabetically
- [x] Reorder pattern detector candidate tracking fields
- [x] Reorganize search Manager struct fields by functionality
- [x] Reorder vector Client struct fields with mutex protections grouped
- [x] Reorganize handler request/response struct fields
- [x] Update handlers_test.go to expect wrapped response format
- [x] Reorder middleware TokenAuth and rate limiter fields
- [x] Reorganize Service struct fields with grouped functionality
- [x] Fix RateLimiter field ordering for clarity
- [x] Reorder CircuitBreaker metrics fields

* fix(security): improve JSON output safety and path traversal protection

- [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler
- [x] Remove escapeJSONString helper function in favor of standard JSON marshaling
- [x] Add safeResolvePath function to validate paths and prevent directory traversal
- [x] Apply path traversal validation in captureFileMtimes operations
- [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation

* fix(sdk): improve path traversal protection and allocation safety

- [x] Enhance safeResolvePath with stricter validation using filepath.Rel
- [x] Reject paths containing ".." after cleaning to prevent traversal
- [x] Validate absolute paths are within cwd when cwd is specified
- [x] Apply safeResolvePath validation to GetFileContent for consistency
- [x] Add comprehensive test coverage for path traversal protection
- [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
2026-01-11 01:51:20 +00:00
committed by GitHub
parent 3107eddeb2
commit d04b60517a
46 changed files with 12710 additions and 2068 deletions
+428 -43
View File
@@ -4,11 +4,15 @@ package sdk
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
@@ -20,8 +24,178 @@ import (
"github.com/rs/zerolog/log"
)
// CircuitBreaker implements a simple circuit breaker pattern for CLI calls.
type CircuitBreaker struct {
failures int64 // Current failure count
lastFailure int64 // Unix timestamp of last failure
threshold int64 // Number of failures before opening
resetTimeout int64 // Seconds to wait before trying again
state int32 // 0=closed, 1=open, 2=half-open
}
const (
circuitClosed int32 = 0
circuitOpen int32 = 1
circuitHalfOpen int32 = 2
)
// NewCircuitBreaker creates a new circuit breaker.
func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker {
return &CircuitBreaker{
threshold: threshold,
resetTimeout: resetTimeout,
}
}
// Allow checks if a request should be allowed through.
func (cb *CircuitBreaker) Allow() bool {
state := atomic.LoadInt32(&cb.state)
if state == circuitClosed {
return true
}
if state == circuitOpen {
// Check if reset timeout has passed
lastFail := atomic.LoadInt64(&cb.lastFailure)
if time.Now().Unix()-lastFail > cb.resetTimeout {
// Transition to half-open
atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen)
return true
}
return false
}
// Half-open: allow one request through
return true
}
// RecordSuccess records a successful call.
func (cb *CircuitBreaker) RecordSuccess() {
atomic.StoreInt64(&cb.failures, 0)
atomic.StoreInt32(&cb.state, circuitClosed)
}
// RecordFailure records a failed call.
func (cb *CircuitBreaker) RecordFailure() {
failures := atomic.AddInt64(&cb.failures, 1)
atomic.StoreInt64(&cb.lastFailure, time.Now().Unix())
if failures >= cb.threshold {
atomic.StoreInt32(&cb.state, circuitOpen)
log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled")
}
}
// State returns the current state as a string.
func (cb *CircuitBreaker) State() string {
switch atomic.LoadInt32(&cb.state) {
case circuitOpen:
return "open"
case circuitHalfOpen:
return "half-open"
default:
return "closed"
}
}
// CircuitBreakerMetrics contains metrics about the circuit breaker state.
type CircuitBreakerMetrics struct {
State string `json:"state"`
Failures int64 `json:"failures"`
Threshold int64 `json:"threshold"`
ResetTimeoutSecs int64 `json:"reset_timeout_secs"`
LastFailureUnix int64 `json:"last_failure_unix,omitempty"`
SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"`
}
// Metrics returns the current metrics of the circuit breaker.
func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics {
failures := atomic.LoadInt64(&cb.failures)
lastFail := atomic.LoadInt64(&cb.lastFailure)
state := cb.State()
metrics := CircuitBreakerMetrics{
State: state,
Failures: failures,
Threshold: cb.threshold,
ResetTimeoutSecs: cb.resetTimeout,
}
if lastFail > 0 {
metrics.LastFailureUnix = lastFail
if state == "open" {
remaining := cb.resetTimeout - (time.Now().Unix() - lastFail)
if remaining > 0 {
metrics.SecondsUntilReset = remaining
}
}
}
return metrics
}
// RequestDeduplicator tracks recent requests to prevent duplicates.
type RequestDeduplicator struct {
seen map[string]int64 // hash -> timestamp
mu sync.RWMutex
ttlSecs int64
maxSize int
}
// NewRequestDeduplicator creates a new deduplicator.
func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator {
return &RequestDeduplicator{
seen: make(map[string]int64),
ttlSecs: ttlSecs,
maxSize: maxSize,
}
}
// IsDuplicate checks if a request hash was seen recently.
func (d *RequestDeduplicator) IsDuplicate(hash string) bool {
now := time.Now().Unix()
d.mu.RLock()
ts, exists := d.seen[hash]
d.mu.RUnlock()
if exists && now-ts < d.ttlSecs {
return true
}
return false
}
// Record marks a request hash as seen.
func (d *RequestDeduplicator) Record(hash string) {
now := time.Now().Unix()
d.mu.Lock()
defer d.mu.Unlock()
// Evict old entries if at capacity
if len(d.seen) >= d.maxSize {
threshold := now - d.ttlSecs
for k, ts := range d.seen {
if ts < threshold {
delete(d.seen, k)
}
}
}
d.seen[hash] = now
}
// hashRequest creates a hash of a request for deduplication.
func hashRequest(toolName, input, output string) string {
h := sha256.New()
h.Write([]byte(toolName))
h.Write([]byte(input))
h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output
return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient
}
// BroadcastFunc is a callback for broadcasting events to SSE clients.
type BroadcastFunc func(event map[string]interface{})
type BroadcastFunc func(event map[string]any)
// SyncObservationFunc is a callback for syncing observations to vector DB.
type SyncObservationFunc func(obs *models.Observation)
@@ -29,16 +203,26 @@ type SyncObservationFunc func(obs *models.Observation)
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
type SyncSummaryFunc func(summary *models.SessionSummary)
// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations.
// This prevents unbounded goroutine spawning during high-volume observation ingestion.
const MaxVectorSyncWorkers = 8
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
// Field order optimized for memory alignment (fieldalignment).
type Processor struct {
observationStore *gorm.ObservationStore
summaryStore *gorm.SummaryStore
broadcastFunc BroadcastFunc
syncObservationFunc SyncObservationFunc
syncSummaryFunc SyncSummaryFunc
circuitBreaker *CircuitBreaker
deduplicator *RequestDeduplicator
vectorSyncChan chan *models.Observation
vectorSyncDone chan struct{}
sem chan struct{}
claudePath string
model string
vectorSyncWg sync.WaitGroup
}
// SetBroadcastFunc sets the broadcast callback for SSE events.
@@ -57,7 +241,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
}
// broadcast sends an event via the broadcast callback if set.
func (p *Processor) broadcast(event map[string]interface{}) {
func (p *Processor) broadcast(event map[string]any) {
if p.broadcastFunc != nil {
p.broadcastFunc(event)
}
@@ -93,9 +277,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su
observationStore: observationStore,
summaryStore: summaryStore,
sem: make(chan struct{}, MaxConcurrentCLICalls),
circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s
deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel
vectorSyncDone: make(chan struct{}),
}, nil
}
// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations.
// Call this after setting the sync function via SetSyncObservationFunc.
func (p *Processor) StartVectorSyncWorkers() {
for i := 0; i < MaxVectorSyncWorkers; i++ {
p.vectorSyncWg.Add(1)
go p.vectorSyncWorker()
}
log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started")
}
// StopVectorSyncWorkers gracefully stops the worker pool.
func (p *Processor) StopVectorSyncWorkers() {
close(p.vectorSyncDone)
p.vectorSyncWg.Wait()
log.Info().Msg("Vector sync worker pool stopped")
}
// vectorSyncWorker is a worker goroutine that processes vector sync requests.
func (p *Processor) vectorSyncWorker() {
defer p.vectorSyncWg.Done()
for {
select {
case <-p.vectorSyncDone:
// Drain remaining items before exiting
for {
select {
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
default:
return
}
}
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
}
}
}
// CircuitBreakerState returns the current state of the circuit breaker.
func (p *Processor) CircuitBreakerState() string {
return p.circuitBreaker.State()
}
// CircuitBreakerMetrics returns detailed metrics about the circuit breaker.
func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics {
return p.circuitBreaker.Metrics()
}
// IsAvailable checks if the Claude CLI is available for processing.
func (p *Processor) IsAvailable() bool {
_, err := os.Stat(p.claudePath)
@@ -103,7 +343,7 @@ func (p *Processor) IsAvailable() bool {
}
// ProcessObservation processes a single tool observation and extracts insights.
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error {
// Skip certain tools that aren't worth processing
if shouldSkipTool(toolName) {
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
@@ -120,11 +360,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
return nil
}
// Check for duplicate request within TTL window
reqHash := hashRequest(toolName, inputStr, outputStr)
if p.deduplicator.IsDuplicate(reqHash) {
log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)")
return nil
}
// Check circuit breaker before making CLI call
if !p.circuitBreaker.Allow() {
log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call")
return fmt.Errorf("circuit breaker open")
}
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
// Note: Removed the "file already has observations" check
// Each tool execution can produce unique insights even for the same file
// Similarity-based deduplication will handle true duplicates
// Record this request to prevent duplicates
p.deduplicator.Record(reqHash)
// Build the prompt
exec := ToolExecution{
@@ -146,9 +398,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
// Call Claude Code CLI
response, err := p.callClaudeCLI(ctx, prompt)
if err != nil {
p.circuitBreaker.RecordFailure()
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
return err
}
p.circuitBreaker.RecordSuccess()
// Parse observations from response
observations := ParseObservations(response, sdkSessionID)
@@ -199,16 +453,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
Int("trackedFiles", len(obs.FileMtimes)).
Msg("Observation stored")
// Sync to vector DB if callback is set
if p.syncObservationFunc != nil {
// Sync to vector DB via bounded worker pool (non-blocking to reduce latency)
if p.syncObservationFunc != nil && p.vectorSyncChan != nil {
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
fullObs.ID = id
fullObs.CreatedAtEpoch = createdAtEpoch
p.syncObservationFunc(fullObs)
// Non-blocking send to worker pool - drops if channel is full
select {
case p.vectorSyncChan <- fullObs:
// Sent to worker pool
default:
// Channel full, fall back to direct sync in goroutine (bounded by channel buffer)
log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine")
go func(obsToSync *models.Observation) {
p.syncObservationFunc(obsToSync)
}(fullObs)
}
}
// Broadcast new observation event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "observation",
"action": "created",
"id": id,
@@ -310,7 +574,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
}
// Broadcast new summary event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "summary",
"action": "created",
"id": id,
@@ -320,8 +584,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
return nil
}
// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI.
// This prevents resource exhaustion from extremely large prompts.
const MaxPromptSize = 100 * 1024 // 100KB
// sanitizePrompt removes null bytes and control characters from a prompt.
// Keeps newlines, tabs, and carriage returns as they're valid in prompts.
func sanitizePrompt(s string) string {
return strings.Map(func(r rune) rune {
// Keep printable ASCII, extended Unicode, and common whitespace
if r >= 32 || r == '\n' || r == '\t' || r == '\r' {
return r
}
// Remove null bytes and other control characters
return -1
}, s)
}
// callClaudeCLI calls the Claude Code CLI with the given prompt.
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
// Validate and sanitize prompt
if len(prompt) > MaxPromptSize {
return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize)
}
prompt = sanitizePrompt(prompt)
// Build the full prompt with system instructions
fullPrompt := systemPrompt + "\n\n" + prompt
@@ -418,8 +705,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
return true
}
// Skip if output indicates an error or empty result
// Pre-compute lowercase strings once to avoid repeated allocations
lowerOutput := strings.ToLower(outputStr)
lowerInput := strings.ToLower(inputStr)
// Skip if output indicates an error or empty result
trivialOutputs := []string{
"no matches found",
"file not found",
@@ -443,13 +733,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
// Skip reading config files that rarely contain project-specific insights
boringFiles := []string{
"package-lock.json", "yarn.lock", "pnpm-lock.yaml",
"go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock",
"go.sum", "cargo.lock", "gemfile.lock", "poetry.lock",
".gitignore", ".dockerignore", ".eslintignore",
"tsconfig.json", "jsconfig.json", "vite.config",
"tailwind.config", "postcss.config",
}
for _, boring := range boringFiles {
if strings.Contains(inputStr, boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -461,14 +751,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
case "Bash":
// Skip simple status commands
// Skip simple status commands (use pre-computed lowerInput)
boringCommands := []string{
"git status", "git diff", "git log", "git branch",
"ls ", "pwd", "echo ", "cat ", "which ", "type ",
"npm list", "npm outdated", "npm audit",
}
for _, boring := range boringCommands {
if strings.Contains(strings.ToLower(inputStr), boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -478,7 +768,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
// toJSONString converts an interface to a JSON string.
func toJSONString(v interface{}) string {
func toJSONString(v any) string {
if v == nil {
return ""
}
@@ -492,38 +782,132 @@ func toJSONString(v interface{}) string {
return string(b)
}
// safeResolvePath resolves a path relative to cwd and validates it doesn't escape the cwd directory.
// Returns the resolved absolute path and true if valid, or empty string and false if path traversal detected.
// This function is a security sanitizer for path traversal attacks.
func safeResolvePath(path, cwd string) (string, bool) {
// Clean the input path to normalize any .. or . components
cleanPath := filepath.Clean(path)
// Reject paths that explicitly contain parent directory traversal after cleaning
if strings.Contains(cleanPath, "..") {
return "", false
}
if filepath.IsAbs(cleanPath) {
// For absolute paths, verify they're within cwd if cwd is specified
if cwd != "" {
cleanCwd := filepath.Clean(cwd)
if !strings.HasPrefix(cleanPath, cleanCwd+string(filepath.Separator)) && cleanPath != cleanCwd {
return "", false
}
}
return cleanPath, true
}
if cwd == "" {
return cleanPath, true
}
// Clean the cwd first
cleanCwd := filepath.Clean(cwd)
// Join and clean the path
absPath := filepath.Join(cleanCwd, cleanPath)
// Use filepath.Rel to verify the path is actually within cwd
// If Rel returns a path starting with "..", it escapes the base
rel, err := filepath.Rel(cleanCwd, absPath)
if err != nil || strings.HasPrefix(rel, "..") {
return "", false
}
return absPath, true
}
// captureFileMtimes captures current modification times for tracked files.
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
// For large file lists (>10 files), uses parallel stat calls for better performance.
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
mtimes := make(map[string]int64)
// Combine all unique file paths
allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified))
for _, path := range filesRead {
allPaths[path] = struct{}{}
}
for _, path := range filesModified {
allPaths[path] = struct{}{}
}
// Helper to get mtime for a file path
getMtime := func(path string) (int64, bool) {
// Resolve relative paths against cwd
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
// For small lists, use sequential processing (goroutine overhead not worth it)
if len(allPaths) <= 10 {
return captureFileMtimesSequential(allPaths, cwd)
}
// For larger lists, parallelize with bounded concurrency
return captureFileMtimesParallel(allPaths, cwd)
}
// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists).
func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 {
mtimes := make(map[string]int64, len(paths))
for path := range paths {
absPath, ok := safeResolvePath(path, cwd)
if !ok {
// Skip paths that attempt directory traversal
continue
}
info, err := os.Stat(absPath)
if err != nil {
return 0, false
}
return info.ModTime().UnixMilli(), true
}
// Capture mtimes for all read files
for _, path := range filesRead {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
if err == nil {
mtimes[path] = info.ModTime().UnixMilli()
}
}
// Capture mtimes for all modified files
for _, path := range filesModified {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
}
return mtimes
}
// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency.
func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 {
type mtimeResult struct {
path string
mtime int64
}
results := make(chan mtimeResult, len(paths))
sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls
var wg sync.WaitGroup
for path := range paths {
wg.Add(1)
go func(p string) {
defer wg.Done()
sem <- struct{}{} // Acquire
defer func() { <-sem }() // Release
absPath, ok := safeResolvePath(p, cwd)
if !ok {
// Skip paths that attempt directory traversal
return
}
info, err := os.Stat(absPath)
if err == nil {
results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()}
}
}(path)
}
// Close results channel when all goroutines complete
go func() {
wg.Wait()
close(results)
}()
// Collect results
mtimes := make(map[string]int64, len(paths))
for res := range results {
mtimes[res.path] = res.mtime
}
return mtimes
@@ -538,12 +922,13 @@ func GetFileMtimes(paths []string, cwd string) map[string]int64 {
// GetFileContent reads file content for verification purposes.
// Returns content and ok status.
func GetFileContent(path, cwd string) (string, bool) {
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
absPath, ok := safeResolvePath(path, cwd)
if !ok {
// Reject paths that attempt directory traversal
return "", false
}
content, err := os.ReadFile(absPath) // #nosec G304 -- intentional file read for verification
content, err := os.ReadFile(absPath) // #nosec G304 -- path validated by safeResolvePath
if err != nil {
return "", false
}