Update, bugfixes on diff and edit handling

This commit is contained in:
2026-02-18 21:49:05 +00:00
parent 9205b2bc26
commit 6980d3b294
23 changed files with 3406 additions and 2083 deletions
+13 -1
View File
@@ -67,9 +67,11 @@ func setupLogger(level string, logFile string) *slog.Logger {
}
var handler slog.Handler
var logFileErr error
if logFile != "" {
f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
logFileErr = err
// Fallback to stderr
handler = slog.NewJSONHandler(os.Stderr, opts)
} else {
@@ -80,5 +82,15 @@ func setupLogger(level string, logFile string) *slog.Logger {
handler = slog.NewJSONHandler(os.Stderr, opts)
}
return slog.New(handler)
logger := slog.New(handler)
// Warn if log file couldn't be opened
if logFileErr != nil {
logger.Warn("failed to open log file, using stderr",
"file", logFile,
"error", logFileErr,
)
}
return logger
}
+1300 -1257
View File
File diff suppressed because it is too large Load Diff
+46 -88
View File
@@ -482,13 +482,12 @@ func (e *Engine) matchesSelector(sel ASTSelector, n *sitter.Node, content []byte
}
// applyEdit applies the edit operation to the content.
// AST mode uses exact byte positions — new_content is inserted verbatim without auto-indentation.
func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]byte, error) {
startByte := node.StartByte()
endByte := node.EndByte()
// Detect and preserve indentation
indentation := detectIndentation(content, startByte)
newContent := indentContent(edit.NewContent, indentation)
newContent := edit.NewContent
var result []byte
@@ -499,15 +498,21 @@ func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]
result = append(result, content[endByte:]...)
case EditInsertBefore:
insertion := newContent
if !strings.HasSuffix(insertion, "\n") {
insertion += "\n"
}
result = append(result, content[:startByte]...)
result = append(result, []byte(newContent)...)
result = append(result, '\n')
result = append(result, []byte(insertion)...)
result = append(result, content[startByte:]...)
case EditInsertAfter:
insertion := newContent
if !strings.HasPrefix(insertion, "\n") {
insertion = "\n" + insertion
}
result = append(result, content[:endByte]...)
result = append(result, '\n')
result = append(result, []byte(newContent)...)
result = append(result, []byte(insertion)...)
result = append(result, content[endByte:]...)
case EditDelete:
@@ -522,16 +527,16 @@ func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]
}
// detectIndentation detects the indentation at a given byte position.
func detectIndentation(content []byte, bytePos uint32) string {
func detectIndentation(content []byte, bytePos int) string {
// Find the start of the line
lineStart := int(bytePos)
lineStart := bytePos
for lineStart > 0 && content[lineStart-1] != '\n' {
lineStart--
}
// Extract leading whitespace
var indent strings.Builder
for i := lineStart; i < int(bytePos) && i < len(content); i++ {
for i := lineStart; i < bytePos && i < len(content); i++ {
c := content[i]
if c == ' ' || c == '\t' {
indent.WriteByte(c)
@@ -560,10 +565,16 @@ func indentContent(content string, indent string) string {
}
// generateDiff creates a unified diff between original and modified content.
// Uses Myers diff algorithm for accurate and readable diffs.
// Uses line-level Myers diff algorithm for accurate and readable diffs.
func generateDiff(original, modified, filename string) string {
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(original, modified, false)
// Use line-level diffing: encode each line as a single character,
// diff the encoded strings, then decode back to real lines.
// This prevents character-level diffs from splitting lines incorrectly.
chars1, chars2, lineArray := dmp.DiffLinesToChars(original, modified)
diffs := dmp.DiffMain(chars1, chars2, false)
diffs = dmp.DiffCharsToLines(diffs, lineArray)
// Cleanup for readability
diffs = dmp.DiffCleanupSemantic(diffs)
@@ -573,24 +584,25 @@ func generateDiff(original, modified, filename string) string {
buf.WriteString(fmt.Sprintf("--- %s\n", filename))
buf.WriteString(fmt.Sprintf("+++ %s\n", filename))
// Group diffs into hunks
lineNum := 1
for _, diff := range diffs {
lines := strings.Split(diff.Text, "\n")
for i, line := range lines {
// Skip empty last line from split
if i == len(lines)-1 && line == "" {
// SplitAfter preserves the trailing \n on each line, so we can
// distinguish real lines from a trailing empty split artifact.
lines := strings.SplitAfter(diff.Text, "\n")
for _, line := range lines {
if line == "" {
continue
}
// Remove trailing newline for display — we add our own.
cleanLine := strings.TrimSuffix(line, "\n")
switch diff.Type {
case diffmatchpatch.DiffDelete:
buf.WriteString(fmt.Sprintf("-%s\n", line))
buf.WriteString(fmt.Sprintf("-%s\n", cleanLine))
case diffmatchpatch.DiffInsert:
buf.WriteString(fmt.Sprintf("+%s\n", line))
buf.WriteString(fmt.Sprintf("+%s\n", cleanLine))
case diffmatchpatch.DiffEqual:
buf.WriteString(fmt.Sprintf(" %s\n", line))
lineNum++
buf.WriteString(fmt.Sprintf(" %s\n", cleanLine))
}
}
}
@@ -639,24 +651,6 @@ func (e *Engine) findExactText(content []byte, text string, index int) (start, e
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text not found: %q", truncateString(text, 50)))
}
// If multiple matches and no index specified, require explicit selection
if len(matches) > 1 && index == 0 {
// Check if index was explicitly set to 0 or just defaulted
// Since we can't distinguish, we'll allow index 0 but warn about multiple matches
// Actually, let's be strict and require explicit index for multiple matches
locations := make([]string, 0, min(len(matches), 5))
for i, m := range matches {
if i >= 5 {
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
break
}
line := countLines(content[:m.start]) + 1
locations = append(locations, fmt.Sprintf("line %d", line))
}
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text matches %d locations (%s); use selector_index to specify which one (0-%d)",
len(matches), strings.Join(locations, ", "), len(matches)-1))
}
if index >= len(matches) {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
}
@@ -676,21 +670,6 @@ func (e *Engine) findRegexPattern(content []byte, pattern string, index int) (st
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern not found: %q", truncateString(pattern, 50)))
}
// If multiple matches and index is 0 (default), show error with locations
if len(matches) > 1 && index == 0 {
locations := make([]string, 0, min(len(matches), 5))
for i, m := range matches {
if i >= 5 {
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
break
}
line := countLines(content[:m[0]]) + 1
locations = append(locations, fmt.Sprintf("line %d", line))
}
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern matches %d locations (%s); use selector_index to specify which one (0-%d)",
len(matches), strings.Join(locations, ", "), len(matches)-1))
}
if index >= len(matches) {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
}
@@ -728,7 +707,7 @@ func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, e
// Calculate byte positions
start = 0
for i := 0; i < startIdx; i++ {
for i := range startIdx {
start += len(lines[i]) + 1 // +1 for newline
}
@@ -746,7 +725,7 @@ func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, e
// applyTextEditOperation applies a text edit operation.
func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start, end int, newContent string) ([]byte, error) {
// Detect indentation at the selection point
indentation := detectIndentationAtByte(content, start)
indentation := detectIndentation(content, start)
indentedContent := indentContent(newContent, indentation)
var result []byte
@@ -758,15 +737,21 @@ func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start,
result = append(result, content[end:]...)
case EditInsertBefore:
insertion := indentedContent
if !strings.HasSuffix(insertion, "\n") {
insertion += "\n"
}
result = append(result, content[:start]...)
result = append(result, []byte(indentedContent)...)
result = append(result, '\n')
result = append(result, []byte(insertion)...)
result = append(result, content[start:]...)
case EditInsertAfter:
insertion := indentedContent
if !strings.HasPrefix(insertion, "\n") {
insertion = "\n" + insertion
}
result = append(result, content[:end]...)
result = append(result, '\n')
result = append(result, []byte(indentedContent)...)
result = append(result, []byte(insertion)...)
result = append(result, content[end:]...)
case EditDelete:
@@ -780,28 +765,6 @@ func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start,
return result, nil
}
// detectIndentationAtByte detects indentation at a byte position.
func detectIndentationAtByte(content []byte, bytePos int) string {
// Find the start of the line
lineStart := bytePos
for lineStart > 0 && content[lineStart-1] != '\n' {
lineStart--
}
// Extract leading whitespace
var indent strings.Builder
for i := lineStart; i < bytePos && i < len(content); i++ {
c := content[i]
if c == ' ' || c == '\t' {
indent.WriteByte(c)
} else {
break
}
}
return indent.String()
}
// truncateString truncates a string to maxLen with ellipsis.
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
@@ -810,11 +773,6 @@ func truncateString(s string, maxLen int) string {
return s[:maxLen-3] + "..."
}
// countLines counts the number of newlines in content.
func countLines(content []byte) int {
return bytes.Count(content, []byte("\n"))
}
// ValidateLanguage checks if AST editing is supported for a file.
// Returns nil for supported languages, error for unsupported.
// Note: Text-based editing is always available regardless of this check.
+79 -7
View File
@@ -357,7 +357,7 @@ func TestDetectIndentation(t *testing.T) {
name string
content string
want string
pos uint32
pos int
}{
{
name: "no indent",
@@ -410,6 +410,71 @@ func TestGenerateDiff(t *testing.T) {
}
}
func TestGenerateDiffLineLevelAccuracy(t *testing.T) {
// Regression test: diff must operate at line level, not character level.
// A character-level diff would split "hello" and "hello world" mid-line,
// producing broken output like:
// fmt.Println("hello
// + world
// ")
original := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello\")\n}\n"
modified := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello world\")\n}\n"
diff := generateDiff(original, modified, "test.go")
// The diff must show whole-line removals and additions
if !strings.Contains(diff, "-\tfmt.Println(\"hello\")\n") {
t.Errorf("diff should show full removed line, got:\n%s", diff)
}
if !strings.Contains(diff, "+\tfmt.Println(\"hello world\")\n") {
t.Errorf("diff should show full added line, got:\n%s", diff)
}
// The diff must NOT split lines at character boundaries
if strings.Contains(diff, "+hello") && !strings.Contains(diff, "Println") {
t.Errorf("diff appears to be character-level (split mid-line), got:\n%s", diff)
}
// Context lines should not be marked as changed
for line := range strings.SplitSeq(diff, "\n") {
if strings.HasPrefix(line, "-") || strings.HasPrefix(line, "+") {
// Changed lines should only be the Println lines
if strings.Contains(line, "package main") ||
strings.Contains(line, "func hello()") {
t.Errorf("unchanged line incorrectly marked as changed: %q", line)
}
}
}
}
func TestGenerateDiffNoPhantomChanges(t *testing.T) {
// Regression test: replacing a line range should not produce phantom
// +/- lines for unchanged code after the edit region.
original := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\n"
modified := "line1\nREPLACED\nline3\nline4\nline5\nline6\nline7\nline8\n"
diff := generateDiff(original, modified, "test.txt")
// Count changed lines (excluding headers)
addCount := 0
delCount := 0
for line := range strings.SplitSeq(diff, "\n") {
if strings.HasPrefix(line, "+") && !strings.HasPrefix(line, "+++") {
addCount++
}
if strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "---") {
delCount++
}
}
if addCount != 1 {
t.Errorf("expected 1 added line, got %d. Diff:\n%s", addCount, diff)
}
if delCount != 1 {
t.Errorf("expected 1 deleted line, got %d. Diff:\n%s", delCount, diff)
}
}
// ==================== Text-based editing tests ====================
func TestTextEditWithExactText(t *testing.T) {
@@ -593,7 +658,7 @@ SECRET_KEY=abc123
}
}
func TestTextEditMultipleMatchesError(t *testing.T) {
func TestTextEditMultipleMatchesSelectsFirst(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
@@ -624,12 +689,19 @@ more code
t.Fatalf("apply failed: %v", err)
}
// Should fail because of multiple matches
if result.Success {
t.Error("expected error for multiple matches without index")
// Index 0 (default) should select the first match
if !result.Success {
t.Fatalf("expected success for multiple matches with default index 0: %s", result.Error)
}
if !strings.Contains(result.Error, "matches") {
t.Errorf("error should mention multiple matches: %s", result.Error)
// Verify only first TODO was replaced
fileContent, _ := os.ReadFile(tmpFile)
contentStr := string(fileContent)
if !strings.Contains(contentStr, "DONE: fix this") {
t.Error("first TODO should be replaced with DONE")
}
if !strings.Contains(contentStr, "TODO: also fix this") {
t.Error("second TODO should be unchanged")
}
}
+29 -2
View File
@@ -16,6 +16,12 @@ import (
json "github.com/goccy/go-json"
)
// ProcessKillTimeout is the timeout for waiting for a process to exit before force killing.
const ProcessKillTimeout = 5 * time.Second
// StderrBufferSize is the buffer size for draining stderr.
const StderrBufferSize = 1024
// Client represents an LSP client connection.
type Client struct {
stdin io.WriteCloser
@@ -104,12 +110,33 @@ func NewClient(cmd *exec.Cmd) (*Client, error) {
notifications: make(chan *Notification, 100),
}
// Start reader goroutine
// Start reader goroutine for stdout
go c.readLoop()
// Start stderr drain goroutine to prevent pipe buffer from filling up
go c.drainStderr()
return c, nil
}
// drainStderr consumes stderr output to prevent the LSP server from blocking.
// LSP servers may write diagnostic messages to stderr which we discard.
func (c *Client) drainStderr() {
buf := make([]byte, StderrBufferSize)
for {
select {
case <-c.done:
return
default:
}
// Read and discard stderr output
_, err := c.stderr.Read(buf)
if err != nil {
return
}
}
}
// Call sends a request and waits for a response.
func (c *Client) Call(ctx context.Context, method string, params interface{}) (*Response, error) {
c.runningMu.RLock()
@@ -208,7 +235,7 @@ func (c *Client) Close() error {
select {
case <-done:
// Clean exit
case <-time.After(5 * time.Second):
case <-time.After(ProcessKillTimeout):
// Force kill
_ = c.cmd.Process.Kill()
}
+53 -7
View File
@@ -15,6 +15,18 @@ import (
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// LSP timeout and interval constants.
const (
// DefaultLSPTimeout is the default timeout for LSP requests.
DefaultLSPTimeout = 10 * time.Second
// DefaultIdleTimeout is the duration before idle LSP servers are reaped.
DefaultIdleTimeout = 5 * time.Minute
// ReaperInterval is how often the idle server reaper runs.
ReaperInterval = 60 * time.Second
// ShutdownTimeout is the timeout for graceful LSP server shutdown.
ShutdownTimeout = 2 * time.Second
)
// Manager manages LSP servers for different languages.
type Manager struct {
servers map[protocol.Language]*ManagedServer
@@ -70,12 +82,27 @@ var DefaultServerConfigs = map[protocol.Language]ServerConfig{
},
}
// AllowedLSPBinaries is a whitelist of allowed LSP server binary names.
// This prevents command injection by ensuring only known LSP servers can be executed.
var AllowedLSPBinaries = map[string]bool{
"gopls": true,
"typescript-language-server": true,
"pylsp": true,
"clangd": true,
// Common alternatives
"tsserver": true,
"pyright": true,
"ruff-lsp": true,
"rust-analyzer": true,
"ccls": true,
}
// NewManager creates a new LSP manager.
func NewManager(workspaceRoot string, logger *slog.Logger) *Manager {
m := &Manager{
servers: make(map[protocol.Language]*ManagedServer),
timeout: 10 * time.Second,
idleTimeout: 5 * time.Minute,
timeout: DefaultLSPTimeout,
idleTimeout: DefaultIdleTimeout,
workspaceRoot: workspaceRoot,
logger: logger,
stopReaper: make(chan struct{}),
@@ -127,12 +154,20 @@ func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*Manag
return nil, errors.NewLSPServerNotFound(string(lang), config.Command[0])
}
// Validate command against whitelist to prevent command injection
binaryName := filepath.Base(cmdPath)
if !AllowedLSPBinaries[binaryName] {
return nil, errors.New(errors.ErrLSPServerNotFound, fmt.Sprintf("LSP binary %q is not in the allowed list", binaryName)).
WithContext("language", string(lang)).
WithContext("binary", binaryName).
WithRemediation("Only whitelisted LSP server binaries are allowed for security reasons")
}
// Create command
args := append(config.Command[1:], config.Args...)
cmd := exec.CommandContext(ctx, cmdPath, args...)
cmd.Env = os.Environ()
cmd.Dir = m.workspaceRoot
// Create client
client, err := NewClient(cmd)
if err != nil {
@@ -447,7 +482,7 @@ func (m *Manager) CloseDocument(_ context.Context, lang protocol.Language, file
// reapIdleServers periodically closes idle servers.
func (m *Manager) reapIdleServers() {
ticker := time.NewTicker(60 * time.Second)
ticker := time.NewTicker(ReaperInterval)
defer ticker.Stop()
for {
@@ -455,6 +490,10 @@ func (m *Manager) reapIdleServers() {
case <-m.stopReaper:
return
case <-ticker.C:
// Collect idle servers first to avoid holding the lock while closing
var toClose []*ManagedServer
var toCloseLanguages []protocol.Language
m.mu.Lock()
for lang, srv := range m.servers {
// Check lastUsed with server's lock to avoid race condition
@@ -463,12 +502,19 @@ func (m *Manager) reapIdleServers() {
srv.mu.Unlock()
if idle {
m.logger.Info("closing idle LSP server", "language", lang)
_ = srv.client.Close()
toClose = append(toClose, srv)
toCloseLanguages = append(toCloseLanguages, lang)
delete(m.servers, lang)
}
}
m.mu.Unlock()
// Close servers outside the lock to prevent deadlock
// (Close can block waiting for the process to exit)
for i, srv := range toClose {
m.logger.Info("closing idle LSP server", "language", toCloseLanguages[i])
_ = srv.client.Close()
}
}
}
}
@@ -485,7 +531,7 @@ func (m *Manager) Close() error {
for lang, srv := range m.servers {
m.logger.Info("shutting down LSP server", "language", lang)
// Try graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
_, _ = srv.client.Call(ctx, "shutdown", nil)
cancel()
_ = srv.client.Notify("exit", nil)
+561
View File
@@ -0,0 +1,561 @@
package metrics
import (
"strings"
"sync"
"testing"
"time"
)
func TestCounter(t *testing.T) {
tests := []struct {
name string
ops func(c *Counter)
expected int64
}{
{
name: "initial value is zero",
ops: func(c *Counter) {},
expected: 0,
},
{
name: "single inc",
ops: func(c *Counter) {
c.Inc()
},
expected: 1,
},
{
name: "multiple inc",
ops: func(c *Counter) {
c.Inc()
c.Inc()
c.Inc()
},
expected: 3,
},
{
name: "add positive",
ops: func(c *Counter) {
c.Add(10)
},
expected: 10,
},
{
name: "mixed operations",
ops: func(c *Counter) {
c.Inc()
c.Add(5)
c.Inc()
},
expected: 7,
},
{
name: "reset",
ops: func(c *Counter) {
c.Add(100)
c.Reset()
},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewCounter("test_counter", "test help", nil)
tt.ops(c)
if got := c.Value(); got != tt.expected {
t.Errorf("Counter.Value() = %d, want %d", got, tt.expected)
}
})
}
}
func TestCounterConcurrency(t *testing.T) {
c := NewCounter("concurrent_counter", "test", nil)
var wg sync.WaitGroup
numGoroutines := 100
incsPerGoroutine := 1000
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < incsPerGoroutine; j++ {
c.Inc()
}
}()
}
wg.Wait()
expected := int64(numGoroutines * incsPerGoroutine)
if got := c.Value(); got != expected {
t.Errorf("Counter.Value() = %d, want %d after concurrent increments", got, expected)
}
}
func TestGauge(t *testing.T) {
tests := []struct {
name string
ops func(g *Gauge)
expected int64
}{
{
name: "initial value is zero",
ops: func(g *Gauge) {},
expected: 0,
},
{
name: "set value",
ops: func(g *Gauge) {
g.Set(42)
},
expected: 42,
},
{
name: "inc",
ops: func(g *Gauge) {
g.Inc()
g.Inc()
},
expected: 2,
},
{
name: "dec",
ops: func(g *Gauge) {
g.Set(10)
g.Dec()
g.Dec()
},
expected: 8,
},
{
name: "add positive",
ops: func(g *Gauge) {
g.Add(5)
},
expected: 5,
},
{
name: "add negative",
ops: func(g *Gauge) {
g.Set(10)
g.Add(-3)
},
expected: 7,
},
{
name: "mixed operations",
ops: func(g *Gauge) {
g.Set(100)
g.Inc()
g.Dec()
g.Add(-50)
},
expected: 50,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewGauge("test_gauge", "test help", nil)
tt.ops(g)
if got := g.Value(); got != tt.expected {
t.Errorf("Gauge.Value() = %d, want %d", got, tt.expected)
}
})
}
}
func TestHistogram(t *testing.T) {
t.Run("default buckets", func(t *testing.T) {
h := NewHistogram("test_histogram", "test", nil, nil)
if len(h.buckets) != len(DefaultDurationBuckets) {
t.Errorf("expected default buckets, got %d buckets", len(h.buckets))
}
})
t.Run("custom buckets", func(t *testing.T) {
buckets := []float64{1.0, 5.0, 10.0}
h := NewHistogram("test_histogram", "test", nil, buckets)
if len(h.buckets) != 3 {
t.Errorf("expected 3 buckets, got %d", len(h.buckets))
}
})
t.Run("buckets are sorted", func(t *testing.T) {
buckets := []float64{10.0, 1.0, 5.0}
h := NewHistogram("test_histogram", "test", nil, buckets)
if h.buckets[0] != 1.0 || h.buckets[1] != 5.0 || h.buckets[2] != 10.0 {
t.Errorf("buckets not sorted: %v", h.buckets)
}
})
t.Run("observe values", func(t *testing.T) {
h := NewHistogram("test_histogram", "test", nil, []float64{1.0, 5.0, 10.0})
h.Observe(0.5) // goes to bucket 0 (<=1.0)
h.Observe(3.0) // goes to bucket 1 (<=5.0)
h.Observe(7.0) // goes to bucket 2 (<=10.0)
h.Observe(15.0) // goes to +Inf bucket
if h.Count() != 4 {
t.Errorf("expected count 4, got %d", h.Count())
}
})
t.Run("observe duration", func(t *testing.T) {
h := NewHistogram("test_histogram", "test", nil, []float64{0.001, 0.01, 0.1})
h.ObserveDuration(500 * time.Microsecond) // 0.0005s, goes to bucket 0
h.ObserveDuration(5 * time.Millisecond) // 0.005s, goes to bucket 1
if h.Count() != 2 {
t.Errorf("expected count 2, got %d", h.Count())
}
})
t.Run("sum tracking", func(t *testing.T) {
h := NewHistogram("test_histogram", "test", nil, []float64{1.0, 5.0, 10.0})
h.Observe(1.0)
h.Observe(2.0)
h.Observe(3.0)
expectedSum := 6.0
if got := h.Sum(); got != expectedSum {
t.Errorf("expected sum %f, got %f", expectedSum, got)
}
})
}
func TestRegistry(t *testing.T) {
t.Run("counter registration", func(t *testing.T) {
r := NewRegistry()
c1 := r.Counter("test_counter", "help", nil)
c2 := r.Counter("test_counter", "help", nil)
if c1 != c2 {
t.Error("expected same counter instance for same name")
}
})
t.Run("counter with labels", func(t *testing.T) {
r := NewRegistry()
labels1 := map[string]string{"method": "get"}
labels2 := map[string]string{"method": "post"}
c1 := r.Counter("http_requests", "help", labels1)
c2 := r.Counter("http_requests", "help", labels2)
if c1 == c2 {
t.Error("expected different counter instances for different labels")
}
})
t.Run("gauge registration", func(t *testing.T) {
r := NewRegistry()
g1 := r.Gauge("test_gauge", "help", nil)
g2 := r.Gauge("test_gauge", "help", nil)
if g1 != g2 {
t.Error("expected same gauge instance for same name")
}
})
t.Run("histogram registration", func(t *testing.T) {
r := NewRegistry()
h1 := r.Histogram("test_histogram", "help", nil, nil)
h2 := r.Histogram("test_histogram", "help", nil, nil)
if h1 != h2 {
t.Error("expected same histogram instance for same name")
}
})
}
func TestRegistryConcurrency(t *testing.T) {
r := NewRegistry()
var wg sync.WaitGroup
numGoroutines := 100
// Concurrent access to registry
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
c := r.Counter("concurrent_test", "test", nil)
c.Inc()
g := r.Gauge("concurrent_gauge", "test", nil)
g.Inc()
}(i)
}
wg.Wait()
c := r.Counter("concurrent_test", "test", nil)
if c.Value() != int64(numGoroutines) {
t.Errorf("expected counter value %d, got %d", numGoroutines, c.Value())
}
}
func TestRegistryExpose(t *testing.T) {
r := NewRegistry()
// Add some metrics
c := r.Counter("test_requests_total", "Total requests", nil)
c.Add(42)
g := r.Gauge("test_connections", "Active connections", nil)
g.Set(10)
h := r.Histogram("test_duration_seconds", "Request duration", nil, []float64{0.1, 0.5, 1.0})
h.Observe(0.05)
h.Observe(0.3)
h.Observe(0.8)
output := r.Expose()
// Check counter output
if !strings.Contains(output, "# TYPE test_requests_total counter") {
t.Error("expected counter type in output")
}
if !strings.Contains(output, "test_requests_total 42") {
t.Error("expected counter value in output")
}
// Check gauge output
if !strings.Contains(output, "# TYPE test_connections gauge") {
t.Error("expected gauge type in output")
}
if !strings.Contains(output, "test_connections 10") {
t.Error("expected gauge value in output")
}
// Check histogram output
if !strings.Contains(output, "# TYPE test_duration_seconds histogram") {
t.Error("expected histogram type in output")
}
if !strings.Contains(output, "test_duration_seconds_bucket") {
t.Error("expected histogram buckets in output")
}
if !strings.Contains(output, "test_duration_seconds_sum") {
t.Error("expected histogram sum in output")
}
if !strings.Contains(output, "test_duration_seconds_count") {
t.Error("expected histogram count in output")
}
}
func TestRegistryReset(t *testing.T) {
r := NewRegistry()
c := r.Counter("test_counter", "test", nil)
c.Add(100)
g := r.Gauge("test_gauge", "test", nil)
g.Set(50)
r.Reset()
if c.Value() != 0 {
t.Errorf("expected counter reset to 0, got %d", c.Value())
}
if g.Value() != 0 {
t.Errorf("expected gauge reset to 0, got %d", g.Value())
}
}
func TestServerMetrics(t *testing.T) {
t.Run("creation", func(t *testing.T) {
m := NewServerMetrics()
if m.RequestsTotal == nil {
t.Error("RequestsTotal should not be nil")
}
if m.RequestErrors == nil {
t.Error("RequestErrors should not be nil")
}
if m.RequestDuration == nil {
t.Error("RequestDuration should not be nil")
}
if m.CacheHits == nil {
t.Error("CacheHits should not be nil")
}
if m.CacheMisses == nil {
t.Error("CacheMisses should not be nil")
}
if m.ActiveLSPServers == nil {
t.Error("ActiveLSPServers should not be nil")
}
if m.ParseDuration == nil {
t.Error("ParseDuration should not be nil")
}
if m.ParseErrors == nil {
t.Error("ParseErrors should not be nil")
}
})
t.Run("record request success", func(t *testing.T) {
m := NewServerMetrics()
m.RecordRequest(100*time.Millisecond, nil)
if m.RequestsTotal.Value() != 1 {
t.Errorf("expected RequestsTotal 1, got %d", m.RequestsTotal.Value())
}
if m.RequestErrors.Value() != 0 {
t.Errorf("expected RequestErrors 0, got %d", m.RequestErrors.Value())
}
if m.RequestDuration.Count() != 1 {
t.Errorf("expected RequestDuration count 1, got %d", m.RequestDuration.Count())
}
})
t.Run("record request error", func(t *testing.T) {
m := NewServerMetrics()
m.RecordRequest(50*time.Millisecond, &testError{})
if m.RequestsTotal.Value() != 1 {
t.Errorf("expected RequestsTotal 1, got %d", m.RequestsTotal.Value())
}
if m.RequestErrors.Value() != 1 {
t.Errorf("expected RequestErrors 1, got %d", m.RequestErrors.Value())
}
})
t.Run("record parse", func(t *testing.T) {
m := NewServerMetrics()
m.RecordParse(10*time.Millisecond, nil)
m.RecordParse(5*time.Millisecond, &testError{})
if m.ParseDuration.Count() != 2 {
t.Errorf("expected ParseDuration count 2, got %d", m.ParseDuration.Count())
}
if m.ParseErrors.Value() != 1 {
t.Errorf("expected ParseErrors 1, got %d", m.ParseErrors.Value())
}
})
t.Run("record cache", func(t *testing.T) {
m := NewServerMetrics()
m.RecordCacheHit()
m.RecordCacheHit()
m.RecordCacheMiss()
if m.CacheHits.Value() != 2 {
t.Errorf("expected CacheHits 2, got %d", m.CacheHits.Value())
}
if m.CacheMisses.Value() != 1 {
t.Errorf("expected CacheMisses 1, got %d", m.CacheMisses.Value())
}
})
t.Run("set active LSP servers", func(t *testing.T) {
m := NewServerMetrics()
m.SetActiveLSPServers(5)
if m.ActiveLSPServers.Value() != 5 {
t.Errorf("expected ActiveLSPServers 5, got %d", m.ActiveLSPServers.Value())
}
m.SetActiveLSPServers(3)
if m.ActiveLSPServers.Value() != 3 {
t.Errorf("expected ActiveLSPServers 3, got %d", m.ActiveLSPServers.Value())
}
})
t.Run("expose", func(t *testing.T) {
m := NewServerMetrics()
m.RecordRequest(100*time.Millisecond, nil)
output := m.Expose()
if !strings.Contains(output, "mcp_requests_total") {
t.Error("expected mcp_requests_total in output")
}
if !strings.Contains(output, "mcp_request_duration_seconds") {
t.Error("expected mcp_request_duration_seconds in output")
}
})
}
func TestMetricKey(t *testing.T) {
tests := []struct {
name string
labels map[string]string
expected string
}{
{
name: "no labels",
labels: nil,
expected: "test_metric",
},
{
name: "single label",
labels: map[string]string{"method": "get"},
expected: `test_metric{method="get"}`,
},
{
name: "multiple labels sorted",
labels: map[string]string{"method": "get", "code": "200"},
expected: `test_metric{code="200",method="get"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := metricKey("test_metric", tt.labels)
if got != tt.expected {
t.Errorf("metricKey() = %q, want %q", got, tt.expected)
}
})
}
}
func TestFormatLabels(t *testing.T) {
tests := []struct {
name string
labels map[string]string
expected string
}{
{
name: "no labels",
labels: nil,
expected: "",
},
{
name: "single label",
labels: map[string]string{"method": "get"},
expected: `{method="get"}`,
},
{
name: "multiple labels sorted",
labels: map[string]string{"method": "get", "code": "200"},
expected: `{code="200",method="get"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := formatLabels(tt.labels)
if got != tt.expected {
t.Errorf("formatLabels() = %q, want %q", got, tt.expected)
}
})
}
}
// testError is a simple error type for testing
type testError struct{}
func (e *testError) Error() string { return "test error" }
+4
View File
@@ -104,6 +104,7 @@ func NoComment() {}
if doc == nil {
t.Fatal("expected doc, got nil")
return
}
if doc.Text != tt.wantText {
@@ -192,6 +193,7 @@ function validate(name) {}
if doc == nil {
t.Fatal("expected doc, got nil")
return
}
if doc.Text != tt.wantText {
@@ -291,6 +293,7 @@ func TestExtractPythonDocComment(t *testing.T) {
if doc == nil {
t.Fatal("expected doc, got nil")
return
}
if doc.Text != tt.wantText {
@@ -372,6 +375,7 @@ int simple() { return 1; }
if doc == nil {
t.Fatal("expected doc, got nil")
return
}
if doc.Text != tt.wantText {
+26 -6
View File
@@ -34,6 +34,8 @@ type Registry struct {
cache *lru.Cache[string, *CachedTree]
maxParseSize int64
mu sync.RWMutex
parserMu sync.Map // per-language mutexes for parse serialization
closed bool // Indicates if the registry has been closed
// Cache metrics (atomic for thread-safety)
cacheHits atomic.Int64
@@ -136,8 +138,13 @@ func getLanguage(lang protocol.Language) (*sitter.Language, error) {
}
// GetParser returns a parser for the given language.
// Returns an error if the registry has been closed.
func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
r.mu.RLock()
if r.closed {
r.mu.RUnlock()
return nil, errors.New(errors.ErrInternal, "parser registry is closed")
}
if p, ok := r.parsers[lang]; ok {
r.mu.RUnlock()
return p, nil
@@ -148,6 +155,11 @@ func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
r.mu.Lock()
defer r.mu.Unlock()
// Check closed again after acquiring write lock
if r.closed {
return nil, errors.New(errors.ErrInternal, "parser registry is closed")
}
// Double-check after acquiring write lock
if p, ok := r.parsers[lang]; ok {
return p, nil
@@ -196,7 +208,8 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
}
// Check cache (LRU cache is thread-safe)
hash := contentHash(content)
// Include language in cache key to prevent cross-language collisions
hash := fmt.Sprintf("%s:%016x", string(lang), xxhash.Sum64(content))
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
r.cacheHits.Add(1)
errors := extractErrors(cached.Tree.RootNode(), content)
@@ -215,13 +228,16 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
return nil, err
}
// Parse content - tree-sitter parsers are not thread-safe,
// so we need to hold the lock during parsing
// Track parse duration
// Parse content - tree-sitter parsers are not thread-safe per instance,
// but parsers for different languages are independent.
// Use per-language locks to allow concurrent parsing of different languages.
muVal, _ := r.parserMu.LoadOrStore(lang, &sync.Mutex{})
langMu := muVal.(*sync.Mutex)
start := time.Now()
r.mu.Lock()
langMu.Lock()
tree, err := parser.ParseCtx(ctx, nil, content)
r.mu.Unlock()
langMu.Unlock()
duration := time.Since(start)
// Update duration metrics
@@ -351,10 +367,14 @@ func isBinary(content []byte) bool {
}
// Close closes all parsers and clears the cache.
// After Close is called, the registry cannot be used for parsing.
func (r *Registry) Close() {
r.mu.Lock()
defer r.mu.Unlock()
// Mark as closed first to prevent new parse operations
r.closed = true
for _, p := range r.parsers {
p.Close()
}
+23 -21
View File
@@ -214,9 +214,10 @@ func (s *Searcher) buildArgs(req *Request) []string {
args = append(args, "--no-ignore")
}
// Max count per file to limit results
// Global result cap — --max-total-count stops rg early across all files.
// Requires ripgrep >= 13.0. In-process truncation in parseOutput is kept as a safety net.
if req.MaxResults > 0 {
args = append(args, fmt.Sprintf("--max-count=%d", req.MaxResults))
args = append(args, fmt.Sprintf("--max-total-count=%d", req.MaxResults))
}
// Add pattern
@@ -243,9 +244,9 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
Results: []Result{},
}
// Track context by file and line
contextBefore := make(map[string][]string) // file -> lines before current match
currentFile := ""
// Track before-context lines linearly: accumulate context lines until the next match consumes them.
var pendingBefore []string
pendingFile := ""
scanner := bufio.NewScanner(output)
for scanner.Scan() {
@@ -285,14 +286,14 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
result.Column = match.Submatches[0].Start + 1 // 1-indexed
}
// Add context before
if ctx, ok := contextBefore[match.Path.Text]; ok {
result.Context.Before = ctx
delete(contextBefore, match.Path.Text)
// Attach pending before-context if it belongs to this file
if pendingFile == match.Path.Text && len(pendingBefore) > 0 {
result.Context.Before = pendingBefore
}
pendingBefore = nil
pendingFile = ""
results.Results = append(results.Results, result)
currentFile = match.Path.Text
case "context":
var ctx rgContext
@@ -302,19 +303,20 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
lineText := strings.TrimRight(ctx.Lines.Text, "\n\r")
// Determine if this is before or after context
isAfter := false
if len(results.Results) > 0 {
lastResult := &results.Results[len(results.Results)-1]
if lastResult.File == ctx.Path.Text && ctx.LineNumber > lastResult.Line {
// This is after context
lastResult.Context.After = append(lastResult.Context.After, lineText)
} else if ctx.Path.Text == currentFile || currentFile == "" {
// This is before context for a potential upcoming match
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
last := &results.Results[len(results.Results)-1]
if last.File == ctx.Path.Text && ctx.LineNumber > last.Line {
last.Context.After = append(last.Context.After, lineText)
isAfter = true
}
} else {
// Before any match - store as potential before context
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
}
if !isAfter {
if pendingFile != ctx.Path.Text {
pendingBefore = nil
pendingFile = ctx.Path.Text
}
pendingBefore = append(pendingBefore, lineText)
}
case "summary":
+1 -1
View File
@@ -89,7 +89,7 @@ func TestBuildArgs(t *testing.T) {
MaxResults: 10,
Regex: true,
},
expected: []string{"--json", "--max-count=10", "--", "test", "."},
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
},
}
+226
View File
@@ -0,0 +1,226 @@
// Package server implements the MCP server for file operations.
package server
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/mark3labs/mcp-go/mcp"
)
// handleASTQuery handles the ast_query tool.
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Acquire semaphore to limit concurrent queries (prevents CPU exhaustion)
select {
case s.querySem <- struct{}{}:
defer func() { <-s.querySem }()
case <-ctx.Done():
return mcp.NewToolResultError("request cancelled"), nil
}
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
}
language, err := request.RequireString("language")
if err != nil {
return mcp.NewToolResultError("language is required"), nil
}
// Build query
astQuery := &query.ASTQuery{
Pattern: pattern,
Language: language,
Filters: query.QueryFilters{
NameMatches: request.GetString("name_matches", ""),
NameExact: request.GetString("name_exact", ""),
KindIn: request.GetStringSlice("kind_in", nil),
},
}
maxResults := request.GetInt("max_results", 100)
paths := request.GetStringSlice("paths", nil)
// Default to workspace root if no paths specified
if len(paths) == 0 {
paths = []string{s.cfg.WorkspaceRoot}
}
// Find files to search based on language
ext := languageToExtension(language)
if ext == "" {
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
}
var allResults []query.MatchResult
// Walk through paths and find matching files
for _, searchPath := range paths {
// Validate path is within workspace
if !s.cfg.IsPathAllowed(searchPath) {
continue
}
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
// Check for context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err != nil {
return nil // Skip files with errors
}
if info.IsDir() {
// Skip hidden directories
if strings.HasPrefix(info.Name(), ".") {
return filepath.SkipDir
}
return nil
}
// Check file extension matches language
if !strings.HasSuffix(path, ext) {
return nil
}
// Read and parse file
content, err := os.ReadFile(path)
if err != nil {
return nil // Skip unreadable files
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return nil // Skip large files
}
// Parse file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return nil // Skip unparseable files
}
// Run query
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
if err != nil {
return nil // Skip on error
}
allResults = append(allResults, matches...)
// Stop if we have enough results
if maxResults > 0 && len(allResults) >= maxResults {
return filepath.SkipAll
}
return nil
})
if err != nil {
s.logger.Warn("error walking path", "path", searchPath, "error", err)
}
}
// Format and return results
output := query.FormatResults(allResults, maxResults)
return mcp.NewToolResultText(output), nil
}
// generateASTSummary generates a summary of symbols in the file.
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
// Parse the file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return "" // Silently skip AST if parsing fails
}
// Extract symbols
lang := protocol.DetectLanguage(path)
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
if len(symbols) == 0 {
return ""
}
var sb strings.Builder
// Get relative path
relPath := path
if absPath, err := filepath.Abs(path); err == nil {
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
relPath = rel
}
}
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
sb.WriteString("Symbols:\n")
for _, sym := range symbols {
kindStr := symbolKindIcon(sym.Kind)
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
}
return sb.String()
}
// symbolKindIcon returns an icon/prefix for a symbol kind.
func symbolKindIcon(kind protocol.SymbolKind) string {
switch kind {
case protocol.SymbolFunction:
return "func"
case protocol.SymbolMethod:
return "meth"
case protocol.SymbolClass:
return "class"
case protocol.SymbolStruct:
return "struct"
case protocol.SymbolInterface:
return "iface"
case protocol.SymbolVariable:
return "var"
case protocol.SymbolConstant:
return "const"
case protocol.SymbolType:
return "type"
case protocol.SymbolField:
return "field"
case protocol.SymbolProperty:
return "prop"
case protocol.SymbolModule:
return "mod"
case protocol.SymbolPackage:
return "pkg"
default:
return "sym"
}
}
// languageToExtension maps language names to file extensions.
func languageToExtension(language string) string {
switch strings.ToLower(language) {
case "go":
return ".go"
case "typescript":
return ".ts"
case "javascript":
return ".js"
case "python":
return ".py"
case "c":
return ".c"
case "cpp", "c++":
return ".cpp"
case "elixir":
return ".ex"
default:
return ""
}
}
+91
View File
@@ -0,0 +1,91 @@
// Package server implements the MCP server for file operations.
package server
import (
"context"
"fmt"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/internal/edit"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/mark3labs/mcp-go/mcp"
)
// handleEditPreview handles the edit_preview tool.
func (s *Server) handleEditPreview(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleEdit(ctx, request, false)
}
// handleEditApply handles the edit_apply tool.
func (s *Server) handleEditApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleEdit(ctx, request, true)
}
// handleEdit is the shared implementation for edit_preview and edit_apply.
func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, apply bool) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
operation, err := request.RequireString("operation")
if err != nil {
return mcp.NewToolResultError("operation is required"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// Note: We no longer validate language support here.
// The edit engine automatically detects whether to use AST or text mode.
// Build edit request with both AST and text-mode selectors
astEdit := &edit.ASTEdit{
File: file,
Operation: edit.EditOperation(operation),
NewContent: request.GetString("new_content", ""),
Selector: edit.ASTSelector{
// AST-mode selectors
Kind: request.GetString("selector_kind", ""),
Name: request.GetString("selector_name", ""),
AtLine: request.GetInt("selector_line", 0),
Index: request.GetInt("selector_index", 0),
// Text-mode selectors
LineEnd: request.GetInt("selector_line_end", 0),
Text: request.GetString("selector_text", ""),
TextPattern: request.GetString("selector_pattern", ""),
},
}
// Perform edit
var result *edit.EditResult
if apply {
result, err = s.editor.Apply(ctx, astEdit)
} else {
result, err = s.editor.Preview(ctx, astEdit)
}
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("edit failed: %s", errors.SanitizeError(err))), nil
}
if !result.Success {
return mcp.NewToolResultError(result.Error), nil
}
// Format output
var output strings.Builder
if apply {
output.WriteString("**Edit Applied Successfully**\n\n")
} else {
output.WriteString("**Edit Preview**\n\n")
}
output.WriteString("Diff:\n```diff\n")
output.WriteString(result.Diff)
output.WriteString("```\n")
return mcp.NewToolResultText(output.String()), nil
}
+185
View File
@@ -0,0 +1,185 @@
// Package server implements the MCP server for file operations.
package server
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"time"
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/mark3labs/mcp-go/mcp"
)
// handleFileSearch handles the file_search tool.
func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
start := time.Now()
defer func() {
s.logger.Debug("file_search completed",
"duration_ms", time.Since(start).Milliseconds(),
)
}()
if s.searcher == nil {
return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil
}
// Parse request arguments using SDK helpers
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
}
req := &search.Request{
Pattern: pattern,
Paths: request.GetStringSlice("paths", nil),
FileTypes: request.GetStringSlice("file_types", nil),
IgnoreCase: request.GetBool("ignore_case", false),
Regex: request.GetBool("regex", true),
ContextLines: request.GetInt("context_lines", 2),
MaxResults: request.GetInt("max_results", 0),
}
// Execute search
results, err := s.searcher.Search(ctx, req)
if err != nil {
s.logger.Warn("search error", "error", err)
return mcp.NewToolResultError(fmt.Sprintf("search error: %s", errors.SanitizeError(err))), nil
}
s.logger.Info("search completed",
"pattern", pattern,
"results_count", len(results.Results),
"truncated", results.Truncated,
)
// Format results
output := s.searcher.FormatResults(results)
return mcp.NewToolResultText(output), nil
}
// handleFileRead handles the file_read tool.
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Acquire semaphore to limit concurrent reads (prevents memory exhaustion)
select {
case s.readSem <- struct{}{}:
defer func() { <-s.readSem }()
case <-ctx.Done():
return mcp.NewToolResultError("request cancelled"), nil
}
path, err := request.RequireString("path")
if err != nil {
return mcp.NewToolResultError("path is required"), nil
}
// Validate path is within workspace
if !s.cfg.IsPathAllowed(path) {
return mcp.NewToolResultError("path is outside workspace root"), nil
}
// Read file
content, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
}
if os.IsPermission(err) {
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
}
s.logger.Warn("file read error", "path", path, "error", err)
return mcp.NewToolResultError("error reading file"), nil
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
}
// Handle line range
lines := splitLines(string(content))
lineStart := request.GetInt("line_start", 1)
lineEnd := request.GetInt("line_end", len(lines))
// Clamp to valid range
if lineStart < 1 {
lineStart = 1
}
if lineEnd > len(lines) {
lineEnd = len(lines)
}
if lineStart > lineEnd {
lineStart = lineEnd
}
var output strings.Builder
// Include AST summary if requested
includeAST := request.GetBool("include_ast", false)
symbolsOnly := request.GetBool("symbols_only", false)
maxLines := request.GetInt("max_lines", 0)
// Validate symbols_only requires include_ast
if symbolsOnly && !includeAST {
return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil
}
if includeAST {
astSummary := s.generateASTSummary(ctx, path, content)
if astSummary != "" {
output.WriteString(astSummary)
if !symbolsOnly {
output.WriteString("\n---\n\n")
}
}
}
// Skip file content if symbols_only mode
if !symbolsOnly {
// Apply max_lines limit if specified
effectiveEnd := lineEnd
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
effectiveEnd = lineStart + maxLines - 1
if effectiveEnd < lineEnd {
// Add note that output was truncated
defer func() {
output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd))
}()
}
}
// Extract requested lines
for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ {
output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i]))
}
}
return mcp.NewToolResultText(output.String()), nil
}
// splitLines splits a string into lines.
// For large files (> 1MB), uses bufio.Scanner which is more memory efficient.
// For smaller files, uses simple string split which is faster.
func splitLines(s string) []string {
const largeSizeThreshold = 1024 * 1024 // 1MB
if len(s) > largeSizeThreshold {
// Use scanner for large files
scanner := bufio.NewScanner(strings.NewReader(s))
var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
// Handle potential error and add empty line if string ended with newline
if len(s) > 0 && s[len(s)-1] == '\n' {
lines = append(lines, "")
}
return lines
}
// Use optimized stdlib implementation for smaller files (2-3x faster than manual loop)
return strings.Split(s, "\n")
}
+211
View File
@@ -0,0 +1,211 @@
// Package server implements the MCP server for file operations.
package server
import (
"context"
"fmt"
"os"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/internal/lsp"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/mark3labs/mcp-go/mcp"
)
// handleSymbolAt handles the symbol_at tool.
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// Try LSP hover
hover, err := s.lspManager.Hover(ctx, file, line, col)
if err != nil {
// Fall back to AST-based info
return s.handleSymbolAtFallback(ctx, file, line, col)
}
if hover == nil {
return mcp.NewToolResultText("No symbol information available at this position."), nil
}
var output strings.Builder
output.WriteString("**Symbol Information**\n\n")
output.WriteString(hover.Contents.Value)
return mcp.NewToolResultText(output.String()), nil
}
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
content, err := os.ReadFile(file)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %s", errors.SanitizeError(err))), nil
}
result, err := s.parser.Parse(ctx, file, content)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to parse file: %s", errors.SanitizeError(err))), nil
}
node := parser.FindNodeAtPosition(result.Tree, line, col)
if node == nil {
return mcp.NewToolResultText("No symbol at this position."), nil
}
var output strings.Builder
output.WriteString("**Symbol Information** (AST fallback)\n\n")
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
return mcp.NewToolResultText(output.String()), nil
}
// handleFindDefinition handles the find_definition tool.
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
locations, err := s.lspManager.Definition(ctx, file, line, col)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %s", errors.SanitizeError(err))), nil
}
if len(locations) == 0 {
return mcp.NewToolResultText("No definition found."), nil
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
// Try to read a preview snippet
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
if preview != "" {
output.WriteString("```\n")
output.WriteString(preview)
output.WriteString("```\n")
}
output.WriteString("\n")
}
return mcp.NewToolResultText(output.String()), nil
}
// handleFindReferences handles the find_references tool.
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
includeDecl := request.GetBool("include_declaration", true)
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %s", errors.SanitizeError(err))), nil
}
if len(locations) == 0 {
return mcp.NewToolResultText("No references found."), nil
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
// Group by file
fileGroups := make(map[string][]lsp.Location)
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
fileGroups[filePath] = append(fileGroups[filePath], loc)
}
for filePath, locs := range fileGroups {
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
for _, loc := range locs {
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
}
output.WriteString("\n")
}
return mcp.NewToolResultText(output.String()), nil
}
// readFilePreview reads a few lines from a file around the given line.
func readFilePreview(file string, line, contextLines int) string {
content, err := os.ReadFile(file)
if err != nil {
return ""
}
lines := splitLines(string(content))
startLine := max(1, line-contextLines)
endLine := min(line+contextLines, len(lines))
var preview strings.Builder
for i := startLine - 1; i < endLine && i < len(lines); i++ {
lineText := lines[i]
if len(lineText) > PreviewLineMaxLength {
lineText = lineText[:PreviewLineMaxLength] + "..."
}
prefix := " "
if i+1 == line {
prefix = "> "
}
preview.WriteString(fmt.Sprintf("%s%4d: %s\n", prefix, i+1, lineText))
}
return preview.String()
}
+8
View File
@@ -52,6 +52,7 @@ func Hello() string {
}
if result == nil {
t.Fatal("handlePing() returned nil")
return
}
if len(result.Content) == 0 {
t.Fatal("handlePing() returned empty content")
@@ -70,6 +71,7 @@ func Hello() string {
}
if result == nil {
t.Fatal("handleFileRead() returned nil")
return
}
if len(result.Content) == 0 {
t.Fatal("handleFileRead() returned empty content")
@@ -90,6 +92,7 @@ func Hello() string {
}
if result == nil {
t.Fatal("handleASTQuery() returned nil")
return
}
})
@@ -110,6 +113,7 @@ func Hello() string {
}
if previewResult == nil {
t.Fatal("handleEditPreview() returned nil")
return
}
// Verify file unchanged after preview
@@ -127,6 +131,7 @@ func Hello() string {
}
if applyResult == nil {
t.Fatal("handleEditApply() returned nil")
return
}
// Verify file changed after apply
@@ -352,6 +357,7 @@ func Add(a, b int) int {
}
if readResult == nil {
t.Fatal("handleFileRead() returned nil")
return
}
// 2. Query AST
@@ -367,6 +373,7 @@ func Add(a, b int) int {
}
if queryResult == nil {
t.Fatal("handleASTQuery() returned nil")
return
}
// 3. Preview edit
@@ -384,6 +391,7 @@ func Add(a, b int) int {
}
if editResult == nil {
t.Fatal("handleEditPreview() returned nil")
return
}
})
}
+23 -650
View File
@@ -2,14 +2,10 @@
package server
import (
"bufio"
"context"
"fmt"
"log/slog"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
@@ -19,11 +15,22 @@ import (
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// MaxConcurrentReads limits concurrent file read operations to prevent memory exhaustion.
const MaxConcurrentReads = 10
// MaxConcurrentQueries limits concurrent AST query operations to prevent CPU exhaustion.
const MaxConcurrentQueries = 5
// ServerShutdownTimeout is the timeout for graceful server shutdown.
const ServerShutdownTimeout = 10 * time.Second
// PreviewLineMaxLength is the maximum length for preview lines before truncation.
const PreviewLineMaxLength = 100
// Server represents the MCP file operations server.
type Server struct {
cfg *config.Config
@@ -34,17 +41,21 @@ type Server struct {
matcher *query.Matcher
lspManager *lsp.Manager
editor *edit.Engine
readSem chan struct{} // Semaphore for limiting concurrent file reads
querySem chan struct{} // Semaphore for limiting concurrent AST queries
}
// New creates a new MCP server instance.
func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
parserRegistry := parser.NewRegistryWithSize(cfg.MaxParseSize)
s := &Server{
cfg: cfg,
logger: logger,
parser: parserRegistry,
matcher: query.NewMatcher(parserRegistry),
editor: edit.NewEngine(parserRegistry),
cfg: cfg,
logger: logger,
parser: parserRegistry,
matcher: query.NewMatcher(parserRegistry),
editor: edit.NewEngine(parserRegistry),
readSem: make(chan struct{}, MaxConcurrentReads),
querySem: make(chan struct{}, MaxConcurrentQueries),
}
// Initialize searcher
@@ -341,644 +352,6 @@ func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*
return mcp.NewToolResultText("pong"), nil
}
// handleFileSearch handles the file_search tool.
func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
start := time.Now()
defer func() {
s.logger.Debug("file_search completed",
"duration_ms", time.Since(start).Milliseconds(),
)
}()
if s.searcher == nil {
return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil
}
// Parse request arguments using SDK helpers
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
}
req := &search.Request{
Pattern: pattern,
Paths: request.GetStringSlice("paths", nil),
FileTypes: request.GetStringSlice("file_types", nil),
IgnoreCase: request.GetBool("ignore_case", false),
Regex: request.GetBool("regex", true),
ContextLines: request.GetInt("context_lines", 2),
MaxResults: request.GetInt("max_results", 0),
}
// Execute search
results, err := s.searcher.Search(ctx, req)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("search error: %v", err)), nil
}
s.logger.Info("search completed",
"pattern", pattern,
"results_count", len(results.Results),
"truncated", results.Truncated,
)
// Format results
output := s.searcher.FormatResults(results)
return mcp.NewToolResultText(output), nil
}
// handleFileRead handles the file_read tool.
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
path, err := request.RequireString("path")
if err != nil {
return mcp.NewToolResultError("path is required"), nil
}
// Validate path is within workspace
if !s.cfg.IsPathAllowed(path) {
return mcp.NewToolResultError("path is outside workspace root"), nil
}
// Read file
content, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
}
if os.IsPermission(err) {
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
}
return mcp.NewToolResultError(fmt.Sprintf("error reading file: %v", err)), nil
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
}
// Handle line range
lines := splitLines(string(content))
lineStart := request.GetInt("line_start", 1)
lineEnd := request.GetInt("line_end", len(lines))
// Clamp to valid range
if lineStart < 1 {
lineStart = 1
}
if lineEnd > len(lines) {
lineEnd = len(lines)
}
if lineStart > lineEnd {
lineStart = lineEnd
}
var output strings.Builder
// Include AST summary if requested
includeAST := request.GetBool("include_ast", false)
symbolsOnly := request.GetBool("symbols_only", false)
maxLines := request.GetInt("max_lines", 0)
// Validate symbols_only requires include_ast
if symbolsOnly && !includeAST {
return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil
}
if includeAST {
astSummary := s.generateASTSummary(ctx, path, content)
if astSummary != "" {
output.WriteString(astSummary)
if !symbolsOnly {
output.WriteString("\n---\n\n")
}
}
}
// Skip file content if symbols_only mode
if !symbolsOnly {
// Apply max_lines limit if specified
effectiveEnd := lineEnd
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
effectiveEnd = lineStart + maxLines - 1
if effectiveEnd < lineEnd {
// Add note that output was truncated
defer func() {
output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd))
}()
}
}
// Extract requested lines
for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ {
output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i]))
}
}
return mcp.NewToolResultText(output.String()), nil
}
// generateASTSummary generates a summary of symbols in the file.
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
// Parse the file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return "" // Silently skip AST if parsing fails
}
// Extract symbols
lang := protocol.DetectLanguage(path)
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
if len(symbols) == 0 {
return ""
}
var sb strings.Builder
// Get relative path
relPath := path
if absPath, err := filepath.Abs(path); err == nil {
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
relPath = rel
}
}
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
sb.WriteString("Symbols:\n")
for _, sym := range symbols {
kindStr := symbolKindIcon(sym.Kind)
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
}
return sb.String()
}
// symbolKindIcon returns an icon/prefix for a symbol kind.
func symbolKindIcon(kind protocol.SymbolKind) string {
switch kind {
case protocol.SymbolFunction:
return "func"
case protocol.SymbolMethod:
return "meth"
case protocol.SymbolClass:
return "class"
case protocol.SymbolStruct:
return "struct"
case protocol.SymbolInterface:
return "iface"
case protocol.SymbolVariable:
return "var"
case protocol.SymbolConstant:
return "const"
case protocol.SymbolType:
return "type"
case protocol.SymbolField:
return "field"
case protocol.SymbolProperty:
return "prop"
case protocol.SymbolModule:
return "mod"
case protocol.SymbolPackage:
return "pkg"
default:
return "sym"
}
}
func splitLines(s string) []string {
// For large files (> 1MB), use bufio.Scanner which is more memory efficient
// For smaller files, use simple string split which is faster
const largeSizeThreshold = 1024 * 1024 // 1MB
if len(s) > largeSizeThreshold {
// Use scanner for large files
scanner := bufio.NewScanner(strings.NewReader(s))
var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
// Handle potential error and add empty line if string ended with newline
if len(s) > 0 && s[len(s)-1] == '\n' {
lines = append(lines, "")
}
return lines
}
// Use optimized stdlib implementation for smaller files (2-3x faster than manual loop)
return strings.Split(s, "\n")
}
// handleASTQuery handles the ast_query tool.
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
}
language, err := request.RequireString("language")
if err != nil {
return mcp.NewToolResultError("language is required"), nil
}
// Build query
astQuery := &query.ASTQuery{
Pattern: pattern,
Language: language,
Filters: query.QueryFilters{
NameMatches: request.GetString("name_matches", ""),
NameExact: request.GetString("name_exact", ""),
KindIn: request.GetStringSlice("kind_in", nil),
},
}
maxResults := request.GetInt("max_results", 100)
paths := request.GetStringSlice("paths", nil)
// Default to workspace root if no paths specified
if len(paths) == 0 {
paths = []string{s.cfg.WorkspaceRoot}
}
// Find files to search based on language
ext := languageToExtension(language)
if ext == "" {
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
}
var allResults []query.MatchResult
// Walk through paths and find matching files
for _, searchPath := range paths {
// Validate path is within workspace
if !s.cfg.IsPathAllowed(searchPath) {
continue
}
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
// Check for context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err != nil {
return nil // Skip files with errors
}
if info.IsDir() {
// Skip hidden directories
if strings.HasPrefix(info.Name(), ".") {
return filepath.SkipDir
}
return nil
}
// Check file extension matches language
if !strings.HasSuffix(path, ext) {
return nil
}
// Read and parse file
content, err := os.ReadFile(path)
if err != nil {
return nil // Skip unreadable files
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return nil // Skip large files
}
// Parse file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return nil // Skip unparseable files
}
// Run query
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
if err != nil {
return nil // Skip on error
}
allResults = append(allResults, matches...)
// Stop if we have enough results
if maxResults > 0 && len(allResults) >= maxResults {
return filepath.SkipAll
}
return nil
})
if err != nil {
s.logger.Warn("error walking path", "path", searchPath, "error", err)
}
}
// Format and return results
output := query.FormatResults(allResults, maxResults)
return mcp.NewToolResultText(output), nil
}
// languageToExtension maps language names to file extensions.
func languageToExtension(language string) string {
switch strings.ToLower(language) {
case "go":
return ".go"
case "typescript":
return ".ts"
case "javascript":
return ".js"
case "python":
return ".py"
case "c":
return ".c"
case "cpp", "c++":
return ".cpp"
case "elixir":
return ".ex"
default:
return ""
}
}
// handleSymbolAt handles the symbol_at tool.
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// Try LSP hover
hover, err := s.lspManager.Hover(ctx, file, line, col)
if err != nil {
// Fall back to AST-based info
return s.handleSymbolAtFallback(ctx, file, line, col)
}
if hover == nil {
return mcp.NewToolResultText("No symbol information available at this position."), nil
}
var output strings.Builder
output.WriteString("**Symbol Information**\n\n")
output.WriteString(hover.Contents.Value)
return mcp.NewToolResultText(output.String()), nil
}
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
content, err := os.ReadFile(file)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %v", err)), nil
}
result, err := s.parser.Parse(ctx, file, content)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to parse file: %v", err)), nil
}
node := parser.FindNodeAtPosition(result.Tree, line, col)
if node == nil {
return mcp.NewToolResultText("No symbol at this position."), nil
}
var output strings.Builder
output.WriteString("**Symbol Information** (AST fallback)\n\n")
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
return mcp.NewToolResultText(output.String()), nil
}
// handleFindDefinition handles the find_definition tool.
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
locations, err := s.lspManager.Definition(ctx, file, line, col)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %v", err)), nil
}
if len(locations) == 0 {
return mcp.NewToolResultText("No definition found."), nil
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
// Try to read a preview snippet
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
if preview != "" {
output.WriteString("```\n")
output.WriteString(preview)
output.WriteString("```\n")
}
output.WriteString("\n")
}
return mcp.NewToolResultText(output.String()), nil
}
// handleFindReferences handles the find_references tool.
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
includeDecl := request.GetBool("include_declaration", true)
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %v", err)), nil
}
if len(locations) == 0 {
return mcp.NewToolResultText("No references found."), nil
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
// Group by file
fileGroups := make(map[string][]lsp.Location)
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
fileGroups[filePath] = append(fileGroups[filePath], loc)
}
for filePath, locs := range fileGroups {
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
for _, loc := range locs {
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
}
output.WriteString("\n")
}
return mcp.NewToolResultText(output.String()), nil
}
// readFilePreview reads a few lines from a file around the given line.
func readFilePreview(file string, line, contextLines int) string {
content, err := os.ReadFile(file)
if err != nil {
return ""
}
lines := splitLines(string(content))
startLine := max(1, line-contextLines)
endLine := min(line+contextLines, len(lines))
var preview strings.Builder
for i := startLine - 1; i < endLine && i < len(lines); i++ {
lineText := lines[i]
if len(lineText) > 100 {
lineText = lineText[:100] + "..."
}
prefix := " "
if i+1 == line {
prefix = "> "
}
preview.WriteString(fmt.Sprintf("%s%4d: %s\n", prefix, i+1, lineText))
}
return preview.String()
}
// handleEditPreview handles the edit_preview tool.
func (s *Server) handleEditPreview(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleEdit(ctx, request, false)
}
// handleEditApply handles the edit_apply tool.
func (s *Server) handleEditApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleEdit(ctx, request, true)
}
// handleEdit is the shared implementation for edit_preview and edit_apply.
func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, apply bool) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
operation, err := request.RequireString("operation")
if err != nil {
return mcp.NewToolResultError("operation is required"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// Note: We no longer validate language support here.
// The edit engine automatically detects whether to use AST or text mode.
// Build edit request with both AST and text-mode selectors
astEdit := &edit.ASTEdit{
File: file,
Operation: edit.EditOperation(operation),
NewContent: request.GetString("new_content", ""),
Selector: edit.ASTSelector{
// AST-mode selectors
Kind: request.GetString("selector_kind", ""),
Name: request.GetString("selector_name", ""),
AtLine: request.GetInt("selector_line", 0),
Index: request.GetInt("selector_index", 0),
// Text-mode selectors
LineEnd: request.GetInt("selector_line_end", 0),
Text: request.GetString("selector_text", ""),
TextPattern: request.GetString("selector_pattern", ""),
},
}
// Perform edit
var result *edit.EditResult
if apply {
result, err = s.editor.Apply(ctx, astEdit)
} else {
result, err = s.editor.Preview(ctx, astEdit)
}
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("edit failed: %v", err)), nil
}
if !result.Success {
return mcp.NewToolResultError(result.Error), nil
}
// Format output
var output strings.Builder
if apply {
output.WriteString("**Edit Applied Successfully**\n\n")
} else {
output.WriteString("**Edit Preview**\n\n")
}
output.WriteString("Diff:\n```diff\n")
output.WriteString(result.Diff)
output.WriteString("```\n")
return mcp.NewToolResultText(output.String()), nil
}
// Run starts the MCP server and blocks until shutdown.
func (s *Server) Run(ctx context.Context) error {
// Set up signal handling for graceful shutdown
@@ -1007,7 +380,7 @@ func (s *Server) Run(ctx context.Context) error {
s.logger.Info("received shutdown signal", "signal", sig)
// Create timeout context for shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
defer shutdownCancel()
// Call graceful shutdown
@@ -1025,7 +398,7 @@ func (s *Server) Run(ctx context.Context) error {
case <-ctx.Done():
// Context cancelled externally
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
defer shutdownCancel()
if err := s.Shutdown(shutdownCtx); err != nil {
+7 -1
View File
@@ -29,8 +29,8 @@ func TestNew(t *testing.T) {
if srv == nil {
t.Fatal("New() returned nil server")
return
}
if srv.cfg != cfg {
t.Error("server config mismatch")
}
@@ -68,6 +68,7 @@ func TestHandlePing(t *testing.T) {
if result == nil {
t.Fatal("handlePing() returned nil result")
return
}
// Check that the result contains "pong"
@@ -123,6 +124,7 @@ func Hello() {
if result == nil {
t.Fatal("handleFileRead() returned nil result")
return
}
contents := result.Content
@@ -179,6 +181,7 @@ func Hello() {
if result == nil {
t.Fatal("handleFileRead() returned nil result")
return
}
contents := result.Content
@@ -270,6 +273,7 @@ func Goodbye() error {
if result == nil {
t.Fatal("handleASTQuery() returned nil result")
return
}
contents := result.Content
@@ -318,6 +322,7 @@ func Hello() {
if result == nil {
t.Fatal("handleEdit(preview) returned nil result")
return
}
// Verify file was NOT modified (it's just a preview)
@@ -367,6 +372,7 @@ func Hello() {
if result == nil {
t.Fatal("handleEdit(apply) returned nil result")
return
}
// Verify file WAS modified
+106 -6
View File
@@ -2,18 +2,74 @@
package util
import (
"fmt"
"regexp"
"sync"
"sync/atomic"
)
const (
// MaxPatternLength is the maximum allowed length for regex patterns.
// This prevents memory issues from extremely long patterns.
MaxPatternLength = 1000
// MaxCacheSize is the maximum number of patterns to cache.
// When exceeded, the cache is cleared to prevent unbounded memory growth.
MaxCacheSize = 10000
)
// regexCache is a global thread-safe cache for compiled regular expressions.
// Caching regex compilation provides 10-50x speedup for repeated patterns.
var regexCache sync.Map // string -> *regexp.Regexp
var (
regexCache sync.Map // string -> *regexp.Regexp
cacheSize atomic.Int64
)
// CompileRegex compiles a regex pattern with caching for performance.
// RegexError represents an error during regex compilation or validation.
type RegexError struct {
Pattern string
Reason string
Err error
}
func (e *RegexError) Error() string {
if e.Err != nil {
return fmt.Sprintf("regex error for pattern %q: %s: %v", e.Pattern, e.Reason, e.Err)
}
return fmt.Sprintf("regex error for pattern %q: %s", e.Pattern, e.Reason)
}
func (e *RegexError) Unwrap() error {
return e.Err
}
// ValidatePattern validates a regex pattern for safety.
// Returns an error if the pattern is too long or appears malicious.
func ValidatePattern(pattern string) error {
// Check pattern length
if len(pattern) > MaxPatternLength {
return &RegexError{
Pattern: truncatePattern(pattern),
Reason: fmt.Sprintf("pattern too long (%d chars, max %d)", len(pattern), MaxPatternLength),
}
}
// Note: Go's regexp package uses Thompson NFA which guarantees O(n) matching time,
// making it inherently resistant to ReDoS attacks. However, we still validate
// pattern length to prevent memory issues during compilation.
return nil
}
// CompileRegex compiles a regex pattern with caching and validation for security.
// Thread-safe: uses LoadOrStore to prevent race conditions.
// Returns the compiled regex or an error if the pattern is invalid.
// Returns the compiled regex or an error if the pattern is invalid or unsafe.
func CompileRegex(pattern string) (*regexp.Regexp, error) {
// Validate pattern first
if err := ValidatePattern(pattern); err != nil {
return nil, err
}
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
@@ -22,20 +78,64 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) {
// Compile regex
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
return nil, &RegexError{
Pattern: truncatePattern(pattern),
Reason: "invalid regex syntax",
Err: err,
}
}
// Check cache size and clear if too large
if cacheSize.Load() >= MaxCacheSize {
ClearRegexCache()
}
// Try to store - if another goroutine already stored it, use theirs
// This prevents race conditions where multiple goroutines compile the same pattern
actual, _ := regexCache.LoadOrStore(pattern, re)
actual, loaded := regexCache.LoadOrStore(pattern, re)
if !loaded {
cacheSize.Add(1)
}
return actual.(*regexp.Regexp), nil
}
// CompileRegexUncached compiles a regex pattern without caching.
// Useful for one-off patterns that shouldn't pollute the cache.
func CompileRegexUncached(pattern string) (*regexp.Regexp, error) {
if err := ValidatePattern(pattern); err != nil {
return nil, err
}
re, err := regexp.Compile(pattern)
if err != nil {
return nil, &RegexError{
Pattern: truncatePattern(pattern),
Reason: "invalid regex syntax",
Err: err,
}
}
return re, nil
}
// ClearRegexCache clears all cached compiled regular expressions.
// Useful for testing or when memory usage needs to be reduced.
func ClearRegexCache() {
regexCache.Range(func(key, value interface{}) bool {
regexCache.Range(func(key, _ interface{}) bool {
regexCache.Delete(key)
return true
})
cacheSize.Store(0)
}
// CacheStats returns the current number of cached patterns.
func CacheStats() int64 {
return cacheSize.Load()
}
// truncatePattern truncates a pattern for display in error messages.
func truncatePattern(pattern string) string {
if len(pattern) > 50 {
return pattern[:47] + "..."
}
return pattern
}
+375
View File
@@ -0,0 +1,375 @@
package util
import (
"errors"
"strings"
"sync"
"testing"
)
func TestValidatePattern(t *testing.T) {
tests := []struct {
name string
pattern string
expectErr bool
}{
{
name: "valid short pattern",
pattern: "^hello.*world$",
expectErr: false,
},
{
name: "valid empty pattern",
pattern: "",
expectErr: false,
},
{
name: "valid pattern at max length",
pattern: strings.Repeat("a", MaxPatternLength),
expectErr: false,
},
{
name: "pattern too long",
pattern: strings.Repeat("a", MaxPatternLength+1),
expectErr: true,
},
{
name: "very long pattern",
pattern: strings.Repeat("x", MaxPatternLength*2),
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidatePattern(tt.pattern)
if tt.expectErr && err == nil {
t.Error("expected error but got nil")
}
if !tt.expectErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestCompileRegex(t *testing.T) {
// Clear cache before each test
ClearRegexCache()
t.Run("valid pattern is compiled and cached", func(t *testing.T) {
ClearRegexCache()
pattern := "^test.*pattern$"
re1, err := CompileRegex(pattern)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if re1 == nil {
t.Fatal("expected non-nil regex")
}
// Second call should return cached version
re2, err := CompileRegex(pattern)
if err != nil {
t.Fatalf("unexpected error on second call: %v", err)
}
// Should be the same pointer
if re1 != re2 {
t.Error("expected same regex instance from cache")
}
// Cache should have one entry
if stats := CacheStats(); stats != 1 {
t.Errorf("expected cache size 1, got %d", stats)
}
})
t.Run("invalid pattern returns error", func(t *testing.T) {
ClearRegexCache()
pattern := "[invalid(regex"
_, err := CompileRegex(pattern)
if err == nil {
t.Fatal("expected error for invalid regex")
}
var regexErr *RegexError
if !errors.As(err, &regexErr) {
t.Errorf("expected RegexError, got %T", err)
}
})
t.Run("pattern too long returns error", func(t *testing.T) {
ClearRegexCache()
pattern := strings.Repeat("a", MaxPatternLength+1)
_, err := CompileRegex(pattern)
if err == nil {
t.Fatal("expected error for long pattern")
}
var regexErr *RegexError
if !errors.As(err, &regexErr) {
t.Errorf("expected RegexError, got %T", err)
}
})
t.Run("different patterns are cached separately", func(t *testing.T) {
ClearRegexCache()
re1, _ := CompileRegex("pattern1")
re2, _ := CompileRegex("pattern2")
if re1 == re2 {
t.Error("different patterns should produce different regex instances")
}
if stats := CacheStats(); stats != 2 {
t.Errorf("expected cache size 2, got %d", stats)
}
})
t.Run("regex matches correctly", func(t *testing.T) {
ClearRegexCache()
re, err := CompileRegex("^hello\\s+world$")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !re.MatchString("hello world") {
t.Error("expected match for 'hello world'")
}
if !re.MatchString("hello world") {
t.Error("expected match for 'hello world'")
}
if re.MatchString("helloworld") {
t.Error("unexpected match for 'helloworld'")
}
})
}
func TestCompileRegexUncached(t *testing.T) {
ClearRegexCache()
t.Run("valid pattern compiles without caching", func(t *testing.T) {
initialSize := CacheStats()
re, err := CompileRegexUncached("^uncached.*pattern$")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if re == nil {
t.Fatal("expected non-nil regex")
}
// Cache size should not change
if stats := CacheStats(); stats != initialSize {
t.Errorf("cache size changed from %d to %d", initialSize, stats)
}
})
t.Run("invalid pattern returns error", func(t *testing.T) {
_, err := CompileRegexUncached("[invalid")
if err == nil {
t.Fatal("expected error for invalid regex")
}
})
t.Run("pattern too long returns error", func(t *testing.T) {
pattern := strings.Repeat("x", MaxPatternLength+1)
_, err := CompileRegexUncached(pattern)
if err == nil {
t.Fatal("expected error for long pattern")
}
})
}
func TestClearRegexCache(t *testing.T) {
// Add some patterns
_, _ = CompileRegex("pattern1")
_, _ = CompileRegex("pattern2")
_, _ = CompileRegex("pattern3")
if stats := CacheStats(); stats < 3 {
t.Fatalf("expected at least 3 cached patterns, got %d", stats)
}
ClearRegexCache()
if stats := CacheStats(); stats != 0 {
t.Errorf("expected cache size 0 after clear, got %d", stats)
}
}
func TestCacheStats(t *testing.T) {
ClearRegexCache()
if stats := CacheStats(); stats != 0 {
t.Errorf("expected initial cache size 0, got %d", stats)
}
_, _ = CompileRegex("a")
if stats := CacheStats(); stats != 1 {
t.Errorf("expected cache size 1, got %d", stats)
}
_, _ = CompileRegex("b")
if stats := CacheStats(); stats != 2 {
t.Errorf("expected cache size 2, got %d", stats)
}
// Same pattern should not increase cache size
_, _ = CompileRegex("a")
if stats := CacheStats(); stats != 2 {
t.Errorf("expected cache size 2 after duplicate, got %d", stats)
}
}
func TestConcurrentAccess(t *testing.T) {
ClearRegexCache()
var wg sync.WaitGroup
numGoroutines := 100
numPatterns := 10
// Generate some patterns
patterns := make([]string, numPatterns)
for i := range patterns {
patterns[i] = strings.Repeat("p", i+1)
}
// Concurrent compilation of same patterns
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
pattern := patterns[id%numPatterns]
re, err := CompileRegex(pattern)
if err != nil {
t.Errorf("goroutine %d: unexpected error: %v", id, err)
return
}
if re == nil {
t.Errorf("goroutine %d: nil regex returned", id)
}
}(i)
}
wg.Wait()
// Should have exactly numPatterns cached
if stats := CacheStats(); stats != int64(numPatterns) {
t.Errorf("expected cache size %d, got %d", numPatterns, stats)
}
}
func TestRegexError(t *testing.T) {
t.Run("error message with underlying error", func(t *testing.T) {
underlying := errors.New("underlying error")
err := &RegexError{
Pattern: "test.*",
Reason: "test reason",
Err: underlying,
}
msg := err.Error()
if !strings.Contains(msg, "test.*") {
t.Error("error message should contain pattern")
}
if !strings.Contains(msg, "test reason") {
t.Error("error message should contain reason")
}
if !strings.Contains(msg, "underlying error") {
t.Error("error message should contain underlying error")
}
})
t.Run("error message without underlying error", func(t *testing.T) {
err := &RegexError{
Pattern: "test.*",
Reason: "test reason",
Err: nil,
}
msg := err.Error()
if !strings.Contains(msg, "test.*") {
t.Error("error message should contain pattern")
}
if !strings.Contains(msg, "test reason") {
t.Error("error message should contain reason")
}
})
t.Run("error unwrap", func(t *testing.T) {
underlying := errors.New("underlying")
err := &RegexError{
Pattern: "test",
Reason: "reason",
Err: underlying,
}
if errors.Unwrap(err) != underlying {
t.Error("Unwrap should return underlying error")
}
})
}
func TestTruncatePattern(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "short pattern unchanged",
input: "short",
expected: "short",
},
{
name: "exactly 50 chars unchanged",
input: strings.Repeat("x", 50),
expected: strings.Repeat("x", 50),
},
{
name: "long pattern truncated",
input: strings.Repeat("x", 60),
expected: strings.Repeat("x", 47) + "...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := truncatePattern(tt.input)
if got != tt.expected {
t.Errorf("truncatePattern() = %q (len %d), want %q (len %d)",
got, len(got), tt.expected, len(tt.expected))
}
})
}
}
// BenchmarkCompileRegex benchmarks regex compilation with caching
func BenchmarkCompileRegex(b *testing.B) {
ClearRegexCache()
pattern := "^test.*pattern\\d+$"
// First call to populate cache
_, _ = CompileRegex(pattern)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = CompileRegex(pattern)
}
}
// BenchmarkCompileRegexUncached benchmarks regex compilation without caching
func BenchmarkCompileRegexUncached(b *testing.B) {
pattern := "^test.*pattern\\d+$"
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = CompileRegexUncached(pattern)
}
}
+36
View File
@@ -120,6 +120,42 @@ func (e *StructuredError) WithRemediation(msg string) *StructuredError {
return e
}
// UserMessage returns a user-safe error message without internal details.
// This should be used for API responses to avoid leaking implementation details.
func (e *StructuredError) UserMessage() string {
// Return only the message and remediation, no stack trace, no cause details
if e.Remediation != "" {
return fmt.Sprintf("%s. %s", e.Message, e.Remediation)
}
return e.Message
}
// SanitizeError returns a user-safe error message from any error.
// For StructuredError, it returns UserMessage().
// For other errors, it returns a generic message to avoid leaking internals.
func SanitizeError(err error) string {
if err == nil {
return ""
}
// Check if it's a StructuredError
if se, ok := err.(*StructuredError); ok {
return se.UserMessage()
}
// For other errors, extract only the basic message
// Avoid exposing full paths or implementation details
msg := err.Error()
// Truncate very long messages that might contain stack traces
const maxLen = 200
if len(msg) > maxLen {
msg = msg[:maxLen] + "..."
}
return msg
}
// New creates a new structured error with stack trace.
func New(code ErrorCode, message string) *StructuredError {
return &StructuredError{
+3 -13
View File
@@ -1,6 +1,8 @@
// Package protocol defines shared types used across the MCP file operations server.
package protocol
import "path/filepath"
// Location represents a position in a file.
type Location struct {
File string `json:"file"`
@@ -66,7 +68,7 @@ const (
// DetectLanguage detects the language from a filename.
func DetectLanguage(filename string) Language {
ext := getExtension(filename)
ext := filepath.Ext(filename)
switch ext {
case ".go":
return LangGo
@@ -94,15 +96,3 @@ func DetectLanguage(filename string) Language {
return LangUnknown
}
}
func getExtension(filename string) string {
for i := len(filename) - 1; i >= 0; i-- {
if filename[i] == '.' {
return filename[i:]
}
if filename[i] == '/' || filename[i] == '\\' {
break
}
}
return ""
}
-23
View File
@@ -48,26 +48,3 @@ func TestDetectLanguage(t *testing.T) {
})
}
}
func TestGetExtension(t *testing.T) {
tests := []struct {
filename string
expected string
}{
{"file.go", ".go"},
{"file.test.go", ".go"},
{"path/to/file.ts", ".ts"},
{"noextension", ""},
{".hidden", ".hidden"},
{"file.", "."},
}
for _, tt := range tests {
t.Run(tt.filename, func(t *testing.T) {
result := getExtension(tt.filename)
if result != tt.expected {
t.Errorf("getExtension(%q) = %q, want %q", tt.filename, result, tt.expected)
}
})
}
}