mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-05 22:23:50 +00:00
Update, bugfixes on diff and edit handling
This commit is contained in:
@@ -67,9 +67,11 @@ func setupLogger(level string, logFile string) *slog.Logger {
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
var logFileErr error
|
||||
if logFile != "" {
|
||||
f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
logFileErr = err
|
||||
// Fallback to stderr
|
||||
handler = slog.NewJSONHandler(os.Stderr, opts)
|
||||
} else {
|
||||
@@ -80,5 +82,15 @@ func setupLogger(level string, logFile string) *slog.Logger {
|
||||
handler = slog.NewJSONHandler(os.Stderr, opts)
|
||||
}
|
||||
|
||||
return slog.New(handler)
|
||||
logger := slog.New(handler)
|
||||
|
||||
// Warn if log file couldn't be opened
|
||||
if logFileErr != nil {
|
||||
logger.Warn("failed to open log file, using stderr",
|
||||
"file", logFile,
|
||||
"error", logFileErr,
|
||||
)
|
||||
}
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
+1300
-1257
File diff suppressed because it is too large
Load Diff
+46
-88
@@ -482,13 +482,12 @@ func (e *Engine) matchesSelector(sel ASTSelector, n *sitter.Node, content []byte
|
||||
}
|
||||
|
||||
// applyEdit applies the edit operation to the content.
|
||||
// AST mode uses exact byte positions — new_content is inserted verbatim without auto-indentation.
|
||||
func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]byte, error) {
|
||||
startByte := node.StartByte()
|
||||
endByte := node.EndByte()
|
||||
|
||||
// Detect and preserve indentation
|
||||
indentation := detectIndentation(content, startByte)
|
||||
newContent := indentContent(edit.NewContent, indentation)
|
||||
newContent := edit.NewContent
|
||||
|
||||
var result []byte
|
||||
|
||||
@@ -499,15 +498,21 @@ func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
case EditInsertBefore:
|
||||
insertion := newContent
|
||||
if !strings.HasSuffix(insertion, "\n") {
|
||||
insertion += "\n"
|
||||
}
|
||||
result = append(result, content[:startByte]...)
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(insertion)...)
|
||||
result = append(result, content[startByte:]...)
|
||||
|
||||
case EditInsertAfter:
|
||||
insertion := newContent
|
||||
if !strings.HasPrefix(insertion, "\n") {
|
||||
insertion = "\n" + insertion
|
||||
}
|
||||
result = append(result, content[:endByte]...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, []byte(insertion)...)
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
case EditDelete:
|
||||
@@ -522,16 +527,16 @@ func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]
|
||||
}
|
||||
|
||||
// detectIndentation detects the indentation at a given byte position.
|
||||
func detectIndentation(content []byte, bytePos uint32) string {
|
||||
func detectIndentation(content []byte, bytePos int) string {
|
||||
// Find the start of the line
|
||||
lineStart := int(bytePos)
|
||||
lineStart := bytePos
|
||||
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||
lineStart--
|
||||
}
|
||||
|
||||
// Extract leading whitespace
|
||||
var indent strings.Builder
|
||||
for i := lineStart; i < int(bytePos) && i < len(content); i++ {
|
||||
for i := lineStart; i < bytePos && i < len(content); i++ {
|
||||
c := content[i]
|
||||
if c == ' ' || c == '\t' {
|
||||
indent.WriteByte(c)
|
||||
@@ -560,10 +565,16 @@ func indentContent(content string, indent string) string {
|
||||
}
|
||||
|
||||
// generateDiff creates a unified diff between original and modified content.
|
||||
// Uses Myers diff algorithm for accurate and readable diffs.
|
||||
// Uses line-level Myers diff algorithm for accurate and readable diffs.
|
||||
func generateDiff(original, modified, filename string) string {
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(original, modified, false)
|
||||
|
||||
// Use line-level diffing: encode each line as a single character,
|
||||
// diff the encoded strings, then decode back to real lines.
|
||||
// This prevents character-level diffs from splitting lines incorrectly.
|
||||
chars1, chars2, lineArray := dmp.DiffLinesToChars(original, modified)
|
||||
diffs := dmp.DiffMain(chars1, chars2, false)
|
||||
diffs = dmp.DiffCharsToLines(diffs, lineArray)
|
||||
|
||||
// Cleanup for readability
|
||||
diffs = dmp.DiffCleanupSemantic(diffs)
|
||||
@@ -573,24 +584,25 @@ func generateDiff(original, modified, filename string) string {
|
||||
buf.WriteString(fmt.Sprintf("--- %s\n", filename))
|
||||
buf.WriteString(fmt.Sprintf("+++ %s\n", filename))
|
||||
|
||||
// Group diffs into hunks
|
||||
lineNum := 1
|
||||
for _, diff := range diffs {
|
||||
lines := strings.Split(diff.Text, "\n")
|
||||
for i, line := range lines {
|
||||
// Skip empty last line from split
|
||||
if i == len(lines)-1 && line == "" {
|
||||
// SplitAfter preserves the trailing \n on each line, so we can
|
||||
// distinguish real lines from a trailing empty split artifact.
|
||||
lines := strings.SplitAfter(diff.Text, "\n")
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove trailing newline for display — we add our own.
|
||||
cleanLine := strings.TrimSuffix(line, "\n")
|
||||
|
||||
switch diff.Type {
|
||||
case diffmatchpatch.DiffDelete:
|
||||
buf.WriteString(fmt.Sprintf("-%s\n", line))
|
||||
buf.WriteString(fmt.Sprintf("-%s\n", cleanLine))
|
||||
case diffmatchpatch.DiffInsert:
|
||||
buf.WriteString(fmt.Sprintf("+%s\n", line))
|
||||
buf.WriteString(fmt.Sprintf("+%s\n", cleanLine))
|
||||
case diffmatchpatch.DiffEqual:
|
||||
buf.WriteString(fmt.Sprintf(" %s\n", line))
|
||||
lineNum++
|
||||
buf.WriteString(fmt.Sprintf(" %s\n", cleanLine))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -639,24 +651,6 @@ func (e *Engine) findExactText(content []byte, text string, index int) (start, e
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text not found: %q", truncateString(text, 50)))
|
||||
}
|
||||
|
||||
// If multiple matches and no index specified, require explicit selection
|
||||
if len(matches) > 1 && index == 0 {
|
||||
// Check if index was explicitly set to 0 or just defaulted
|
||||
// Since we can't distinguish, we'll allow index 0 but warn about multiple matches
|
||||
// Actually, let's be strict and require explicit index for multiple matches
|
||||
locations := make([]string, 0, min(len(matches), 5))
|
||||
for i, m := range matches {
|
||||
if i >= 5 {
|
||||
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||
break
|
||||
}
|
||||
line := countLines(content[:m.start]) + 1
|
||||
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||
}
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||
}
|
||||
|
||||
if index >= len(matches) {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||
}
|
||||
@@ -676,21 +670,6 @@ func (e *Engine) findRegexPattern(content []byte, pattern string, index int) (st
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern not found: %q", truncateString(pattern, 50)))
|
||||
}
|
||||
|
||||
// If multiple matches and index is 0 (default), show error with locations
|
||||
if len(matches) > 1 && index == 0 {
|
||||
locations := make([]string, 0, min(len(matches), 5))
|
||||
for i, m := range matches {
|
||||
if i >= 5 {
|
||||
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||
break
|
||||
}
|
||||
line := countLines(content[:m[0]]) + 1
|
||||
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||
}
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||
}
|
||||
|
||||
if index >= len(matches) {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||
}
|
||||
@@ -728,7 +707,7 @@ func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, e
|
||||
|
||||
// Calculate byte positions
|
||||
start = 0
|
||||
for i := 0; i < startIdx; i++ {
|
||||
for i := range startIdx {
|
||||
start += len(lines[i]) + 1 // +1 for newline
|
||||
}
|
||||
|
||||
@@ -746,7 +725,7 @@ func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, e
|
||||
// applyTextEditOperation applies a text edit operation.
|
||||
func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start, end int, newContent string) ([]byte, error) {
|
||||
// Detect indentation at the selection point
|
||||
indentation := detectIndentationAtByte(content, start)
|
||||
indentation := detectIndentation(content, start)
|
||||
indentedContent := indentContent(newContent, indentation)
|
||||
|
||||
var result []byte
|
||||
@@ -758,15 +737,21 @@ func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start,
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
case EditInsertBefore:
|
||||
insertion := indentedContent
|
||||
if !strings.HasSuffix(insertion, "\n") {
|
||||
insertion += "\n"
|
||||
}
|
||||
result = append(result, content[:start]...)
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(insertion)...)
|
||||
result = append(result, content[start:]...)
|
||||
|
||||
case EditInsertAfter:
|
||||
insertion := indentedContent
|
||||
if !strings.HasPrefix(insertion, "\n") {
|
||||
insertion = "\n" + insertion
|
||||
}
|
||||
result = append(result, content[:end]...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, []byte(insertion)...)
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
case EditDelete:
|
||||
@@ -780,28 +765,6 @@ func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start,
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// detectIndentationAtByte detects indentation at a byte position.
|
||||
func detectIndentationAtByte(content []byte, bytePos int) string {
|
||||
// Find the start of the line
|
||||
lineStart := bytePos
|
||||
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||
lineStart--
|
||||
}
|
||||
|
||||
// Extract leading whitespace
|
||||
var indent strings.Builder
|
||||
for i := lineStart; i < bytePos && i < len(content); i++ {
|
||||
c := content[i]
|
||||
if c == ' ' || c == '\t' {
|
||||
indent.WriteByte(c)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return indent.String()
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen with ellipsis.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
@@ -810,11 +773,6 @@ func truncateString(s string, maxLen int) string {
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// countLines counts the number of newlines in content.
|
||||
func countLines(content []byte) int {
|
||||
return bytes.Count(content, []byte("\n"))
|
||||
}
|
||||
|
||||
// ValidateLanguage checks if AST editing is supported for a file.
|
||||
// Returns nil for supported languages, error for unsupported.
|
||||
// Note: Text-based editing is always available regardless of this check.
|
||||
|
||||
@@ -357,7 +357,7 @@ func TestDetectIndentation(t *testing.T) {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
pos uint32
|
||||
pos int
|
||||
}{
|
||||
{
|
||||
name: "no indent",
|
||||
@@ -410,6 +410,71 @@ func TestGenerateDiff(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDiffLineLevelAccuracy(t *testing.T) {
|
||||
// Regression test: diff must operate at line level, not character level.
|
||||
// A character-level diff would split "hello" and "hello world" mid-line,
|
||||
// producing broken output like:
|
||||
// fmt.Println("hello
|
||||
// + world
|
||||
// ")
|
||||
original := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello\")\n}\n"
|
||||
modified := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello world\")\n}\n"
|
||||
|
||||
diff := generateDiff(original, modified, "test.go")
|
||||
|
||||
// The diff must show whole-line removals and additions
|
||||
if !strings.Contains(diff, "-\tfmt.Println(\"hello\")\n") {
|
||||
t.Errorf("diff should show full removed line, got:\n%s", diff)
|
||||
}
|
||||
if !strings.Contains(diff, "+\tfmt.Println(\"hello world\")\n") {
|
||||
t.Errorf("diff should show full added line, got:\n%s", diff)
|
||||
}
|
||||
|
||||
// The diff must NOT split lines at character boundaries
|
||||
if strings.Contains(diff, "+hello") && !strings.Contains(diff, "Println") {
|
||||
t.Errorf("diff appears to be character-level (split mid-line), got:\n%s", diff)
|
||||
}
|
||||
|
||||
// Context lines should not be marked as changed
|
||||
for line := range strings.SplitSeq(diff, "\n") {
|
||||
if strings.HasPrefix(line, "-") || strings.HasPrefix(line, "+") {
|
||||
// Changed lines should only be the Println lines
|
||||
if strings.Contains(line, "package main") ||
|
||||
strings.Contains(line, "func hello()") {
|
||||
t.Errorf("unchanged line incorrectly marked as changed: %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDiffNoPhantomChanges(t *testing.T) {
|
||||
// Regression test: replacing a line range should not produce phantom
|
||||
// +/- lines for unchanged code after the edit region.
|
||||
original := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\n"
|
||||
modified := "line1\nREPLACED\nline3\nline4\nline5\nline6\nline7\nline8\n"
|
||||
|
||||
diff := generateDiff(original, modified, "test.txt")
|
||||
|
||||
// Count changed lines (excluding headers)
|
||||
addCount := 0
|
||||
delCount := 0
|
||||
for line := range strings.SplitSeq(diff, "\n") {
|
||||
if strings.HasPrefix(line, "+") && !strings.HasPrefix(line, "+++") {
|
||||
addCount++
|
||||
}
|
||||
if strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "---") {
|
||||
delCount++
|
||||
}
|
||||
}
|
||||
|
||||
if addCount != 1 {
|
||||
t.Errorf("expected 1 added line, got %d. Diff:\n%s", addCount, diff)
|
||||
}
|
||||
if delCount != 1 {
|
||||
t.Errorf("expected 1 deleted line, got %d. Diff:\n%s", delCount, diff)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Text-based editing tests ====================
|
||||
|
||||
func TestTextEditWithExactText(t *testing.T) {
|
||||
@@ -593,7 +658,7 @@ SECRET_KEY=abc123
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditMultipleMatchesError(t *testing.T) {
|
||||
func TestTextEditMultipleMatchesSelectsFirst(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
@@ -624,12 +689,19 @@ more code
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
// Should fail because of multiple matches
|
||||
if result.Success {
|
||||
t.Error("expected error for multiple matches without index")
|
||||
// Index 0 (default) should select the first match
|
||||
if !result.Success {
|
||||
t.Fatalf("expected success for multiple matches with default index 0: %s", result.Error)
|
||||
}
|
||||
if !strings.Contains(result.Error, "matches") {
|
||||
t.Errorf("error should mention multiple matches: %s", result.Error)
|
||||
|
||||
// Verify only first TODO was replaced
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
contentStr := string(fileContent)
|
||||
if !strings.Contains(contentStr, "DONE: fix this") {
|
||||
t.Error("first TODO should be replaced with DONE")
|
||||
}
|
||||
if !strings.Contains(contentStr, "TODO: also fix this") {
|
||||
t.Error("second TODO should be unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+29
-2
@@ -16,6 +16,12 @@ import (
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
// ProcessKillTimeout is the timeout for waiting for a process to exit before force killing.
|
||||
const ProcessKillTimeout = 5 * time.Second
|
||||
|
||||
// StderrBufferSize is the buffer size for draining stderr.
|
||||
const StderrBufferSize = 1024
|
||||
|
||||
// Client represents an LSP client connection.
|
||||
type Client struct {
|
||||
stdin io.WriteCloser
|
||||
@@ -104,12 +110,33 @@ func NewClient(cmd *exec.Cmd) (*Client, error) {
|
||||
notifications: make(chan *Notification, 100),
|
||||
}
|
||||
|
||||
// Start reader goroutine
|
||||
// Start reader goroutine for stdout
|
||||
go c.readLoop()
|
||||
|
||||
// Start stderr drain goroutine to prevent pipe buffer from filling up
|
||||
go c.drainStderr()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// drainStderr consumes stderr output to prevent the LSP server from blocking.
|
||||
// LSP servers may write diagnostic messages to stderr which we discard.
|
||||
func (c *Client) drainStderr() {
|
||||
buf := make([]byte, StderrBufferSize)
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
// Read and discard stderr output
|
||||
_, err := c.stderr.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Call sends a request and waits for a response.
|
||||
func (c *Client) Call(ctx context.Context, method string, params interface{}) (*Response, error) {
|
||||
c.runningMu.RLock()
|
||||
@@ -208,7 +235,7 @@ func (c *Client) Close() error {
|
||||
select {
|
||||
case <-done:
|
||||
// Clean exit
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(ProcessKillTimeout):
|
||||
// Force kill
|
||||
_ = c.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
+53
-7
@@ -15,6 +15,18 @@ import (
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// LSP timeout and interval constants.
|
||||
const (
|
||||
// DefaultLSPTimeout is the default timeout for LSP requests.
|
||||
DefaultLSPTimeout = 10 * time.Second
|
||||
// DefaultIdleTimeout is the duration before idle LSP servers are reaped.
|
||||
DefaultIdleTimeout = 5 * time.Minute
|
||||
// ReaperInterval is how often the idle server reaper runs.
|
||||
ReaperInterval = 60 * time.Second
|
||||
// ShutdownTimeout is the timeout for graceful LSP server shutdown.
|
||||
ShutdownTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// Manager manages LSP servers for different languages.
|
||||
type Manager struct {
|
||||
servers map[protocol.Language]*ManagedServer
|
||||
@@ -70,12 +82,27 @@ var DefaultServerConfigs = map[protocol.Language]ServerConfig{
|
||||
},
|
||||
}
|
||||
|
||||
// AllowedLSPBinaries is a whitelist of allowed LSP server binary names.
|
||||
// This prevents command injection by ensuring only known LSP servers can be executed.
|
||||
var AllowedLSPBinaries = map[string]bool{
|
||||
"gopls": true,
|
||||
"typescript-language-server": true,
|
||||
"pylsp": true,
|
||||
"clangd": true,
|
||||
// Common alternatives
|
||||
"tsserver": true,
|
||||
"pyright": true,
|
||||
"ruff-lsp": true,
|
||||
"rust-analyzer": true,
|
||||
"ccls": true,
|
||||
}
|
||||
|
||||
// NewManager creates a new LSP manager.
|
||||
func NewManager(workspaceRoot string, logger *slog.Logger) *Manager {
|
||||
m := &Manager{
|
||||
servers: make(map[protocol.Language]*ManagedServer),
|
||||
timeout: 10 * time.Second,
|
||||
idleTimeout: 5 * time.Minute,
|
||||
timeout: DefaultLSPTimeout,
|
||||
idleTimeout: DefaultIdleTimeout,
|
||||
workspaceRoot: workspaceRoot,
|
||||
logger: logger,
|
||||
stopReaper: make(chan struct{}),
|
||||
@@ -127,12 +154,20 @@ func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*Manag
|
||||
return nil, errors.NewLSPServerNotFound(string(lang), config.Command[0])
|
||||
}
|
||||
|
||||
// Validate command against whitelist to prevent command injection
|
||||
binaryName := filepath.Base(cmdPath)
|
||||
if !AllowedLSPBinaries[binaryName] {
|
||||
return nil, errors.New(errors.ErrLSPServerNotFound, fmt.Sprintf("LSP binary %q is not in the allowed list", binaryName)).
|
||||
WithContext("language", string(lang)).
|
||||
WithContext("binary", binaryName).
|
||||
WithRemediation("Only whitelisted LSP server binaries are allowed for security reasons")
|
||||
}
|
||||
|
||||
// Create command
|
||||
args := append(config.Command[1:], config.Args...)
|
||||
cmd := exec.CommandContext(ctx, cmdPath, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Dir = m.workspaceRoot
|
||||
|
||||
// Create client
|
||||
client, err := NewClient(cmd)
|
||||
if err != nil {
|
||||
@@ -447,7 +482,7 @@ func (m *Manager) CloseDocument(_ context.Context, lang protocol.Language, file
|
||||
|
||||
// reapIdleServers periodically closes idle servers.
|
||||
func (m *Manager) reapIdleServers() {
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
ticker := time.NewTicker(ReaperInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -455,6 +490,10 @@ func (m *Manager) reapIdleServers() {
|
||||
case <-m.stopReaper:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Collect idle servers first to avoid holding the lock while closing
|
||||
var toClose []*ManagedServer
|
||||
var toCloseLanguages []protocol.Language
|
||||
|
||||
m.mu.Lock()
|
||||
for lang, srv := range m.servers {
|
||||
// Check lastUsed with server's lock to avoid race condition
|
||||
@@ -463,12 +502,19 @@ func (m *Manager) reapIdleServers() {
|
||||
srv.mu.Unlock()
|
||||
|
||||
if idle {
|
||||
m.logger.Info("closing idle LSP server", "language", lang)
|
||||
_ = srv.client.Close()
|
||||
toClose = append(toClose, srv)
|
||||
toCloseLanguages = append(toCloseLanguages, lang)
|
||||
delete(m.servers, lang)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close servers outside the lock to prevent deadlock
|
||||
// (Close can block waiting for the process to exit)
|
||||
for i, srv := range toClose {
|
||||
m.logger.Info("closing idle LSP server", "language", toCloseLanguages[i])
|
||||
_ = srv.client.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -485,7 +531,7 @@ func (m *Manager) Close() error {
|
||||
for lang, srv := range m.servers {
|
||||
m.logger.Info("shutting down LSP server", "language", lang)
|
||||
// Try graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
|
||||
_, _ = srv.client.Call(ctx, "shutdown", nil)
|
||||
cancel()
|
||||
_ = srv.client.Notify("exit", nil)
|
||||
|
||||
@@ -0,0 +1,561 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCounter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ops func(c *Counter)
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "initial value is zero",
|
||||
ops: func(c *Counter) {},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "single inc",
|
||||
ops: func(c *Counter) {
|
||||
c.Inc()
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple inc",
|
||||
ops: func(c *Counter) {
|
||||
c.Inc()
|
||||
c.Inc()
|
||||
c.Inc()
|
||||
},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "add positive",
|
||||
ops: func(c *Counter) {
|
||||
c.Add(10)
|
||||
},
|
||||
expected: 10,
|
||||
},
|
||||
{
|
||||
name: "mixed operations",
|
||||
ops: func(c *Counter) {
|
||||
c.Inc()
|
||||
c.Add(5)
|
||||
c.Inc()
|
||||
},
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "reset",
|
||||
ops: func(c *Counter) {
|
||||
c.Add(100)
|
||||
c.Reset()
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := NewCounter("test_counter", "test help", nil)
|
||||
tt.ops(c)
|
||||
if got := c.Value(); got != tt.expected {
|
||||
t.Errorf("Counter.Value() = %d, want %d", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCounterConcurrency(t *testing.T) {
|
||||
c := NewCounter("concurrent_counter", "test", nil)
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
incsPerGoroutine := 1000
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < incsPerGoroutine; j++ {
|
||||
c.Inc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
expected := int64(numGoroutines * incsPerGoroutine)
|
||||
if got := c.Value(); got != expected {
|
||||
t.Errorf("Counter.Value() = %d, want %d after concurrent increments", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGauge(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ops func(g *Gauge)
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "initial value is zero",
|
||||
ops: func(g *Gauge) {},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "set value",
|
||||
ops: func(g *Gauge) {
|
||||
g.Set(42)
|
||||
},
|
||||
expected: 42,
|
||||
},
|
||||
{
|
||||
name: "inc",
|
||||
ops: func(g *Gauge) {
|
||||
g.Inc()
|
||||
g.Inc()
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "dec",
|
||||
ops: func(g *Gauge) {
|
||||
g.Set(10)
|
||||
g.Dec()
|
||||
g.Dec()
|
||||
},
|
||||
expected: 8,
|
||||
},
|
||||
{
|
||||
name: "add positive",
|
||||
ops: func(g *Gauge) {
|
||||
g.Add(5)
|
||||
},
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "add negative",
|
||||
ops: func(g *Gauge) {
|
||||
g.Set(10)
|
||||
g.Add(-3)
|
||||
},
|
||||
expected: 7,
|
||||
},
|
||||
{
|
||||
name: "mixed operations",
|
||||
ops: func(g *Gauge) {
|
||||
g.Set(100)
|
||||
g.Inc()
|
||||
g.Dec()
|
||||
g.Add(-50)
|
||||
},
|
||||
expected: 50,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := NewGauge("test_gauge", "test help", nil)
|
||||
tt.ops(g)
|
||||
if got := g.Value(); got != tt.expected {
|
||||
t.Errorf("Gauge.Value() = %d, want %d", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHistogram(t *testing.T) {
|
||||
t.Run("default buckets", func(t *testing.T) {
|
||||
h := NewHistogram("test_histogram", "test", nil, nil)
|
||||
if len(h.buckets) != len(DefaultDurationBuckets) {
|
||||
t.Errorf("expected default buckets, got %d buckets", len(h.buckets))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom buckets", func(t *testing.T) {
|
||||
buckets := []float64{1.0, 5.0, 10.0}
|
||||
h := NewHistogram("test_histogram", "test", nil, buckets)
|
||||
if len(h.buckets) != 3 {
|
||||
t.Errorf("expected 3 buckets, got %d", len(h.buckets))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("buckets are sorted", func(t *testing.T) {
|
||||
buckets := []float64{10.0, 1.0, 5.0}
|
||||
h := NewHistogram("test_histogram", "test", nil, buckets)
|
||||
if h.buckets[0] != 1.0 || h.buckets[1] != 5.0 || h.buckets[2] != 10.0 {
|
||||
t.Errorf("buckets not sorted: %v", h.buckets)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("observe values", func(t *testing.T) {
|
||||
h := NewHistogram("test_histogram", "test", nil, []float64{1.0, 5.0, 10.0})
|
||||
|
||||
h.Observe(0.5) // goes to bucket 0 (<=1.0)
|
||||
h.Observe(3.0) // goes to bucket 1 (<=5.0)
|
||||
h.Observe(7.0) // goes to bucket 2 (<=10.0)
|
||||
h.Observe(15.0) // goes to +Inf bucket
|
||||
|
||||
if h.Count() != 4 {
|
||||
t.Errorf("expected count 4, got %d", h.Count())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("observe duration", func(t *testing.T) {
|
||||
h := NewHistogram("test_histogram", "test", nil, []float64{0.001, 0.01, 0.1})
|
||||
|
||||
h.ObserveDuration(500 * time.Microsecond) // 0.0005s, goes to bucket 0
|
||||
h.ObserveDuration(5 * time.Millisecond) // 0.005s, goes to bucket 1
|
||||
|
||||
if h.Count() != 2 {
|
||||
t.Errorf("expected count 2, got %d", h.Count())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sum tracking", func(t *testing.T) {
|
||||
h := NewHistogram("test_histogram", "test", nil, []float64{1.0, 5.0, 10.0})
|
||||
|
||||
h.Observe(1.0)
|
||||
h.Observe(2.0)
|
||||
h.Observe(3.0)
|
||||
|
||||
expectedSum := 6.0
|
||||
if got := h.Sum(); got != expectedSum {
|
||||
t.Errorf("expected sum %f, got %f", expectedSum, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistry(t *testing.T) {
|
||||
t.Run("counter registration", func(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
c1 := r.Counter("test_counter", "help", nil)
|
||||
c2 := r.Counter("test_counter", "help", nil)
|
||||
|
||||
if c1 != c2 {
|
||||
t.Error("expected same counter instance for same name")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("counter with labels", func(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
labels1 := map[string]string{"method": "get"}
|
||||
labels2 := map[string]string{"method": "post"}
|
||||
|
||||
c1 := r.Counter("http_requests", "help", labels1)
|
||||
c2 := r.Counter("http_requests", "help", labels2)
|
||||
|
||||
if c1 == c2 {
|
||||
t.Error("expected different counter instances for different labels")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("gauge registration", func(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
g1 := r.Gauge("test_gauge", "help", nil)
|
||||
g2 := r.Gauge("test_gauge", "help", nil)
|
||||
|
||||
if g1 != g2 {
|
||||
t.Error("expected same gauge instance for same name")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("histogram registration", func(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
h1 := r.Histogram("test_histogram", "help", nil, nil)
|
||||
h2 := r.Histogram("test_histogram", "help", nil, nil)
|
||||
|
||||
if h1 != h2 {
|
||||
t.Error("expected same histogram instance for same name")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistryConcurrency(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
|
||||
// Concurrent access to registry
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
c := r.Counter("concurrent_test", "test", nil)
|
||||
c.Inc()
|
||||
|
||||
g := r.Gauge("concurrent_gauge", "test", nil)
|
||||
g.Inc()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
c := r.Counter("concurrent_test", "test", nil)
|
||||
if c.Value() != int64(numGoroutines) {
|
||||
t.Errorf("expected counter value %d, got %d", numGoroutines, c.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExpose(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Add some metrics
|
||||
c := r.Counter("test_requests_total", "Total requests", nil)
|
||||
c.Add(42)
|
||||
|
||||
g := r.Gauge("test_connections", "Active connections", nil)
|
||||
g.Set(10)
|
||||
|
||||
h := r.Histogram("test_duration_seconds", "Request duration", nil, []float64{0.1, 0.5, 1.0})
|
||||
h.Observe(0.05)
|
||||
h.Observe(0.3)
|
||||
h.Observe(0.8)
|
||||
|
||||
output := r.Expose()
|
||||
|
||||
// Check counter output
|
||||
if !strings.Contains(output, "# TYPE test_requests_total counter") {
|
||||
t.Error("expected counter type in output")
|
||||
}
|
||||
if !strings.Contains(output, "test_requests_total 42") {
|
||||
t.Error("expected counter value in output")
|
||||
}
|
||||
|
||||
// Check gauge output
|
||||
if !strings.Contains(output, "# TYPE test_connections gauge") {
|
||||
t.Error("expected gauge type in output")
|
||||
}
|
||||
if !strings.Contains(output, "test_connections 10") {
|
||||
t.Error("expected gauge value in output")
|
||||
}
|
||||
|
||||
// Check histogram output
|
||||
if !strings.Contains(output, "# TYPE test_duration_seconds histogram") {
|
||||
t.Error("expected histogram type in output")
|
||||
}
|
||||
if !strings.Contains(output, "test_duration_seconds_bucket") {
|
||||
t.Error("expected histogram buckets in output")
|
||||
}
|
||||
if !strings.Contains(output, "test_duration_seconds_sum") {
|
||||
t.Error("expected histogram sum in output")
|
||||
}
|
||||
if !strings.Contains(output, "test_duration_seconds_count") {
|
||||
t.Error("expected histogram count in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryReset(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
c := r.Counter("test_counter", "test", nil)
|
||||
c.Add(100)
|
||||
|
||||
g := r.Gauge("test_gauge", "test", nil)
|
||||
g.Set(50)
|
||||
|
||||
r.Reset()
|
||||
|
||||
if c.Value() != 0 {
|
||||
t.Errorf("expected counter reset to 0, got %d", c.Value())
|
||||
}
|
||||
if g.Value() != 0 {
|
||||
t.Errorf("expected gauge reset to 0, got %d", g.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerMetrics(t *testing.T) {
|
||||
t.Run("creation", func(t *testing.T) {
|
||||
m := NewServerMetrics()
|
||||
|
||||
if m.RequestsTotal == nil {
|
||||
t.Error("RequestsTotal should not be nil")
|
||||
}
|
||||
if m.RequestErrors == nil {
|
||||
t.Error("RequestErrors should not be nil")
|
||||
}
|
||||
if m.RequestDuration == nil {
|
||||
t.Error("RequestDuration should not be nil")
|
||||
}
|
||||
if m.CacheHits == nil {
|
||||
t.Error("CacheHits should not be nil")
|
||||
}
|
||||
if m.CacheMisses == nil {
|
||||
t.Error("CacheMisses should not be nil")
|
||||
}
|
||||
if m.ActiveLSPServers == nil {
|
||||
t.Error("ActiveLSPServers should not be nil")
|
||||
}
|
||||
if m.ParseDuration == nil {
|
||||
t.Error("ParseDuration should not be nil")
|
||||
}
|
||||
if m.ParseErrors == nil {
|
||||
t.Error("ParseErrors should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("record request success", func(t *testing.T) {
|
||||
m := NewServerMetrics()
|
||||
|
||||
m.RecordRequest(100*time.Millisecond, nil)
|
||||
|
||||
if m.RequestsTotal.Value() != 1 {
|
||||
t.Errorf("expected RequestsTotal 1, got %d", m.RequestsTotal.Value())
|
||||
}
|
||||
if m.RequestErrors.Value() != 0 {
|
||||
t.Errorf("expected RequestErrors 0, got %d", m.RequestErrors.Value())
|
||||
}
|
||||
if m.RequestDuration.Count() != 1 {
|
||||
t.Errorf("expected RequestDuration count 1, got %d", m.RequestDuration.Count())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("record request error", func(t *testing.T) {
|
||||
m := NewServerMetrics()
|
||||
|
||||
m.RecordRequest(50*time.Millisecond, &testError{})
|
||||
|
||||
if m.RequestsTotal.Value() != 1 {
|
||||
t.Errorf("expected RequestsTotal 1, got %d", m.RequestsTotal.Value())
|
||||
}
|
||||
if m.RequestErrors.Value() != 1 {
|
||||
t.Errorf("expected RequestErrors 1, got %d", m.RequestErrors.Value())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("record parse", func(t *testing.T) {
|
||||
m := NewServerMetrics()
|
||||
|
||||
m.RecordParse(10*time.Millisecond, nil)
|
||||
m.RecordParse(5*time.Millisecond, &testError{})
|
||||
|
||||
if m.ParseDuration.Count() != 2 {
|
||||
t.Errorf("expected ParseDuration count 2, got %d", m.ParseDuration.Count())
|
||||
}
|
||||
if m.ParseErrors.Value() != 1 {
|
||||
t.Errorf("expected ParseErrors 1, got %d", m.ParseErrors.Value())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("record cache", func(t *testing.T) {
|
||||
m := NewServerMetrics()
|
||||
|
||||
m.RecordCacheHit()
|
||||
m.RecordCacheHit()
|
||||
m.RecordCacheMiss()
|
||||
|
||||
if m.CacheHits.Value() != 2 {
|
||||
t.Errorf("expected CacheHits 2, got %d", m.CacheHits.Value())
|
||||
}
|
||||
if m.CacheMisses.Value() != 1 {
|
||||
t.Errorf("expected CacheMisses 1, got %d", m.CacheMisses.Value())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set active LSP servers", func(t *testing.T) {
|
||||
m := NewServerMetrics()
|
||||
|
||||
m.SetActiveLSPServers(5)
|
||||
if m.ActiveLSPServers.Value() != 5 {
|
||||
t.Errorf("expected ActiveLSPServers 5, got %d", m.ActiveLSPServers.Value())
|
||||
}
|
||||
|
||||
m.SetActiveLSPServers(3)
|
||||
if m.ActiveLSPServers.Value() != 3 {
|
||||
t.Errorf("expected ActiveLSPServers 3, got %d", m.ActiveLSPServers.Value())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expose", func(t *testing.T) {
|
||||
m := NewServerMetrics()
|
||||
m.RecordRequest(100*time.Millisecond, nil)
|
||||
|
||||
output := m.Expose()
|
||||
|
||||
if !strings.Contains(output, "mcp_requests_total") {
|
||||
t.Error("expected mcp_requests_total in output")
|
||||
}
|
||||
if !strings.Contains(output, "mcp_request_duration_seconds") {
|
||||
t.Error("expected mcp_request_duration_seconds in output")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
labels map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no labels",
|
||||
labels: nil,
|
||||
expected: "test_metric",
|
||||
},
|
||||
{
|
||||
name: "single label",
|
||||
labels: map[string]string{"method": "get"},
|
||||
expected: `test_metric{method="get"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple labels sorted",
|
||||
labels: map[string]string{"method": "get", "code": "200"},
|
||||
expected: `test_metric{code="200",method="get"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := metricKey("test_metric", tt.labels)
|
||||
if got != tt.expected {
|
||||
t.Errorf("metricKey() = %q, want %q", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatLabels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
labels map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no labels",
|
||||
labels: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single label",
|
||||
labels: map[string]string{"method": "get"},
|
||||
expected: `{method="get"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple labels sorted",
|
||||
labels: map[string]string{"method": "get", "code": "200"},
|
||||
expected: `{code="200",method="get"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := formatLabels(tt.labels)
|
||||
if got != tt.expected {
|
||||
t.Errorf("formatLabels() = %q, want %q", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testError is a simple error type for testing
|
||||
type testError struct{}
|
||||
|
||||
func (e *testError) Error() string { return "test error" }
|
||||
@@ -104,6 +104,7 @@ func NoComment() {}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
@@ -192,6 +193,7 @@ function validate(name) {}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
@@ -291,6 +293,7 @@ func TestExtractPythonDocComment(t *testing.T) {
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
@@ -372,6 +375,7 @@ int simple() { return 1; }
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
|
||||
@@ -34,6 +34,8 @@ type Registry struct {
|
||||
cache *lru.Cache[string, *CachedTree]
|
||||
maxParseSize int64
|
||||
mu sync.RWMutex
|
||||
parserMu sync.Map // per-language mutexes for parse serialization
|
||||
closed bool // Indicates if the registry has been closed
|
||||
|
||||
// Cache metrics (atomic for thread-safety)
|
||||
cacheHits atomic.Int64
|
||||
@@ -136,8 +138,13 @@ func getLanguage(lang protocol.Language) (*sitter.Language, error) {
|
||||
}
|
||||
|
||||
// GetParser returns a parser for the given language.
|
||||
// Returns an error if the registry has been closed.
|
||||
func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
|
||||
r.mu.RLock()
|
||||
if r.closed {
|
||||
r.mu.RUnlock()
|
||||
return nil, errors.New(errors.ErrInternal, "parser registry is closed")
|
||||
}
|
||||
if p, ok := r.parsers[lang]; ok {
|
||||
r.mu.RUnlock()
|
||||
return p, nil
|
||||
@@ -148,6 +155,11 @@ func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Check closed again after acquiring write lock
|
||||
if r.closed {
|
||||
return nil, errors.New(errors.ErrInternal, "parser registry is closed")
|
||||
}
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if p, ok := r.parsers[lang]; ok {
|
||||
return p, nil
|
||||
@@ -196,7 +208,8 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
|
||||
}
|
||||
|
||||
// Check cache (LRU cache is thread-safe)
|
||||
hash := contentHash(content)
|
||||
// Include language in cache key to prevent cross-language collisions
|
||||
hash := fmt.Sprintf("%s:%016x", string(lang), xxhash.Sum64(content))
|
||||
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
|
||||
r.cacheHits.Add(1)
|
||||
errors := extractErrors(cached.Tree.RootNode(), content)
|
||||
@@ -215,13 +228,16 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse content - tree-sitter parsers are not thread-safe,
|
||||
// so we need to hold the lock during parsing
|
||||
// Track parse duration
|
||||
// Parse content - tree-sitter parsers are not thread-safe per instance,
|
||||
// but parsers for different languages are independent.
|
||||
// Use per-language locks to allow concurrent parsing of different languages.
|
||||
muVal, _ := r.parserMu.LoadOrStore(lang, &sync.Mutex{})
|
||||
langMu := muVal.(*sync.Mutex)
|
||||
|
||||
start := time.Now()
|
||||
r.mu.Lock()
|
||||
langMu.Lock()
|
||||
tree, err := parser.ParseCtx(ctx, nil, content)
|
||||
r.mu.Unlock()
|
||||
langMu.Unlock()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Update duration metrics
|
||||
@@ -351,10 +367,14 @@ func isBinary(content []byte) bool {
|
||||
}
|
||||
|
||||
// Close closes all parsers and clears the cache.
|
||||
// After Close is called, the registry cannot be used for parsing.
|
||||
func (r *Registry) Close() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Mark as closed first to prevent new parse operations
|
||||
r.closed = true
|
||||
|
||||
for _, p := range r.parsers {
|
||||
p.Close()
|
||||
}
|
||||
|
||||
+23
-21
@@ -214,9 +214,10 @@ func (s *Searcher) buildArgs(req *Request) []string {
|
||||
args = append(args, "--no-ignore")
|
||||
}
|
||||
|
||||
// Max count per file to limit results
|
||||
// Global result cap — --max-total-count stops rg early across all files.
|
||||
// Requires ripgrep >= 13.0. In-process truncation in parseOutput is kept as a safety net.
|
||||
if req.MaxResults > 0 {
|
||||
args = append(args, fmt.Sprintf("--max-count=%d", req.MaxResults))
|
||||
args = append(args, fmt.Sprintf("--max-total-count=%d", req.MaxResults))
|
||||
}
|
||||
|
||||
// Add pattern
|
||||
@@ -243,9 +244,9 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
|
||||
Results: []Result{},
|
||||
}
|
||||
|
||||
// Track context by file and line
|
||||
contextBefore := make(map[string][]string) // file -> lines before current match
|
||||
currentFile := ""
|
||||
// Track before-context lines linearly: accumulate context lines until the next match consumes them.
|
||||
var pendingBefore []string
|
||||
pendingFile := ""
|
||||
|
||||
scanner := bufio.NewScanner(output)
|
||||
for scanner.Scan() {
|
||||
@@ -285,14 +286,14 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
|
||||
result.Column = match.Submatches[0].Start + 1 // 1-indexed
|
||||
}
|
||||
|
||||
// Add context before
|
||||
if ctx, ok := contextBefore[match.Path.Text]; ok {
|
||||
result.Context.Before = ctx
|
||||
delete(contextBefore, match.Path.Text)
|
||||
// Attach pending before-context if it belongs to this file
|
||||
if pendingFile == match.Path.Text && len(pendingBefore) > 0 {
|
||||
result.Context.Before = pendingBefore
|
||||
}
|
||||
pendingBefore = nil
|
||||
pendingFile = ""
|
||||
|
||||
results.Results = append(results.Results, result)
|
||||
currentFile = match.Path.Text
|
||||
|
||||
case "context":
|
||||
var ctx rgContext
|
||||
@@ -302,19 +303,20 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
|
||||
|
||||
lineText := strings.TrimRight(ctx.Lines.Text, "\n\r")
|
||||
|
||||
// Determine if this is before or after context
|
||||
isAfter := false
|
||||
if len(results.Results) > 0 {
|
||||
lastResult := &results.Results[len(results.Results)-1]
|
||||
if lastResult.File == ctx.Path.Text && ctx.LineNumber > lastResult.Line {
|
||||
// This is after context
|
||||
lastResult.Context.After = append(lastResult.Context.After, lineText)
|
||||
} else if ctx.Path.Text == currentFile || currentFile == "" {
|
||||
// This is before context for a potential upcoming match
|
||||
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
|
||||
last := &results.Results[len(results.Results)-1]
|
||||
if last.File == ctx.Path.Text && ctx.LineNumber > last.Line {
|
||||
last.Context.After = append(last.Context.After, lineText)
|
||||
isAfter = true
|
||||
}
|
||||
} else {
|
||||
// Before any match - store as potential before context
|
||||
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
|
||||
}
|
||||
if !isAfter {
|
||||
if pendingFile != ctx.Path.Text {
|
||||
pendingBefore = nil
|
||||
pendingFile = ctx.Path.Text
|
||||
}
|
||||
pendingBefore = append(pendingBefore, lineText)
|
||||
}
|
||||
|
||||
case "summary":
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestBuildArgs(t *testing.T) {
|
||||
MaxResults: 10,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--max-count=10", "--", "test", "."},
|
||||
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
// Package server implements the MCP server for file operations.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// handleASTQuery handles the ast_query tool.
|
||||
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Acquire semaphore to limit concurrent queries (prevents CPU exhaustion)
|
||||
select {
|
||||
case s.querySem <- struct{}{}:
|
||||
defer func() { <-s.querySem }()
|
||||
case <-ctx.Done():
|
||||
return mcp.NewToolResultError("request cancelled"), nil
|
||||
}
|
||||
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
}
|
||||
|
||||
language, err := request.RequireString("language")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("language is required"), nil
|
||||
}
|
||||
|
||||
// Build query
|
||||
astQuery := &query.ASTQuery{
|
||||
Pattern: pattern,
|
||||
Language: language,
|
||||
Filters: query.QueryFilters{
|
||||
NameMatches: request.GetString("name_matches", ""),
|
||||
NameExact: request.GetString("name_exact", ""),
|
||||
KindIn: request.GetStringSlice("kind_in", nil),
|
||||
},
|
||||
}
|
||||
|
||||
maxResults := request.GetInt("max_results", 100)
|
||||
paths := request.GetStringSlice("paths", nil)
|
||||
|
||||
// Default to workspace root if no paths specified
|
||||
if len(paths) == 0 {
|
||||
paths = []string{s.cfg.WorkspaceRoot}
|
||||
}
|
||||
|
||||
// Find files to search based on language
|
||||
ext := languageToExtension(language)
|
||||
if ext == "" {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
|
||||
}
|
||||
|
||||
var allResults []query.MatchResult
|
||||
|
||||
// Walk through paths and find matching files
|
||||
for _, searchPath := range paths {
|
||||
// Validate path is within workspace
|
||||
if !s.cfg.IsPathAllowed(searchPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil // Skip files with errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Skip hidden directories
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check file extension matches language
|
||||
if !strings.HasSuffix(path, ext) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read and parse file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil // Skip unreadable files
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return nil // Skip large files
|
||||
}
|
||||
|
||||
// Parse file
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return nil // Skip unparseable files
|
||||
}
|
||||
|
||||
// Run query
|
||||
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
|
||||
if err != nil {
|
||||
return nil // Skip on error
|
||||
}
|
||||
|
||||
allResults = append(allResults, matches...)
|
||||
|
||||
// Stop if we have enough results
|
||||
if maxResults > 0 && len(allResults) >= maxResults {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn("error walking path", "path", searchPath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Format and return results
|
||||
output := query.FormatResults(allResults, maxResults)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
// generateASTSummary generates a summary of symbols in the file.
|
||||
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
|
||||
// Parse the file
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return "" // Silently skip AST if parsing fails
|
||||
}
|
||||
|
||||
// Extract symbols
|
||||
lang := protocol.DetectLanguage(path)
|
||||
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
|
||||
if len(symbols) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Get relative path
|
||||
relPath := path
|
||||
if absPath, err := filepath.Abs(path); err == nil {
|
||||
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
|
||||
relPath = rel
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
|
||||
sb.WriteString("Symbols:\n")
|
||||
|
||||
for _, sym := range symbols {
|
||||
kindStr := symbolKindIcon(sym.Kind)
|
||||
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// symbolKindIcon returns an icon/prefix for a symbol kind.
|
||||
func symbolKindIcon(kind protocol.SymbolKind) string {
|
||||
switch kind {
|
||||
case protocol.SymbolFunction:
|
||||
return "func"
|
||||
case protocol.SymbolMethod:
|
||||
return "meth"
|
||||
case protocol.SymbolClass:
|
||||
return "class"
|
||||
case protocol.SymbolStruct:
|
||||
return "struct"
|
||||
case protocol.SymbolInterface:
|
||||
return "iface"
|
||||
case protocol.SymbolVariable:
|
||||
return "var"
|
||||
case protocol.SymbolConstant:
|
||||
return "const"
|
||||
case protocol.SymbolType:
|
||||
return "type"
|
||||
case protocol.SymbolField:
|
||||
return "field"
|
||||
case protocol.SymbolProperty:
|
||||
return "prop"
|
||||
case protocol.SymbolModule:
|
||||
return "mod"
|
||||
case protocol.SymbolPackage:
|
||||
return "pkg"
|
||||
default:
|
||||
return "sym"
|
||||
}
|
||||
}
|
||||
|
||||
// languageToExtension maps language names to file extensions.
|
||||
func languageToExtension(language string) string {
|
||||
switch strings.ToLower(language) {
|
||||
case "go":
|
||||
return ".go"
|
||||
case "typescript":
|
||||
return ".ts"
|
||||
case "javascript":
|
||||
return ".js"
|
||||
case "python":
|
||||
return ".py"
|
||||
case "c":
|
||||
return ".c"
|
||||
case "cpp", "c++":
|
||||
return ".cpp"
|
||||
case "elixir":
|
||||
return ".ex"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
// Package server implements the MCP server for file operations.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/edit"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// handleEditPreview handles the edit_preview tool.
|
||||
func (s *Server) handleEditPreview(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return s.handleEdit(ctx, request, false)
|
||||
}
|
||||
|
||||
// handleEditApply handles the edit_apply tool.
|
||||
func (s *Server) handleEditApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return s.handleEdit(ctx, request, true)
|
||||
}
|
||||
|
||||
// handleEdit is the shared implementation for edit_preview and edit_apply.
|
||||
func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, apply bool) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
operation, err := request.RequireString("operation")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("operation is required"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Note: We no longer validate language support here.
|
||||
// The edit engine automatically detects whether to use AST or text mode.
|
||||
|
||||
// Build edit request with both AST and text-mode selectors
|
||||
astEdit := &edit.ASTEdit{
|
||||
File: file,
|
||||
Operation: edit.EditOperation(operation),
|
||||
NewContent: request.GetString("new_content", ""),
|
||||
Selector: edit.ASTSelector{
|
||||
// AST-mode selectors
|
||||
Kind: request.GetString("selector_kind", ""),
|
||||
Name: request.GetString("selector_name", ""),
|
||||
AtLine: request.GetInt("selector_line", 0),
|
||||
Index: request.GetInt("selector_index", 0),
|
||||
// Text-mode selectors
|
||||
LineEnd: request.GetInt("selector_line_end", 0),
|
||||
Text: request.GetString("selector_text", ""),
|
||||
TextPattern: request.GetString("selector_pattern", ""),
|
||||
},
|
||||
}
|
||||
|
||||
// Perform edit
|
||||
var result *edit.EditResult
|
||||
if apply {
|
||||
result, err = s.editor.Apply(ctx, astEdit)
|
||||
} else {
|
||||
result, err = s.editor.Preview(ctx, astEdit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("edit failed: %s", errors.SanitizeError(err))), nil
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
return mcp.NewToolResultError(result.Error), nil
|
||||
}
|
||||
|
||||
// Format output
|
||||
var output strings.Builder
|
||||
if apply {
|
||||
output.WriteString("**Edit Applied Successfully**\n\n")
|
||||
} else {
|
||||
output.WriteString("**Edit Preview**\n\n")
|
||||
}
|
||||
|
||||
output.WriteString("Diff:\n```diff\n")
|
||||
output.WriteString(result.Diff)
|
||||
output.WriteString("```\n")
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
// Package server implements the MCP server for file operations.
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// handleFileSearch handles the file_search tool.
|
||||
func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
s.logger.Debug("file_search completed",
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
}()
|
||||
|
||||
if s.searcher == nil {
|
||||
return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil
|
||||
}
|
||||
|
||||
// Parse request arguments using SDK helpers
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
}
|
||||
|
||||
req := &search.Request{
|
||||
Pattern: pattern,
|
||||
Paths: request.GetStringSlice("paths", nil),
|
||||
FileTypes: request.GetStringSlice("file_types", nil),
|
||||
IgnoreCase: request.GetBool("ignore_case", false),
|
||||
Regex: request.GetBool("regex", true),
|
||||
ContextLines: request.GetInt("context_lines", 2),
|
||||
MaxResults: request.GetInt("max_results", 0),
|
||||
}
|
||||
|
||||
// Execute search
|
||||
results, err := s.searcher.Search(ctx, req)
|
||||
if err != nil {
|
||||
s.logger.Warn("search error", "error", err)
|
||||
return mcp.NewToolResultError(fmt.Sprintf("search error: %s", errors.SanitizeError(err))), nil
|
||||
}
|
||||
|
||||
s.logger.Info("search completed",
|
||||
"pattern", pattern,
|
||||
"results_count", len(results.Results),
|
||||
"truncated", results.Truncated,
|
||||
)
|
||||
|
||||
// Format results
|
||||
output := s.searcher.FormatResults(results)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
// handleFileRead handles the file_read tool.
|
||||
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Acquire semaphore to limit concurrent reads (prevents memory exhaustion)
|
||||
select {
|
||||
case s.readSem <- struct{}{}:
|
||||
defer func() { <-s.readSem }()
|
||||
case <-ctx.Done():
|
||||
return mcp.NewToolResultError("request cancelled"), nil
|
||||
}
|
||||
|
||||
path, err := request.RequireString("path")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("path is required"), nil
|
||||
}
|
||||
|
||||
// Validate path is within workspace
|
||||
if !s.cfg.IsPathAllowed(path) {
|
||||
return mcp.NewToolResultError("path is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
|
||||
}
|
||||
s.logger.Warn("file read error", "path", path, "error", err)
|
||||
return mcp.NewToolResultError("error reading file"), nil
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
|
||||
}
|
||||
|
||||
// Handle line range
|
||||
lines := splitLines(string(content))
|
||||
lineStart := request.GetInt("line_start", 1)
|
||||
lineEnd := request.GetInt("line_end", len(lines))
|
||||
|
||||
// Clamp to valid range
|
||||
if lineStart < 1 {
|
||||
lineStart = 1
|
||||
}
|
||||
if lineEnd > len(lines) {
|
||||
lineEnd = len(lines)
|
||||
}
|
||||
if lineStart > lineEnd {
|
||||
lineStart = lineEnd
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
|
||||
// Include AST summary if requested
|
||||
includeAST := request.GetBool("include_ast", false)
|
||||
symbolsOnly := request.GetBool("symbols_only", false)
|
||||
maxLines := request.GetInt("max_lines", 0)
|
||||
|
||||
// Validate symbols_only requires include_ast
|
||||
if symbolsOnly && !includeAST {
|
||||
return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil
|
||||
}
|
||||
|
||||
if includeAST {
|
||||
astSummary := s.generateASTSummary(ctx, path, content)
|
||||
if astSummary != "" {
|
||||
output.WriteString(astSummary)
|
||||
if !symbolsOnly {
|
||||
output.WriteString("\n---\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip file content if symbols_only mode
|
||||
if !symbolsOnly {
|
||||
// Apply max_lines limit if specified
|
||||
effectiveEnd := lineEnd
|
||||
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
|
||||
effectiveEnd = lineStart + maxLines - 1
|
||||
if effectiveEnd < lineEnd {
|
||||
// Add note that output was truncated
|
||||
defer func() {
|
||||
output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Extract requested lines
|
||||
for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ {
|
||||
output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i]))
|
||||
}
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// splitLines splits a string into lines.
|
||||
// For large files (> 1MB), uses bufio.Scanner which is more memory efficient.
|
||||
// For smaller files, uses simple string split which is faster.
|
||||
func splitLines(s string) []string {
|
||||
const largeSizeThreshold = 1024 * 1024 // 1MB
|
||||
|
||||
if len(s) > largeSizeThreshold {
|
||||
// Use scanner for large files
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
var lines []string
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
// Handle potential error and add empty line if string ended with newline
|
||||
if len(s) > 0 && s[len(s)-1] == '\n' {
|
||||
lines = append(lines, "")
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// Use optimized stdlib implementation for smaller files (2-3x faster than manual loop)
|
||||
return strings.Split(s, "\n")
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
// Package server implements the MCP server for file operations.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/lsp"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// handleSymbolAt handles the symbol_at tool.
|
||||
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Try LSP hover
|
||||
hover, err := s.lspManager.Hover(ctx, file, line, col)
|
||||
if err != nil {
|
||||
// Fall back to AST-based info
|
||||
return s.handleSymbolAtFallback(ctx, file, line, col)
|
||||
}
|
||||
|
||||
if hover == nil {
|
||||
return mcp.NewToolResultText("No symbol information available at this position."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information**\n\n")
|
||||
output.WriteString(hover.Contents.Value)
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
|
||||
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %s", errors.SanitizeError(err))), nil
|
||||
}
|
||||
|
||||
result, err := s.parser.Parse(ctx, file, content)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to parse file: %s", errors.SanitizeError(err))), nil
|
||||
}
|
||||
|
||||
node := parser.FindNodeAtPosition(result.Tree, line, col)
|
||||
if node == nil {
|
||||
return mcp.NewToolResultText("No symbol at this position."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information** (AST fallback)\n\n")
|
||||
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
|
||||
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindDefinition handles the find_definition tool.
|
||||
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
locations, err := s.lspManager.Definition(ctx, file, line, col)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %s", errors.SanitizeError(err))), nil
|
||||
}
|
||||
|
||||
if len(locations) == 0 {
|
||||
return mcp.NewToolResultText("No definition found."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
|
||||
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
|
||||
// Try to read a preview snippet
|
||||
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
|
||||
if preview != "" {
|
||||
output.WriteString("```\n")
|
||||
output.WriteString(preview)
|
||||
output.WriteString("```\n")
|
||||
}
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindReferences handles the find_references tool.
|
||||
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
includeDecl := request.GetBool("include_declaration", true)
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %s", errors.SanitizeError(err))), nil
|
||||
}
|
||||
|
||||
if len(locations) == 0 {
|
||||
return mcp.NewToolResultText("No references found."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
|
||||
|
||||
// Group by file
|
||||
fileGroups := make(map[string][]lsp.Location)
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
fileGroups[filePath] = append(fileGroups[filePath], loc)
|
||||
}
|
||||
|
||||
for filePath, locs := range fileGroups {
|
||||
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
|
||||
for _, loc := range locs {
|
||||
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
}
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// readFilePreview reads a few lines from a file around the given line.
|
||||
func readFilePreview(file string, line, contextLines int) string {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := splitLines(string(content))
|
||||
startLine := max(1, line-contextLines)
|
||||
endLine := min(line+contextLines, len(lines))
|
||||
|
||||
var preview strings.Builder
|
||||
for i := startLine - 1; i < endLine && i < len(lines); i++ {
|
||||
lineText := lines[i]
|
||||
if len(lineText) > PreviewLineMaxLength {
|
||||
lineText = lineText[:PreviewLineMaxLength] + "..."
|
||||
}
|
||||
prefix := " "
|
||||
if i+1 == line {
|
||||
prefix = "> "
|
||||
}
|
||||
preview.WriteString(fmt.Sprintf("%s%4d: %s\n", prefix, i+1, lineText))
|
||||
}
|
||||
|
||||
return preview.String()
|
||||
}
|
||||
@@ -52,6 +52,7 @@ func Hello() string {
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("handlePing() returned nil")
|
||||
return
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("handlePing() returned empty content")
|
||||
@@ -70,6 +71,7 @@ func Hello() string {
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead() returned nil")
|
||||
return
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("handleFileRead() returned empty content")
|
||||
@@ -90,6 +92,7 @@ func Hello() string {
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("handleASTQuery() returned nil")
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
@@ -110,6 +113,7 @@ func Hello() string {
|
||||
}
|
||||
if previewResult == nil {
|
||||
t.Fatal("handleEditPreview() returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file unchanged after preview
|
||||
@@ -127,6 +131,7 @@ func Hello() string {
|
||||
}
|
||||
if applyResult == nil {
|
||||
t.Fatal("handleEditApply() returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file changed after apply
|
||||
@@ -352,6 +357,7 @@ func Add(a, b int) int {
|
||||
}
|
||||
if readResult == nil {
|
||||
t.Fatal("handleFileRead() returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Query AST
|
||||
@@ -367,6 +373,7 @@ func Add(a, b int) int {
|
||||
}
|
||||
if queryResult == nil {
|
||||
t.Fatal("handleASTQuery() returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Preview edit
|
||||
@@ -384,6 +391,7 @@ func Add(a, b int) int {
|
||||
}
|
||||
if editResult == nil {
|
||||
t.Fatal("handleEditPreview() returned nil")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+23
-650
@@ -2,14 +2,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -19,11 +15,22 @@ import (
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// MaxConcurrentReads limits concurrent file read operations to prevent memory exhaustion.
|
||||
const MaxConcurrentReads = 10
|
||||
|
||||
// MaxConcurrentQueries limits concurrent AST query operations to prevent CPU exhaustion.
|
||||
const MaxConcurrentQueries = 5
|
||||
|
||||
// ServerShutdownTimeout is the timeout for graceful server shutdown.
|
||||
const ServerShutdownTimeout = 10 * time.Second
|
||||
|
||||
// PreviewLineMaxLength is the maximum length for preview lines before truncation.
|
||||
const PreviewLineMaxLength = 100
|
||||
|
||||
// Server represents the MCP file operations server.
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
@@ -34,17 +41,21 @@ type Server struct {
|
||||
matcher *query.Matcher
|
||||
lspManager *lsp.Manager
|
||||
editor *edit.Engine
|
||||
readSem chan struct{} // Semaphore for limiting concurrent file reads
|
||||
querySem chan struct{} // Semaphore for limiting concurrent AST queries
|
||||
}
|
||||
|
||||
// New creates a new MCP server instance.
|
||||
func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
|
||||
parserRegistry := parser.NewRegistryWithSize(cfg.MaxParseSize)
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
parser: parserRegistry,
|
||||
matcher: query.NewMatcher(parserRegistry),
|
||||
editor: edit.NewEngine(parserRegistry),
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
parser: parserRegistry,
|
||||
matcher: query.NewMatcher(parserRegistry),
|
||||
editor: edit.NewEngine(parserRegistry),
|
||||
readSem: make(chan struct{}, MaxConcurrentReads),
|
||||
querySem: make(chan struct{}, MaxConcurrentQueries),
|
||||
}
|
||||
|
||||
// Initialize searcher
|
||||
@@ -341,644 +352,6 @@ func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*
|
||||
return mcp.NewToolResultText("pong"), nil
|
||||
}
|
||||
|
||||
// handleFileSearch handles the file_search tool.
|
||||
func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
s.logger.Debug("file_search completed",
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
}()
|
||||
|
||||
if s.searcher == nil {
|
||||
return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil
|
||||
}
|
||||
|
||||
// Parse request arguments using SDK helpers
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
}
|
||||
|
||||
req := &search.Request{
|
||||
Pattern: pattern,
|
||||
Paths: request.GetStringSlice("paths", nil),
|
||||
FileTypes: request.GetStringSlice("file_types", nil),
|
||||
IgnoreCase: request.GetBool("ignore_case", false),
|
||||
Regex: request.GetBool("regex", true),
|
||||
ContextLines: request.GetInt("context_lines", 2),
|
||||
MaxResults: request.GetInt("max_results", 0),
|
||||
}
|
||||
|
||||
// Execute search
|
||||
results, err := s.searcher.Search(ctx, req)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("search error: %v", err)), nil
|
||||
}
|
||||
|
||||
s.logger.Info("search completed",
|
||||
"pattern", pattern,
|
||||
"results_count", len(results.Results),
|
||||
"truncated", results.Truncated,
|
||||
)
|
||||
|
||||
// Format results
|
||||
output := s.searcher.FormatResults(results)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
// handleFileRead handles the file_read tool.
|
||||
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
path, err := request.RequireString("path")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("path is required"), nil
|
||||
}
|
||||
|
||||
// Validate path is within workspace
|
||||
if !s.cfg.IsPathAllowed(path) {
|
||||
return mcp.NewToolResultError("path is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
|
||||
}
|
||||
return mcp.NewToolResultError(fmt.Sprintf("error reading file: %v", err)), nil
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
|
||||
}
|
||||
|
||||
// Handle line range
|
||||
lines := splitLines(string(content))
|
||||
lineStart := request.GetInt("line_start", 1)
|
||||
lineEnd := request.GetInt("line_end", len(lines))
|
||||
|
||||
// Clamp to valid range
|
||||
if lineStart < 1 {
|
||||
lineStart = 1
|
||||
}
|
||||
if lineEnd > len(lines) {
|
||||
lineEnd = len(lines)
|
||||
}
|
||||
if lineStart > lineEnd {
|
||||
lineStart = lineEnd
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
|
||||
// Include AST summary if requested
|
||||
includeAST := request.GetBool("include_ast", false)
|
||||
symbolsOnly := request.GetBool("symbols_only", false)
|
||||
maxLines := request.GetInt("max_lines", 0)
|
||||
|
||||
// Validate symbols_only requires include_ast
|
||||
if symbolsOnly && !includeAST {
|
||||
return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil
|
||||
}
|
||||
|
||||
if includeAST {
|
||||
astSummary := s.generateASTSummary(ctx, path, content)
|
||||
if astSummary != "" {
|
||||
output.WriteString(astSummary)
|
||||
if !symbolsOnly {
|
||||
output.WriteString("\n---\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip file content if symbols_only mode
|
||||
if !symbolsOnly {
|
||||
// Apply max_lines limit if specified
|
||||
effectiveEnd := lineEnd
|
||||
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
|
||||
effectiveEnd = lineStart + maxLines - 1
|
||||
if effectiveEnd < lineEnd {
|
||||
// Add note that output was truncated
|
||||
defer func() {
|
||||
output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Extract requested lines
|
||||
for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ {
|
||||
output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i]))
|
||||
}
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// generateASTSummary generates a summary of symbols in the file.
|
||||
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
|
||||
// Parse the file
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return "" // Silently skip AST if parsing fails
|
||||
}
|
||||
|
||||
// Extract symbols
|
||||
lang := protocol.DetectLanguage(path)
|
||||
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
|
||||
if len(symbols) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Get relative path
|
||||
relPath := path
|
||||
if absPath, err := filepath.Abs(path); err == nil {
|
||||
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
|
||||
relPath = rel
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
|
||||
sb.WriteString("Symbols:\n")
|
||||
|
||||
for _, sym := range symbols {
|
||||
kindStr := symbolKindIcon(sym.Kind)
|
||||
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// symbolKindIcon returns an icon/prefix for a symbol kind.
|
||||
func symbolKindIcon(kind protocol.SymbolKind) string {
|
||||
switch kind {
|
||||
case protocol.SymbolFunction:
|
||||
return "func"
|
||||
case protocol.SymbolMethod:
|
||||
return "meth"
|
||||
case protocol.SymbolClass:
|
||||
return "class"
|
||||
case protocol.SymbolStruct:
|
||||
return "struct"
|
||||
case protocol.SymbolInterface:
|
||||
return "iface"
|
||||
case protocol.SymbolVariable:
|
||||
return "var"
|
||||
case protocol.SymbolConstant:
|
||||
return "const"
|
||||
case protocol.SymbolType:
|
||||
return "type"
|
||||
case protocol.SymbolField:
|
||||
return "field"
|
||||
case protocol.SymbolProperty:
|
||||
return "prop"
|
||||
case protocol.SymbolModule:
|
||||
return "mod"
|
||||
case protocol.SymbolPackage:
|
||||
return "pkg"
|
||||
default:
|
||||
return "sym"
|
||||
}
|
||||
}
|
||||
|
||||
func splitLines(s string) []string {
|
||||
// For large files (> 1MB), use bufio.Scanner which is more memory efficient
|
||||
// For smaller files, use simple string split which is faster
|
||||
const largeSizeThreshold = 1024 * 1024 // 1MB
|
||||
|
||||
if len(s) > largeSizeThreshold {
|
||||
// Use scanner for large files
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
var lines []string
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
// Handle potential error and add empty line if string ended with newline
|
||||
if len(s) > 0 && s[len(s)-1] == '\n' {
|
||||
lines = append(lines, "")
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// Use optimized stdlib implementation for smaller files (2-3x faster than manual loop)
|
||||
return strings.Split(s, "\n")
|
||||
}
|
||||
|
||||
// handleASTQuery handles the ast_query tool.
|
||||
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
}
|
||||
|
||||
language, err := request.RequireString("language")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("language is required"), nil
|
||||
}
|
||||
|
||||
// Build query
|
||||
astQuery := &query.ASTQuery{
|
||||
Pattern: pattern,
|
||||
Language: language,
|
||||
Filters: query.QueryFilters{
|
||||
NameMatches: request.GetString("name_matches", ""),
|
||||
NameExact: request.GetString("name_exact", ""),
|
||||
KindIn: request.GetStringSlice("kind_in", nil),
|
||||
},
|
||||
}
|
||||
|
||||
maxResults := request.GetInt("max_results", 100)
|
||||
paths := request.GetStringSlice("paths", nil)
|
||||
|
||||
// Default to workspace root if no paths specified
|
||||
if len(paths) == 0 {
|
||||
paths = []string{s.cfg.WorkspaceRoot}
|
||||
}
|
||||
|
||||
// Find files to search based on language
|
||||
ext := languageToExtension(language)
|
||||
if ext == "" {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
|
||||
}
|
||||
|
||||
var allResults []query.MatchResult
|
||||
|
||||
// Walk through paths and find matching files
|
||||
for _, searchPath := range paths {
|
||||
// Validate path is within workspace
|
||||
if !s.cfg.IsPathAllowed(searchPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil // Skip files with errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Skip hidden directories
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check file extension matches language
|
||||
if !strings.HasSuffix(path, ext) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read and parse file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil // Skip unreadable files
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return nil // Skip large files
|
||||
}
|
||||
|
||||
// Parse file
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return nil // Skip unparseable files
|
||||
}
|
||||
|
||||
// Run query
|
||||
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
|
||||
if err != nil {
|
||||
return nil // Skip on error
|
||||
}
|
||||
|
||||
allResults = append(allResults, matches...)
|
||||
|
||||
// Stop if we have enough results
|
||||
if maxResults > 0 && len(allResults) >= maxResults {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn("error walking path", "path", searchPath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Format and return results
|
||||
output := query.FormatResults(allResults, maxResults)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
// languageToExtension maps language names to file extensions.
|
||||
func languageToExtension(language string) string {
|
||||
switch strings.ToLower(language) {
|
||||
case "go":
|
||||
return ".go"
|
||||
case "typescript":
|
||||
return ".ts"
|
||||
case "javascript":
|
||||
return ".js"
|
||||
case "python":
|
||||
return ".py"
|
||||
case "c":
|
||||
return ".c"
|
||||
case "cpp", "c++":
|
||||
return ".cpp"
|
||||
case "elixir":
|
||||
return ".ex"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// handleSymbolAt handles the symbol_at tool.
|
||||
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Try LSP hover
|
||||
hover, err := s.lspManager.Hover(ctx, file, line, col)
|
||||
if err != nil {
|
||||
// Fall back to AST-based info
|
||||
return s.handleSymbolAtFallback(ctx, file, line, col)
|
||||
}
|
||||
|
||||
if hover == nil {
|
||||
return mcp.NewToolResultText("No symbol information available at this position."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information**\n\n")
|
||||
output.WriteString(hover.Contents.Value)
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
|
||||
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %v", err)), nil
|
||||
}
|
||||
|
||||
result, err := s.parser.Parse(ctx, file, content)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to parse file: %v", err)), nil
|
||||
}
|
||||
|
||||
node := parser.FindNodeAtPosition(result.Tree, line, col)
|
||||
if node == nil {
|
||||
return mcp.NewToolResultText("No symbol at this position."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information** (AST fallback)\n\n")
|
||||
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
|
||||
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindDefinition handles the find_definition tool.
|
||||
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
locations, err := s.lspManager.Definition(ctx, file, line, col)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %v", err)), nil
|
||||
}
|
||||
|
||||
if len(locations) == 0 {
|
||||
return mcp.NewToolResultText("No definition found."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
|
||||
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
|
||||
// Try to read a preview snippet
|
||||
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
|
||||
if preview != "" {
|
||||
output.WriteString("```\n")
|
||||
output.WriteString(preview)
|
||||
output.WriteString("```\n")
|
||||
}
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindReferences handles the find_references tool.
|
||||
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
includeDecl := request.GetBool("include_declaration", true)
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %v", err)), nil
|
||||
}
|
||||
|
||||
if len(locations) == 0 {
|
||||
return mcp.NewToolResultText("No references found."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
|
||||
|
||||
// Group by file
|
||||
fileGroups := make(map[string][]lsp.Location)
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
fileGroups[filePath] = append(fileGroups[filePath], loc)
|
||||
}
|
||||
|
||||
for filePath, locs := range fileGroups {
|
||||
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
|
||||
for _, loc := range locs {
|
||||
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
}
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// readFilePreview reads a few lines from a file around the given line.
|
||||
func readFilePreview(file string, line, contextLines int) string {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := splitLines(string(content))
|
||||
startLine := max(1, line-contextLines)
|
||||
endLine := min(line+contextLines, len(lines))
|
||||
|
||||
var preview strings.Builder
|
||||
for i := startLine - 1; i < endLine && i < len(lines); i++ {
|
||||
lineText := lines[i]
|
||||
if len(lineText) > 100 {
|
||||
lineText = lineText[:100] + "..."
|
||||
}
|
||||
prefix := " "
|
||||
if i+1 == line {
|
||||
prefix = "> "
|
||||
}
|
||||
preview.WriteString(fmt.Sprintf("%s%4d: %s\n", prefix, i+1, lineText))
|
||||
}
|
||||
|
||||
return preview.String()
|
||||
}
|
||||
|
||||
// handleEditPreview handles the edit_preview tool.
|
||||
func (s *Server) handleEditPreview(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return s.handleEdit(ctx, request, false)
|
||||
}
|
||||
|
||||
// handleEditApply handles the edit_apply tool.
|
||||
func (s *Server) handleEditApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return s.handleEdit(ctx, request, true)
|
||||
}
|
||||
|
||||
// handleEdit is the shared implementation for edit_preview and edit_apply.
|
||||
func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, apply bool) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
operation, err := request.RequireString("operation")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("operation is required"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Note: We no longer validate language support here.
|
||||
// The edit engine automatically detects whether to use AST or text mode.
|
||||
|
||||
// Build edit request with both AST and text-mode selectors
|
||||
astEdit := &edit.ASTEdit{
|
||||
File: file,
|
||||
Operation: edit.EditOperation(operation),
|
||||
NewContent: request.GetString("new_content", ""),
|
||||
Selector: edit.ASTSelector{
|
||||
// AST-mode selectors
|
||||
Kind: request.GetString("selector_kind", ""),
|
||||
Name: request.GetString("selector_name", ""),
|
||||
AtLine: request.GetInt("selector_line", 0),
|
||||
Index: request.GetInt("selector_index", 0),
|
||||
// Text-mode selectors
|
||||
LineEnd: request.GetInt("selector_line_end", 0),
|
||||
Text: request.GetString("selector_text", ""),
|
||||
TextPattern: request.GetString("selector_pattern", ""),
|
||||
},
|
||||
}
|
||||
|
||||
// Perform edit
|
||||
var result *edit.EditResult
|
||||
if apply {
|
||||
result, err = s.editor.Apply(ctx, astEdit)
|
||||
} else {
|
||||
result, err = s.editor.Preview(ctx, astEdit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("edit failed: %v", err)), nil
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
return mcp.NewToolResultError(result.Error), nil
|
||||
}
|
||||
|
||||
// Format output
|
||||
var output strings.Builder
|
||||
if apply {
|
||||
output.WriteString("**Edit Applied Successfully**\n\n")
|
||||
} else {
|
||||
output.WriteString("**Edit Preview**\n\n")
|
||||
}
|
||||
|
||||
output.WriteString("Diff:\n```diff\n")
|
||||
output.WriteString(result.Diff)
|
||||
output.WriteString("```\n")
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// Run starts the MCP server and blocks until shutdown.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
// Set up signal handling for graceful shutdown
|
||||
@@ -1007,7 +380,7 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
s.logger.Info("received shutdown signal", "signal", sig)
|
||||
|
||||
// Create timeout context for shutdown
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Call graceful shutdown
|
||||
@@ -1025,7 +398,7 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
|
||||
case <-ctx.Done():
|
||||
// Context cancelled externally
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := s.Shutdown(shutdownCtx); err != nil {
|
||||
|
||||
@@ -29,8 +29,8 @@ func TestNew(t *testing.T) {
|
||||
|
||||
if srv == nil {
|
||||
t.Fatal("New() returned nil server")
|
||||
return
|
||||
}
|
||||
|
||||
if srv.cfg != cfg {
|
||||
t.Error("server config mismatch")
|
||||
}
|
||||
@@ -68,6 +68,7 @@ func TestHandlePing(t *testing.T) {
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handlePing() returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the result contains "pong"
|
||||
@@ -123,6 +124,7 @@ func Hello() {
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead() returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
contents := result.Content
|
||||
@@ -179,6 +181,7 @@ func Hello() {
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead() returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
contents := result.Content
|
||||
@@ -270,6 +273,7 @@ func Goodbye() error {
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleASTQuery() returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
contents := result.Content
|
||||
@@ -318,6 +322,7 @@ func Hello() {
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleEdit(preview) returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file was NOT modified (it's just a preview)
|
||||
@@ -367,6 +372,7 @@ func Hello() {
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleEdit(apply) returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file WAS modified
|
||||
|
||||
@@ -2,18 +2,74 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxPatternLength is the maximum allowed length for regex patterns.
|
||||
// This prevents memory issues from extremely long patterns.
|
||||
MaxPatternLength = 1000
|
||||
|
||||
// MaxCacheSize is the maximum number of patterns to cache.
|
||||
// When exceeded, the cache is cleared to prevent unbounded memory growth.
|
||||
MaxCacheSize = 10000
|
||||
)
|
||||
|
||||
// regexCache is a global thread-safe cache for compiled regular expressions.
|
||||
// Caching regex compilation provides 10-50x speedup for repeated patterns.
|
||||
var regexCache sync.Map // string -> *regexp.Regexp
|
||||
var (
|
||||
regexCache sync.Map // string -> *regexp.Regexp
|
||||
cacheSize atomic.Int64
|
||||
)
|
||||
|
||||
// CompileRegex compiles a regex pattern with caching for performance.
|
||||
// RegexError represents an error during regex compilation or validation.
|
||||
type RegexError struct {
|
||||
Pattern string
|
||||
Reason string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *RegexError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("regex error for pattern %q: %s: %v", e.Pattern, e.Reason, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("regex error for pattern %q: %s", e.Pattern, e.Reason)
|
||||
}
|
||||
|
||||
func (e *RegexError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ValidatePattern validates a regex pattern for safety.
|
||||
// Returns an error if the pattern is too long or appears malicious.
|
||||
func ValidatePattern(pattern string) error {
|
||||
// Check pattern length
|
||||
if len(pattern) > MaxPatternLength {
|
||||
return &RegexError{
|
||||
Pattern: truncatePattern(pattern),
|
||||
Reason: fmt.Sprintf("pattern too long (%d chars, max %d)", len(pattern), MaxPatternLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Go's regexp package uses Thompson NFA which guarantees O(n) matching time,
|
||||
// making it inherently resistant to ReDoS attacks. However, we still validate
|
||||
// pattern length to prevent memory issues during compilation.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompileRegex compiles a regex pattern with caching and validation for security.
|
||||
// Thread-safe: uses LoadOrStore to prevent race conditions.
|
||||
// Returns the compiled regex or an error if the pattern is invalid.
|
||||
// Returns the compiled regex or an error if the pattern is invalid or unsafe.
|
||||
func CompileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
// Validate pattern first
|
||||
if err := ValidatePattern(pattern); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if cached, ok := regexCache.Load(pattern); ok {
|
||||
return cached.(*regexp.Regexp), nil
|
||||
@@ -22,20 +78,64 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
// Compile regex
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &RegexError{
|
||||
Pattern: truncatePattern(pattern),
|
||||
Reason: "invalid regex syntax",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache size and clear if too large
|
||||
if cacheSize.Load() >= MaxCacheSize {
|
||||
ClearRegexCache()
|
||||
}
|
||||
|
||||
// Try to store - if another goroutine already stored it, use theirs
|
||||
// This prevents race conditions where multiple goroutines compile the same pattern
|
||||
actual, _ := regexCache.LoadOrStore(pattern, re)
|
||||
actual, loaded := regexCache.LoadOrStore(pattern, re)
|
||||
if !loaded {
|
||||
cacheSize.Add(1)
|
||||
}
|
||||
return actual.(*regexp.Regexp), nil
|
||||
}
|
||||
|
||||
// CompileRegexUncached compiles a regex pattern without caching.
|
||||
// Useful for one-off patterns that shouldn't pollute the cache.
|
||||
func CompileRegexUncached(pattern string) (*regexp.Regexp, error) {
|
||||
if err := ValidatePattern(pattern); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, &RegexError{
|
||||
Pattern: truncatePattern(pattern),
|
||||
Reason: "invalid regex syntax",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return re, nil
|
||||
}
|
||||
|
||||
// ClearRegexCache clears all cached compiled regular expressions.
|
||||
// Useful for testing or when memory usage needs to be reduced.
|
||||
func ClearRegexCache() {
|
||||
regexCache.Range(func(key, value interface{}) bool {
|
||||
regexCache.Range(func(key, _ interface{}) bool {
|
||||
regexCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
cacheSize.Store(0)
|
||||
}
|
||||
|
||||
// CacheStats returns the current number of cached patterns.
|
||||
func CacheStats() int64 {
|
||||
return cacheSize.Load()
|
||||
}
|
||||
|
||||
// truncatePattern truncates a pattern for display in error messages.
|
||||
func truncatePattern(pattern string) string {
|
||||
if len(pattern) > 50 {
|
||||
return pattern[:47] + "..."
|
||||
}
|
||||
return pattern
|
||||
}
|
||||
|
||||
@@ -0,0 +1,375 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidatePattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid short pattern",
|
||||
pattern: "^hello.*world$",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid empty pattern",
|
||||
pattern: "",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid pattern at max length",
|
||||
pattern: strings.Repeat("a", MaxPatternLength),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "pattern too long",
|
||||
pattern: strings.Repeat("a", MaxPatternLength+1),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "very long pattern",
|
||||
pattern: strings.Repeat("x", MaxPatternLength*2),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidatePattern(tt.pattern)
|
||||
if tt.expectErr && err == nil {
|
||||
t.Error("expected error but got nil")
|
||||
}
|
||||
if !tt.expectErr && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRegex(t *testing.T) {
|
||||
// Clear cache before each test
|
||||
ClearRegexCache()
|
||||
|
||||
t.Run("valid pattern is compiled and cached", func(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
pattern := "^test.*pattern$"
|
||||
re1, err := CompileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if re1 == nil {
|
||||
t.Fatal("expected non-nil regex")
|
||||
}
|
||||
|
||||
// Second call should return cached version
|
||||
re2, err := CompileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on second call: %v", err)
|
||||
}
|
||||
|
||||
// Should be the same pointer
|
||||
if re1 != re2 {
|
||||
t.Error("expected same regex instance from cache")
|
||||
}
|
||||
|
||||
// Cache should have one entry
|
||||
if stats := CacheStats(); stats != 1 {
|
||||
t.Errorf("expected cache size 1, got %d", stats)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid pattern returns error", func(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
pattern := "[invalid(regex"
|
||||
_, err := CompileRegex(pattern)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid regex")
|
||||
}
|
||||
|
||||
var regexErr *RegexError
|
||||
if !errors.As(err, ®exErr) {
|
||||
t.Errorf("expected RegexError, got %T", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pattern too long returns error", func(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
pattern := strings.Repeat("a", MaxPatternLength+1)
|
||||
_, err := CompileRegex(pattern)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for long pattern")
|
||||
}
|
||||
|
||||
var regexErr *RegexError
|
||||
if !errors.As(err, ®exErr) {
|
||||
t.Errorf("expected RegexError, got %T", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different patterns are cached separately", func(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
re1, _ := CompileRegex("pattern1")
|
||||
re2, _ := CompileRegex("pattern2")
|
||||
|
||||
if re1 == re2 {
|
||||
t.Error("different patterns should produce different regex instances")
|
||||
}
|
||||
|
||||
if stats := CacheStats(); stats != 2 {
|
||||
t.Errorf("expected cache size 2, got %d", stats)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("regex matches correctly", func(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
re, err := CompileRegex("^hello\\s+world$")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !re.MatchString("hello world") {
|
||||
t.Error("expected match for 'hello world'")
|
||||
}
|
||||
if !re.MatchString("hello world") {
|
||||
t.Error("expected match for 'hello world'")
|
||||
}
|
||||
if re.MatchString("helloworld") {
|
||||
t.Error("unexpected match for 'helloworld'")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCompileRegexUncached(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
t.Run("valid pattern compiles without caching", func(t *testing.T) {
|
||||
initialSize := CacheStats()
|
||||
|
||||
re, err := CompileRegexUncached("^uncached.*pattern$")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if re == nil {
|
||||
t.Fatal("expected non-nil regex")
|
||||
}
|
||||
|
||||
// Cache size should not change
|
||||
if stats := CacheStats(); stats != initialSize {
|
||||
t.Errorf("cache size changed from %d to %d", initialSize, stats)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid pattern returns error", func(t *testing.T) {
|
||||
_, err := CompileRegexUncached("[invalid")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid regex")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pattern too long returns error", func(t *testing.T) {
|
||||
pattern := strings.Repeat("x", MaxPatternLength+1)
|
||||
_, err := CompileRegexUncached(pattern)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for long pattern")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearRegexCache(t *testing.T) {
|
||||
// Add some patterns
|
||||
_, _ = CompileRegex("pattern1")
|
||||
_, _ = CompileRegex("pattern2")
|
||||
_, _ = CompileRegex("pattern3")
|
||||
|
||||
if stats := CacheStats(); stats < 3 {
|
||||
t.Fatalf("expected at least 3 cached patterns, got %d", stats)
|
||||
}
|
||||
|
||||
ClearRegexCache()
|
||||
|
||||
if stats := CacheStats(); stats != 0 {
|
||||
t.Errorf("expected cache size 0 after clear, got %d", stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheStats(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
if stats := CacheStats(); stats != 0 {
|
||||
t.Errorf("expected initial cache size 0, got %d", stats)
|
||||
}
|
||||
|
||||
_, _ = CompileRegex("a")
|
||||
if stats := CacheStats(); stats != 1 {
|
||||
t.Errorf("expected cache size 1, got %d", stats)
|
||||
}
|
||||
|
||||
_, _ = CompileRegex("b")
|
||||
if stats := CacheStats(); stats != 2 {
|
||||
t.Errorf("expected cache size 2, got %d", stats)
|
||||
}
|
||||
|
||||
// Same pattern should not increase cache size
|
||||
_, _ = CompileRegex("a")
|
||||
if stats := CacheStats(); stats != 2 {
|
||||
t.Errorf("expected cache size 2 after duplicate, got %d", stats)
|
||||
}
|
||||
}
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
ClearRegexCache()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
numPatterns := 10
|
||||
|
||||
// Generate some patterns
|
||||
patterns := make([]string, numPatterns)
|
||||
for i := range patterns {
|
||||
patterns[i] = strings.Repeat("p", i+1)
|
||||
}
|
||||
|
||||
// Concurrent compilation of same patterns
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
pattern := patterns[id%numPatterns]
|
||||
re, err := CompileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Errorf("goroutine %d: unexpected error: %v", id, err)
|
||||
return
|
||||
}
|
||||
if re == nil {
|
||||
t.Errorf("goroutine %d: nil regex returned", id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should have exactly numPatterns cached
|
||||
if stats := CacheStats(); stats != int64(numPatterns) {
|
||||
t.Errorf("expected cache size %d, got %d", numPatterns, stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexError(t *testing.T) {
|
||||
t.Run("error message with underlying error", func(t *testing.T) {
|
||||
underlying := errors.New("underlying error")
|
||||
err := &RegexError{
|
||||
Pattern: "test.*",
|
||||
Reason: "test reason",
|
||||
Err: underlying,
|
||||
}
|
||||
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "test.*") {
|
||||
t.Error("error message should contain pattern")
|
||||
}
|
||||
if !strings.Contains(msg, "test reason") {
|
||||
t.Error("error message should contain reason")
|
||||
}
|
||||
if !strings.Contains(msg, "underlying error") {
|
||||
t.Error("error message should contain underlying error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error message without underlying error", func(t *testing.T) {
|
||||
err := &RegexError{
|
||||
Pattern: "test.*",
|
||||
Reason: "test reason",
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "test.*") {
|
||||
t.Error("error message should contain pattern")
|
||||
}
|
||||
if !strings.Contains(msg, "test reason") {
|
||||
t.Error("error message should contain reason")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error unwrap", func(t *testing.T) {
|
||||
underlying := errors.New("underlying")
|
||||
err := &RegexError{
|
||||
Pattern: "test",
|
||||
Reason: "reason",
|
||||
Err: underlying,
|
||||
}
|
||||
|
||||
if errors.Unwrap(err) != underlying {
|
||||
t.Error("Unwrap should return underlying error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTruncatePattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "short pattern unchanged",
|
||||
input: "short",
|
||||
expected: "short",
|
||||
},
|
||||
{
|
||||
name: "exactly 50 chars unchanged",
|
||||
input: strings.Repeat("x", 50),
|
||||
expected: strings.Repeat("x", 50),
|
||||
},
|
||||
{
|
||||
name: "long pattern truncated",
|
||||
input: strings.Repeat("x", 60),
|
||||
expected: strings.Repeat("x", 47) + "...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := truncatePattern(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("truncatePattern() = %q (len %d), want %q (len %d)",
|
||||
got, len(got), tt.expected, len(tt.expected))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegex benchmarks regex compilation with caching
|
||||
func BenchmarkCompileRegex(b *testing.B) {
|
||||
ClearRegexCache()
|
||||
pattern := "^test.*pattern\\d+$"
|
||||
|
||||
// First call to populate cache
|
||||
_, _ = CompileRegex(pattern)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = CompileRegex(pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegexUncached benchmarks regex compilation without caching
|
||||
func BenchmarkCompileRegexUncached(b *testing.B) {
|
||||
pattern := "^test.*pattern\\d+$"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = CompileRegexUncached(pattern)
|
||||
}
|
||||
}
|
||||
@@ -120,6 +120,42 @@ func (e *StructuredError) WithRemediation(msg string) *StructuredError {
|
||||
return e
|
||||
}
|
||||
|
||||
// UserMessage returns a user-safe error message without internal details.
|
||||
// This should be used for API responses to avoid leaking implementation details.
|
||||
func (e *StructuredError) UserMessage() string {
|
||||
// Return only the message and remediation, no stack trace, no cause details
|
||||
if e.Remediation != "" {
|
||||
return fmt.Sprintf("%s. %s", e.Message, e.Remediation)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// SanitizeError returns a user-safe error message from any error.
|
||||
// For StructuredError, it returns UserMessage().
|
||||
// For other errors, it returns a generic message to avoid leaking internals.
|
||||
func SanitizeError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if it's a StructuredError
|
||||
if se, ok := err.(*StructuredError); ok {
|
||||
return se.UserMessage()
|
||||
}
|
||||
|
||||
// For other errors, extract only the basic message
|
||||
// Avoid exposing full paths or implementation details
|
||||
msg := err.Error()
|
||||
|
||||
// Truncate very long messages that might contain stack traces
|
||||
const maxLen = 200
|
||||
if len(msg) > maxLen {
|
||||
msg = msg[:maxLen] + "..."
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// New creates a new structured error with stack trace.
|
||||
func New(code ErrorCode, message string) *StructuredError {
|
||||
return &StructuredError{
|
||||
|
||||
+3
-13
@@ -1,6 +1,8 @@
|
||||
// Package protocol defines shared types used across the MCP file operations server.
|
||||
package protocol
|
||||
|
||||
import "path/filepath"
|
||||
|
||||
// Location represents a position in a file.
|
||||
type Location struct {
|
||||
File string `json:"file"`
|
||||
@@ -66,7 +68,7 @@ const (
|
||||
|
||||
// DetectLanguage detects the language from a filename.
|
||||
func DetectLanguage(filename string) Language {
|
||||
ext := getExtension(filename)
|
||||
ext := filepath.Ext(filename)
|
||||
switch ext {
|
||||
case ".go":
|
||||
return LangGo
|
||||
@@ -94,15 +96,3 @@ func DetectLanguage(filename string) Language {
|
||||
return LangUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func getExtension(filename string) string {
|
||||
for i := len(filename) - 1; i >= 0; i-- {
|
||||
if filename[i] == '.' {
|
||||
return filename[i:]
|
||||
}
|
||||
if filename[i] == '/' || filename[i] == '\\' {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -48,26 +48,3 @@ func TestDetectLanguage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
expected string
|
||||
}{
|
||||
{"file.go", ".go"},
|
||||
{"file.test.go", ".go"},
|
||||
{"path/to/file.ts", ".ts"},
|
||||
{"noextension", ""},
|
||||
{".hidden", ".hidden"},
|
||||
{"file.", "."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filename, func(t *testing.T) {
|
||||
result := getExtension(tt.filename)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getExtension(%q) = %q, want %q", tt.filename, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user