mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-11 00:09:28 +00:00
Make things 'betterer' across the board (#23)
* Make things 'betterer' across the board * fix: reorganize struct fields and config parameters for consistency - [x] Reorder Config struct fields alphabetically and by related functionality - [x] Reorganize Observation model fields with archival fields grouped together - [x] Reorder ObservationStore fields to group related members - [x] Reorder Store struct fields with health check caching grouped - [x] Reorganize HealthInfo and PoolMetrics struct field order - [x] Reorder maintenance Service struct fields logically - [x] Reorganize MCP server handler parameter structs alphabetically - [x] Reorder pattern detector candidate tracking fields - [x] Reorganize search Manager struct fields by functionality - [x] Reorder vector Client struct fields with mutex protections grouped - [x] Reorganize handler request/response struct fields - [x] Update handlers_test.go to expect wrapped response format - [x] Reorder middleware TokenAuth and rate limiter fields - [x] Reorganize Service struct fields with grouped functionality - [x] Fix RateLimiter field ordering for clarity - [x] Reorder CircuitBreaker metrics fields * fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation * fix(sdk): improve path traversal protection and allocation safety - [x] Enhance safeResolvePath with stricter validation using filepath.Rel - [x] Reject paths containing ".." after cleaning to prevent traversal - [x] Validate absolute paths are within cwd when cwd is specified - [x] Apply safeResolvePath validation to GetFileContent for consistency - [x] Add comprehensive test coverage for path traversal protection - [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
@@ -4,11 +4,15 @@ package sdk
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
@@ -20,8 +24,178 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// CircuitBreaker implements a simple circuit breaker pattern for CLI calls.
|
||||
type CircuitBreaker struct {
|
||||
failures int64 // Current failure count
|
||||
lastFailure int64 // Unix timestamp of last failure
|
||||
threshold int64 // Number of failures before opening
|
||||
resetTimeout int64 // Seconds to wait before trying again
|
||||
state int32 // 0=closed, 1=open, 2=half-open
|
||||
}
|
||||
|
||||
const (
|
||||
circuitClosed int32 = 0
|
||||
circuitOpen int32 = 1
|
||||
circuitHalfOpen int32 = 2
|
||||
)
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker.
|
||||
func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
threshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed through.
|
||||
func (cb *CircuitBreaker) Allow() bool {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == circuitClosed {
|
||||
return true
|
||||
}
|
||||
|
||||
if state == circuitOpen {
|
||||
// Check if reset timeout has passed
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailure)
|
||||
if time.Now().Unix()-lastFail > cb.resetTimeout {
|
||||
// Transition to half-open
|
||||
atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Half-open: allow one request through
|
||||
return true
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful call.
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
atomic.StoreInt32(&cb.state, circuitClosed)
|
||||
}
|
||||
|
||||
// RecordFailure records a failed call.
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
failures := atomic.AddInt64(&cb.failures, 1)
|
||||
atomic.StoreInt64(&cb.lastFailure, time.Now().Unix())
|
||||
|
||||
if failures >= cb.threshold {
|
||||
atomic.StoreInt32(&cb.state, circuitOpen)
|
||||
log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current state as a string.
|
||||
func (cb *CircuitBreaker) State() string {
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case circuitOpen:
|
||||
return "open"
|
||||
case circuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "closed"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerMetrics contains metrics about the circuit breaker state.
|
||||
type CircuitBreakerMetrics struct {
|
||||
State string `json:"state"`
|
||||
Failures int64 `json:"failures"`
|
||||
Threshold int64 `json:"threshold"`
|
||||
ResetTimeoutSecs int64 `json:"reset_timeout_secs"`
|
||||
LastFailureUnix int64 `json:"last_failure_unix,omitempty"`
|
||||
SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"`
|
||||
}
|
||||
|
||||
// Metrics returns the current metrics of the circuit breaker.
|
||||
func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics {
|
||||
failures := atomic.LoadInt64(&cb.failures)
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailure)
|
||||
state := cb.State()
|
||||
|
||||
metrics := CircuitBreakerMetrics{
|
||||
State: state,
|
||||
Failures: failures,
|
||||
Threshold: cb.threshold,
|
||||
ResetTimeoutSecs: cb.resetTimeout,
|
||||
}
|
||||
|
||||
if lastFail > 0 {
|
||||
metrics.LastFailureUnix = lastFail
|
||||
if state == "open" {
|
||||
remaining := cb.resetTimeout - (time.Now().Unix() - lastFail)
|
||||
if remaining > 0 {
|
||||
metrics.SecondsUntilReset = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// RequestDeduplicator tracks recent requests to prevent duplicates.
|
||||
type RequestDeduplicator struct {
|
||||
seen map[string]int64 // hash -> timestamp
|
||||
mu sync.RWMutex
|
||||
ttlSecs int64
|
||||
maxSize int
|
||||
}
|
||||
|
||||
// NewRequestDeduplicator creates a new deduplicator.
|
||||
func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator {
|
||||
return &RequestDeduplicator{
|
||||
seen: make(map[string]int64),
|
||||
ttlSecs: ttlSecs,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// IsDuplicate checks if a request hash was seen recently.
|
||||
func (d *RequestDeduplicator) IsDuplicate(hash string) bool {
|
||||
now := time.Now().Unix()
|
||||
|
||||
d.mu.RLock()
|
||||
ts, exists := d.seen[hash]
|
||||
d.mu.RUnlock()
|
||||
|
||||
if exists && now-ts < d.ttlSecs {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Record marks a request hash as seen.
|
||||
func (d *RequestDeduplicator) Record(hash string) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// Evict old entries if at capacity
|
||||
if len(d.seen) >= d.maxSize {
|
||||
threshold := now - d.ttlSecs
|
||||
for k, ts := range d.seen {
|
||||
if ts < threshold {
|
||||
delete(d.seen, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.seen[hash] = now
|
||||
}
|
||||
|
||||
// hashRequest creates a hash of a request for deduplication.
|
||||
func hashRequest(toolName, input, output string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(toolName))
|
||||
h.Write([]byte(input))
|
||||
h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output
|
||||
return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient
|
||||
}
|
||||
|
||||
// BroadcastFunc is a callback for broadcasting events to SSE clients.
|
||||
type BroadcastFunc func(event map[string]interface{})
|
||||
type BroadcastFunc func(event map[string]any)
|
||||
|
||||
// SyncObservationFunc is a callback for syncing observations to vector DB.
|
||||
type SyncObservationFunc func(obs *models.Observation)
|
||||
@@ -29,16 +203,26 @@ type SyncObservationFunc func(obs *models.Observation)
|
||||
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
|
||||
type SyncSummaryFunc func(summary *models.SessionSummary)
|
||||
|
||||
// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations.
|
||||
// This prevents unbounded goroutine spawning during high-volume observation ingestion.
|
||||
const MaxVectorSyncWorkers = 8
|
||||
|
||||
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
|
||||
// Field order optimized for memory alignment (fieldalignment).
|
||||
type Processor struct {
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
broadcastFunc BroadcastFunc
|
||||
syncObservationFunc SyncObservationFunc
|
||||
syncSummaryFunc SyncSummaryFunc
|
||||
circuitBreaker *CircuitBreaker
|
||||
deduplicator *RequestDeduplicator
|
||||
vectorSyncChan chan *models.Observation
|
||||
vectorSyncDone chan struct{}
|
||||
sem chan struct{}
|
||||
claudePath string
|
||||
model string
|
||||
vectorSyncWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// SetBroadcastFunc sets the broadcast callback for SSE events.
|
||||
@@ -57,7 +241,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
|
||||
}
|
||||
|
||||
// broadcast sends an event via the broadcast callback if set.
|
||||
func (p *Processor) broadcast(event map[string]interface{}) {
|
||||
func (p *Processor) broadcast(event map[string]any) {
|
||||
if p.broadcastFunc != nil {
|
||||
p.broadcastFunc(event)
|
||||
}
|
||||
@@ -93,9 +277,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
sem: make(chan struct{}, MaxConcurrentCLICalls),
|
||||
circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s
|
||||
deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries
|
||||
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel
|
||||
vectorSyncDone: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations.
|
||||
// Call this after setting the sync function via SetSyncObservationFunc.
|
||||
func (p *Processor) StartVectorSyncWorkers() {
|
||||
for i := 0; i < MaxVectorSyncWorkers; i++ {
|
||||
p.vectorSyncWg.Add(1)
|
||||
go p.vectorSyncWorker()
|
||||
}
|
||||
log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started")
|
||||
}
|
||||
|
||||
// StopVectorSyncWorkers gracefully stops the worker pool.
|
||||
func (p *Processor) StopVectorSyncWorkers() {
|
||||
close(p.vectorSyncDone)
|
||||
p.vectorSyncWg.Wait()
|
||||
log.Info().Msg("Vector sync worker pool stopped")
|
||||
}
|
||||
|
||||
// vectorSyncWorker is a worker goroutine that processes vector sync requests.
|
||||
func (p *Processor) vectorSyncWorker() {
|
||||
defer p.vectorSyncWg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-p.vectorSyncDone:
|
||||
// Drain remaining items before exiting
|
||||
for {
|
||||
select {
|
||||
case obs := <-p.vectorSyncChan:
|
||||
if p.syncObservationFunc != nil {
|
||||
p.syncObservationFunc(obs)
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case obs := <-p.vectorSyncChan:
|
||||
if p.syncObservationFunc != nil {
|
||||
p.syncObservationFunc(obs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerState returns the current state of the circuit breaker.
|
||||
func (p *Processor) CircuitBreakerState() string {
|
||||
return p.circuitBreaker.State()
|
||||
}
|
||||
|
||||
// CircuitBreakerMetrics returns detailed metrics about the circuit breaker.
|
||||
func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics {
|
||||
return p.circuitBreaker.Metrics()
|
||||
}
|
||||
|
||||
// IsAvailable checks if the Claude CLI is available for processing.
|
||||
func (p *Processor) IsAvailable() bool {
|
||||
_, err := os.Stat(p.claudePath)
|
||||
@@ -103,7 +343,7 @@ func (p *Processor) IsAvailable() bool {
|
||||
}
|
||||
|
||||
// ProcessObservation processes a single tool observation and extracts insights.
|
||||
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
|
||||
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error {
|
||||
// Skip certain tools that aren't worth processing
|
||||
if shouldSkipTool(toolName) {
|
||||
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
|
||||
@@ -120,11 +360,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for duplicate request within TTL window
|
||||
reqHash := hashRequest(toolName, inputStr, outputStr)
|
||||
if p.deduplicator.IsDuplicate(reqHash) {
|
||||
log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check circuit breaker before making CLI call
|
||||
if !p.circuitBreaker.Allow() {
|
||||
log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call")
|
||||
return fmt.Errorf("circuit breaker open")
|
||||
}
|
||||
|
||||
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
|
||||
|
||||
// Note: Removed the "file already has observations" check
|
||||
// Each tool execution can produce unique insights even for the same file
|
||||
// Similarity-based deduplication will handle true duplicates
|
||||
// Record this request to prevent duplicates
|
||||
p.deduplicator.Record(reqHash)
|
||||
|
||||
// Build the prompt
|
||||
exec := ToolExecution{
|
||||
@@ -146,9 +398,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
// Call Claude Code CLI
|
||||
response, err := p.callClaudeCLI(ctx, prompt)
|
||||
if err != nil {
|
||||
p.circuitBreaker.RecordFailure()
|
||||
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
|
||||
return err
|
||||
}
|
||||
p.circuitBreaker.RecordSuccess()
|
||||
|
||||
// Parse observations from response
|
||||
observations := ParseObservations(response, sdkSessionID)
|
||||
@@ -199,16 +453,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
Int("trackedFiles", len(obs.FileMtimes)).
|
||||
Msg("Observation stored")
|
||||
|
||||
// Sync to vector DB if callback is set
|
||||
if p.syncObservationFunc != nil {
|
||||
// Sync to vector DB via bounded worker pool (non-blocking to reduce latency)
|
||||
if p.syncObservationFunc != nil && p.vectorSyncChan != nil {
|
||||
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
|
||||
fullObs.ID = id
|
||||
fullObs.CreatedAtEpoch = createdAtEpoch
|
||||
p.syncObservationFunc(fullObs)
|
||||
// Non-blocking send to worker pool - drops if channel is full
|
||||
select {
|
||||
case p.vectorSyncChan <- fullObs:
|
||||
// Sent to worker pool
|
||||
default:
|
||||
// Channel full, fall back to direct sync in goroutine (bounded by channel buffer)
|
||||
log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine")
|
||||
go func(obsToSync *models.Observation) {
|
||||
p.syncObservationFunc(obsToSync)
|
||||
}(fullObs)
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast new observation event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
p.broadcast(map[string]any{
|
||||
"type": "observation",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
@@ -310,7 +574,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
|
||||
}
|
||||
|
||||
// Broadcast new summary event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
p.broadcast(map[string]any{
|
||||
"type": "summary",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
@@ -320,8 +584,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI.
|
||||
// This prevents resource exhaustion from extremely large prompts.
|
||||
const MaxPromptSize = 100 * 1024 // 100KB
|
||||
|
||||
// sanitizePrompt removes null bytes and control characters from a prompt.
|
||||
// Keeps newlines, tabs, and carriage returns as they're valid in prompts.
|
||||
func sanitizePrompt(s string) string {
|
||||
return strings.Map(func(r rune) rune {
|
||||
// Keep printable ASCII, extended Unicode, and common whitespace
|
||||
if r >= 32 || r == '\n' || r == '\t' || r == '\r' {
|
||||
return r
|
||||
}
|
||||
// Remove null bytes and other control characters
|
||||
return -1
|
||||
}, s)
|
||||
}
|
||||
|
||||
// callClaudeCLI calls the Claude Code CLI with the given prompt.
|
||||
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
|
||||
// Validate and sanitize prompt
|
||||
if len(prompt) > MaxPromptSize {
|
||||
return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize)
|
||||
}
|
||||
prompt = sanitizePrompt(prompt)
|
||||
|
||||
// Build the full prompt with system instructions
|
||||
fullPrompt := systemPrompt + "\n\n" + prompt
|
||||
|
||||
@@ -418,8 +705,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip if output indicates an error or empty result
|
||||
// Pre-compute lowercase strings once to avoid repeated allocations
|
||||
lowerOutput := strings.ToLower(outputStr)
|
||||
lowerInput := strings.ToLower(inputStr)
|
||||
|
||||
// Skip if output indicates an error or empty result
|
||||
trivialOutputs := []string{
|
||||
"no matches found",
|
||||
"file not found",
|
||||
@@ -443,13 +733,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
// Skip reading config files that rarely contain project-specific insights
|
||||
boringFiles := []string{
|
||||
"package-lock.json", "yarn.lock", "pnpm-lock.yaml",
|
||||
"go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock",
|
||||
"go.sum", "cargo.lock", "gemfile.lock", "poetry.lock",
|
||||
".gitignore", ".dockerignore", ".eslintignore",
|
||||
"tsconfig.json", "jsconfig.json", "vite.config",
|
||||
"tailwind.config", "postcss.config",
|
||||
}
|
||||
for _, boring := range boringFiles {
|
||||
if strings.Contains(inputStr, boring) {
|
||||
if strings.Contains(lowerInput, boring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -461,14 +751,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
}
|
||||
|
||||
case "Bash":
|
||||
// Skip simple status commands
|
||||
// Skip simple status commands (use pre-computed lowerInput)
|
||||
boringCommands := []string{
|
||||
"git status", "git diff", "git log", "git branch",
|
||||
"ls ", "pwd", "echo ", "cat ", "which ", "type ",
|
||||
"npm list", "npm outdated", "npm audit",
|
||||
}
|
||||
for _, boring := range boringCommands {
|
||||
if strings.Contains(strings.ToLower(inputStr), boring) {
|
||||
if strings.Contains(lowerInput, boring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -478,7 +768,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
}
|
||||
|
||||
// toJSONString converts an interface to a JSON string.
|
||||
func toJSONString(v interface{}) string {
|
||||
func toJSONString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -492,38 +782,132 @@ func toJSONString(v interface{}) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// safeResolvePath resolves a path relative to cwd and validates it doesn't escape the cwd directory.
|
||||
// Returns the resolved absolute path and true if valid, or empty string and false if path traversal detected.
|
||||
// This function is a security sanitizer for path traversal attacks.
|
||||
func safeResolvePath(path, cwd string) (string, bool) {
|
||||
// Clean the input path to normalize any .. or . components
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Reject paths that explicitly contain parent directory traversal after cleaning
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if filepath.IsAbs(cleanPath) {
|
||||
// For absolute paths, verify they're within cwd if cwd is specified
|
||||
if cwd != "" {
|
||||
cleanCwd := filepath.Clean(cwd)
|
||||
if !strings.HasPrefix(cleanPath, cleanCwd+string(filepath.Separator)) && cleanPath != cleanCwd {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
return cleanPath, true
|
||||
}
|
||||
|
||||
if cwd == "" {
|
||||
return cleanPath, true
|
||||
}
|
||||
|
||||
// Clean the cwd first
|
||||
cleanCwd := filepath.Clean(cwd)
|
||||
|
||||
// Join and clean the path
|
||||
absPath := filepath.Join(cleanCwd, cleanPath)
|
||||
|
||||
// Use filepath.Rel to verify the path is actually within cwd
|
||||
// If Rel returns a path starting with "..", it escapes the base
|
||||
rel, err := filepath.Rel(cleanCwd, absPath)
|
||||
if err != nil || strings.HasPrefix(rel, "..") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
// captureFileMtimes captures current modification times for tracked files.
|
||||
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
|
||||
// For large file lists (>10 files), uses parallel stat calls for better performance.
|
||||
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
|
||||
mtimes := make(map[string]int64)
|
||||
// Combine all unique file paths
|
||||
allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified))
|
||||
for _, path := range filesRead {
|
||||
allPaths[path] = struct{}{}
|
||||
}
|
||||
for _, path := range filesModified {
|
||||
allPaths[path] = struct{}{}
|
||||
}
|
||||
|
||||
// Helper to get mtime for a file path
|
||||
getMtime := func(path string) (int64, bool) {
|
||||
// Resolve relative paths against cwd
|
||||
absPath := path
|
||||
if !filepath.IsAbs(path) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, path)
|
||||
// For small lists, use sequential processing (goroutine overhead not worth it)
|
||||
if len(allPaths) <= 10 {
|
||||
return captureFileMtimesSequential(allPaths, cwd)
|
||||
}
|
||||
|
||||
// For larger lists, parallelize with bounded concurrency
|
||||
return captureFileMtimesParallel(allPaths, cwd)
|
||||
}
|
||||
|
||||
// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists).
|
||||
func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 {
|
||||
mtimes := make(map[string]int64, len(paths))
|
||||
|
||||
for path := range paths {
|
||||
absPath, ok := safeResolvePath(path, cwd)
|
||||
if !ok {
|
||||
// Skip paths that attempt directory traversal
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return info.ModTime().UnixMilli(), true
|
||||
}
|
||||
|
||||
// Capture mtimes for all read files
|
||||
for _, path := range filesRead {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
if err == nil {
|
||||
mtimes[path] = info.ModTime().UnixMilli()
|
||||
}
|
||||
}
|
||||
|
||||
// Capture mtimes for all modified files
|
||||
for _, path := range filesModified {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
}
|
||||
return mtimes
|
||||
}
|
||||
|
||||
// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency.
|
||||
func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 {
|
||||
type mtimeResult struct {
|
||||
path string
|
||||
mtime int64
|
||||
}
|
||||
|
||||
results := make(chan mtimeResult, len(paths))
|
||||
sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for path := range paths {
|
||||
wg.Add(1)
|
||||
go func(p string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // Acquire
|
||||
defer func() { <-sem }() // Release
|
||||
|
||||
absPath, ok := safeResolvePath(p, cwd)
|
||||
if !ok {
|
||||
// Skip paths that attempt directory traversal
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err == nil {
|
||||
results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()}
|
||||
}
|
||||
}(path)
|
||||
}
|
||||
|
||||
// Close results channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
mtimes := make(map[string]int64, len(paths))
|
||||
for res := range results {
|
||||
mtimes[res.path] = res.mtime
|
||||
}
|
||||
|
||||
return mtimes
|
||||
@@ -538,12 +922,13 @@ func GetFileMtimes(paths []string, cwd string) map[string]int64 {
|
||||
// GetFileContent reads file content for verification purposes.
|
||||
// Returns content and ok status.
|
||||
func GetFileContent(path, cwd string) (string, bool) {
|
||||
absPath := path
|
||||
if !filepath.IsAbs(path) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, path)
|
||||
absPath, ok := safeResolvePath(path, cwd)
|
||||
if !ok {
|
||||
// Reject paths that attempt directory traversal
|
||||
return "", false
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath) // #nosec G304 -- intentional file read for verification
|
||||
content, err := os.ReadFile(absPath) // #nosec G304 -- path validated by safeResolvePath
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user