From ccfbdc513f4512d75ca71d8490f5948afc89f36c Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Thu, 12 Mar 2026 19:21:13 +0000 Subject: [PATCH] fixup! fixup! fixup! fixup! Add Docker usage instructions to README --- internal/parser/symbols.go | 69 ++++++++++++ internal/server/handlers_edit.go | 43 ++++---- internal/server/handlers_file.go | 177 ++++++++++++++++++++++--------- internal/server/server.go | 49 +++++++-- 4 files changed, 258 insertions(+), 80 deletions(-) diff --git a/internal/parser/symbols.go b/internal/parser/symbols.go index e97516c..494ab20 100644 --- a/internal/parser/symbols.go +++ b/internal/parser/symbols.go @@ -878,3 +878,72 @@ func extractElixirImpl(n *sitter.Node, content []byte, filename string) *protoco Location: NodeLocation(n, filename), } } + +// kindMatchesNode returns true if the given SymbolKind matches the node type. +// An empty kind matches all symbol-bearing nodes. +func kindMatchesNode(kind protocol.SymbolKind, nodeType string) bool { + if kind == "" { + return true + } + switch kind { + case protocol.SymbolFunction: + return nodeType == "function_declaration" || nodeType == "function_definition" || nodeType == "function_item" + case protocol.SymbolMethod: + return nodeType == "method_declaration" || nodeType == "method_definition" + case protocol.SymbolClass: + return nodeType == "class_declaration" || nodeType == "class_definition" + case protocol.SymbolStruct: + return nodeType == "struct_item" || nodeType == "type_declaration" + case protocol.SymbolInterface: + return nodeType == "interface_declaration" + case protocol.SymbolType: + return nodeType == "type_declaration" || nodeType == "type_alias_declaration" || nodeType == "type_item" + case protocol.SymbolEnum: + return nodeType == "enum_item" + case protocol.SymbolTrait: + return nodeType == "trait_item" + case protocol.SymbolConstant: + return nodeType == "const_item" + case protocol.SymbolModule: + return nodeType == "mod_item" + } + return true +} + +// FindSymbolRange finds a named symbol in the AST and returns its line range (1-indexed, inclusive). +// symbolKind filters by symbol kind (e.g. "function", "struct"); empty string matches any kind. +// Returns (0, 0, false) if the symbol is not found. +func FindSymbolRange(tree *sitter.Tree, content []byte, filename, symbolName string, symbolKind protocol.SymbolKind) (startLine, endLine int, found bool) { + if tree == nil || symbolName == "" { + return + } + WalkTree(tree.RootNode(), func(n *sitter.Node) bool { + if found { + return false + } + nameNode := n.ChildByFieldName("name") + if nameNode == nil { + return true + } + if GetNodeText(nameNode, content) != symbolName { + return true + } + switch n.Type() { + case "function_declaration", "method_declaration", "type_declaration", + "function_definition", "class_definition", "class_declaration", + "interface_declaration", "type_alias_declaration", + "function_item", "struct_item", "enum_item", "trait_item", + "type_item", "const_item", "mod_item": + if !kindMatchesNode(symbolKind, n.Type()) { + return true // kind mismatch, keep searching + } + r := NodeRange(n, filename) + startLine = r.Start.Line + endLine = r.End.Line + found = true + return false + } + return true + }) + return +} diff --git a/internal/server/handlers_edit.go b/internal/server/handlers_edit.go index c14e241..d067915 100644 --- a/internal/server/handlers_edit.go +++ b/internal/server/handlers_edit.go @@ -4,17 +4,18 @@ package server import ( "context" "fmt" + "os" "strings" "github.com/lukaszraczylo/mcp-filepuff/internal/edit" "github.com/lukaszraczylo/mcp-filepuff/pkg/errors" + "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" "github.com/mark3labs/mcp-go/mcp" ) // unescapeNewlines converts literal \n, \t, \" sequences to actual characters. // This handles cases where MCP clients send double-escaped JSON strings. func unescapeNewlines(s string) string { - // Replace common escape sequences s = strings.ReplaceAll(s, "\\n", "\n") s = strings.ReplaceAll(s, "\\t", "\t") s = strings.ReplaceAll(s, "\\\"", "\"") @@ -39,7 +40,6 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest) (* return mcp.NewToolResultError("operation is required"), nil } - // Validate operation against known values switch edit.EditOperation(operation) { case edit.EditReplace, edit.EditInsertBefore, edit.EditInsertAfter, edit.EditDelete: // valid @@ -49,40 +49,31 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest) (* )), nil } - // Validate path if !s.cfg.IsPathAllowed(file) { return mcp.NewToolResultError("file is outside workspace root"), nil } - // Note: We no longer validate language support here. - // The edit engine automatically detects whether to use AST or text mode. - - // Build edit request with both AST and text-mode selectors newContent := request.GetString("new_content", "") - - // Unescape common escape sequences that may be double-encoded by MCP clients newContent = unescapeNewlines(newContent) + selectorName := request.GetString("selector_name", "") + astEdit := &edit.ASTEdit{ File: file, Operation: edit.EditOperation(operation), NewContent: newContent, Selector: edit.ASTSelector{ - // AST-mode selectors - Kind: request.GetString("selector_kind", ""), - Name: request.GetString("selector_name", ""), - AtLine: request.GetInt("selector_line", 0), - Index: request.GetInt("selector_index", 0), - // Text-mode selectors + Kind: request.GetString("selector_kind", ""), + Name: selectorName, + AtLine: request.GetInt("selector_line", 0), + Index: request.GetInt("selector_index", 0), LineEnd: request.GetInt("selector_line_end", 0), Text: request.GetString("selector_text", ""), TextPattern: request.GetString("selector_pattern", ""), }, } - // Perform edit (always apply) result, err := s.editor.Apply(ctx, astEdit) - if err != nil { return mcp.NewToolResultError(fmt.Sprintf("edit failed: %s", errors.SanitizeError(err))), nil } @@ -91,10 +82,24 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest) (* return mcp.NewToolResultError(result.Error), nil } - // Format output + // compact_response: return just the modified symbol instead of the full diff + if request.GetBool("compact_response", false) && selectorName != "" { + if content, readErr := os.ReadFile(file); readErr == nil { + if start, end, found := s.resolveSymbolLines(ctx, file, content, selectorName, protocol.SymbolKind("")); found { + lines := splitLines(string(content)) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("**Edit Applied** — %s (L%d-L%d):\n\n", selectorName, start, end)) + for i := start - 1; i < end && i < len(lines); i++ { + sb.WriteString(fmt.Sprintf("%4d| %s\n", i+1, lines[i])) + } + return mcp.NewToolResultText(sb.String()), nil + } + } + // fall through to diff if symbol lookup fails + } + var output strings.Builder output.WriteString("**Edit Applied Successfully**\n\n") - output.WriteString("Diff:\n```diff\n") output.WriteString(result.Diff) output.WriteString("```\n") diff --git a/internal/server/handlers_file.go b/internal/server/handlers_file.go index d7df5b9..a11c8c3 100644 --- a/internal/server/handlers_file.go +++ b/internal/server/handlers_file.go @@ -9,8 +9,11 @@ import ( "strings" "time" + xxhash "github.com/cespare/xxhash/v2" + "github.com/lukaszraczylo/mcp-filepuff/internal/parser" "github.com/lukaszraczylo/mcp-filepuff/internal/search" "github.com/lukaszraczylo/mcp-filepuff/pkg/errors" + "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" "github.com/mark3labs/mcp-go/mcp" ) @@ -27,7 +30,6 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil } - // Parse request arguments using SDK helpers pattern, err := request.RequireString("pattern") if err != nil { return mcp.NewToolResultError("pattern is required"), nil @@ -43,7 +45,6 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque MaxResults: request.GetInt("max_results", 0), } - // Execute search results, err := s.searcher.Search(ctx, req) if err != nil { s.logger.Warn("search error", "error", err) @@ -56,14 +57,12 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque "truncated", results.Truncated, ) - // Format results output := s.searcher.FormatResults(results) return mcp.NewToolResultText(output), nil } // handleFileRead handles the file_read tool. func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // Acquire semaphore to limit concurrent reads (prevents memory exhaustion) select { case s.readSem <- struct{}{}: defer func() { <-s.readSem }() @@ -71,47 +70,104 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest return mcp.NewToolResultError("request cancelled"), nil } - path, err := request.RequireString("path") + // Batch mode: paths[] takes precedence over path + if paths := request.GetStringSlice("paths", nil); len(paths) > 0 { + var output strings.Builder + for i, p := range paths { + if i > 0 { + output.WriteString("\n") + } + result, err := s.readOneFile(ctx, request, p) + if err != nil { + output.WriteString(fmt.Sprintf("--- %s ---\n[error: %s]\n", p, errors.SanitizeError(err))) + continue + } + output.WriteString(fmt.Sprintf("--- %s ---\n%s", p, result)) + } + return mcp.NewToolResultText(output.String()), nil + } + + path := request.GetString("path", "") + if path == "" { + return mcp.NewToolResultError("path or paths is required"), nil + } + + result, err := s.readOneFile(ctx, request, path) if err != nil { - return mcp.NewToolResultError("path is required"), nil + return mcp.NewToolResultError(errors.SanitizeError(err)), nil } + return mcp.NewToolResultText(result), nil +} - // Validate path is within workspace +// readOneFile reads a single file applying all formatting options from the request. +func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, path string) (string, error) { if !s.cfg.IsPathAllowed(path) { - return mcp.NewToolResultError("path is outside workspace root"), nil + return "", fmt.Errorf("path is outside workspace root") } - // Check file size before reading to avoid loading huge files into memory info, err := os.Stat(path) if err != nil { if os.IsNotExist(err) { - return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil + return "", fmt.Errorf("file not found: %s", path) } if os.IsPermission(err) { - return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil + return "", fmt.Errorf("permission denied: %s", path) } s.logger.Warn("file stat error", "path", path, "error", err) - return mcp.NewToolResultError("error accessing file"), nil + return "", fmt.Errorf("error accessing file") } if info.Size() > s.cfg.MaxFileSize { - return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)), nil + return "", fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize) } - // Read file content, err := os.ReadFile(path) if err != nil { if os.IsPermission(err) { - return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil + return "", fmt.Errorf("permission denied: %s", path) } s.logger.Warn("file read error", "path", path, "error", err) - return mcp.NewToolResultError("error reading file"), nil + return "", fmt.Errorf("error reading file") + } + + // Compute etag from content hash + etag := fmt.Sprintf("%016x", xxhash.Sum64(content)) + + // Short-circuit if caller has the current version + if prev := request.GetString("previous_etag", ""); prev != "" && prev == etag { + return fmt.Sprintf("[unchanged, etag: %s]\n", etag), nil + } + + // Parse request options + includeAST := request.GetBool("include_ast", false) + symbolsOnly := request.GetBool("symbols_only", false) + symbolName := request.GetString("symbol_name", "") + noLineNumbers := request.GetBool("no_line_numbers", false) + lineInterval := request.GetInt("line_number_interval", 1) + collapseBlank := request.GetBool("collapse_blank_lines", false) + maxLines := request.GetInt("max_lines", 0) + + if lineInterval == 0 { + noLineNumbers = true + } + if symbolsOnly && !includeAST { + return "", fmt.Errorf("symbols_only requires include_ast=true") } - // Handle line range lines := splitLines(string(content)) lineStart := request.GetInt("line_start", 1) lineEnd := request.GetInt("line_end", len(lines)) + // Symbol-based line range: find the symbol and use its exact bounds + if symbolName != "" { + symbolKind := protocol.SymbolKind(request.GetString("symbol_kind", "")) + start, end, found := s.resolveSymbolLines(ctx, path, content, symbolName, symbolKind) + if !found { + return "", fmt.Errorf("symbol %q not found in %s", symbolName, path) + } + lineStart = start + lineEnd = end + } + // Clamp to valid range if lineStart < 1 { lineStart = 1 @@ -125,47 +181,68 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest var output strings.Builder - // Include AST summary if requested - includeAST := request.GetBool("include_ast", false) - symbolsOnly := request.GetBool("symbols_only", false) - maxLines := request.GetInt("max_lines", 0) - - // Validate symbols_only requires include_ast - if symbolsOnly && !includeAST { - return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil - } - if includeAST { - astSummary := s.generateASTSummary(ctx, path, content) - if astSummary != "" { - output.WriteString(astSummary) + if summary := s.generateASTSummary(ctx, path, content); summary != "" { + output.WriteString(summary) if !symbolsOnly { output.WriteString("\n---\n\n") } } } - // Skip file content if symbols_only mode - if !symbolsOnly { - // Apply max_lines limit if specified - effectiveEnd := lineEnd - if maxLines > 0 && (lineEnd-lineStart+1) > maxLines { - effectiveEnd = lineStart + maxLines - 1 - if effectiveEnd < lineEnd { - // Add note that output was truncated - defer func() { - output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd)) - }() - } - } + if symbolsOnly { + output.WriteString(fmt.Sprintf("[etag: %s]\n", etag)) + return output.String(), nil + } - // Extract requested lines - for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ { - output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i])) + writeLines(&output, lines, lineStart, lineEnd, maxLines, noLineNumbers, lineInterval, collapseBlank) + + output.WriteString(fmt.Sprintf("[etag: %s]\n", etag)) + return output.String(), nil +} + +// resolveSymbolLines parses the AST and returns the line range of the named symbol. +// symbolKind optionally filters by kind (empty = any). +func (s *Server) resolveSymbolLines(ctx context.Context, path string, content []byte, symbolName string, symbolKind protocol.SymbolKind) (startLine, endLine int, found bool) { + result, err := s.parser.Parse(ctx, path, content) + if err != nil { + return + } + return parser.FindSymbolRange(result.Tree, content, path, symbolName, symbolKind) +} + +// writeLines writes the selected line range into output, applying all formatting options. +func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, maxLines int, noLineNumbers bool, lineInterval int, collapseBlank bool) { + effectiveEnd := lineEnd + truncatedCount := 0 + if maxLines > 0 && (lineEnd-lineStart+1) > maxLines { + effectiveEnd = lineStart + maxLines - 1 + truncatedCount = lineEnd - effectiveEnd + } + + prevBlank := false + for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ { + line := lines[i] + isBlank := strings.TrimSpace(line) == "" + if collapseBlank && isBlank && prevBlank { + continue + } + prevBlank = isBlank + + lineNum := i + 1 + switch { + case noLineNumbers: + output.WriteString(line + "\n") + case lineInterval <= 1 || lineNum%lineInterval == 0 || i == lineStart-1 || i == effectiveEnd-1: + fmt.Fprintf(output, "%4d│ %s\n", lineNum, line) + default: + fmt.Fprintf(output, " │ %s\n", line) } } - return mcp.NewToolResultText(output.String()), nil + if truncatedCount > 0 { + fmt.Fprintf(output, "\n[... %d more lines omitted. Use line_start/line_end or increase max_lines to see more]\n", truncatedCount) + } } // splitLines splits a string into lines. @@ -175,24 +252,20 @@ func splitLines(s string) []string { const largeSizeThreshold = 1024 * 1024 // 1MB if len(s) > largeSizeThreshold { - // Use scanner for large files with increased buffer for long lines scanner := bufio.NewScanner(strings.NewReader(s)) - scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), 1024*1024) // up to 1MB per line + scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), 1024*1024) var lines []string for scanner.Scan() { lines = append(lines, scanner.Text()) } if err := scanner.Err(); err != nil { - // If scanning fails (e.g. line exceeds buffer), fall back to strings.Split return strings.Split(s, "\n") } - // Add empty line if string ended with newline if len(s) > 0 && s[len(s)-1] == '\n' { lines = append(lines, "") } return lines } - // Use optimized stdlib implementation for smaller files (2-3x faster than manual loop) return strings.Split(s, "\n") } diff --git a/internal/server/server.go b/internal/server/server.go index 80c038c..de4d913 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -138,19 +138,38 @@ func (s *Server) registerTools() { s.mcp.AddTool( mcp.NewTool("file_read", mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary.\n\n"+ - "Returns: File content with numbered lines (format: \" 12│ line text\"). "+ - "When include_ast=true: prepends symbol summary (\"**file.go** (N lines, go)\\nSymbols:\\n func Name L12\\n struct Config L45\"). "+ - "When symbols_only=true: returns only the symbol summary (~95% fewer tokens). "+ - "When max_lines is set: truncates output with \"[... N more lines omitted]\" notice.\n\n"+ + "Token-saving features:\n"+ + " previous_etag: skip re-reading unchanged files (returns '[unchanged, etag: ...]' if unchanged)\n"+ + " symbol_name: read only a named function/struct/class (eliminates ast_query round-trip)\n"+ + " symbols_only=true: return only symbol list, ~95% fewer tokens (requires include_ast=true)\n"+ + " no_line_numbers=true: strip the line-number prefix (~10%% savings)\n"+ + " line_number_interval=N: print line numbers only every N lines\n"+ + " collapse_blank_lines=true: collapse consecutive blank lines to one\n"+ + " max_lines=N: truncate output with omitted count notice\n"+ + " paths=[...]: read multiple files in one call\n\n"+ + "All responses include '[etag: hex]' footer for use as previous_etag in subsequent reads.\n\n"+ "Examples:\n"+ - " Full file: {\"path\": \"main.go\"}\n"+ - " With AST: {\"path\": \"main.go\", \"include_ast\": true}\n"+ - " Symbols only: {\"path\": \"main.go\", \"include_ast\": true, \"symbols_only\": true}\n"+ + " Full file: {\"path\": \"main.go\"}\n"+ + " Etag check: {\"path\": \"main.go\", \"previous_etag\": \"a3f9c2b1\"}\n"+ + " By symbol: {\"path\": \"server.go\", \"symbol_name\": \"handleFileRead\"}\n"+ + " Batch: {\"paths\": [\"a.go\", \"b.go\"]}\n"+ " Line range: {\"path\": \"main.go\", \"line_start\": 10, \"line_end\": 50}"), mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("path", - mcp.Required(), - mcp.Description("Path to the file to read"), + mcp.Description("Path to the file to read (required unless paths is provided)"), + ), + mcp.WithArray("paths", + mcp.Description("Read multiple files in one call. Each file gets a '--- path ---' header. Overrides path if both provided."), + mcp.WithStringItems(), + ), + mcp.WithString("previous_etag", + mcp.Description("Etag from a previous read of this file. If the file is unchanged, returns '[unchanged, etag: ...]' with no content — saving all content tokens."), + ), + mcp.WithString("symbol_name", + mcp.Description("Read only the named symbol (function, struct, class, etc.) instead of the whole file. Resolves line range via AST — eliminates an ast_query round-trip."), + ), + mcp.WithString("symbol_kind", + mcp.Description("Disambiguate symbol_name by kind when multiple symbols share the same name. Accepted values: function, method, struct, class, interface, type, enum, trait, constant, module."), ), mcp.WithNumber("line_start", mcp.Description("Starting line number (1-indexed)"), @@ -167,6 +186,15 @@ func (s *Server) registerTools() { mcp.WithNumber("max_lines", mcp.Description("Maximum number of lines to return (for token efficiency). Applied after line_start/line_end."), ), + mcp.WithBoolean("no_line_numbers", + mcp.Description("Omit the ' 12│ ' line number prefix entirely. Saves ~10% tokens. line_number_interval=0 has the same effect."), + ), + mcp.WithNumber("line_number_interval", + mcp.Description("Print line numbers only every N lines (default: 1 = every line). E.g. 10 = anchor every 10th line plus first/last. 0 = no line numbers."), + ), + mcp.WithBoolean("collapse_blank_lines", + mcp.Description("Collapse runs of consecutive blank lines to a single blank line. Useful for token savings on heavily-spaced code."), + ), ), s.handleFileRead, ) @@ -334,6 +362,9 @@ func (s *Server) registerTools() { mcp.WithString("selector_pattern", mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."), ), + mcp.WithBoolean("compact_response", + mcp.Description("Return only the modified symbol's content instead of a full diff. Requires selector_name. Saves tokens on large-file edits."), + ), ), s.handleEditApply, )