mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-10 22:59:01 +00:00
V2/token optimization (#11)
* v2.0: token-optimization overhaul Additive (backward-compatible flags): - file_read: skeleton mode, strip (imports/license/block_comments), compact_line_numbers, 8-char etag with prefix-match compat - ast_query: format=verbose|compact|location, pagination cursor - file_search: cluster mode, pagination cursor - lsp_query (references): compact output Breaking (v2): - Preambles removed; opt-in verbose=true restores - edit_apply: response=count|diff|none, default count - ping tool removed - symbol_at/find_definition/find_references merged into lsp_query - Tool descriptions trimmed -83%, help moved to filepuff://help/<tool> - Batch file_read dedups by etag Protocol: - ResourceLink returned for file_read >64 KiB (force_inline override) - OnAfterInitialize hook reads capabilities.experimental.filepuff for session defaults (default_format, default_max_results, default_cluster, compact_refs, line_numbers, resource_link_threshold) * fix: drop --max-total-count from ripgrep args The flag does not exist in stable ripgrep (confirmed up to 15.1.0 -- "unrecognized flag --max-total-count, similar flags that are available: --max-count"). Every file_search call failed on hosts with stock rg. --max-count is per-file, not a drop-in replacement, so rely on the in-process truncation in parseOutput that was already the documented safety net.
This commit is contained in:
+31
-28
@@ -13,43 +13,46 @@ import (
|
||||
|
||||
// Config holds all configuration options for the MCP server.
|
||||
type Config struct {
|
||||
Formatters map[string]string `json:"formatters"`
|
||||
WorkspaceRoot string `json:"workspace_root"`
|
||||
LSPTimeout time.Duration `json:"lsp_timeout"`
|
||||
SearchTimeout time.Duration `json:"search_timeout"`
|
||||
MaxFileSize int64 `json:"max_file_size"`
|
||||
MaxParseSize int64 `json:"max_parse_size"`
|
||||
MaxSearchResults int `json:"max_search_results"`
|
||||
MaxEditSize int64 `json:"max_edit_size"`
|
||||
EnableLSP bool `json:"enable_lsp"`
|
||||
FollowSymlinks bool `json:"follow_symlinks"`
|
||||
RespectGitignore bool `json:"respect_gitignore"`
|
||||
Formatters map[string]string `json:"formatters"`
|
||||
WorkspaceRoot string `json:"workspace_root"`
|
||||
LSPTimeout time.Duration `json:"lsp_timeout"`
|
||||
SearchTimeout time.Duration `json:"search_timeout"`
|
||||
MaxFileSize int64 `json:"max_file_size"`
|
||||
MaxParseSize int64 `json:"max_parse_size"`
|
||||
MaxSearchResults int `json:"max_search_results"`
|
||||
MaxEditSize int64 `json:"max_edit_size"`
|
||||
EnableLSP bool `json:"enable_lsp"`
|
||||
FollowSymlinks bool `json:"follow_symlinks"`
|
||||
RespectGitignore bool `json:"respect_gitignore"`
|
||||
ResourceLinkThresholdBytes int `json:"resource_link_threshold_bytes"`
|
||||
}
|
||||
|
||||
// Default values for configuration.
|
||||
const (
|
||||
DefaultLSPTimeout = 5 * time.Minute
|
||||
DefaultSearchTimeout = 30 * time.Second
|
||||
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
|
||||
DefaultMaxParseSize = 10 * 1024 * 1024 // 10 MB
|
||||
DefaultMaxSearchResults = 1000
|
||||
DefaultMaxEditSize = 100 * 1024 // 100 KB
|
||||
DefaultLSPTimeout = 5 * time.Minute
|
||||
DefaultSearchTimeout = 30 * time.Second
|
||||
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
|
||||
DefaultMaxParseSize = 10 * 1024 * 1024 // 10 MB
|
||||
DefaultMaxSearchResults = 1000
|
||||
DefaultMaxEditSize = 100 * 1024 // 100 KB
|
||||
DefaultResourceLinkThresholdBytes = 64 * 1024 // 64 KiB
|
||||
)
|
||||
|
||||
// Default returns a Config with default values.
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
WorkspaceRoot: ".",
|
||||
LSPTimeout: DefaultLSPTimeout,
|
||||
SearchTimeout: DefaultSearchTimeout,
|
||||
MaxFileSize: DefaultMaxFileSize,
|
||||
MaxParseSize: DefaultMaxParseSize,
|
||||
MaxSearchResults: DefaultMaxSearchResults,
|
||||
MaxEditSize: DefaultMaxEditSize,
|
||||
EnableLSP: true,
|
||||
Formatters: make(map[string]string),
|
||||
FollowSymlinks: true,
|
||||
RespectGitignore: true,
|
||||
WorkspaceRoot: ".",
|
||||
LSPTimeout: DefaultLSPTimeout,
|
||||
SearchTimeout: DefaultSearchTimeout,
|
||||
MaxFileSize: DefaultMaxFileSize,
|
||||
MaxParseSize: DefaultMaxParseSize,
|
||||
MaxSearchResults: DefaultMaxSearchResults,
|
||||
MaxEditSize: DefaultMaxEditSize,
|
||||
EnableLSP: true,
|
||||
Formatters: make(map[string]string),
|
||||
FollowSymlinks: true,
|
||||
RespectGitignore: true,
|
||||
ResourceLinkThresholdBytes: DefaultResourceLinkThresholdBytes,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// Package cursor implements opaque pagination cursors for MCP tools.
|
||||
// A cursor encodes an offset into a result stream plus a query hash so stale
|
||||
// cursors from different queries fail cleanly.
|
||||
//
|
||||
// Encoding: base64url(json({"offset":N,"query_hash":"hex"}))
|
||||
// The query_hash is a hex-encoded sha256 over the deterministic query params.
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
// payload is the JSON structure inside a cursor.
|
||||
type payload struct {
|
||||
Offset int `json:"offset"`
|
||||
QueryHash string `json:"query_hash"`
|
||||
}
|
||||
|
||||
// Encode creates an opaque cursor string from an offset and query hash.
|
||||
func Encode(offset int, queryHash string) string {
|
||||
p := payload{Offset: offset, QueryHash: queryHash}
|
||||
b, _ := json.Marshal(p)
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// Decode parses a cursor string. Returns offset, queryHash, error.
|
||||
func Decode(cursor string) (int, string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(cursor)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("invalid cursor encoding: %w", err)
|
||||
}
|
||||
var p payload
|
||||
if err := json.Unmarshal(b, &p); err != nil {
|
||||
return 0, "", fmt.Errorf("invalid cursor payload: %w", err)
|
||||
}
|
||||
return p.Offset, p.QueryHash, nil
|
||||
}
|
||||
|
||||
// HashParams computes a deterministic query hash from a set of key=value params.
|
||||
// Keys are sorted before hashing so order doesn't matter.
|
||||
func HashParams(params map[string]string) string {
|
||||
keys := make([]string, 0, len(params))
|
||||
for k := range params {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var sb strings.Builder
|
||||
for _, k := range keys {
|
||||
sb.WriteString(k)
|
||||
sb.WriteByte('=')
|
||||
sb.WriteString(params[k])
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(sb.String()))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package cursor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncodeDecodRoundTrip(t *testing.T) {
|
||||
hash := HashParams(map[string]string{"a": "1", "b": "2"})
|
||||
|
||||
encoded := Encode(42, hash)
|
||||
if encoded == "" {
|
||||
t.Fatal("Encode returned empty string")
|
||||
}
|
||||
|
||||
offset, gotHash, err := Decode(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode error: %v", err)
|
||||
}
|
||||
if offset != 42 {
|
||||
t.Errorf("offset: got %d, want 42", offset)
|
||||
}
|
||||
if gotHash != hash {
|
||||
t.Errorf("hash mismatch: got %s, want %s", gotHash, hash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeInvalid(t *testing.T) {
|
||||
_, _, err := Decode("!!!notbase64!!!")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid base64, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeCorruptPayload(t *testing.T) {
|
||||
import64 := "bm90anNvbg" // "notjson" in base64
|
||||
_, _, err := Decode(import64)
|
||||
if err == nil {
|
||||
t.Error("expected error for corrupt payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashParamsDeterministic(t *testing.T) {
|
||||
// Same params regardless of insertion order
|
||||
h1 := HashParams(map[string]string{"z": "last", "a": "first"})
|
||||
h2 := HashParams(map[string]string{"a": "first", "z": "last"})
|
||||
if h1 != h2 {
|
||||
t.Errorf("hash not deterministic: %s != %s", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashParamsDifferentForDifferentQueries(t *testing.T) {
|
||||
h1 := HashParams(map[string]string{"pattern": "foo"})
|
||||
h2 := HashParams(map[string]string{"pattern": "bar"})
|
||||
if h1 == h2 {
|
||||
t.Error("different queries should produce different hashes")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
// Package parser provides skeleton rendering for source files.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// SkeletonFile returns a skeleton representation of the file: top-level declarations
|
||||
// with signatures and doc-comments intact, but function/method bodies replaced with
|
||||
// a language-appropriate placeholder.
|
||||
//
|
||||
// Supported: go, typescript, javascript, python, rust.
|
||||
// Other languages fall back to symbols_only (AST summary text).
|
||||
// Returns (skeletonText, isFullSkeleton, error).
|
||||
func SkeletonFile(ctx context.Context, reg *Registry, filename string, content []byte) (string, bool, error) {
|
||||
result, err := reg.Parse(ctx, filename, content)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
lang := protocol.DetectLanguage(filename)
|
||||
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return skeletonGo(result.Tree, content), true, nil
|
||||
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||
return skeletonTS(result.Tree, content), true, nil
|
||||
case protocol.LangPython:
|
||||
return skeletonPython(result.Tree, content), true, nil
|
||||
case protocol.LangRust:
|
||||
return skeletonRust(result.Tree, content), true, nil
|
||||
default:
|
||||
// TODO: skeleton for c, cpp, elixir, html, vue — fall back to symbols_only
|
||||
syms := ExtractSymbols(result.Tree, content, lang, filename)
|
||||
return renderSymbolsOnly(syms, filename, lang, content), false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// renderSymbolsOnly renders a simple symbol list (fallback for unsupported languages).
|
||||
func renderSymbolsOnly(syms []protocol.Symbol, _ string, lang protocol.Language, content []byte) string {
|
||||
lines := strings.Split(string(content), "\n")
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("// skeleton unavailable for %s — symbol list only\n", lang))
|
||||
for _, s := range syms {
|
||||
sb.WriteString(fmt.Sprintf("// %s %s (line %d)\n", s.Kind, s.Name, s.Location.Line))
|
||||
if s.Doc != "" {
|
||||
sb.WriteString(fmt.Sprintf("// doc: %s\n", s.Doc))
|
||||
}
|
||||
if s.Location.Line >= 1 && s.Location.Line <= len(lines) {
|
||||
sb.WriteString(lines[s.Location.Line-1] + " { ... }\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ---- Go skeleton ----
|
||||
|
||||
// skeletonGoBodyNodes lists Go node types that have a body field to replace.
|
||||
var skeletonGoBodyNodes = map[string]string{
|
||||
"function_declaration": "body",
|
||||
"method_declaration": "body",
|
||||
}
|
||||
|
||||
func skeletonGo(tree *sitter.Tree, content []byte) string {
|
||||
if tree == nil {
|
||||
return string(content)
|
||||
}
|
||||
root := tree.RootNode()
|
||||
var sb strings.Builder
|
||||
skeletonGoNode(root, content, &sb)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func skeletonGoNode(node *sitter.Node, content []byte, sb *strings.Builder) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
nodeType := node.Type()
|
||||
|
||||
if bodyField, ok := skeletonGoBodyNodes[nodeType]; ok {
|
||||
body := node.ChildByFieldName(bodyField)
|
||||
if body != nil {
|
||||
sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
|
||||
sb.WriteString(sig)
|
||||
sb.WriteString("{ ... }\n\n")
|
||||
return
|
||||
}
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
return
|
||||
}
|
||||
|
||||
switch nodeType {
|
||||
case "source_file":
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
skeletonGoNode(child, content, sb)
|
||||
}
|
||||
return
|
||||
case "comment":
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
return
|
||||
default:
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- TypeScript / JavaScript skeleton ----
|
||||
|
||||
func skeletonTS(tree *sitter.Tree, content []byte) string {
|
||||
if tree == nil {
|
||||
return string(content)
|
||||
}
|
||||
root := tree.RootNode()
|
||||
var sb strings.Builder
|
||||
skeletonTSNode(root, content, &sb)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func skeletonTSNode(node *sitter.Node, content []byte, sb *strings.Builder) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
nodeType := node.Type()
|
||||
|
||||
switch nodeType {
|
||||
case "program":
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
skeletonTSNode(child, content, sb)
|
||||
}
|
||||
return
|
||||
case "function_declaration":
|
||||
body := node.ChildByFieldName("body")
|
||||
if body != nil {
|
||||
sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
|
||||
sb.WriteString(sig)
|
||||
sb.WriteString("{ ... }\n\n")
|
||||
return
|
||||
}
|
||||
case "class_declaration":
|
||||
body := node.ChildByFieldName("body")
|
||||
if body != nil {
|
||||
header := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
|
||||
sb.WriteString(header)
|
||||
sb.WriteString("{\n")
|
||||
for i := 0; i < int(body.ChildCount()); i++ {
|
||||
child := body.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
if child.Type() == "method_definition" {
|
||||
methBody := child.ChildByFieldName("body")
|
||||
if methBody != nil {
|
||||
methSig := strings.TrimRight(string(content[child.StartByte():methBody.StartByte()]), " \t")
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(methSig)
|
||||
sb.WriteString("{ ... }\n")
|
||||
continue
|
||||
}
|
||||
}
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(string(content[child.StartByte():child.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("}\n\n")
|
||||
return
|
||||
}
|
||||
case "comment":
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ---- Python skeleton ----
|
||||
|
||||
func skeletonPython(tree *sitter.Tree, content []byte) string {
|
||||
if tree == nil {
|
||||
return string(content)
|
||||
}
|
||||
root := tree.RootNode()
|
||||
var sb strings.Builder
|
||||
skeletonPythonNode(root, content, &sb, "")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func skeletonPythonNode(node *sitter.Node, content []byte, sb *strings.Builder, indent string) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
nodeType := node.Type()
|
||||
|
||||
switch nodeType {
|
||||
case "module":
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
skeletonPythonNode(child, content, sb, indent)
|
||||
}
|
||||
return
|
||||
case "function_definition", "decorated_definition":
|
||||
nodeText := string(content[node.StartByte():node.EndByte()])
|
||||
lines := strings.SplitN(nodeText, "\n", 2)
|
||||
sb.WriteString(indent)
|
||||
sb.WriteString(lines[0])
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(indent)
|
||||
sb.WriteString(" ...\n\n")
|
||||
return
|
||||
case "class_definition":
|
||||
body := node.ChildByFieldName("body")
|
||||
if body != nil {
|
||||
header := string(content[node.StartByte():body.StartByte()])
|
||||
firstLine := strings.SplitN(header, "\n", 2)[0]
|
||||
sb.WriteString(indent)
|
||||
sb.WriteString(firstLine)
|
||||
sb.WriteString("\n")
|
||||
for i := 0; i < int(body.ChildCount()); i++ {
|
||||
child := body.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
if child.Type() == "function_definition" || child.Type() == "decorated_definition" {
|
||||
childText := string(content[child.StartByte():child.EndByte()])
|
||||
childLines := strings.SplitN(childText, "\n", 2)
|
||||
sb.WriteString(indent + " ")
|
||||
sb.WriteString(childLines[0])
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(indent + " ...")
|
||||
sb.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
if child.Type() == "expression_statement" {
|
||||
sb.WriteString(indent + " ")
|
||||
sb.WriteString(string(content[child.StartByte():child.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
return
|
||||
}
|
||||
case "comment":
|
||||
sb.WriteString(indent)
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(indent)
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ---- Rust skeleton ----
|
||||
|
||||
func skeletonRust(tree *sitter.Tree, content []byte) string {
|
||||
if tree == nil {
|
||||
return string(content)
|
||||
}
|
||||
root := tree.RootNode()
|
||||
var sb strings.Builder
|
||||
skeletonRustNode(root, content, &sb)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func skeletonRustNode(node *sitter.Node, content []byte, sb *strings.Builder) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
nodeType := node.Type()
|
||||
|
||||
switch nodeType {
|
||||
case "source_file":
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
skeletonRustNode(child, content, sb)
|
||||
}
|
||||
return
|
||||
case "function_item":
|
||||
body := node.ChildByFieldName("body")
|
||||
if body != nil {
|
||||
sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
|
||||
sb.WriteString(sig)
|
||||
sb.WriteString("{ ... }\n\n")
|
||||
return
|
||||
}
|
||||
case "impl_item":
|
||||
body := node.ChildByFieldName("body")
|
||||
if body != nil {
|
||||
header := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
|
||||
sb.WriteString(header)
|
||||
sb.WriteString("{\n")
|
||||
for i := 0; i < int(body.ChildCount()); i++ {
|
||||
child := body.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
if child.Type() == "function_item" {
|
||||
methBody := child.ChildByFieldName("body")
|
||||
if methBody != nil {
|
||||
methSig := strings.TrimRight(string(content[child.StartByte():methBody.StartByte()]), " \t")
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(methSig)
|
||||
sb.WriteString("{ ... }\n")
|
||||
continue
|
||||
}
|
||||
}
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(string(content[child.StartByte():child.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("}\n\n")
|
||||
return
|
||||
}
|
||||
case "line_comment", "block_comment":
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// StripFlag names the categories of content to remove.
|
||||
type StripFlag string
|
||||
|
||||
const (
|
||||
StripImports StripFlag = "imports"
|
||||
StripLicense StripFlag = "license"
|
||||
StripBlockComments StripFlag = "block_comments"
|
||||
)
|
||||
|
||||
// StripResult holds the stripped content and which flags actually removed content.
|
||||
type StripResult struct {
|
||||
Content string
|
||||
Stripped []StripFlag
|
||||
}
|
||||
|
||||
// StripContent applies requested strip operations to content, in order:
|
||||
// license → imports → block_comments.
|
||||
// lang is used to pick language-specific heuristics.
|
||||
func StripContent(content string, flags []StripFlag, lang protocol.Language) StripResult {
|
||||
flagSet := make(map[StripFlag]bool, len(flags))
|
||||
for _, f := range flags {
|
||||
flagSet[f] = true
|
||||
}
|
||||
|
||||
var stripped []StripFlag
|
||||
|
||||
if flagSet[StripLicense] {
|
||||
next, removed := stripLicense(content)
|
||||
if removed {
|
||||
content = next
|
||||
stripped = append(stripped, StripLicense)
|
||||
}
|
||||
}
|
||||
|
||||
if flagSet[StripImports] {
|
||||
next, removed := stripImports(content, lang)
|
||||
if removed {
|
||||
content = next
|
||||
stripped = append(stripped, StripImports)
|
||||
}
|
||||
}
|
||||
|
||||
if flagSet[StripBlockComments] {
|
||||
next, removed := stripBlockComments(content, lang)
|
||||
if removed {
|
||||
content = next
|
||||
stripped = append(stripped, StripBlockComments)
|
||||
}
|
||||
}
|
||||
|
||||
return StripResult{Content: content, Stripped: stripped}
|
||||
}
|
||||
|
||||
// stripLicense removes a leading block comment that looks like a license header.
|
||||
// A comment qualifies if it contains "copyright", "license", or "spdx-license-identifier" (case-insensitive).
|
||||
func stripLicense(content string) (string, bool) {
|
||||
trimmed := strings.TrimLeft(content, " \t\n\r")
|
||||
|
||||
// C-style block comment at top
|
||||
if strings.HasPrefix(trimmed, "/*") {
|
||||
end := strings.Index(trimmed, "*/")
|
||||
if end >= 0 {
|
||||
candidate := trimmed[:end+2]
|
||||
lower := strings.ToLower(candidate)
|
||||
if strings.Contains(lower, "copyright") ||
|
||||
strings.Contains(lower, "license") ||
|
||||
strings.Contains(lower, "spdx-license-identifier") {
|
||||
rest := trimmed[end+2:]
|
||||
// Consume trailing newline(s)
|
||||
rest = strings.TrimLeft(rest, "\r\n")
|
||||
return rest, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Python/hash-style leading comment block
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
lines := strings.Split(trimmed, "\n")
|
||||
var commentLines []string
|
||||
var rest []string
|
||||
inComment := true
|
||||
for i, l := range lines {
|
||||
if inComment && (strings.HasPrefix(l, "#") || strings.TrimSpace(l) == "") {
|
||||
commentLines = append(commentLines, l)
|
||||
} else {
|
||||
rest = lines[i:]
|
||||
break
|
||||
}
|
||||
}
|
||||
block := strings.Join(commentLines, "\n")
|
||||
lower := strings.ToLower(block)
|
||||
if strings.Contains(lower, "copyright") ||
|
||||
strings.Contains(lower, "license") ||
|
||||
strings.Contains(lower, "spdx-license-identifier") {
|
||||
return strings.Join(rest, "\n"), true
|
||||
}
|
||||
}
|
||||
|
||||
return content, false
|
||||
}
|
||||
|
||||
// stripImports removes top-of-file import blocks, language-specific.
|
||||
func stripImports(content string, lang protocol.Language) (string, bool) {
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return stripGoImports(content)
|
||||
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||
return stripTSImports(content)
|
||||
case protocol.LangPython:
|
||||
return stripPythonImports(content)
|
||||
case protocol.LangRust:
|
||||
return stripRustImports(content)
|
||||
default:
|
||||
return content, false
|
||||
}
|
||||
}
|
||||
|
||||
// stripGoImports removes Go import(...) or single import "..." declarations.
|
||||
func stripGoImports(content string) (string, bool) {
|
||||
lines := strings.Split(content, "\n")
|
||||
var out []string
|
||||
removed := false
|
||||
i := 0
|
||||
for i < len(lines) {
|
||||
trimLine := strings.TrimSpace(lines[i])
|
||||
if strings.HasPrefix(trimLine, "import (") || trimLine == "import (" {
|
||||
// multi-line import block
|
||||
removed = true
|
||||
i++ // skip "import ("
|
||||
for i < len(lines) {
|
||||
if strings.TrimSpace(lines[i]) == ")" {
|
||||
i++ // skip closing ")"
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
// skip one blank line after
|
||||
if i < len(lines) && strings.TrimSpace(lines[i]) == "" {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimLine, `import "`) || strings.HasPrefix(trimLine, "import `") {
|
||||
removed = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
out = append(out, lines[i])
|
||||
i++
|
||||
}
|
||||
if !removed {
|
||||
return content, false
|
||||
}
|
||||
return strings.Join(out, "\n"), true
|
||||
}
|
||||
|
||||
// stripTSImports removes TypeScript/JavaScript "import ... from ..." and "require(...)" lines.
|
||||
func stripTSImports(content string) (string, bool) {
|
||||
lines := strings.Split(content, "\n")
|
||||
var out []string
|
||||
removed := false
|
||||
for _, l := range lines {
|
||||
trimLine := strings.TrimSpace(l)
|
||||
if strings.HasPrefix(trimLine, "import ") || strings.HasPrefix(trimLine, "const {") && strings.Contains(trimLine, "require(") {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, l)
|
||||
}
|
||||
if !removed {
|
||||
return content, false
|
||||
}
|
||||
return strings.Join(out, "\n"), true
|
||||
}
|
||||
|
||||
// stripPythonImports removes Python "import ..." and "from ... import ..." lines.
|
||||
func stripPythonImports(content string) (string, bool) {
|
||||
lines := strings.Split(content, "\n")
|
||||
var out []string
|
||||
removed := false
|
||||
for _, l := range lines {
|
||||
trimLine := strings.TrimSpace(l)
|
||||
if strings.HasPrefix(trimLine, "import ") || strings.HasPrefix(trimLine, "from ") {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, l)
|
||||
}
|
||||
if !removed {
|
||||
return content, false
|
||||
}
|
||||
return strings.Join(out, "\n"), true
|
||||
}
|
||||
|
||||
// stripRustImports removes Rust "use ..." declarations.
|
||||
func stripRustImports(content string) (string, bool) {
|
||||
lines := strings.Split(content, "\n")
|
||||
var out []string
|
||||
removed := false
|
||||
inMulti := false
|
||||
for _, l := range lines {
|
||||
trimLine := strings.TrimSpace(l)
|
||||
if inMulti {
|
||||
// look for semicolon terminating multi-line use
|
||||
if strings.Contains(trimLine, ";") {
|
||||
inMulti = false
|
||||
}
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimLine, "use ") {
|
||||
removed = true
|
||||
if !strings.HasSuffix(trimLine, ";") {
|
||||
inMulti = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
out = append(out, l)
|
||||
}
|
||||
if !removed {
|
||||
return content, false
|
||||
}
|
||||
return strings.Join(out, "\n"), true
|
||||
}
|
||||
|
||||
// stripBlockComments removes /* ... */ block comments (Go/TS/C/Rust)
|
||||
// and Python triple-quoted docstrings.
|
||||
func stripBlockComments(content string, lang protocol.Language) (string, bool) {
|
||||
if lang == protocol.LangPython {
|
||||
return stripPythonDocstrings(content)
|
||||
}
|
||||
return stripCStyleBlockComments(content)
|
||||
}
|
||||
|
||||
// stripCStyleBlockComments removes /* ... */ from content.
|
||||
func stripCStyleBlockComments(content string) (string, bool) {
|
||||
removed := false
|
||||
var sb strings.Builder
|
||||
i := 0
|
||||
for i < len(content) {
|
||||
if i+1 < len(content) && content[i] == '/' && content[i+1] == '*' {
|
||||
// find closing */
|
||||
end := strings.Index(content[i+2:], "*/")
|
||||
if end >= 0 {
|
||||
removed = true
|
||||
// advance past */
|
||||
i = i + 2 + end + 2
|
||||
// consume trailing newline
|
||||
if i < len(content) && content[i] == '\n' {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
sb.WriteByte(content[i])
|
||||
i++
|
||||
}
|
||||
if !removed {
|
||||
return content, false
|
||||
}
|
||||
return sb.String(), true
|
||||
}
|
||||
|
||||
// stripPythonDocstrings removes triple-quoted strings (""" and ”').
|
||||
func stripPythonDocstrings(content string) (string, bool) {
|
||||
removed := false
|
||||
var sb strings.Builder
|
||||
i := 0
|
||||
for i < len(content) {
|
||||
if i+2 < len(content) {
|
||||
triple := content[i : i+3]
|
||||
if triple == `"""` || triple == `'''` {
|
||||
end := strings.Index(content[i+3:], triple)
|
||||
if end >= 0 {
|
||||
removed = true
|
||||
i = i + 3 + end + 3
|
||||
if i < len(content) && content[i] == '\n' {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
sb.WriteByte(content[i])
|
||||
i++
|
||||
}
|
||||
if !removed {
|
||||
return content, false
|
||||
}
|
||||
return sb.String(), true
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// makeResults builds N dummy MatchResults.
|
||||
func makeResults(n int) []MatchResult {
|
||||
out := make([]MatchResult, n)
|
||||
for i := range out {
|
||||
out[i] = MatchResult{
|
||||
File: "file.go",
|
||||
Location: protocol.Location{
|
||||
Line: i + 1,
|
||||
Column: 1,
|
||||
},
|
||||
Text: "func Foo() {}\nline2",
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TestFormatResultsVerboseDefault verifies verbose format includes code blocks (no preamble by default).
|
||||
func TestFormatResultsVerboseDefault(t *testing.T) {
|
||||
results := makeResults(2)
|
||||
out := FormatResultsWithOptions(results, 0, "verbose", 0)
|
||||
// v2 default: no preamble
|
||||
if strings.Contains(out, "Found ") {
|
||||
t.Errorf("v2 default should NOT emit preamble, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "```") {
|
||||
t.Error("verbose mode should include code blocks")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatResultsVerbosePreamble verifies verbose=true restores the preamble.
|
||||
func TestFormatResultsVerbosePreamble(t *testing.T) {
|
||||
results := makeResults(2)
|
||||
out := FormatResultsWithOptions(results, 0, "verbose", 0, true)
|
||||
if !strings.Contains(out, "Found 2 match(es):") {
|
||||
t.Errorf("expected preamble with verbose=true, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsCompact(t *testing.T) {
|
||||
results := makeResults(3)
|
||||
out := FormatResultsWithOptions(results, 0, "compact", 0)
|
||||
// v2 default: no preamble in compact mode
|
||||
// Should NOT have code blocks
|
||||
if strings.Contains(out, "```") {
|
||||
t.Error("compact mode should not have code blocks")
|
||||
}
|
||||
// Should have one line per match (beyond header)
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
// First two lines are "Found..." and blank, then 3 match lines
|
||||
matchLines := 0
|
||||
for _, l := range lines {
|
||||
if strings.Contains(l, "file.go:") {
|
||||
matchLines++
|
||||
}
|
||||
}
|
||||
if matchLines != 3 {
|
||||
t.Errorf("expected 3 match lines in compact mode, got %d\nOutput:\n%s", matchLines, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsLocation(t *testing.T) {
|
||||
results := makeResults(3)
|
||||
out := FormatResultsWithOptions(results, 0, "location", 0)
|
||||
if strings.Contains(out, "```") {
|
||||
t.Error("location mode should not have code blocks")
|
||||
}
|
||||
// Should be file:line only
|
||||
for i := 1; i <= 3; i++ {
|
||||
expected := "file.go:" + itoa(i)
|
||||
if !strings.Contains(out, expected) {
|
||||
t.Errorf("location output missing %s", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsMaxResults(t *testing.T) {
|
||||
results := makeResults(5)
|
||||
out := FormatResultsWithOptions(results, 3, "verbose", 0)
|
||||
// v2 default: no preamble — check that exactly 3 code blocks are present
|
||||
codeBlockCount := strings.Count(out, "```")
|
||||
if codeBlockCount != 6 { // 3 opening + 3 closing = 6
|
||||
t.Errorf("expected 3 matches (6 backtick markers), got %d in:\n%s", codeBlockCount, out)
|
||||
}
|
||||
if !strings.Contains(out, "[remaining: 2]") {
|
||||
t.Errorf("expected [remaining: 2] footer, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsOffset(t *testing.T) {
|
||||
results := makeResults(5)
|
||||
// Skip first 2, show all remaining
|
||||
out := FormatResultsWithOptions(results, 0, "verbose", 2)
|
||||
// offset=2 from 5 results → 3 results; check 3 code blocks
|
||||
codeBlockCount := strings.Count(out, "```")
|
||||
if codeBlockCount != 6 {
|
||||
t.Errorf("expected 3 matches (6 backtick markers) after offset=2, got %d in:\n%s", codeBlockCount, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsOffsetBeyondEnd(t *testing.T) {
|
||||
results := makeResults(3)
|
||||
out := FormatResultsWithOptions(results, 0, "verbose", 10)
|
||||
if out != "No matches found." {
|
||||
t.Errorf("expected 'No matches found.' for offset beyond end, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsPaginationCursor(t *testing.T) {
|
||||
// Offset=2, maxResults=2, 5 total → show items 3&4, remaining=1
|
||||
results := makeResults(5)
|
||||
out := FormatResultsWithOptions(results, 2, "verbose", 2)
|
||||
// offset=2, maxResults=2 → items 3&4; check 2 code blocks
|
||||
codeBlockCount := strings.Count(out, "```")
|
||||
if codeBlockCount != 4 {
|
||||
t.Errorf("expected 2 matches (4 backtick markers), got %d in:\n%s", codeBlockCount, out)
|
||||
}
|
||||
if !strings.Contains(out, "[remaining: 1]") {
|
||||
t.Errorf("expected [remaining: 1], got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsEmpty(t *testing.T) {
|
||||
out := FormatResultsWithOptions(nil, 0, "verbose", 0)
|
||||
if out != "No matches found." {
|
||||
t.Errorf("expected 'No matches found.', got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsBackwardCompat(t *testing.T) {
|
||||
// FormatResults wrapper should produce same output as FormatResultsWithOptions with verbose=false (default).
|
||||
results := makeResults(2)
|
||||
a := FormatResults(results, 0)
|
||||
b := FormatResultsWithOptions(results, 0, "verbose", 0)
|
||||
if a != b {
|
||||
t.Error("FormatResults and FormatResultsWithOptions(verbose,0) should be identical")
|
||||
}
|
||||
// Both should have no preamble.
|
||||
if strings.Contains(a, "Found ") {
|
||||
t.Error("FormatResults should not emit preamble by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstLineOf(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{"hello world", 20, "hello world"},
|
||||
{"line1\nline2", 20, "line1"},
|
||||
{"\n\nfoo", 20, "foo"},
|
||||
{"abcdefghij", 5, "abcd…"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := firstLineOf(c.input, c.maxLen)
|
||||
if got != c.want {
|
||||
t.Errorf("firstLineOf(%q, %d) = %q, want %q", c.input, c.maxLen, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(n int) string {
|
||||
if n < 10 {
|
||||
return string(rune('0' + n))
|
||||
}
|
||||
return strings.TrimRight(strings.TrimRight(
|
||||
func() string {
|
||||
buf := make([]byte, 20)
|
||||
pos := 20
|
||||
for n > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + n%10)
|
||||
n /= 10
|
||||
}
|
||||
return string(buf[pos:])
|
||||
}(), ""), "")
|
||||
}
|
||||
+99
-40
@@ -451,62 +451,121 @@ func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool
|
||||
return true
|
||||
}
|
||||
|
||||
// FormatResults formats match results for display.
|
||||
// FormatResults formats match results for display (backward-compat wrapper, verbose mode).
|
||||
func FormatResults(results []MatchResult, maxResults int) string {
|
||||
return FormatResultsWithOptions(results, maxResults, "verbose", 0)
|
||||
}
|
||||
|
||||
// FormatResultsWithOptions formats match results with configurable output format.
|
||||
// format: "verbose" (default) | "compact" | "location"
|
||||
// offset: skip this many results before rendering (used for cursor pagination).
|
||||
// verbose: opt-in variadic — pass true to restore "Found N match(es):" preamble (v1 behaviour).
|
||||
func FormatResultsWithOptions(results []MatchResult, maxResults int, format string, offset int, verbose ...bool) string {
|
||||
if len(results) == 0 {
|
||||
return "No matches found."
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results)))
|
||||
|
||||
displayCount := len(results)
|
||||
truncated := false
|
||||
if maxResults > 0 && displayCount > maxResults {
|
||||
displayCount = maxResults
|
||||
truncated = true
|
||||
// Apply offset (pagination skip).
|
||||
if offset > 0 {
|
||||
if offset >= len(results) {
|
||||
return "No matches found."
|
||||
}
|
||||
results = results[offset:]
|
||||
}
|
||||
|
||||
for i := 0; i < displayCount; i++ {
|
||||
r := results[i]
|
||||
nodeType := "unknown"
|
||||
if r.Node != nil {
|
||||
nodeType = r.Node.Type()
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
|
||||
// Determine how many to render and whether more remain.
|
||||
renderCount := len(results)
|
||||
remaining := 0
|
||||
if maxResults > 0 && renderCount > maxResults {
|
||||
remaining = renderCount - maxResults
|
||||
renderCount = maxResults
|
||||
}
|
||||
|
||||
// Truncate very long text
|
||||
text := r.Text
|
||||
if len(text) > 500 {
|
||||
text = text[:500] + "..."
|
||||
}
|
||||
sb.WriteString("```\n")
|
||||
sb.WriteString(text)
|
||||
sb.WriteString("\n```\n")
|
||||
var sb strings.Builder
|
||||
|
||||
// Show captures
|
||||
if len(r.Captures) > 0 {
|
||||
sb.WriteString("Captures: ")
|
||||
first := true
|
||||
for name, cap := range r.Captures {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
// Emit preamble only when verbose=true is explicitly passed (opt-in, default off).
|
||||
wantVerbose := len(verbose) > 0 && verbose[0]
|
||||
if wantVerbose {
|
||||
sb.WriteString(fmt.Sprintf("Found %d match(es):\n", renderCount))
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "compact":
|
||||
for i := 0; i < renderCount; i++ {
|
||||
r := results[i]
|
||||
nodeType := "unknown"
|
||||
if r.Node != nil {
|
||||
nodeType = r.Node.Type()
|
||||
}
|
||||
firstLine := firstLineOf(r.Text, 80)
|
||||
sb.WriteString(fmt.Sprintf("%s:%d (%s) %s\n", r.File, r.Location.Line, nodeType, firstLine))
|
||||
}
|
||||
|
||||
case "location":
|
||||
for i := 0; i < renderCount; i++ {
|
||||
r := results[i]
|
||||
sb.WriteString(fmt.Sprintf("%s:%d\n", r.File, r.Location.Line))
|
||||
}
|
||||
|
||||
default: // "verbose"
|
||||
for i := 0; i < renderCount; i++ {
|
||||
r := results[i]
|
||||
nodeType := "unknown"
|
||||
if r.Node != nil {
|
||||
nodeType = r.Node.Type()
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
|
||||
|
||||
// Truncate very long text
|
||||
text := r.Text
|
||||
if len(text) > 500 {
|
||||
text = text[:500] + "..."
|
||||
}
|
||||
sb.WriteString("```\n")
|
||||
sb.WriteString(text)
|
||||
sb.WriteString("\n```\n")
|
||||
|
||||
// Show captures
|
||||
if len(r.Captures) > 0 {
|
||||
sb.WriteString("Captures: ")
|
||||
first := true
|
||||
for name, cap := range r.Captures {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
first = false
|
||||
capText := cap.Text
|
||||
if len(capText) > 50 {
|
||||
capText = capText[:50] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
|
||||
}
|
||||
first = false
|
||||
capText := cap.Text
|
||||
if len(capText) > 50 {
|
||||
capText = capText[:50] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if truncated {
|
||||
sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults))
|
||||
if remaining > 0 {
|
||||
// Caller must embed the cursor token; we just append the remaining count hint.
|
||||
// The actual [cursor: ...] line is written by the handler after calling MakeCursor.
|
||||
sb.WriteString(fmt.Sprintf("[remaining: %d]\n", remaining))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// firstLineOf returns the first non-empty line of s, trimmed and capped at maxLen chars.
|
||||
func firstLineOf(s string, maxLen int) string {
|
||||
for _, line := range strings.Split(s, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if len(line) > maxLen {
|
||||
return line[:maxLen-1] + "…"
|
||||
}
|
||||
return line
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func newTestSearcher(t *testing.T) *Searcher {
|
||||
t.Helper()
|
||||
cfg := &config.Config{
|
||||
WorkspaceRoot: t.TempDir(),
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
// Create a Searcher directly without requiring rg
|
||||
return &Searcher{cfg: cfg, logger: logger, rgPath: "rg"}
|
||||
}
|
||||
|
||||
func makeSearchResults(lines []int) *SearchResults {
|
||||
results := make([]Result, len(lines))
|
||||
for i, l := range lines {
|
||||
results[i] = Result{
|
||||
File: "/workspace/foo.go",
|
||||
Line: l,
|
||||
MatchText: "match at line " + itoa2(l),
|
||||
}
|
||||
}
|
||||
return &SearchResults{Results: results}
|
||||
}
|
||||
|
||||
// itoa2 simple int to string for test
|
||||
func itoa2(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
digits := []byte{}
|
||||
for n > 0 {
|
||||
digits = append([]byte{byte('0' + n%10)}, digits...)
|
||||
n /= 10
|
||||
}
|
||||
return string(digits)
|
||||
}
|
||||
|
||||
// TestFormatResultsVerboseBackwardCompat verifies that Verbose=true restores the v1 preamble.
|
||||
func TestFormatResultsVerboseBackwardCompat(t *testing.T) {
|
||||
s := newTestSearcher(t)
|
||||
sr := makeSearchResults([]int{1, 5, 10})
|
||||
// With Verbose=true, the preamble is emitted (v1 behaviour).
|
||||
out := s.FormatResultsWithOptions(sr, FormatOptions{Verbose: true})
|
||||
if !strings.Contains(out, "Found 3 matches in 1 files") {
|
||||
t.Errorf("expected header with Verbose=true, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "L1│") {
|
||||
t.Errorf("expected L1 match line, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatResultsDefaultNoPreamble verifies the v2 default has no preamble.
|
||||
func TestFormatResultsDefaultNoPreamble(t *testing.T) {
|
||||
s := newTestSearcher(t)
|
||||
sr := makeSearchResults([]int{1, 5, 10})
|
||||
out := s.FormatResults(sr)
|
||||
if strings.Contains(out, "Found ") {
|
||||
t.Errorf("v2 default should NOT emit preamble, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "L1│") {
|
||||
t.Errorf("expected L1 match line, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsClusterSingleLine(t *testing.T) {
|
||||
s := newTestSearcher(t)
|
||||
// Non-consecutive lines → each should appear separately
|
||||
sr := makeSearchResults([]int{1, 5, 10})
|
||||
out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true})
|
||||
if !strings.Contains(out, "L1│") {
|
||||
t.Errorf("expected L1 cluster, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "L5│") {
|
||||
t.Errorf("expected L5 cluster, got:\n%s", out)
|
||||
}
|
||||
// Should NOT have " │" context lines in cluster mode
|
||||
if strings.Contains(out, " │") {
|
||||
t.Errorf("cluster mode should not have context-line decoration, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsClusterConsecutive(t *testing.T) {
|
||||
s := newTestSearcher(t)
|
||||
// Lines 3,4,5 are consecutive → should be clustered as L3-5
|
||||
sr := makeSearchResults([]int{3, 4, 5, 10})
|
||||
out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true})
|
||||
if !strings.Contains(out, "L3-5│") {
|
||||
t.Errorf("expected L3-5 cluster, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "L10│") {
|
||||
t.Errorf("expected L10 separate cluster, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsClusterAdjacentMerge(t *testing.T) {
|
||||
s := newTestSearcher(t)
|
||||
// Lines 7 and 8 differ by 1 → merge
|
||||
sr := makeSearchResults([]int{7, 8})
|
||||
out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true})
|
||||
if !strings.Contains(out, "L7-8│") {
|
||||
t.Errorf("expected L7-8 cluster, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsCursorLine(t *testing.T) {
|
||||
s := newTestSearcher(t)
|
||||
sr := makeSearchResults([]int{1})
|
||||
cursorText := "[cursor: abc123, remaining: 5]"
|
||||
out := s.FormatResultsWithOptions(sr, FormatOptions{CursorLine: cursorText})
|
||||
if !strings.Contains(out, cursorText) {
|
||||
t.Errorf("expected cursor footer in output, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResultsNoMatchesVerbose(t *testing.T) {
|
||||
s := newTestSearcher(t)
|
||||
sr := &SearchResults{Results: nil}
|
||||
out := s.FormatResults(sr)
|
||||
if out != "No matches found." {
|
||||
t.Errorf("expected 'No matches found.', got: %s", out)
|
||||
}
|
||||
}
|
||||
+97
-26
@@ -219,11 +219,9 @@ func (s *Searcher) buildArgs(req *Request) []string {
|
||||
args = append(args, "--no-ignore")
|
||||
}
|
||||
|
||||
// Global result cap — --max-total-count stops rg early across all files.
|
||||
// Requires ripgrep >= 13.0. In-process truncation in parseOutput is kept as a safety net.
|
||||
if req.MaxResults > 0 {
|
||||
args = append(args, fmt.Sprintf("--max-total-count=%d", req.MaxResults))
|
||||
}
|
||||
// Result cap enforced in-process by parseOutput. rg has no cross-file
|
||||
// total-count flag in stable releases, so we don't pass one; --max-count is
|
||||
// per-file and would miss results unevenly.
|
||||
|
||||
// Add pattern
|
||||
args = append(args, "--", req.Pattern)
|
||||
@@ -356,8 +354,21 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// FormatResults formats search results for display.
|
||||
// FormatOptions controls how search results are rendered.
|
||||
type FormatOptions struct {
|
||||
Cluster bool // coalesce consecutive matches into line-range blocks
|
||||
CursorLine string // if non-empty, appended as a footer line
|
||||
Verbose bool // if true, emit "Found N matches in M files:" preamble (opt-in)
|
||||
}
|
||||
|
||||
// FormatResults formats search results for display (backward-compat wrapper).
|
||||
func (s *Searcher) FormatResults(results *SearchResults) string {
|
||||
return s.FormatResultsWithOptions(results, FormatOptions{})
|
||||
}
|
||||
|
||||
// FormatResultsWithOptions formats search results with configurable output.
|
||||
// By default the "Found N matches in M files:" preamble is omitted; set opts.Verbose=true to restore it.
|
||||
func (s *Searcher) FormatResultsWithOptions(results *SearchResults, opts FormatOptions) string {
|
||||
if len(results.Results) == 0 {
|
||||
return "No matches found."
|
||||
}
|
||||
@@ -374,14 +385,18 @@ func (s *Searcher) FormatResults(results *SearchResults) string {
|
||||
fileResults[r.File] = append(fileResults[r.File], r)
|
||||
}
|
||||
|
||||
// Write summary
|
||||
totalMatches := len(results.Results)
|
||||
fileCount := len(fileResults)
|
||||
sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount))
|
||||
if results.Truncated {
|
||||
sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total))
|
||||
// Write preamble only when Verbose is requested.
|
||||
if opts.Verbose {
|
||||
totalMatches := len(results.Results)
|
||||
fileCount := len(fileResults)
|
||||
sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount))
|
||||
if results.Truncated {
|
||||
sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total))
|
||||
}
|
||||
sb.WriteString(":\n\n")
|
||||
} else if results.Truncated {
|
||||
sb.WriteString(fmt.Sprintf("(truncated, showing subset of %d total matches)\n\n", results.Total))
|
||||
}
|
||||
sb.WriteString(":\n\n")
|
||||
|
||||
// Write results grouped by file
|
||||
for _, file := range fileOrder {
|
||||
@@ -395,26 +410,82 @@ func (s *Searcher) FormatResults(results *SearchResults) string {
|
||||
|
||||
sb.WriteString(fmt.Sprintf("**%s**\n", relPath))
|
||||
|
||||
for _, r := range fileResults[file] {
|
||||
// Write context before
|
||||
for _, ctx := range r.Context.Before {
|
||||
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
|
||||
}
|
||||
|
||||
// Write match line
|
||||
sb.WriteString(fmt.Sprintf("L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200)))
|
||||
|
||||
// Write context after
|
||||
for _, ctx := range r.Context.After {
|
||||
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
|
||||
}
|
||||
if opts.Cluster {
|
||||
writeClusteredResults(&sb, fileResults[file])
|
||||
} else {
|
||||
writeVerboseResults(&sb, fileResults[file])
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if opts.CursorLine != "" {
|
||||
sb.WriteString(opts.CursorLine)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// writeVerboseResults writes results in the standard verbose format.
|
||||
func writeVerboseResults(sb *strings.Builder, results []Result) {
|
||||
for _, r := range results {
|
||||
// Write context before
|
||||
for _, ctx := range r.Context.Before {
|
||||
fmt.Fprintf(sb, " │ %s\n", truncateLine(ctx, 200))
|
||||
}
|
||||
// Write match line
|
||||
fmt.Fprintf(sb, "L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200))
|
||||
// Write context after
|
||||
for _, ctx := range r.Context.After {
|
||||
fmt.Fprintf(sb, " │ %s\n", truncateLine(ctx, 200))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeClusteredResults coalesces consecutive or adjacent match lines into
|
||||
// a single "L12-14│ <first-match-text>" entry. Context lines are dropped
|
||||
// in cluster mode to maximise information density.
|
||||
func writeClusteredResults(sb *strings.Builder, results []Result) {
|
||||
if len(results) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
type clusterEntry struct {
|
||||
startLine int
|
||||
endLine int
|
||||
firstText string
|
||||
}
|
||||
|
||||
var clusters []clusterEntry
|
||||
cur := clusterEntry{
|
||||
startLine: results[0].Line,
|
||||
endLine: results[0].Line,
|
||||
firstText: results[0].MatchText,
|
||||
}
|
||||
|
||||
for _, r := range results[1:] {
|
||||
// Merge if adjacent (within 1 line gap)
|
||||
if r.Line <= cur.endLine+1 {
|
||||
if r.Line > cur.endLine {
|
||||
cur.endLine = r.Line
|
||||
}
|
||||
} else {
|
||||
clusters = append(clusters, cur)
|
||||
cur = clusterEntry{startLine: r.Line, endLine: r.Line, firstText: r.MatchText}
|
||||
}
|
||||
}
|
||||
clusters = append(clusters, cur)
|
||||
|
||||
for _, c := range clusters {
|
||||
text := truncateLine(c.firstText, 200)
|
||||
if c.startLine == c.endLine {
|
||||
fmt.Fprintf(sb, "L%d│ %s\n", c.startLine, text)
|
||||
} else {
|
||||
fmt.Fprintf(sb, "L%d-%d│ %s\n", c.startLine, c.endLine, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// truncateLine truncates a line if it exceeds maxLen.
|
||||
func truncateLine(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
|
||||
@@ -94,8 +94,8 @@ func TestBuildArgs(t *testing.T) {
|
||||
MaxResults: 10,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
|
||||
notExpected: []string{"--ignore-case", "--fixed-strings"},
|
||||
expected: []string{"--json", "--", "test", "."},
|
||||
notExpected: []string{"--ignore-case", "--fixed-strings", "--max-total-count=10", "--max-count=10"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,344 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// newTestServer creates a server with a temp workspace containing Go files.
|
||||
func newFeaturesServer(t *testing.T) (*Server, string) {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Write a few Go files for AST queries
|
||||
goFile1 := filepath.Join(tmpDir, "a.go")
|
||||
if err := os.WriteFile(goFile1, []byte(`package main
|
||||
|
||||
func Alpha() string { return "alpha" }
|
||||
func Beta() string { return "beta" }
|
||||
func Gamma() string { return "gamma" }
|
||||
`), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goFile2 := filepath.Join(tmpDir, "b.go")
|
||||
if err := os.WriteFile(goFile2, []byte(`package main
|
||||
|
||||
func Delta() int { return 1 }
|
||||
func Epsilon() int { return 2 }
|
||||
`), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
WorkspaceRoot: tmpDir,
|
||||
EnableLSP: false,
|
||||
MaxFileSize: config.DefaultMaxFileSize,
|
||||
MaxParseSize: config.DefaultMaxParseSize,
|
||||
SearchTimeout: 30 * time.Second,
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error: %v", err)
|
||||
}
|
||||
return srv, tmpDir
|
||||
}
|
||||
|
||||
// ---- Feature 1: ast_query format flag ----
|
||||
|
||||
func TestASTQueryFormatVerboseDefault(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
ctx := context.Background()
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME() string",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
}
|
||||
res, err := srv.handleASTQuery(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if res == nil || len(res.Content) == 0 {
|
||||
t.Fatal("nil/empty result")
|
||||
}
|
||||
text := res.Content[0].(mcp.TextContent).Text
|
||||
// verbose mode has code blocks
|
||||
if !strings.Contains(text, "```") {
|
||||
t.Errorf("verbose mode should have code blocks, got:\n%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestASTQueryFormatCompact(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
ctx := context.Background()
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME() string",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"format": "compact",
|
||||
}
|
||||
res, err := srv.handleASTQuery(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
text := res.Content[0].(mcp.TextContent).Text
|
||||
if strings.Contains(text, "```") {
|
||||
t.Errorf("compact mode should NOT have code blocks, got:\n%s", text)
|
||||
}
|
||||
// Each line should contain file:line (kind) text
|
||||
for _, line := range strings.Split(strings.TrimSpace(text), "\n") {
|
||||
if line == "" || strings.HasPrefix(line, "Found") {
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(line, ":") {
|
||||
t.Errorf("compact line missing ':' separator: %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestASTQueryFormatLocation(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
ctx := context.Background()
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME() string",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"format": "location",
|
||||
}
|
||||
res, err := srv.handleASTQuery(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
text := res.Content[0].(mcp.TextContent).Text
|
||||
if strings.Contains(text, "```") {
|
||||
t.Errorf("location mode should NOT have code blocks, got:\n%s", text)
|
||||
}
|
||||
// Lines should be file:linenum only (no parentheses with kind)
|
||||
for _, line := range strings.Split(strings.TrimSpace(text), "\n") {
|
||||
if line == "" || strings.HasPrefix(line, "Found") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(line, "(") {
|
||||
t.Errorf("location mode should not have node type in parens: %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Feature 3 (ast_query): pagination cursor ----
|
||||
|
||||
func TestASTQueryPaginationCursor(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Page 1: max_results=2
|
||||
req1 := mcp.CallToolRequest{}
|
||||
req1.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME() $RET",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"max_results": float64(2),
|
||||
}
|
||||
res1, err := srv.handleASTQuery(ctx, req1)
|
||||
if err != nil {
|
||||
t.Fatalf("page1 error: %v", err)
|
||||
}
|
||||
text1 := res1.Content[0].(mcp.TextContent).Text
|
||||
|
||||
// Should contain cursor footer if there are more results
|
||||
if !strings.Contains(text1, "[cursor:") {
|
||||
// Might have fewer than 2 total results — skip cursor test
|
||||
t.Logf("no cursor footer (fewer than 2 total matches), skipping pagination round-trip")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract cursor token
|
||||
var cursorToken string
|
||||
for _, line := range strings.Split(text1, "\n") {
|
||||
if strings.HasPrefix(line, "[cursor:") {
|
||||
// [cursor: <token>, remaining: N]
|
||||
parts := strings.Split(line, " ")
|
||||
if len(parts) >= 2 {
|
||||
cursorToken = strings.TrimSuffix(parts[1], ",")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if cursorToken == "" {
|
||||
t.Fatal("could not extract cursor token from output")
|
||||
}
|
||||
|
||||
// Page 2: pass cursor back
|
||||
req2 := mcp.CallToolRequest{}
|
||||
req2.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME() $RET",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"max_results": float64(2),
|
||||
"cursor": cursorToken,
|
||||
}
|
||||
res2, err := srv.handleASTQuery(ctx, req2)
|
||||
if err != nil {
|
||||
t.Fatalf("page2 error: %v", err)
|
||||
}
|
||||
text2 := res2.Content[0].(mcp.TextContent).Text
|
||||
if strings.Contains(text2, "cursor is for a different query") {
|
||||
t.Error("cursor was rejected as mismatched query")
|
||||
}
|
||||
// Page 2 should have results
|
||||
if strings.Contains(text2, "No matches found.") {
|
||||
t.Error("page2 should have some results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestASTQueryCursorStaleMismatch(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
ctx := context.Background()
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME() string",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"cursor": "eyJvZmZzZXQiOjIsInF1ZXJ5X2hhc2giOiJkZWFkYmVlZiJ9", // offset=2, hash=deadbeef
|
||||
}
|
||||
res, err := srv.handleASTQuery(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
text := res.Content[0].(mcp.TextContent).Text
|
||||
if !strings.Contains(text, "cursor is for a different query") {
|
||||
t.Errorf("expected stale cursor error, got:\n%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Feature 2: file_search cluster ----
|
||||
|
||||
func TestFileSearchClusterFlag(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
if srv.searcher == nil {
|
||||
t.Skip("rg not available")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"cluster": true,
|
||||
}
|
||||
res, err := srv.handleFileSearch(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if res == nil || len(res.Content) == 0 {
|
||||
t.Fatal("nil/empty result")
|
||||
}
|
||||
text := res.Content[0].(mcp.TextContent).Text
|
||||
if text == "No matches found." {
|
||||
t.Skip("no matches found (unexpected)")
|
||||
}
|
||||
// In cluster mode, should NOT have " │" context decorations
|
||||
if strings.Contains(text, " │") {
|
||||
t.Errorf("cluster mode should not contain context-line decoration ' │', got:\n%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSearchCursorStaleHash(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
if srv.searcher == nil {
|
||||
t.Skip("rg not available")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"cursor": "eyJvZmZzZXQiOjEsInF1ZXJ5X2hhc2giOiJiYWRoYXNoIn0", // offset=1, hash=badhash
|
||||
}
|
||||
res, err := srv.handleFileSearch(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
text := res.Content[0].(mcp.TextContent).Text
|
||||
if !strings.Contains(text, "cursor is for a different query") {
|
||||
t.Errorf("expected stale cursor error, got:\n%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSearchPaginationCursor(t *testing.T) {
|
||||
srv, tmpDir := newFeaturesServer(t)
|
||||
if srv.searcher == nil {
|
||||
t.Skip("rg not available")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// Page 1: get 1 result
|
||||
req1 := mcp.CallToolRequest{}
|
||||
req1.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"max_results": float64(1),
|
||||
"context_lines": float64(0),
|
||||
}
|
||||
res1, err := srv.handleFileSearch(ctx, req1)
|
||||
if err != nil {
|
||||
t.Fatalf("page1 error: %v", err)
|
||||
}
|
||||
text1 := res1.Content[0].(mcp.TextContent).Text
|
||||
if !strings.Contains(text1, "[cursor:") {
|
||||
t.Logf("no cursor in page1 (only 1 total result), skipping round-trip:\n%s", text1)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract cursor
|
||||
var cursorToken string
|
||||
for _, line := range strings.Split(text1, "\n") {
|
||||
if strings.HasPrefix(line, "[cursor:") {
|
||||
parts := strings.Split(line, " ")
|
||||
if len(parts) >= 2 {
|
||||
cursorToken = strings.TrimSuffix(parts[1], ",")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if cursorToken == "" {
|
||||
t.Fatal("could not extract cursor from page1")
|
||||
}
|
||||
|
||||
// Page 2
|
||||
req2 := mcp.CallToolRequest{}
|
||||
req2.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"max_results": float64(1),
|
||||
"context_lines": float64(0),
|
||||
"cursor": cursorToken,
|
||||
}
|
||||
res2, err := srv.handleFileSearch(ctx, req2)
|
||||
if err != nil {
|
||||
t.Fatalf("page2 error: %v", err)
|
||||
}
|
||||
text2 := res2.Content[0].(mcp.TextContent).Text
|
||||
if strings.Contains(text2, "cursor is for a different query") {
|
||||
t.Error("cursor was rejected as mismatched")
|
||||
}
|
||||
if strings.Contains(text2, "No matches found.") {
|
||||
t.Error("page2 should have matches")
|
||||
}
|
||||
}
|
||||
+198
-105
@@ -8,12 +8,193 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/cursor"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// astQueryParams holds resolved parameters for an AST query invocation.
|
||||
type astQueryParams struct {
|
||||
pattern string
|
||||
language string
|
||||
nameMatches string
|
||||
nameExact string
|
||||
kindIn []string
|
||||
paths []string
|
||||
maxResults int
|
||||
format string
|
||||
offset int
|
||||
queryHash string
|
||||
verbose bool
|
||||
}
|
||||
|
||||
// resolveASTQueryParams parses the request and resolves session-pref defaults.
|
||||
// Returns (params, errorResult, error); when errorResult is non-nil, the caller
|
||||
// should return it directly.
|
||||
func (s *Server) resolveASTQueryParams(request mcp.CallToolRequest) (*astQueryParams, *mcp.CallToolResult) {
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return nil, mcp.NewToolResultError("pattern is required")
|
||||
}
|
||||
language, err := request.RequireString("language")
|
||||
if err != nil {
|
||||
return nil, mcp.NewToolResultError("language is required")
|
||||
}
|
||||
|
||||
p := &astQueryParams{
|
||||
pattern: pattern,
|
||||
language: language,
|
||||
nameMatches: request.GetString("name_matches", ""),
|
||||
nameExact: request.GetString("name_exact", ""),
|
||||
kindIn: request.GetStringSlice("kind_in", nil),
|
||||
paths: request.GetStringSlice("paths", nil),
|
||||
verbose: request.GetBool("verbose", false),
|
||||
}
|
||||
|
||||
sp := s.sessionPrefs.Load()
|
||||
var prefsMaxResults int
|
||||
var prefsFormat string
|
||||
if sp != nil {
|
||||
prefsMaxResults = sp.DefaultMaxResults
|
||||
prefsFormat = sp.ASTQueryFormat
|
||||
}
|
||||
p.maxResults = effectiveInt(request, "max_results", prefsMaxResults, 100)
|
||||
|
||||
p.format = request.GetString("format", "")
|
||||
if p.format == "" {
|
||||
if prefsFormat != "" {
|
||||
p.format = prefsFormat
|
||||
} else {
|
||||
p.format = "verbose"
|
||||
}
|
||||
}
|
||||
|
||||
p.queryHash = cursor.HashParams(map[string]string{
|
||||
"pattern": p.pattern,
|
||||
"language": p.language,
|
||||
"name_matches": p.nameMatches,
|
||||
"name_exact": p.nameExact,
|
||||
"kind_in": strings.Join(p.kindIn, ","),
|
||||
"paths": strings.Join(p.paths, ","),
|
||||
})
|
||||
|
||||
if cursorStr := request.GetString("cursor", ""); cursorStr != "" {
|
||||
off, hash, decErr := cursor.Decode(cursorStr)
|
||||
if decErr != nil {
|
||||
return nil, mcp.NewToolResultError(fmt.Sprintf("invalid cursor: %s", decErr))
|
||||
}
|
||||
if hash != p.queryHash {
|
||||
return nil, mcp.NewToolResultError("cursor is for a different query, re-run without cursor")
|
||||
}
|
||||
p.offset = off
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// runASTQueryWalk walks the configured paths and collects matches.
|
||||
func (s *Server) runASTQueryWalk(ctx context.Context, p *astQueryParams, exts []string) []query.MatchResult {
|
||||
astQuery := &query.ASTQuery{
|
||||
Pattern: p.pattern,
|
||||
Language: p.language,
|
||||
Filters: query.QueryFilters{
|
||||
NameMatches: p.nameMatches,
|
||||
NameExact: p.nameExact,
|
||||
KindIn: p.kindIn,
|
||||
},
|
||||
}
|
||||
|
||||
// Collect limit for early-exit: when paginating we need all results first.
|
||||
collectLimit := p.maxResults
|
||||
if p.offset > 0 {
|
||||
collectLimit = 0
|
||||
}
|
||||
|
||||
var allResults []query.MatchResult
|
||||
for _, searchPath := range p.paths {
|
||||
if !s.cfg.IsPathAllowed(searchPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
walkErr := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if !hasAnySuffix(path, exts) {
|
||||
return nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
allResults = append(allResults, matches...)
|
||||
|
||||
if collectLimit > 0 && len(allResults) >= collectLimit {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if walkErr != nil {
|
||||
s.logger.Warn("error walking path", "path", searchPath, "error", walkErr)
|
||||
}
|
||||
}
|
||||
return allResults
|
||||
}
|
||||
|
||||
// hasAnySuffix reports whether path ends with any of the given suffixes.
|
||||
func hasAnySuffix(path string, suffixes []string) bool {
|
||||
for _, ext := range suffixes {
|
||||
if strings.HasSuffix(path, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// buildASTCursorFooter computes the cursor footer line for truncated results.
|
||||
func buildASTCursorFooter(total, offset, maxResults int, queryHash string) string {
|
||||
if offset > 0 && offset < total {
|
||||
totalAfterOffset := total - offset
|
||||
if maxResults > 0 && totalAfterOffset > maxResults {
|
||||
remaining := totalAfterOffset - maxResults
|
||||
nextOffset := offset + maxResults
|
||||
nextCursor := cursor.Encode(nextOffset, queryHash)
|
||||
return fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining)
|
||||
}
|
||||
} else if offset == 0 && maxResults > 0 && total > maxResults {
|
||||
remaining := total - maxResults
|
||||
nextCursor := cursor.Encode(maxResults, queryHash)
|
||||
return fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// handleASTQuery handles the ast_query tool.
|
||||
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// Acquire semaphore to limit concurrent queries (prevents CPU exhaustion)
|
||||
@@ -24,121 +205,33 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest
|
||||
return mcp.NewToolResultError("request cancelled"), nil
|
||||
}
|
||||
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
p, errResult := s.resolveASTQueryParams(request)
|
||||
if errResult != nil {
|
||||
return errResult, nil
|
||||
}
|
||||
|
||||
language, err := request.RequireString("language")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("language is required"), nil
|
||||
if len(p.paths) == 0 {
|
||||
p.paths = []string{s.cfg.WorkspaceRoot}
|
||||
}
|
||||
|
||||
// Build query
|
||||
astQuery := &query.ASTQuery{
|
||||
Pattern: pattern,
|
||||
Language: language,
|
||||
Filters: query.QueryFilters{
|
||||
NameMatches: request.GetString("name_matches", ""),
|
||||
NameExact: request.GetString("name_exact", ""),
|
||||
KindIn: request.GetStringSlice("kind_in", nil),
|
||||
},
|
||||
}
|
||||
|
||||
maxResults := request.GetInt("max_results", 100)
|
||||
paths := request.GetStringSlice("paths", nil)
|
||||
|
||||
// Default to workspace root if no paths specified
|
||||
if len(paths) == 0 {
|
||||
paths = []string{s.cfg.WorkspaceRoot}
|
||||
}
|
||||
|
||||
// Find files to search based on language
|
||||
exts := languageToExtensions(language)
|
||||
exts := languageToExtensions(p.language)
|
||||
if len(exts) == 0 {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir, rust)", language)), nil
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir, rust)", p.language)), nil
|
||||
}
|
||||
|
||||
var allResults []query.MatchResult
|
||||
allResults := s.runASTQueryWalk(ctx, p, exts)
|
||||
cursorFooter := buildASTCursorFooter(len(allResults), p.offset, p.maxResults, p.queryHash)
|
||||
|
||||
// Walk through paths and find matching files
|
||||
for _, searchPath := range paths {
|
||||
// Validate path is within workspace
|
||||
if !s.cfg.IsPathAllowed(searchPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil // Skip files with errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Skip hidden directories
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check file extension matches language
|
||||
matched := false
|
||||
for _, ext := range exts {
|
||||
if strings.HasSuffix(path, ext) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read and parse file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil // Skip unreadable files
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return nil // Skip large files
|
||||
}
|
||||
|
||||
// Parse file
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return nil // Skip unparseable files
|
||||
}
|
||||
|
||||
// Run query
|
||||
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
|
||||
if err != nil {
|
||||
return nil // Skip on error
|
||||
}
|
||||
|
||||
allResults = append(allResults, matches...)
|
||||
|
||||
// Stop if we have enough results
|
||||
if maxResults > 0 && len(allResults) >= maxResults {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn("error walking path", "path", searchPath, "error", err)
|
||||
output := query.FormatResultsWithOptions(allResults, p.maxResults, p.format, p.offset, p.verbose)
|
||||
if cursorFooter != "" {
|
||||
// Replace the [remaining: N] placeholder emitted by FormatResultsWithOptions
|
||||
// with the full [cursor: ..., remaining: N] line.
|
||||
output = strings.ReplaceAll(output, fmt.Sprintf("[remaining: %d]\n", len(allResults)-p.offset-p.maxResults), cursorFooter+"\n")
|
||||
// Fallback: if placeholder wasn't present (e.g. format=location), append footer.
|
||||
if !strings.Contains(output, cursorFooter) {
|
||||
output = strings.TrimRight(output, "\n") + "\n" + cursorFooter + "\n"
|
||||
}
|
||||
}
|
||||
|
||||
// Format and return results
|
||||
output := query.FormatResults(allResults, maxResults)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,12 +4,10 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/edit"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
@@ -82,27 +80,56 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest) (*
|
||||
return mcp.NewToolResultError(result.Error), nil
|
||||
}
|
||||
|
||||
// compact_response: return just the modified symbol instead of the full diff
|
||||
if request.GetBool("compact_response", false) && selectorName != "" {
|
||||
if content, readErr := os.ReadFile(file); readErr == nil {
|
||||
if start, end, found := s.resolveSymbolLines(ctx, file, content, selectorName, protocol.SymbolKind("")); found {
|
||||
lines := splitLines(string(content))
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("**Edit Applied** — %s (L%d-L%d):\n\n", selectorName, start, end))
|
||||
for i := start - 1; i < end && i < len(lines); i++ {
|
||||
sb.WriteString(fmt.Sprintf("%4d| %s\n", i+1, lines[i]))
|
||||
}
|
||||
return mcp.NewToolResultText(sb.String()), nil
|
||||
}
|
||||
}
|
||||
// fall through to diff if symbol lookup fails
|
||||
// Determine response mode.
|
||||
// compact_response is a deprecated alias for response="count".
|
||||
respMode := request.GetString("response", "count")
|
||||
if request.GetBool("compact_response", false) {
|
||||
// Deprecated: use response=count
|
||||
respMode = "count"
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Edit Applied Successfully**\n\n")
|
||||
output.WriteString("Diff:\n```diff\n")
|
||||
output.WriteString(result.Diff)
|
||||
output.WriteString("```\n")
|
||||
switch respMode {
|
||||
case "none":
|
||||
return mcp.NewToolResultText(""), nil
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
case "count":
|
||||
// Compute +added/-removed line counts from the unified diff.
|
||||
added, removed := countDiffLines(result.Diff)
|
||||
return mcp.NewToolResultText(fmt.Sprintf("+%d -%d", added, removed)), nil
|
||||
|
||||
case "diff":
|
||||
var output strings.Builder
|
||||
output.WriteString("Diff:\n```diff\n")
|
||||
output.WriteString(result.Diff)
|
||||
output.WriteString("```\n")
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
|
||||
default:
|
||||
// Fallback: treat unknown values as "diff" for safety.
|
||||
var output strings.Builder
|
||||
output.WriteString("Diff:\n```diff\n")
|
||||
output.WriteString(result.Diff)
|
||||
output.WriteString("```\n")
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
}
|
||||
|
||||
// countDiffLines counts added (+) and removed (-) lines in a unified diff string.
|
||||
func countDiffLines(diff string) (added, removed int) {
|
||||
for _, line := range strings.Split(diff, "\n") {
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
switch line[0] {
|
||||
case '+':
|
||||
if !strings.HasPrefix(line, "+++") {
|
||||
added++
|
||||
}
|
||||
case '-':
|
||||
if !strings.HasPrefix(line, "---") {
|
||||
removed++
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xxhash "github.com/cespare/xxhash/v2"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/cursor"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
@@ -35,14 +37,63 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
}
|
||||
|
||||
paths := request.GetStringSlice("paths", nil)
|
||||
fileTypes := request.GetStringSlice("file_types", nil)
|
||||
ignoreCase := request.GetBool("ignore_case", false)
|
||||
regex := request.GetBool("regex", true)
|
||||
contextLines := request.GetInt("context_lines", 2)
|
||||
|
||||
// Consult session prefs for max_results and cluster when not explicitly supplied.
|
||||
prefs := s.sessionPrefs.Load()
|
||||
var prefsMaxResults int
|
||||
var prefsCluster *bool
|
||||
if prefs != nil {
|
||||
prefsMaxResults = prefs.DefaultMaxResults
|
||||
prefsCluster = prefs.DefaultCluster
|
||||
}
|
||||
maxResults := effectiveInt(request, "max_results", prefsMaxResults, 0)
|
||||
cluster := effectiveBool(request, "cluster", prefsCluster, false)
|
||||
|
||||
cursorStr := request.GetString("cursor", "")
|
||||
|
||||
// Compute query hash for cursor validation.
|
||||
queryHash := cursor.HashParams(map[string]string{
|
||||
"pattern": pattern,
|
||||
"paths": strings.Join(paths, ","),
|
||||
"file_types": strings.Join(fileTypes, ","),
|
||||
"ignore_case": strconv.FormatBool(ignoreCase),
|
||||
"regex": strconv.FormatBool(regex),
|
||||
"context_lines": strconv.Itoa(contextLines),
|
||||
})
|
||||
|
||||
// Resolve cursor offset.
|
||||
offset := 0
|
||||
if cursorStr != "" {
|
||||
off, hash, decErr := cursor.Decode(cursorStr)
|
||||
if decErr != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("invalid cursor: %s", decErr)), nil
|
||||
}
|
||||
if hash != queryHash {
|
||||
return mcp.NewToolResultError("cursor is for a different query, re-run without cursor"), nil
|
||||
}
|
||||
offset = off
|
||||
}
|
||||
|
||||
// When paginating with a cursor, fetch all results (no rg-level cap) so we
|
||||
// can apply the offset in-process. Without a cursor, let rg cap at maxResults.
|
||||
rgMaxResults := maxResults
|
||||
if offset > 0 {
|
||||
rgMaxResults = 0 // fetch all, apply cap after skipping
|
||||
}
|
||||
|
||||
req := &search.Request{
|
||||
Pattern: pattern,
|
||||
Paths: request.GetStringSlice("paths", nil),
|
||||
FileTypes: request.GetStringSlice("file_types", nil),
|
||||
IgnoreCase: request.GetBool("ignore_case", false),
|
||||
Regex: request.GetBool("regex", true),
|
||||
ContextLines: request.GetInt("context_lines", 2),
|
||||
MaxResults: request.GetInt("max_results", 0),
|
||||
Paths: paths,
|
||||
FileTypes: fileTypes,
|
||||
IgnoreCase: ignoreCase,
|
||||
Regex: regex,
|
||||
ContextLines: contextLines,
|
||||
MaxResults: rgMaxResults,
|
||||
}
|
||||
|
||||
results, err := s.searcher.Search(ctx, req)
|
||||
@@ -51,16 +102,66 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque
|
||||
return mcp.NewToolResultError(fmt.Sprintf("search error: %s", errors.SanitizeError(err))), nil
|
||||
}
|
||||
|
||||
// Apply cursor offset.
|
||||
if offset > 0 && offset < len(results.Results) {
|
||||
results.Results = results.Results[offset:]
|
||||
results.Truncated = false // will re-evaluate below
|
||||
} else if offset > 0 {
|
||||
results.Results = nil
|
||||
results.Truncated = false
|
||||
}
|
||||
|
||||
// Apply in-process max_results cap and compute cursor footer.
|
||||
var cursorLine string
|
||||
if maxResults > 0 && len(results.Results) > maxResults {
|
||||
remaining := len(results.Results) - maxResults
|
||||
results.Results = results.Results[:maxResults]
|
||||
results.Truncated = true
|
||||
nextOffset := offset + maxResults
|
||||
nextCursor := cursor.Encode(nextOffset, queryHash)
|
||||
cursorLine = fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining)
|
||||
}
|
||||
|
||||
s.logger.Info("search completed",
|
||||
"pattern", pattern,
|
||||
"results_count", len(results.Results),
|
||||
"truncated", results.Truncated,
|
||||
)
|
||||
|
||||
output := s.searcher.FormatResults(results)
|
||||
verbose := request.GetBool("verbose", false)
|
||||
opts := search.FormatOptions{
|
||||
Cluster: cluster,
|
||||
CursorLine: cursorLine,
|
||||
Verbose: verbose,
|
||||
}
|
||||
output := s.searcher.FormatResultsWithOptions(results, opts)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
// effectiveInt returns the per-call value if the key is explicitly present in the
|
||||
// request arguments, otherwise falls back to sessionDefault (if > 0), then builtIn.
|
||||
func effectiveInt(request mcp.CallToolRequest, key string, sessionDefault, builtIn int) int {
|
||||
if _, explicit := request.GetArguments()[key]; explicit {
|
||||
return request.GetInt(key, builtIn)
|
||||
}
|
||||
if sessionDefault > 0 {
|
||||
return sessionDefault
|
||||
}
|
||||
return builtIn
|
||||
}
|
||||
|
||||
// effectiveBool returns the per-call value if the key is explicitly present in the
|
||||
// request arguments, otherwise falls back to sessionDefault (if non-nil), then builtIn.
|
||||
func effectiveBool(request mcp.CallToolRequest, key string, sessionDefault *bool, builtIn bool) bool {
|
||||
if _, explicit := request.GetArguments()[key]; explicit {
|
||||
return request.GetBool(key, builtIn)
|
||||
}
|
||||
if sessionDefault != nil {
|
||||
return *sessionDefault
|
||||
}
|
||||
return builtIn
|
||||
}
|
||||
|
||||
// handleFileRead handles the file_read tool.
|
||||
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
select {
|
||||
@@ -70,9 +171,13 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
|
||||
return mcp.NewToolResultError("request cancelled"), nil
|
||||
}
|
||||
|
||||
// Batch mode: paths[] takes precedence over path
|
||||
// Batch mode: paths[] takes precedence over path.
|
||||
// NOTE: batch reads are always inlined — mixing dedup + resource_links is
|
||||
// too complex and the savings are unclear for multi-file calls.
|
||||
if paths := request.GetStringSlice("paths", nil); len(paths) > 0 {
|
||||
var output strings.Builder
|
||||
// Dedup: track etag -> first path that produced it.
|
||||
seenEtag := make(map[string]string) // etag -> first path
|
||||
for i, p := range paths {
|
||||
if i > 0 {
|
||||
output.WriteString("\n")
|
||||
@@ -82,6 +187,16 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
|
||||
output.WriteString(fmt.Sprintf("--- %s ---\n[error: %s]\n", p, errors.SanitizeError(err)))
|
||||
continue
|
||||
}
|
||||
// Extract etag from result footer for dedup check.
|
||||
etag := extractEtag(result)
|
||||
if etag != "" {
|
||||
if firstPath, seen := seenEtag[etag]; seen {
|
||||
// Duplicate content: emit pointer instead of full content.
|
||||
output.WriteString(fmt.Sprintf("--- %s ---\n[duplicate of %s, etag: %s]\n", p, firstPath, etag))
|
||||
continue
|
||||
}
|
||||
seenEtag[etag] = p
|
||||
}
|
||||
output.WriteString(fmt.Sprintf("--- %s ---\n%s", p, result))
|
||||
}
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
@@ -96,59 +211,223 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(errors.SanitizeError(err)), nil
|
||||
}
|
||||
|
||||
// Resource-link threshold check: for single-file reads, if the result
|
||||
// exceeds the configured threshold, return a ResourceLink instead of
|
||||
// inlining the content. The client can fetch the resource on demand.
|
||||
// Bypassed when:
|
||||
// - force_inline=true
|
||||
// - max_inline_bytes is set and result fits within it
|
||||
// - threshold is 0 (disabled)
|
||||
// - result is already small (skeleton/symbols_only/line-range paths
|
||||
// produce small output; threshold is on result bytes, not file bytes)
|
||||
// Determine resource-link threshold: session pref overrides cfg, per-call overrides session.
|
||||
threshold := s.cfg.ResourceLinkThresholdBytes
|
||||
if sp := s.sessionPrefs.Load(); sp != nil && sp.ResourceLinkThreshold > 0 {
|
||||
threshold = sp.ResourceLinkThreshold
|
||||
}
|
||||
forceInline := request.GetBool("force_inline", false)
|
||||
maxInlineBytes := request.GetInt("max_inline_bytes", 0)
|
||||
if maxInlineBytes > 0 {
|
||||
threshold = maxInlineBytes
|
||||
}
|
||||
|
||||
if !forceInline && threshold > 0 && len(result) > threshold {
|
||||
etag := extractEtag(result)
|
||||
uri := buildReadResourceURI(path, etag)
|
||||
lineCount := strings.Count(result, "\n")
|
||||
desc := fmt.Sprintf("etag=%s, size=%d bytes, lines=%d", etag, len(result), lineCount)
|
||||
mimeType := detectMIMEType(path)
|
||||
link := mcp.NewResourceLink(uri, path, desc, mimeType)
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{link},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(result), nil
|
||||
}
|
||||
|
||||
// readOneFile reads a single file applying all formatting options from the request.
|
||||
func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, path string) (string, error) {
|
||||
// buildReadResourceURI constructs the filepuff://read URI for a file + etag pair.
|
||||
func buildReadResourceURI(path, etag string) string {
|
||||
if etag == "" {
|
||||
return "filepuff://read/" + path
|
||||
}
|
||||
return "filepuff://read/" + path + "?etag=" + etag
|
||||
}
|
||||
|
||||
// detectMIMEType returns a best-effort MIME type for the given file path.
|
||||
func detectMIMEType(path string) string {
|
||||
ext := strings.ToLower(path)
|
||||
switch {
|
||||
case strings.HasSuffix(ext, ".go"):
|
||||
return "text/x-go"
|
||||
case strings.HasSuffix(ext, ".ts"), strings.HasSuffix(ext, ".tsx"):
|
||||
return "text/typescript"
|
||||
case strings.HasSuffix(ext, ".js"), strings.HasSuffix(ext, ".jsx"):
|
||||
return "text/javascript"
|
||||
case strings.HasSuffix(ext, ".py"):
|
||||
return "text/x-python"
|
||||
case strings.HasSuffix(ext, ".rs"):
|
||||
return "text/x-rust"
|
||||
case strings.HasSuffix(ext, ".md"):
|
||||
return "text/markdown"
|
||||
case strings.HasSuffix(ext, ".json"):
|
||||
return "application/json"
|
||||
case strings.HasSuffix(ext, ".yaml"), strings.HasSuffix(ext, ".yml"):
|
||||
return "text/yaml"
|
||||
case strings.HasSuffix(ext, ".toml"):
|
||||
return "text/toml"
|
||||
case strings.HasSuffix(ext, ".html"), strings.HasSuffix(ext, ".htm"):
|
||||
return "text/html"
|
||||
case strings.HasSuffix(ext, ".css"):
|
||||
return "text/css"
|
||||
case strings.HasSuffix(ext, ".sh"):
|
||||
return "text/x-sh"
|
||||
case strings.HasSuffix(ext, ".c"), strings.HasSuffix(ext, ".h"):
|
||||
return "text/x-c"
|
||||
case strings.HasSuffix(ext, ".cpp"), strings.HasSuffix(ext, ".cc"), strings.HasSuffix(ext, ".cxx"):
|
||||
return "text/x-c++"
|
||||
default:
|
||||
return "text/plain"
|
||||
}
|
||||
}
|
||||
|
||||
// lineNumberOpts holds resolved line-numbering preferences for readOneFile.
|
||||
type lineNumberOpts struct {
|
||||
noLineNumbers bool
|
||||
compactLineNums bool
|
||||
lineInterval int
|
||||
}
|
||||
|
||||
// resolveLineNumberOpts resolves per-call vs session-pref line-number options.
|
||||
func (s *Server) resolveLineNumberOpts(request mcp.CallToolRequest) lineNumberOpts {
|
||||
opts := lineNumberOpts{
|
||||
noLineNumbers: request.GetBool("no_line_numbers", false),
|
||||
lineInterval: request.GetInt("line_number_interval", 1),
|
||||
compactLineNums: request.GetBool("compact_line_numbers", false),
|
||||
}
|
||||
|
||||
// Apply session line_numbers pref when no explicit per-call override was supplied.
|
||||
if sp := s.sessionPrefs.Load(); sp != nil && sp.LineNumbers != "" {
|
||||
_, hasNoLN := request.GetArguments()["no_line_numbers"]
|
||||
_, hasCompact := request.GetArguments()["compact_line_numbers"]
|
||||
_, hasInterval := request.GetArguments()["line_number_interval"]
|
||||
if !hasNoLN && !hasCompact && !hasInterval {
|
||||
switch sp.LineNumbers {
|
||||
case "none":
|
||||
opts.noLineNumbers = true
|
||||
opts.compactLineNums = false
|
||||
case "compact":
|
||||
opts.noLineNumbers = false
|
||||
opts.compactLineNums = true
|
||||
case "full":
|
||||
opts.noLineNumbers = false
|
||||
opts.compactLineNums = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.lineInterval == 0 {
|
||||
opts.noLineNumbers = true
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// applyStrip applies strip flags to the selected line range and returns the
|
||||
// possibly-rewritten lines, new bounds, and a stripped-footer annotation.
|
||||
func applyStrip(lines []string, lineStart, lineEnd int, stripFlags []parser.StripFlag, path string) (newLines []string, newStart, newEnd int, footer string) {
|
||||
if len(stripFlags) == 0 {
|
||||
return lines, lineStart, lineEnd, ""
|
||||
}
|
||||
selectedContent := strings.Join(lines[lineStart-1:lineEnd], "\n")
|
||||
lang := protocol.DetectLanguage(path)
|
||||
stripped := parser.StripContent(selectedContent, stripFlags, lang)
|
||||
if len(stripped.Stripped) > 0 {
|
||||
names := make([]string, len(stripped.Stripped))
|
||||
for i, f := range stripped.Stripped {
|
||||
names[i] = string(f)
|
||||
}
|
||||
footer = "[stripped: " + strings.Join(names, ", ") + "]\n"
|
||||
}
|
||||
newLines = splitLines(stripped.Content)
|
||||
return newLines, 1, len(newLines), footer
|
||||
}
|
||||
|
||||
// loadFileForRead performs workspace, stat, size, and read checks for a path.
|
||||
func (s *Server) loadFileForRead(path string) ([]byte, error) {
|
||||
if !s.cfg.IsPathAllowed(path) {
|
||||
return "", fmt.Errorf("path is outside workspace root")
|
||||
return nil, fmt.Errorf("path is outside workspace root")
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("file not found: %s", path)
|
||||
return nil, fmt.Errorf("file not found: %s", path)
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return "", fmt.Errorf("permission denied: %s", path)
|
||||
return nil, fmt.Errorf("permission denied: %s", path)
|
||||
}
|
||||
s.logger.Warn("file stat error", "path", path, "error", err)
|
||||
return "", fmt.Errorf("error accessing file")
|
||||
return nil, fmt.Errorf("error accessing file")
|
||||
}
|
||||
if info.Size() > s.cfg.MaxFileSize {
|
||||
return "", fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)
|
||||
return nil, fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsPermission(err) {
|
||||
return "", fmt.Errorf("permission denied: %s", path)
|
||||
return nil, fmt.Errorf("permission denied: %s", path)
|
||||
}
|
||||
s.logger.Warn("file read error", "path", path, "error", err)
|
||||
return "", fmt.Errorf("error reading file")
|
||||
return nil, fmt.Errorf("error reading file")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// readOneFile reads a single file applying all formatting options from the request.
|
||||
func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, path string) (string, error) {
|
||||
content, err := s.loadFileForRead(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Compute etag from content hash
|
||||
etag := fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
// Feature 3: short etag — 8 hex chars (32-bit).
|
||||
// Accept previous_etag by prefix match so old 16-char etags keep working.
|
||||
fullHash := fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
etag := fullHash[:8]
|
||||
|
||||
// Short-circuit if caller has the current version
|
||||
if prev := request.GetString("previous_etag", ""); prev != "" && prev == etag {
|
||||
return fmt.Sprintf("[unchanged, etag: %s]\n", etag), nil
|
||||
if prev := request.GetString("previous_etag", ""); prev != "" {
|
||||
// Match: exact 8-char match, old client sent full 16-char etag, or new client sent 8-char prefix of old.
|
||||
if prev == etag || strings.HasPrefix(fullHash, prev) || strings.HasPrefix(prev, etag) {
|
||||
return fmt.Sprintf("[unchanged, etag: %s]\n", etag), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Parse request options
|
||||
includeAST := request.GetBool("include_ast", false)
|
||||
symbolsOnly := request.GetBool("symbols_only", false)
|
||||
symbolName := request.GetString("symbol_name", "")
|
||||
noLineNumbers := request.GetBool("no_line_numbers", false)
|
||||
lineInterval := request.GetInt("line_number_interval", 1)
|
||||
collapseBlank := request.GetBool("collapse_blank_lines", false)
|
||||
maxLines := request.GetInt("max_lines", 0)
|
||||
|
||||
if lineInterval == 0 {
|
||||
noLineNumbers = true
|
||||
lnOpts := s.resolveLineNumberOpts(request)
|
||||
|
||||
// Feature 1: mode flag — "full" (default) | "skeleton" | "symbols_only".
|
||||
// symbols_only mode is an alias for include_ast+symbols_only.
|
||||
mode := request.GetString("mode", "full")
|
||||
if mode == "symbols_only" {
|
||||
symbolsOnly = true
|
||||
includeAST = true
|
||||
}
|
||||
|
||||
// Feature 2: strip — remove selected content classes before line-numbering.
|
||||
stripRaw := request.GetStringSlice("strip", nil)
|
||||
var stripFlags []parser.StripFlag
|
||||
for _, sf := range stripRaw {
|
||||
stripFlags = append(stripFlags, parser.StripFlag(sf))
|
||||
}
|
||||
|
||||
if symbolsOnly && !includeAST {
|
||||
return "", fmt.Errorf("symbols_only requires include_ast=true")
|
||||
}
|
||||
@@ -157,7 +436,7 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
|
||||
lineStart := request.GetInt("line_start", 1)
|
||||
lineEnd := request.GetInt("line_end", len(lines))
|
||||
|
||||
// Symbol-based line range: find the symbol and use its exact bounds
|
||||
// Symbol-based line range: find the symbol and use its exact bounds.
|
||||
if symbolName != "" {
|
||||
symbolKind := protocol.SymbolKind(request.GetString("symbol_kind", ""))
|
||||
start, end, found := s.resolveSymbolLines(ctx, path, content, symbolName, symbolKind)
|
||||
@@ -168,7 +447,7 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
|
||||
lineEnd = end
|
||||
}
|
||||
|
||||
// Clamp to valid range
|
||||
// Clamp to valid range.
|
||||
if lineStart < 1 {
|
||||
lineStart = 1
|
||||
}
|
||||
@@ -181,6 +460,19 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
|
||||
|
||||
var output strings.Builder
|
||||
|
||||
// Feature 1: skeleton mode — replace function bodies with { ... }.
|
||||
if mode == "skeleton" {
|
||||
skText, _, skErr := parser.SkeletonFile(ctx, s.parser, path, content)
|
||||
if skErr != nil {
|
||||
// Fall back to full mode on parse error.
|
||||
s.logger.Warn("skeleton mode failed, falling back to full", "path", path, "error", skErr)
|
||||
} else {
|
||||
output.WriteString(skText)
|
||||
fmt.Fprintf(&output, "[etag: %s]\n", etag)
|
||||
return output.String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
if includeAST {
|
||||
if summary := s.generateASTSummary(ctx, path, content); summary != "" {
|
||||
output.WriteString(summary)
|
||||
@@ -191,13 +483,19 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
|
||||
}
|
||||
|
||||
if symbolsOnly {
|
||||
output.WriteString(fmt.Sprintf("[etag: %s]\n", etag))
|
||||
fmt.Fprintf(&output, "[etag: %s]\n", etag)
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
writeLines(&output, lines, lineStart, lineEnd, maxLines, noLineNumbers, lineInterval, collapseBlank)
|
||||
// Feature 2: apply strip AFTER line-range selection, BEFORE line numbering.
|
||||
lines, lineStart, lineEnd, strippedFooter := applyStrip(lines, lineStart, lineEnd, stripFlags, path)
|
||||
|
||||
output.WriteString(fmt.Sprintf("[etag: %s]\n", etag))
|
||||
writeLines(&output, lines, lineStart, lineEnd, maxLines, lnOpts.noLineNumbers, lnOpts.lineInterval, collapseBlank, lnOpts.compactLineNums)
|
||||
|
||||
if strippedFooter != "" {
|
||||
output.WriteString(strippedFooter)
|
||||
}
|
||||
fmt.Fprintf(&output, "[etag: %s]\n", etag)
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
@@ -212,7 +510,8 @@ func (s *Server) resolveSymbolLines(ctx context.Context, path string, content []
|
||||
}
|
||||
|
||||
// writeLines writes the selected line range into output, applying all formatting options.
|
||||
func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, maxLines int, noLineNumbers bool, lineInterval int, collapseBlank bool) {
|
||||
// compactLineNums=true emits "12│" instead of " 12│ " (no padding, no trailing space).
|
||||
func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, maxLines int, noLineNumbers bool, lineInterval int, collapseBlank bool, compactLineNums bool) {
|
||||
effectiveEnd := lineEnd
|
||||
truncatedCount := 0
|
||||
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
|
||||
@@ -233,6 +532,13 @@ func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, max
|
||||
switch {
|
||||
case noLineNumbers:
|
||||
output.WriteString(line + "\n")
|
||||
case compactLineNums:
|
||||
// Feature 4: compact prefix — "12│content"
|
||||
if lineInterval <= 1 || lineNum%lineInterval == 0 || i == lineStart-1 || i == effectiveEnd-1 {
|
||||
fmt.Fprintf(output, "%d│%s\n", lineNum, line)
|
||||
} else {
|
||||
fmt.Fprintf(output, "│%s\n", line)
|
||||
}
|
||||
case lineInterval <= 1 || lineNum%lineInterval == 0 || i == lineStart-1 || i == effectiveEnd-1:
|
||||
fmt.Fprintf(output, "%4d│ %s\n", lineNum, line)
|
||||
default:
|
||||
@@ -245,6 +551,23 @@ func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, max
|
||||
}
|
||||
}
|
||||
|
||||
// extractEtag extracts the etag value from a readOneFile result string.
|
||||
// Returns empty string if not found.
|
||||
func extractEtag(result string) string {
|
||||
// Look for "[etag: XXXXXXXX]" at end of result.
|
||||
const prefix = "[etag: "
|
||||
idx := strings.LastIndex(result, prefix)
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
rest := result[idx+len(prefix):]
|
||||
close := strings.Index(rest, "]")
|
||||
if close < 0 {
|
||||
return ""
|
||||
}
|
||||
return rest[:close]
|
||||
}
|
||||
|
||||
// splitLines splits a string into lines.
|
||||
// For large files (> 1MB), uses bufio.Scanner which is more memory efficient.
|
||||
// For smaller files, uses simple string split which is faster.
|
||||
|
||||
@@ -0,0 +1,909 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// newTestServer creates a minimal server pointing at tmpDir.
|
||||
func newTestServer(t *testing.T, tmpDir string) *Server {
|
||||
t.Helper()
|
||||
cfg := &config.Config{
|
||||
WorkspaceRoot: tmpDir,
|
||||
EnableLSP: false,
|
||||
MaxFileSize: 10 * 1024 * 1024, // 10 MB — required for file reads to succeed
|
||||
MaxParseSize: 10 * 1024 * 1024,
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
// callRead calls handleFileRead with the given args map and returns the text content.
|
||||
func callRead(t *testing.T, srv *Server, args map[string]interface{}) string {
|
||||
t.Helper()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = args
|
||||
result, err := srv.handleFileRead(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileRead error: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead returned nil")
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("handleFileRead returned empty content")
|
||||
}
|
||||
tc, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("handleFileRead did not return TextContent")
|
||||
}
|
||||
return tc.Text
|
||||
}
|
||||
|
||||
// callReadResult calls handleFileRead and returns the raw CallToolResult (not just text).
|
||||
func callReadResult(t *testing.T, srv *Server, args map[string]interface{}) *mcp.CallToolResult {
|
||||
t.Helper()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = args
|
||||
result, err := srv.handleFileRead(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileRead error: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead returned nil")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// newTestServerWithThreshold creates a server with a custom ResourceLinkThresholdBytes.
|
||||
func newTestServerWithThreshold(t *testing.T, tmpDir string, thresholdBytes int) *Server {
|
||||
t.Helper()
|
||||
cfg := &config.Config{
|
||||
WorkspaceRoot: tmpDir,
|
||||
EnableLSP: false,
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
MaxParseSize: 10 * 1024 * 1024,
|
||||
ResourceLinkThresholdBytes: thresholdBytes,
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
// writeFile writes content to a file in tmpDir and returns its absolute path.
|
||||
func writeFile(t *testing.T, dir, name, content string) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(dir, name)
|
||||
if err := os.WriteFile(p, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("WriteFile(%s): %v", name, err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ---- Feature 1: skeleton mode ----
|
||||
|
||||
func TestSkeletonModeGo(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := `package main
|
||||
|
||||
// Hello says hello
|
||||
func Hello() {
|
||||
println("Hello, World!")
|
||||
println("more body")
|
||||
}
|
||||
|
||||
func Add(a, b int) int {
|
||||
return a + b
|
||||
}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"mode": "skeleton",
|
||||
})
|
||||
|
||||
// Should contain function signatures
|
||||
if !strings.Contains(out, "func Hello()") {
|
||||
t.Errorf("skeleton output missing Hello signature, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "func Add(") {
|
||||
t.Errorf("skeleton output missing Add signature, got:\n%s", out)
|
||||
}
|
||||
// Should NOT contain body contents
|
||||
if strings.Contains(out, `println("more body")`) {
|
||||
t.Errorf("skeleton output should not contain body contents, got:\n%s", out)
|
||||
}
|
||||
// Should contain placeholder
|
||||
if !strings.Contains(out, "{ ... }") {
|
||||
t.Errorf("skeleton output missing { ... } placeholder, got:\n%s", out)
|
||||
}
|
||||
// Should contain etag footer
|
||||
if !strings.Contains(out, "[etag:") {
|
||||
t.Errorf("skeleton output missing etag footer, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkeletonModeFullFlagAlias(t *testing.T) {
|
||||
// mode="full" should behave identically to not specifying mode
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := "package main\nfunc F() { println(1) }\n"
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
outFull := callRead(t, srv, map[string]interface{}{"path": f, "mode": "full"})
|
||||
outDefault := callRead(t, srv, map[string]interface{}{"path": f})
|
||||
|
||||
// Both should have same content (etag will be same, line content same)
|
||||
if outFull != outDefault {
|
||||
t.Errorf("mode=full differs from default\nfull: %q\ndefault: %q", outFull, outDefault)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkeletonModeSymbolsOnlyAlias(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := "package main\nfunc F() { println(1) }\n"
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
// mode=symbols_only should return symbols summary (needs include_ast implicitly)
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"mode": "symbols_only",
|
||||
})
|
||||
// Should contain etag but NOT the function body
|
||||
if !strings.Contains(out, "[etag:") {
|
||||
t.Errorf("symbols_only output missing etag, got:\n%s", out)
|
||||
}
|
||||
if strings.Contains(out, "println") {
|
||||
t.Errorf("symbols_only should not contain body, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkeletonModeTypeScript(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
tsSrc := `// a function
|
||||
function greet(name: string): string {
|
||||
return "Hello " + name;
|
||||
}
|
||||
|
||||
class Greeter {
|
||||
greet(name: string) {
|
||||
return "hi " + name;
|
||||
}
|
||||
}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.ts", tsSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"})
|
||||
|
||||
if !strings.Contains(out, "function greet") {
|
||||
t.Errorf("TS skeleton missing function signature, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "{ ... }") {
|
||||
t.Errorf("TS skeleton missing placeholder, got:\n%s", out)
|
||||
}
|
||||
if strings.Contains(out, `"Hello " + name`) {
|
||||
t.Errorf("TS skeleton should not contain body, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkeletonModePython(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
pySrc := `def greet(name):
|
||||
print("Hello " + name)
|
||||
print("extra line")
|
||||
|
||||
class Foo:
|
||||
def bar(self):
|
||||
return 42
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.py", pySrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"})
|
||||
|
||||
if !strings.Contains(out, "def greet") {
|
||||
t.Errorf("Python skeleton missing greet signature, got:\n%s", out)
|
||||
}
|
||||
if strings.Contains(out, "extra line") {
|
||||
t.Errorf("Python skeleton should not contain body, got:\n%s", out)
|
||||
}
|
||||
// Python uses "..." as placeholder
|
||||
if !strings.Contains(out, "...") {
|
||||
t.Errorf("Python skeleton missing ... placeholder, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkeletonModeRust(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
rsSrc := `fn add(a: i32, b: i32) -> i32 {
|
||||
let result = a + b;
|
||||
result
|
||||
}
|
||||
|
||||
struct Foo {
|
||||
x: i32,
|
||||
}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.rs", rsSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"})
|
||||
|
||||
if !strings.Contains(out, "fn add(") {
|
||||
t.Errorf("Rust skeleton missing fn signature, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "{ ... }") {
|
||||
t.Errorf("Rust skeleton missing placeholder, got:\n%s", out)
|
||||
}
|
||||
if strings.Contains(out, "let result") {
|
||||
t.Errorf("Rust skeleton should not contain body, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Feature 2: strip flag ----
|
||||
|
||||
func TestStripImportsGo(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := `package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("hello")
|
||||
}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"strip": []interface{}{"imports"},
|
||||
})
|
||||
|
||||
if strings.Contains(out, `"fmt"`) {
|
||||
t.Errorf("strip=imports should remove import block, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "func main") {
|
||||
t.Errorf("strip=imports should keep function, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "[stripped: imports]") {
|
||||
t.Errorf("strip footer missing, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripLicenseGoBlockComment(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := `/* Copyright 2024 Acme Corp. All rights reserved.
|
||||
License: MIT
|
||||
*/
|
||||
package main
|
||||
|
||||
func main() {}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "main.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"strip": []interface{}{"license"},
|
||||
})
|
||||
|
||||
if strings.Contains(out, "Copyright") {
|
||||
t.Errorf("strip=license should remove license comment, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "func main") {
|
||||
t.Errorf("strip=license should keep code, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "[stripped: license]") {
|
||||
t.Errorf("license strip footer missing, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBlockCommentsGo(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := `package main
|
||||
|
||||
/* This is a block comment
|
||||
spanning multiple lines */
|
||||
func main() {
|
||||
/* inline block */
|
||||
println("hi")
|
||||
}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"strip": []interface{}{"block_comments"},
|
||||
})
|
||||
|
||||
if strings.Contains(out, "This is a block comment") {
|
||||
t.Errorf("strip=block_comments should remove block comments, got:\n%s", out)
|
||||
}
|
||||
if strings.Contains(out, "inline block") {
|
||||
t.Errorf("strip=block_comments should remove inline block comment, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "func main") {
|
||||
t.Errorf("strip=block_comments should keep code, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "[stripped: block_comments]") {
|
||||
t.Errorf("block_comments strip footer missing, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripMultipleFlags(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := `/* Copyright 2024. License: MIT */
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("hello")
|
||||
}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "main.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"strip": []interface{}{"license", "imports"},
|
||||
})
|
||||
|
||||
if strings.Contains(out, "Copyright") {
|
||||
t.Errorf("license not stripped, got:\n%s", out)
|
||||
}
|
||||
if strings.Contains(out, `"fmt"`) {
|
||||
t.Errorf("imports not stripped, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "[stripped:") {
|
||||
t.Errorf("strip footer missing, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripNoRemovalProducesNoFooter(t *testing.T) {
|
||||
// A file with no imports: strip=imports should not add footer
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := "package main\nfunc main() {}\n"
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"strip": []interface{}{"imports"},
|
||||
})
|
||||
|
||||
if strings.Contains(out, "[stripped:") {
|
||||
t.Errorf("should not have stripped footer when nothing removed, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripImportsPython(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
pySrc := `import os
|
||||
from sys import argv
|
||||
|
||||
def main():
|
||||
print(argv[0])
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.py", pySrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"strip": []interface{}{"imports"},
|
||||
})
|
||||
|
||||
if strings.Contains(out, "import os") {
|
||||
t.Errorf("Python imports not stripped, got:\n%s", out)
|
||||
}
|
||||
if strings.Contains(out, "from sys") {
|
||||
t.Errorf("Python from-import not stripped, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "def main") {
|
||||
t.Errorf("Python function missing after strip, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Feature 3: short etag (8 hex chars) ----
|
||||
|
||||
func TestEtagIs8Chars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
f := writeFile(t, tmpDir, "test.go", "package main\n")
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{"path": f})
|
||||
|
||||
// Find "[etag: XXXXXXXX]"
|
||||
idx := strings.Index(out, "[etag: ")
|
||||
if idx < 0 {
|
||||
t.Fatalf("no etag in output: %q", out)
|
||||
}
|
||||
rest := out[idx+7:]
|
||||
end := strings.Index(rest, "]")
|
||||
if end < 0 {
|
||||
t.Fatalf("malformed etag in output: %q", out)
|
||||
}
|
||||
etagVal := rest[:end]
|
||||
if len(etagVal) != 8 {
|
||||
t.Errorf("etag should be 8 hex chars, got %d chars: %q", len(etagVal), etagVal)
|
||||
}
|
||||
// Validate hex
|
||||
for _, c := range etagVal {
|
||||
isDigit := c >= '0' && c <= '9'
|
||||
isHexLower := c >= 'a' && c <= 'f'
|
||||
if !isDigit && !isHexLower {
|
||||
t.Errorf("etag contains non-hex char %q in %q", c, etagVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtagPreviousEtagShortCircuit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
f := writeFile(t, tmpDir, "test.go", "package main\n")
|
||||
|
||||
// First read: get the etag
|
||||
out1 := callRead(t, srv, map[string]interface{}{"path": f})
|
||||
idx := strings.Index(out1, "[etag: ")
|
||||
etag := out1[idx+7 : idx+7+8]
|
||||
|
||||
// Second read with same etag: should short-circuit
|
||||
out2 := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"previous_etag": etag,
|
||||
})
|
||||
if !strings.Contains(out2, "[unchanged, etag:") {
|
||||
t.Errorf("expected [unchanged, etag:] for same etag, got: %q", out2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtagOldLongEtagStillWorks(t *testing.T) {
|
||||
// Simulate old client sending 16-char etag: should still short-circuit via prefix match.
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
f := writeFile(t, tmpDir, "test.go", "package main\n")
|
||||
|
||||
// Get the 8-char etag
|
||||
out1 := callRead(t, srv, map[string]interface{}{"path": f})
|
||||
idx := strings.Index(out1, "[etag: ")
|
||||
shortEtag := out1[idx+7 : idx+7+8]
|
||||
|
||||
// Construct a fake 16-char etag that starts with the short one
|
||||
fakeOldEtag := shortEtag + "00000000"
|
||||
|
||||
out2 := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"previous_etag": fakeOldEtag,
|
||||
})
|
||||
if !strings.Contains(out2, "[unchanged, etag:") {
|
||||
t.Errorf("old 16-char etag should still short-circuit, got: %q", out2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtagDifferentFileChanges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
f := writeFile(t, tmpDir, "test.go", "package main\n")
|
||||
|
||||
out1 := callRead(t, srv, map[string]interface{}{"path": f})
|
||||
idx := strings.Index(out1, "[etag: ")
|
||||
etag1 := out1[idx+7 : idx+7+8]
|
||||
|
||||
// Modify the file
|
||||
if err := os.WriteFile(f, []byte("package main\n// changed\n"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read again with old etag: should NOT short-circuit
|
||||
out2 := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"previous_etag": etag1,
|
||||
})
|
||||
if strings.Contains(out2, "[unchanged") {
|
||||
t.Errorf("modified file should not return unchanged, got: %q", out2)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Feature 4: compact_line_numbers ----
|
||||
|
||||
func TestCompactLineNumbers(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
content := "line one\nline two\nline three\n"
|
||||
f := writeFile(t, tmpDir, "test.txt", content)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"compact_line_numbers": true,
|
||||
})
|
||||
|
||||
// Should have "1│" not " 1│ "
|
||||
if !strings.Contains(out, "1│line one") {
|
||||
t.Errorf("compact prefix not found, got:\n%s", out)
|
||||
}
|
||||
// Should NOT have padded format
|
||||
if strings.Contains(out, " 1│ line one") {
|
||||
t.Errorf("compact should not have padded prefix, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLineNumbersOffByDefault(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
content := "line one\nline two\n"
|
||||
f := writeFile(t, tmpDir, "test.txt", content)
|
||||
|
||||
// Default: no compact_line_numbers
|
||||
out := callRead(t, srv, map[string]interface{}{"path": f})
|
||||
|
||||
// Should have padded format
|
||||
if !strings.Contains(out, " 1│ line one") {
|
||||
t.Errorf("default should have padded format, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLineNumbersWithInterval(t *testing.T) {
|
||||
// compact_line_numbers + line_number_interval should work together
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
var sb strings.Builder
|
||||
for i := 1; i <= 10; i++ {
|
||||
sb.WriteString("line\n")
|
||||
}
|
||||
f := writeFile(t, tmpDir, "test.txt", sb.String())
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"compact_line_numbers": true,
|
||||
"line_number_interval": 5,
|
||||
})
|
||||
|
||||
// Line 5 should have number prefix
|
||||
if !strings.Contains(out, "5│line") {
|
||||
t.Errorf("compact+interval: line 5 should have number, got:\n%s", out)
|
||||
}
|
||||
// Line 1 (first) should have number
|
||||
if !strings.Contains(out, "1│line") {
|
||||
t.Errorf("compact+interval: line 1 should have number, got:\n%s", out)
|
||||
}
|
||||
// Line 10 (last) should have number
|
||||
if !strings.Contains(out, "10│line") {
|
||||
t.Errorf("compact+interval: line 10 should have number, got:\n%s", out)
|
||||
}
|
||||
// Non-interval line should have bare │ prefix (no number)
|
||||
if !strings.Contains(out, "│line") {
|
||||
t.Errorf("compact+interval: non-interval lines should have bare │, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLineNumbersWithNoLineNumbers(t *testing.T) {
|
||||
// compact_line_numbers + no_line_numbers: no_line_numbers wins
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
f := writeFile(t, tmpDir, "test.txt", "line one\n")
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"compact_line_numbers": true,
|
||||
"no_line_numbers": true,
|
||||
})
|
||||
|
||||
// Should have NO prefix at all
|
||||
if strings.Contains(out, "│") {
|
||||
t.Errorf("no_line_numbers should suppress all prefixes, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Backward compatibility: existing behavior unchanged ----
|
||||
|
||||
func TestDefaultBehaviorUnchanged(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := `package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
`
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{"path": f})
|
||||
|
||||
// All lines present
|
||||
if !strings.Contains(out, `println("hello")`) {
|
||||
t.Errorf("full mode missing body, got:\n%s", out)
|
||||
}
|
||||
// Padded line numbers
|
||||
if !strings.Contains(out, " 1│ package main") {
|
||||
t.Errorf("default should have padded line numbers, got:\n%s", out)
|
||||
}
|
||||
// etag present
|
||||
if !strings.Contains(out, "[etag:") {
|
||||
t.Errorf("missing etag footer, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSymbolsOnlyFlagStillWorks(t *testing.T) {
|
||||
// Old symbols_only=true + include_ast=true should still work
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServer(t, tmpDir)
|
||||
|
||||
goSrc := "package main\nfunc Hello() { println(1) }\n"
|
||||
f := writeFile(t, tmpDir, "test.go", goSrc)
|
||||
|
||||
out := callRead(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"include_ast": true,
|
||||
"symbols_only": true,
|
||||
})
|
||||
|
||||
if strings.Contains(out, "println") {
|
||||
t.Errorf("symbols_only should suppress body, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "[etag:") {
|
||||
t.Errorf("symbols_only missing etag, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Feature: resource_link for big reads ----
|
||||
|
||||
func TestResourceLinkThresholdTrip(t *testing.T) {
|
||||
// When result bytes > threshold, handleFileRead returns a ResourceLink content block.
|
||||
tmpDir := t.TempDir()
|
||||
// Low threshold (10 bytes) guarantees even a tiny file trips it.
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 10)
|
||||
|
||||
f := writeFile(t, tmpDir, "big.txt", strings.Repeat("x", 200))
|
||||
|
||||
result := callReadResult(t, srv, map[string]interface{}{"path": f})
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected content, got none")
|
||||
}
|
||||
link, ok := result.Content[0].(mcp.ResourceLink)
|
||||
if !ok {
|
||||
t.Fatalf("expected ResourceLink content, got %T", result.Content[0])
|
||||
}
|
||||
if !strings.HasPrefix(link.URI, "filepuff://read/") {
|
||||
t.Errorf("ResourceLink URI should start with filepuff://read/, got: %q", link.URI)
|
||||
}
|
||||
if link.Name != f {
|
||||
t.Errorf("ResourceLink Name should be file path %q, got %q", f, link.Name)
|
||||
}
|
||||
if !strings.Contains(link.Description, "etag=") {
|
||||
t.Errorf("ResourceLink Description should contain etag=, got %q", link.Description)
|
||||
}
|
||||
if !strings.Contains(link.Description, "size=") {
|
||||
t.Errorf("ResourceLink Description should contain size=, got %q", link.Description)
|
||||
}
|
||||
if !strings.Contains(link.Description, "lines=") {
|
||||
t.Errorf("ResourceLink Description should contain lines=, got %q", link.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceLinkForceInlineBypass(t *testing.T) {
|
||||
// force_inline=true must always return TextContent regardless of threshold.
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 1) // threshold = 1 byte, always trips
|
||||
|
||||
f := writeFile(t, tmpDir, "test.txt", "hello world")
|
||||
|
||||
result := callReadResult(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"force_inline": true,
|
||||
})
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected content, got none")
|
||||
}
|
||||
tc, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("force_inline=true should return TextContent, got %T", result.Content[0])
|
||||
}
|
||||
if !strings.Contains(tc.Text, "hello world") {
|
||||
t.Errorf("force_inline result should contain file content, got: %q", tc.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceLinkMaxInlineBytesOverride(t *testing.T) {
|
||||
// max_inline_bytes overrides server threshold per-call.
|
||||
tmpDir := t.TempDir()
|
||||
// Server threshold = 1 byte (always trips), but max_inline_bytes allows bigger.
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 1)
|
||||
|
||||
content := strings.Repeat("a", 50) // 50 bytes
|
||||
f := writeFile(t, tmpDir, "test.txt", content)
|
||||
|
||||
// max_inline_bytes=100 > 50 bytes result → should inline
|
||||
result := callReadResult(t, srv, map[string]interface{}{
|
||||
"path": f,
|
||||
"max_inline_bytes": 100,
|
||||
})
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected content, got none")
|
||||
}
|
||||
_, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("max_inline_bytes=100 with 50-byte file should return TextContent, got %T", result.Content[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceLinkStaleEtagRejection(t *testing.T) {
|
||||
// handleReadResource should reject fetch when file has changed since link was issued.
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 1) // always trips
|
||||
|
||||
f := writeFile(t, tmpDir, "stale.txt", "original content")
|
||||
|
||||
// Get a ResourceLink — captures etag of "original content"
|
||||
result := callReadResult(t, srv, map[string]interface{}{"path": f})
|
||||
link, ok := result.Content[0].(mcp.ResourceLink)
|
||||
if !ok {
|
||||
t.Fatalf("expected ResourceLink, got %T", result.Content[0])
|
||||
}
|
||||
// URI contains ?etag=<hash-of-original>
|
||||
if !strings.Contains(link.URI, "?etag=") {
|
||||
t.Fatalf("expected etag in URI, got %q", link.URI)
|
||||
}
|
||||
|
||||
// Overwrite the file with new content.
|
||||
if err := os.WriteFile(f, []byte("modified content — different"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Fetch the resource using the stale URI — should error.
|
||||
req := mcp.ReadResourceRequest{}
|
||||
req.Params.URI = link.URI
|
||||
_, err := srv.handleReadResource(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for stale etag, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "file changed") {
|
||||
t.Errorf("error should mention 'file changed', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceLinkBelowThresholdInlines(t *testing.T) {
|
||||
// When result is small (below threshold), always inline regardless of threshold setting.
|
||||
tmpDir := t.TempDir()
|
||||
// Large threshold — small file should be inlined.
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 64*1024)
|
||||
|
||||
f := writeFile(t, tmpDir, "small.txt", "tiny")
|
||||
|
||||
result := callReadResult(t, srv, map[string]interface{}{"path": f})
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected content, got none")
|
||||
}
|
||||
_, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("small file should return TextContent, got %T", result.Content[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceLinkThresholdZeroDisabled(t *testing.T) {
|
||||
// threshold=0 disables the feature entirely — always inline.
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 0)
|
||||
|
||||
f := writeFile(t, tmpDir, "test.txt", strings.Repeat("z", 10000))
|
||||
|
||||
result := callReadResult(t, srv, map[string]interface{}{"path": f})
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected content")
|
||||
}
|
||||
_, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("threshold=0 should always inline, got %T", result.Content[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReadResource_ValidFetch(t *testing.T) {
|
||||
// handleReadResource fetches file content when etag matches.
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 1)
|
||||
|
||||
f := writeFile(t, tmpDir, "fetch.txt", "fetch me please")
|
||||
|
||||
// Trigger a ResourceLink to get a valid URI with correct etag.
|
||||
toolResult := callReadResult(t, srv, map[string]interface{}{"path": f})
|
||||
link, ok := toolResult.Content[0].(mcp.ResourceLink)
|
||||
if !ok {
|
||||
t.Fatalf("expected ResourceLink, got %T", toolResult.Content[0])
|
||||
}
|
||||
|
||||
req := mcp.ReadResourceRequest{}
|
||||
req.Params.URI = link.URI
|
||||
contents, err := srv.handleReadResource(req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleReadResource error: %v", err)
|
||||
}
|
||||
if len(contents) == 0 {
|
||||
t.Fatal("expected resource contents, got none")
|
||||
}
|
||||
tc, ok := contents[0].(mcp.TextResourceContents)
|
||||
if !ok {
|
||||
t.Fatalf("expected TextResourceContents, got %T", contents[0])
|
||||
}
|
||||
if !strings.Contains(tc.Text, "fetch me please") {
|
||||
t.Errorf("resource contents should include file content, got: %q", tc.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceLinkMIMEType(t *testing.T) {
|
||||
// Verify MIME types for common extensions.
|
||||
tmpDir := t.TempDir()
|
||||
srv := newTestServerWithThreshold(t, tmpDir, 1)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
content string
|
||||
wantMIME string
|
||||
}{
|
||||
{"test.go", "package main\n", "text/x-go"},
|
||||
{"test.py", "# py\n", "text/x-python"},
|
||||
{"test.ts", "// ts\n", "text/typescript"},
|
||||
{"test.json", "{}\n", "application/json"},
|
||||
{"test.md", "# hi\n", "text/markdown"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
f := writeFile(t, tmpDir, c.name, c.content)
|
||||
result := callReadResult(t, srv, map[string]interface{}{"path": f})
|
||||
link, ok := result.Content[0].(mcp.ResourceLink)
|
||||
if !ok {
|
||||
t.Fatalf("expected ResourceLink for %s, got %T", c.name, result.Content[0])
|
||||
}
|
||||
if link.MIMEType != c.wantMIME {
|
||||
t.Errorf("%s: MIMEType = %q, want %q", c.name, link.MIMEType, c.wantMIME)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+131
-68
@@ -13,8 +13,14 @@ import (
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// handleSymbolAt handles the symbol_at tool.
|
||||
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// handleLSPQuery is the unified dispatcher for all LSP operations.
|
||||
// action must be one of: "hover", "definition", "references".
|
||||
func (s *Server) handleLSPQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
action, err := request.RequireString("action")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("action is required (hover | definition | references)"), nil
|
||||
}
|
||||
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
@@ -30,16 +36,51 @@ func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Try LSP hover
|
||||
verbose := request.GetBool("verbose", false)
|
||||
|
||||
switch action {
|
||||
case "hover":
|
||||
if _, ok := request.GetArguments()["include_declaration"]; ok {
|
||||
return mcp.NewToolResultError("include_declaration is only valid for action=references"), nil
|
||||
}
|
||||
if _, ok := request.GetArguments()["compact"]; ok {
|
||||
return mcp.NewToolResultError("compact is only valid for action=references"), nil
|
||||
}
|
||||
return s.lspHover(ctx, file, line, col, verbose)
|
||||
|
||||
case "definition":
|
||||
if _, ok := request.GetArguments()["include_declaration"]; ok {
|
||||
return mcp.NewToolResultError("include_declaration is only valid for action=references"), nil
|
||||
}
|
||||
if _, ok := request.GetArguments()["compact"]; ok {
|
||||
return mcp.NewToolResultError("compact is only valid for action=references"), nil
|
||||
}
|
||||
return s.lspDefinition(ctx, file, line, col, verbose)
|
||||
|
||||
case "references":
|
||||
includeDecl := request.GetBool("include_declaration", true)
|
||||
// compact: explicit call-time > session compact_refs pref > false
|
||||
var prefsCompact *bool
|
||||
if sp := s.sessionPrefs.Load(); sp != nil {
|
||||
prefsCompact = sp.CompactRefs
|
||||
}
|
||||
compact := effectiveBool(request, "compact", prefsCompact, false)
|
||||
return s.lspReferences(ctx, file, line, col, includeDecl, compact, verbose)
|
||||
|
||||
default:
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unknown action %q: must be hover | definition | references", action)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// lspHover performs hover (symbol info) for the given position.
|
||||
func (s *Server) lspHover(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) {
|
||||
hover, err := s.lspManager.Hover(ctx, file, line, col)
|
||||
if err != nil {
|
||||
// Fall back to AST-based info
|
||||
return s.handleSymbolAtFallback(ctx, file, line, col)
|
||||
return s.handleSymbolAtFallback(ctx, file, line, col, verbose)
|
||||
}
|
||||
|
||||
if hover == nil {
|
||||
@@ -47,14 +88,16 @@ func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information**\n\n")
|
||||
if verbose {
|
||||
output.WriteString("**Symbol Information**\n\n")
|
||||
}
|
||||
output.WriteString(hover.Contents.Value)
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
|
||||
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
|
||||
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %s", errors.SanitizeError(err))), nil
|
||||
@@ -71,35 +114,17 @@ func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line,
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information** (AST fallback)\n\n")
|
||||
if verbose {
|
||||
output.WriteString("**Symbol Information** (AST fallback)\n\n")
|
||||
}
|
||||
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
|
||||
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindDefinition handles the find_definition tool.
|
||||
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// lspDefinition finds the definition of the symbol at the given position.
|
||||
func (s *Server) lspDefinition(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) {
|
||||
locations, err := s.lspManager.Definition(ctx, file, line, col)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %s", errors.SanitizeError(err))), nil
|
||||
@@ -110,13 +135,14 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
|
||||
if verbose {
|
||||
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
|
||||
}
|
||||
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
|
||||
// Try to read a preview snippet
|
||||
preview := s.readFilePreview(filePath, loc.Range.Start.Line+1, 3)
|
||||
if preview != "" {
|
||||
output.WriteString("```\n")
|
||||
@@ -129,30 +155,8 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindReferences handles the find_references tool.
|
||||
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
includeDecl := request.GetBool("include_declaration", true)
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// lspReferences finds all references to the symbol at the given position.
|
||||
func (s *Server) lspReferences(ctx context.Context, file string, line, col int, includeDecl, compact, verbose bool) (*mcp.CallToolResult, error) {
|
||||
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %s", errors.SanitizeError(err))), nil
|
||||
@@ -162,25 +166,84 @@ func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolR
|
||||
return mcp.NewToolResultText("No references found."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
|
||||
|
||||
// Group by file
|
||||
// Group by file, preserving encounter order.
|
||||
fileGroups := make(map[string][]lsp.Location)
|
||||
fileOrder := make([]string, 0)
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
if _, seen := fileGroups[filePath]; !seen {
|
||||
fileOrder = append(fileOrder, filePath)
|
||||
}
|
||||
fileGroups[filePath] = append(fileGroups[filePath], loc)
|
||||
}
|
||||
|
||||
for filePath, locs := range fileGroups {
|
||||
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
|
||||
for _, loc := range locs {
|
||||
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
}
|
||||
output.WriteString("\n")
|
||||
return mcp.NewToolResultText(formatReferences(fileGroups, fileOrder, len(locations), compact, verbose)), nil
|
||||
}
|
||||
|
||||
// formatReferences formats grouped reference locations as a string.
|
||||
// compact=false: verbose multi-line format with L{line}:{col} per entry.
|
||||
// compact=true: one line per file — file:[line:col, ...] (N), with same-line
|
||||
// columns collapsed to line:{col1,col2,...}.
|
||||
func formatReferences(fileGroups map[string][]lsp.Location, fileOrder []string, total int, compact bool, verbose bool) string {
|
||||
var output strings.Builder
|
||||
if verbose {
|
||||
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", total))
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
for _, filePath := range fileOrder {
|
||||
locs := fileGroups[filePath]
|
||||
if compact {
|
||||
output.WriteString(formatReferencesCompact(filePath, locs))
|
||||
} else {
|
||||
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
|
||||
for _, loc := range locs {
|
||||
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
}
|
||||
output.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// formatReferencesCompact formats one file's references as a single compact line.
|
||||
// Same-line references are collapsed: 12:{5,8} instead of 12:5, 12:8.
|
||||
func formatReferencesCompact(filePath string, locs []lsp.Location) string {
|
||||
// Build ordered line->col map preserving encounter order per line.
|
||||
type lineEntry struct {
|
||||
lineNum int
|
||||
cols []int
|
||||
}
|
||||
lineMap := make(map[int]*lineEntry)
|
||||
lineOrder := make([]int, 0, len(locs))
|
||||
|
||||
for _, loc := range locs {
|
||||
ln := loc.Range.Start.Line + 1
|
||||
col := loc.Range.Start.Character + 1
|
||||
if e, ok := lineMap[ln]; ok {
|
||||
e.cols = append(e.cols, col)
|
||||
} else {
|
||||
lineMap[ln] = &lineEntry{lineNum: ln, cols: []int{col}}
|
||||
lineOrder = append(lineOrder, ln)
|
||||
}
|
||||
}
|
||||
|
||||
// Build the bracket contents.
|
||||
parts := make([]string, 0, len(lineOrder))
|
||||
for _, ln := range lineOrder {
|
||||
e := lineMap[ln]
|
||||
if len(e.cols) == 1 {
|
||||
parts = append(parts, fmt.Sprintf("%d:%d", ln, e.cols[0]))
|
||||
} else {
|
||||
colStrs := make([]string, len(e.cols))
|
||||
for i, c := range e.cols {
|
||||
colStrs[i] = fmt.Sprintf("%d", c)
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%d:{%s}", ln, strings.Join(colStrs, ",")))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:[%s] (%d)\n", filePath, strings.Join(parts, ", "), len(locs))
|
||||
}
|
||||
|
||||
// readFilePreview reads a few lines from a file around the given line.
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/lsp"
|
||||
)
|
||||
|
||||
// makeLocation builds an lsp.Location with a file:// URI.
|
||||
func makeLocation(file string, line, col int) lsp.Location {
|
||||
return lsp.Location{
|
||||
URI: "file://" + file,
|
||||
Range: lsp.Range{
|
||||
Start: lsp.Position{Line: line - 1, Character: col - 1},
|
||||
End: lsp.Position{Line: line - 1, Character: col - 1},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// groupLocations is a helper that mirrors the grouping logic in lspReferences.
|
||||
func groupLocations(locations []lsp.Location) (map[string][]lsp.Location, []string) {
|
||||
fileGroups := make(map[string][]lsp.Location)
|
||||
fileOrder := make([]string, 0)
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
if _, seen := fileGroups[filePath]; !seen {
|
||||
fileOrder = append(fileOrder, filePath)
|
||||
}
|
||||
fileGroups[filePath] = append(fileGroups[filePath], loc)
|
||||
}
|
||||
return fileGroups, fileOrder
|
||||
}
|
||||
|
||||
// TestFormatReferencesVerbose verifies the default (non-compact) format is unchanged.
|
||||
func TestFormatReferencesVerbose(t *testing.T) {
|
||||
locs := []lsp.Location{
|
||||
makeLocation("/a/foo.go", 12, 5),
|
||||
makeLocation("/a/foo.go", 13, 8),
|
||||
makeLocation("/a/bar.go", 15, 1),
|
||||
}
|
||||
groups, order := groupLocations(locs)
|
||||
out := formatReferences(groups, order, len(locs), false, true)
|
||||
|
||||
if !strings.Contains(out, "Found 3 reference(s):") {
|
||||
t.Errorf("missing header, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "**") {
|
||||
t.Errorf("verbose format should use **file** markers, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "L12:5") {
|
||||
t.Errorf("missing L12:5 in verbose output, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "L13:8") {
|
||||
t.Errorf("missing L13:8 in verbose output, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "L15:1") {
|
||||
t.Errorf("missing L15:1 in verbose output, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatReferencesCompactBasic verifies compact output for distinct lines.
|
||||
func TestFormatReferencesCompactBasic(t *testing.T) {
|
||||
locs := []lsp.Location{
|
||||
makeLocation("/a/foo.go", 12, 5),
|
||||
makeLocation("/a/foo.go", 13, 8),
|
||||
makeLocation("/a/foo.go", 15, 12),
|
||||
}
|
||||
groups, order := groupLocations(locs)
|
||||
out := formatReferences(groups, order, len(locs), true, true)
|
||||
|
||||
if !strings.Contains(out, "Found 3 reference(s):") {
|
||||
t.Errorf("missing header, got:\n%s", out)
|
||||
}
|
||||
// Should contain "foo.go:[12:5, 13:8, 15:12] (3)"
|
||||
if !strings.Contains(out, "12:5") {
|
||||
t.Errorf("missing 12:5 in compact output, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "13:8") {
|
||||
t.Errorf("missing 13:8 in compact output, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "15:12") {
|
||||
t.Errorf("missing 15:12 in compact output, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "(3)") {
|
||||
t.Errorf("missing (3) count, got:\n%s", out)
|
||||
}
|
||||
// Compact format must NOT use ** markers
|
||||
if strings.Contains(out, "**") {
|
||||
t.Errorf("compact format must not use ** markers, got:\n%s", out)
|
||||
}
|
||||
// Must not have L prefix
|
||||
if strings.Contains(out, "L12") {
|
||||
t.Errorf("compact format must not have L prefix, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatReferencesCompactSameLineCollapse verifies same-line columns collapse.
|
||||
func TestFormatReferencesCompactSameLineCollapse(t *testing.T) {
|
||||
locs := []lsp.Location{
|
||||
makeLocation("/a/foo.go", 12, 5),
|
||||
makeLocation("/a/foo.go", 12, 8),
|
||||
makeLocation("/a/foo.go", 15, 3),
|
||||
}
|
||||
groups, order := groupLocations(locs)
|
||||
out := formatReferences(groups, order, len(locs), true, false)
|
||||
|
||||
// Line 12 has two refs → should be collapsed: 12:{5,8}
|
||||
if !strings.Contains(out, "12:{5,8}") {
|
||||
t.Errorf("same-line refs should collapse to 12:{5,8}, got:\n%s", out)
|
||||
}
|
||||
// Line 15 is single → 15:3
|
||||
if !strings.Contains(out, "15:3") {
|
||||
t.Errorf("single ref on line 15 should be 15:3, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatReferencesCompactMultiFile verifies compact output across multiple files.
|
||||
func TestFormatReferencesCompactMultiFile(t *testing.T) {
|
||||
locs := []lsp.Location{
|
||||
makeLocation("/a/foo.go", 5, 1),
|
||||
makeLocation("/b/bar.go", 10, 3),
|
||||
makeLocation("/b/bar.go", 10, 7),
|
||||
}
|
||||
groups, order := groupLocations(locs)
|
||||
out := formatReferences(groups, order, len(locs), true, true)
|
||||
|
||||
if !strings.Contains(out, "Found 3 reference(s):") {
|
||||
t.Errorf("missing header, got:\n%s", out)
|
||||
}
|
||||
// foo.go: one ref
|
||||
if !strings.Contains(out, "5:1") {
|
||||
t.Errorf("missing 5:1 for foo.go, got:\n%s", out)
|
||||
}
|
||||
// bar.go: two refs on same line → collapsed
|
||||
if !strings.Contains(out, "10:{3,7}") {
|
||||
t.Errorf("missing 10:{3,7} collapse for bar.go, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatReferencesCompactSingleRef verifies single-reference compact output.
|
||||
func TestFormatReferencesCompactSingleRef(t *testing.T) {
|
||||
locs := []lsp.Location{
|
||||
makeLocation("/a/only.go", 7, 2),
|
||||
}
|
||||
groups, order := groupLocations(locs)
|
||||
out := formatReferences(groups, order, len(locs), true, false)
|
||||
|
||||
if !strings.Contains(out, "7:2") {
|
||||
t.Errorf("missing 7:2, got:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "(1)") {
|
||||
t.Errorf("missing (1), got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatReferencesCompactNoLPrefix verifies the L prefix is absent in compact mode.
|
||||
func TestFormatReferencesCompactNoLPrefix(t *testing.T) {
|
||||
locs := []lsp.Location{
|
||||
makeLocation("/a/x.go", 3, 4),
|
||||
}
|
||||
groups, order := groupLocations(locs)
|
||||
out := formatReferences(groups, order, len(locs), true, false)
|
||||
|
||||
if strings.Contains(out, "L3") {
|
||||
t.Errorf("compact output must not contain L prefix, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFormatReferencesVerboseNoChange verifies compact=false preserves old L-prefix format.
|
||||
func TestFormatReferencesVerbosePreservesLPrefix(t *testing.T) {
|
||||
locs := []lsp.Location{
|
||||
makeLocation("/a/x.go", 3, 4),
|
||||
}
|
||||
groups, order := groupLocations(locs)
|
||||
out := formatReferences(groups, order, len(locs), false, true)
|
||||
|
||||
if !strings.Contains(out, "L3:4") {
|
||||
t.Errorf("verbose output must contain L3:4, got:\n%s", out)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
// Package server implements the MCP server for file operations.
|
||||
package server
|
||||
|
||||
// helpFileRead is the full flag documentation and examples for the file_read tool,
|
||||
// served at filepuff://help/file_read.
|
||||
const helpFileRead = "# file_read — flags and examples\n\n" +
|
||||
"## Token-saving features\n\n" +
|
||||
"| Flag | Effect |\n" +
|
||||
"|------|--------|\n" +
|
||||
"| `previous_etag` | Skip re-reading unchanged files. Returns `[unchanged, etag: ...]` if file is unchanged. |\n" +
|
||||
"| `symbol_name` | Read only a named function/struct/class — eliminates an ast_query round-trip. |\n" +
|
||||
"| `symbols_only=true` | Return only symbol list (~95% fewer tokens). Requires `include_ast=true`. Alias: `mode='symbols_only'`. |\n" +
|
||||
"| `mode` | `full` (default) \\| `skeleton` (signatures + `{ ... }` stubs, bodies elided) \\| `symbols_only` |\n" +
|
||||
"| `strip` | Remove content classes before line-numbering: `imports`, `license`, `block_comments`. Emits `[stripped: ...]` footer. |\n" +
|
||||
"| `no_line_numbers=true` | Omit the ` 12│ ` line-number prefix (~10% savings). `line_number_interval=0` has the same effect. |\n" +
|
||||
"| `line_number_interval=N` | Print line numbers only every N lines. |\n" +
|
||||
"| `compact_line_numbers=true` | Use compact `12│` prefix instead of ` 12│ ` (no padding, no trailing space). |\n" +
|
||||
"| `collapse_blank_lines=true` | Collapse consecutive blank lines to one. |\n" +
|
||||
"| `max_lines=N` | Truncate output with omitted count notice. Applied after `line_start`/`line_end`. |\n" +
|
||||
"| `paths=[...]` | Read multiple files in one call. Each file gets a `--- path ---` header. |\n\n" +
|
||||
"All responses include `[etag: hex]` footer (8 hex chars) for use as `previous_etag` in subsequent reads.\n\n" +
|
||||
"## Examples\n\n" +
|
||||
"```json\n" +
|
||||
"// Full file\n" +
|
||||
`{"path": "main.go"}` + "\n\n" +
|
||||
"// Etag check — returns unchanged notice if file hasn't changed\n" +
|
||||
`{"path": "main.go", "previous_etag": "a3f9c2b1"}` + "\n\n" +
|
||||
"// Read only one named symbol\n" +
|
||||
`{"path": "server.go", "symbol_name": "handleFileRead"}` + "\n\n" +
|
||||
"// Skeleton mode — signatures only, bodies elided\n" +
|
||||
`{"path": "server.go", "mode": "skeleton"}` + "\n\n" +
|
||||
"// Strip imports and license header\n" +
|
||||
`{"path": "main.go", "strip": ["imports", "license"]}` + "\n\n" +
|
||||
"// Batch read multiple files\n" +
|
||||
`{"paths": ["a.go", "b.go"]}` + "\n\n" +
|
||||
"// Specific line range\n" +
|
||||
`{"path": "main.go", "line_start": 10, "line_end": 50}` + "\n" +
|
||||
"```\n"
|
||||
|
||||
// helpFileSearch is the full flag documentation and examples for the file_search tool,
|
||||
// served at filepuff://help/file_search.
|
||||
const helpFileSearch = "# file_search — flags and examples\n\n" +
|
||||
"## Output format\n\n" +
|
||||
"Matches grouped by file. Each file section has matching lines prefixed by `L{line}│` and context lines prefixed by ` │`. Zero matches: `No matches found.`\n\n" +
|
||||
"## Flags\n\n" +
|
||||
"| Flag | Effect |\n" +
|
||||
"|------|--------|\n" +
|
||||
"| `verbose=true` | Emit `Found N matches in M files:` preamble (v1 behaviour). Default: false. |\n" +
|
||||
"| `cluster=true` | Coalesce consecutive match lines into ranges (`L12-14│ text`). Drops context lines for density. |\n" +
|
||||
"| `cursor` | Opaque pagination token from a previous truncated response — fetches next page. |\n" +
|
||||
"| `max_results` | Page size for pagination. Re-run with `cursor` to get next page. |\n" +
|
||||
"| `context_lines` | Number of context lines around matches (default: 2). |\n" +
|
||||
"| `ignore_case` | Case-insensitive search. |\n" +
|
||||
"| `regex` | Treat pattern as regex (default: true). |\n" +
|
||||
"| `file_types` | Restrict to file extensions, e.g. `[\"go\", \"ts\"]`. |\n" +
|
||||
"| `paths` | Paths to search in (defaults to workspace root). |\n\n" +
|
||||
"## Examples\n\n" +
|
||||
"```json\n" +
|
||||
"// Search for error-returning functions in Go files\n" +
|
||||
`{"pattern": "func.*Error", "file_types": ["go"], "max_results": 20}` + "\n\n" +
|
||||
"// Case-insensitive literal search\n" +
|
||||
`{"pattern": "TODO", "ignore_case": true}` + "\n\n" +
|
||||
"// Paginated search — fetch next page\n" +
|
||||
`{"pattern": "import", "max_results": 50, "cursor": "<token from previous response>"}` + "\n\n" +
|
||||
"// Clustered — dense view of many matches\n" +
|
||||
`{"pattern": "return err", "file_types": ["go"], "cluster": true}` + "\n" +
|
||||
"```\n"
|
||||
|
||||
// helpASTQuery is the full flag documentation and examples for the ast_query tool,
|
||||
// served at filepuff://help/ast_query.
|
||||
const helpASTQuery = "# ast_query — flags and examples\n\n" +
|
||||
"## Output format\n\n" +
|
||||
"Entries in format `**file:line** (node_type)` with code blocks and captured variables (`$NAME=value`). Zero matches: `No matches found.`\n\n" +
|
||||
"## Flags\n\n" +
|
||||
"| Flag | Effect |\n" +
|
||||
"|------|--------|\n" +
|
||||
"| `verbose=true` | Emit `Found N match(es):` preamble (v1 behaviour). Default: false. |\n" +
|
||||
"| `format` | `verbose` (default, full code+captures) \\| `compact` (one line per match) \\| `location` (file:line only) |\n" +
|
||||
"| `cursor` | Opaque pagination token from a previous truncated response — fetches next page. |\n" +
|
||||
"| `max_results` | Page size (default: 100). |\n" +
|
||||
"| `name_exact` | Exact symbol name to match. |\n" +
|
||||
"| `name_matches` | Regex pattern to filter by name. |\n" +
|
||||
"| `kind_in` | Node types to match (e.g. `function_declaration`, `class_declaration`). |\n" +
|
||||
"| `paths` | Paths to search in (defaults to workspace root). |\n\n" +
|
||||
"## Pattern placeholders\n\n" +
|
||||
"| Placeholder | Meaning |\n" +
|
||||
"|-------------|----------|\n" +
|
||||
"| `$NAME` | Matches a single node, captures as `$NAME` |\n" +
|
||||
"| `$$$ARGS` | Matches zero or more nodes (variadic capture) |\n" +
|
||||
"| `$_` | Wildcard — matches any single node, no capture |\n\n" +
|
||||
"## Examples\n\n" +
|
||||
"```json\n" +
|
||||
"// All Go functions returning error\n" +
|
||||
`{"pattern": "func $NAME($$$ARGS) error", "language": "go"}` + "\n\n" +
|
||||
"// Python classes\n" +
|
||||
`{"pattern": "class $NAME: $$$BODY", "language": "python"}` + "\n\n" +
|
||||
"// Specific named function\n" +
|
||||
`{"pattern": "func $NAME($$$ARGS)", "language": "go", "name_exact": "NewServer"}` + "\n\n" +
|
||||
"// Compact output — one line per match\n" +
|
||||
`{"pattern": "func $NAME($$$ARGS) error", "language": "go", "format": "compact"}` + "\n" +
|
||||
"```\n"
|
||||
|
||||
// helpLSPQuery is the full flag documentation and examples for the lsp_query tool,
|
||||
// served at filepuff://help/lsp_query.
|
||||
const helpLSPQuery = "# lsp_query — flags and examples\n\n" +
|
||||
"## Actions\n\n" +
|
||||
"### hover\n" +
|
||||
"Returns type/doc from LSP, falls back to AST node info. `verbose=true` adds `**Symbol Information**` header.\n\n" +
|
||||
"### definition\n" +
|
||||
"Returns `file:line:col` + 3-line code preview for each definition. `verbose=true` adds `Found N definition(s):` header.\n\n" +
|
||||
"### references\n" +
|
||||
"Returns references grouped by file. Flags:\n" +
|
||||
"- `include_declaration` (default true) — include the declaration itself\n" +
|
||||
"- `compact=true` — collapse to one line per file\n" +
|
||||
"- `verbose=true` — add `Found N reference(s):` header\n\n" +
|
||||
"Note: `include_declaration` and `compact` are errors when used with actions other than `references`.\n\n" +
|
||||
"## Examples\n\n" +
|
||||
"```json\n" +
|
||||
"// Hover — type/doc at position\n" +
|
||||
`{"action": "hover", "file": "server.go", "line": 45, "column": 6}` + "\n\n" +
|
||||
"// Definition — where is this symbol defined?\n" +
|
||||
`{"action": "definition", "file": "handler.go", "line": 23, "column": 10}` + "\n\n" +
|
||||
"// References — all usages\n" +
|
||||
`{"action": "references", "file": "types.go", "line": 5, "column": 6}` + "\n\n" +
|
||||
"// References — compact (one line per file)\n" +
|
||||
`{"action": "references", "file": "types.go", "line": 5, "column": 6, "compact": true}` + "\n" +
|
||||
"```\n"
|
||||
|
||||
// helpEditApply is the full flag documentation and examples for the edit_apply tool,
|
||||
// served at filepuff://help/edit_apply.
|
||||
const helpEditApply = "# edit_apply — flags and examples\n\n" +
|
||||
"## Response format (`response` flag)\n\n" +
|
||||
"| Value | Output |\n" +
|
||||
"|-------|--------|\n" +
|
||||
"| `count` (default) | `+3 -1` added/removed line counts only |\n" +
|
||||
"| `diff` | Full unified diff of changes made |\n" +
|
||||
"| `none` | Empty response (silent success) |\n\n" +
|
||||
"`compact_response=true` is a deprecated alias for `response=\"count\"` kept for pre-v2 compatibility.\n\n" +
|
||||
"For code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) syntax is validated before writing — the edit is rejected if it would produce invalid syntax.\n\n" +
|
||||
"## Selector types\n\n" +
|
||||
"### AST-mode selectors (code files)\n" +
|
||||
"- `selector_kind` — AST node type (e.g. `function_declaration`, `class_declaration`)\n" +
|
||||
"- `selector_name` — symbol name to match\n\n" +
|
||||
"### Text-mode selectors (all files)\n" +
|
||||
"- `selector_text` — exact text to match (must be unique, or use `selector_index`)\n" +
|
||||
"- `selector_pattern` — regex pattern to match\n" +
|
||||
"- `selector_line` / `selector_line_end` — line range\n\n" +
|
||||
"### Shared\n" +
|
||||
"- `selector_index` — index of match when multiple exist (default: 0)\n\n" +
|
||||
"## Examples\n\n" +
|
||||
"```json\n" +
|
||||
"// AST mode — replace a named function\n" +
|
||||
"{\n" +
|
||||
` "file": "main.go",` + "\n" +
|
||||
` "operation": "replace",` + "\n" +
|
||||
` "selector_kind": "function_declaration",` + "\n" +
|
||||
` "selector_name": "Hello",` + "\n" +
|
||||
` "new_content": "func Hello() {\n\treturn\n}"` + "\n" +
|
||||
"}\n\n" +
|
||||
"// Text mode — replace a markdown header\n" +
|
||||
"{\n" +
|
||||
` "file": "README.md",` + "\n" +
|
||||
` "operation": "replace",` + "\n" +
|
||||
` "selector_text": "## Old Header",` + "\n" +
|
||||
` "new_content": "## New Header"` + "\n" +
|
||||
"}\n\n" +
|
||||
"// Line range replacement\n" +
|
||||
"{\n" +
|
||||
` "file": "config.yaml",` + "\n" +
|
||||
` "operation": "replace",` + "\n" +
|
||||
` "selector_line": 5,` + "\n" +
|
||||
` "selector_line_end": 10,` + "\n" +
|
||||
` "new_content": "key: value"` + "\n" +
|
||||
"}\n\n" +
|
||||
"// Request full diff in response\n" +
|
||||
"{\n" +
|
||||
` "file": "main.go",` + "\n" +
|
||||
` "operation": "replace",` + "\n" +
|
||||
` "selector_name": "Hello",` + "\n" +
|
||||
` "new_content": "func Hello() {}",` + "\n" +
|
||||
` "response": "diff"` + "\n" +
|
||||
"}\n" +
|
||||
"```\n"
|
||||
@@ -43,23 +43,7 @@ func Hello() string {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test 1: Ping tool (health check)
|
||||
t.Run("ping", func(t *testing.T) {
|
||||
req := mcp.CallToolRequest{}
|
||||
result, err := srv.handlePing(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("handlePing() error = %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("handlePing() returned nil")
|
||||
return
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("handlePing() returned empty content")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 2: File read
|
||||
// Test: File read (ping removed — Change 3)
|
||||
t.Run("file_read", func(t *testing.T) {
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
@@ -144,14 +128,8 @@ func TestMCPToolDiscovery(t *testing.T) {
|
||||
t.Fatal("MCP server not initialized")
|
||||
}
|
||||
|
||||
// Verify each expected tool works
|
||||
ctx := context.Background()
|
||||
|
||||
// Test ping tool
|
||||
pingReq := mcp.CallToolRequest{}
|
||||
if _, err := srv.handlePing(ctx, pingReq); err != nil {
|
||||
t.Errorf("ping tool failed: %v", err)
|
||||
}
|
||||
// Ping tool removed (Change 3 — MCP protocol has own liveness check).
|
||||
// Tools verified via integration tests in TestIntegrationFileOperations.
|
||||
}
|
||||
|
||||
// TestMCPErrorResponses tests error handling following MCP spec.
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
xxhash "github.com/cespare/xxhash/v2"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// helpResources maps a tool name to its help content constant.
|
||||
var helpResources = map[string]string{
|
||||
"file_read": helpFileRead,
|
||||
"file_search": helpFileSearch,
|
||||
"ast_query": helpASTQuery,
|
||||
"lsp_query": helpLSPQuery,
|
||||
"edit_apply": helpEditApply,
|
||||
}
|
||||
|
||||
// registerResources registers one filepuff://help/<tool> resource per tool.
|
||||
// Each resource returns Markdown-formatted flag docs and examples.
|
||||
func (s *Server) registerResources() {
|
||||
for toolName, content := range helpResources {
|
||||
uri := "filepuff://help/" + toolName
|
||||
name := "help/" + toolName
|
||||
description := "Flag documentation and examples for the " + toolName + " tool."
|
||||
captured := content // capture for closure
|
||||
|
||||
s.mcp.AddResource(
|
||||
mcp.NewResource(uri, name,
|
||||
mcp.WithResourceDescription(description),
|
||||
mcp.WithMIMEType("text/markdown"),
|
||||
),
|
||||
func(_ context.Context, _ mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: uri,
|
||||
MIMEType: "text/markdown",
|
||||
Text: captured,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// readHelpResource is a convenience handler that can be used directly when a
|
||||
// single resource handler is needed. It is kept exported for testability.
|
||||
func readHelpResource(uri string) mcpserver.ResourceHandlerFunc {
|
||||
return func(_ context.Context, _ mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
// Extract tool name from filepuff://help/<tool>
|
||||
const prefix = "filepuff://help/"
|
||||
if len(uri) <= len(prefix) {
|
||||
return nil, fmt.Errorf("invalid help URI: %s", uri)
|
||||
}
|
||||
toolName := uri[len(prefix):]
|
||||
content, ok := helpResources[toolName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no help content for tool: %s", toolName)
|
||||
}
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: uri,
|
||||
MIMEType: "text/markdown",
|
||||
Text: content,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// registerReadResource registers the filepuff://read/{+path} resource template.
|
||||
// The handler re-reads the file, validates the etag query param if provided,
|
||||
// and returns the raw file content (no line-number formatting).
|
||||
//
|
||||
// URI format: filepuff://read/<absolute-path>?etag=<etag>
|
||||
// The etag param is optional. If supplied and the file has changed, the handler
|
||||
// returns an error so the caller re-runs file_read to get a fresh ResourceLink.
|
||||
func (s *Server) registerReadResource() {
|
||||
const uriTemplate = "filepuff://read/{+path}"
|
||||
|
||||
s.mcp.AddResourceTemplate(
|
||||
mcp.NewResourceTemplate(uriTemplate, "file-read",
|
||||
mcp.WithTemplateDescription("Raw content of a file previously read via file_read. "+
|
||||
"Fetch when file_read returns a ResourceLink instead of inlining content. "+
|
||||
"URI: filepuff://read/<path>?etag=<etag>"),
|
||||
),
|
||||
func(_ context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
return s.handleReadResource(req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// handleReadResource is the resource handler for filepuff://read/{+path} URIs.
|
||||
func (s *Server) handleReadResource(req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
rawURI := req.Params.URI
|
||||
|
||||
// Parse path and etag from the URI.
|
||||
// URI shape: filepuff://read/<path>[?etag=<hex>]
|
||||
const scheme = "filepuff://read/"
|
||||
if !strings.HasPrefix(rawURI, scheme) {
|
||||
return nil, fmt.Errorf("invalid read resource URI: %s", rawURI)
|
||||
}
|
||||
rest := rawURI[len(scheme):]
|
||||
|
||||
// Split off query string to get the path.
|
||||
filePath := rest
|
||||
var expectedEtag string
|
||||
if qIdx := strings.IndexByte(rest, '?'); qIdx >= 0 {
|
||||
filePath = rest[:qIdx]
|
||||
qs, err := url.ParseQuery(rest[qIdx+1:])
|
||||
if err == nil {
|
||||
expectedEtag = qs.Get("etag")
|
||||
}
|
||||
}
|
||||
|
||||
if filePath == "" {
|
||||
return nil, fmt.Errorf("read resource URI missing path")
|
||||
}
|
||||
|
||||
if !s.cfg.IsPathAllowed(filePath) {
|
||||
return nil, fmt.Errorf("path is outside workspace root")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("file not found: %s", filePath)
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return nil, fmt.Errorf("permission denied: %s", filePath)
|
||||
}
|
||||
return nil, fmt.Errorf("error reading file: %s", filePath)
|
||||
}
|
||||
|
||||
// Validate etag if provided — detect stale references.
|
||||
if expectedEtag != "" {
|
||||
fullHash := fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
currentEtag := fullHash[:8]
|
||||
if expectedEtag != currentEtag && !strings.HasPrefix(fullHash, expectedEtag) && !strings.HasPrefix(expectedEtag, currentEtag) {
|
||||
return nil, fmt.Errorf("file changed since ResourceLink was issued (expected etag %s, got %s); re-run file_read to get fresh content", expectedEtag, currentEtag)
|
||||
}
|
||||
}
|
||||
|
||||
mimeType := detectMIMEType(filePath)
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: rawURI,
|
||||
MIMEType: mimeType,
|
||||
Text: string(content),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
func TestRegisterResources_AllToolsHaveResource(t *testing.T) {
|
||||
// Verify that registerResources wires up without panicking.
|
||||
_ = newTestServer(t, t.TempDir())
|
||||
|
||||
expectedURIs := []string{
|
||||
"filepuff://help/file_read",
|
||||
"filepuff://help/file_search",
|
||||
"filepuff://help/ast_query",
|
||||
"filepuff://help/lsp_query",
|
||||
"filepuff://help/edit_apply",
|
||||
}
|
||||
|
||||
for _, uri := range expectedURIs {
|
||||
t.Run(uri, func(t *testing.T) {
|
||||
handler := readHelpResource(uri)
|
||||
contents, err := handler(context.Background(), mcp.ReadResourceRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("readHelpResource(%q) error = %v", uri, err)
|
||||
}
|
||||
if len(contents) == 0 {
|
||||
t.Fatalf("readHelpResource(%q) returned empty contents", uri)
|
||||
}
|
||||
tc, ok := contents[0].(mcp.TextResourceContents)
|
||||
if !ok {
|
||||
t.Fatalf("readHelpResource(%q) contents[0] is not TextResourceContents", uri)
|
||||
}
|
||||
if tc.MIMEType != "text/markdown" {
|
||||
t.Errorf("MIMEType = %q, want %q", tc.MIMEType, "text/markdown")
|
||||
}
|
||||
if len(tc.Text) == 0 {
|
||||
t.Errorf("Text is empty for %q", uri)
|
||||
}
|
||||
if !strings.HasPrefix(tc.Text, "#") {
|
||||
t.Errorf("expected markdown (# heading) for %q, got: %q", uri, tc.Text[:min(50, len(tc.Text))])
|
||||
}
|
||||
if tc.URI != uri {
|
||||
t.Errorf("URI = %q, want %q", tc.URI, uri)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadHelpResource_UnknownTool(t *testing.T) {
|
||||
handler := readHelpResource("filepuff://help/nonexistent")
|
||||
_, err := handler(context.Background(), mcp.ReadResourceRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown tool, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadHelpResource_InvalidURI(t *testing.T) {
|
||||
handler := readHelpResource("filepuff://help/")
|
||||
_, err := handler(context.Background(), mcp.ReadResourceRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty tool name, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelpContent_NotEmpty(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"file_read": helpFileRead,
|
||||
"file_search": helpFileSearch,
|
||||
"ast_query": helpASTQuery,
|
||||
"lsp_query": helpLSPQuery,
|
||||
"edit_apply": helpEditApply,
|
||||
}
|
||||
for name, content := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if len(content) == 0 {
|
||||
t.Errorf("help content for %q is empty", name)
|
||||
}
|
||||
if !strings.Contains(content, "##") {
|
||||
t.Errorf("expected markdown sections (##) in help content for %q", name)
|
||||
}
|
||||
if !strings.Contains(content, "```") {
|
||||
t.Errorf("expected code fences in help content for %q", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+95
-120
@@ -33,16 +33,17 @@ const PreviewLineMaxLength = 100
|
||||
|
||||
// Server represents the MCP file operations server.
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
mcp *server.MCPServer
|
||||
searcher *search.Searcher
|
||||
parser *parser.Registry
|
||||
matcher *query.Matcher
|
||||
lspManager *lsp.Manager
|
||||
editor *edit.Engine
|
||||
readSem chan struct{} // Semaphore for limiting concurrent file reads
|
||||
querySem chan struct{} // Semaphore for limiting concurrent AST queries
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
mcp *server.MCPServer
|
||||
searcher *search.Searcher
|
||||
parser *parser.Registry
|
||||
matcher *query.Matcher
|
||||
lspManager *lsp.Manager
|
||||
editor *edit.Engine
|
||||
readSem chan struct{} // Semaphore for limiting concurrent file reads
|
||||
querySem chan struct{} // Semaphore for limiting concurrent AST queries
|
||||
sessionPrefs sessionPrefsPtr // Atomic pointer; populated by OnAfterInitialize hook
|
||||
}
|
||||
|
||||
// New creates a new MCP server instance.
|
||||
@@ -70,40 +71,48 @@ func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
|
||||
s.lspManager = lsp.NewManager(cfg.WorkspaceRoot, logger)
|
||||
}
|
||||
|
||||
// Build OnAfterInitialize hook that parses client capability prefs.
|
||||
// Signature (from mcp-go v0.48.0 hooks.go):
|
||||
// func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult)
|
||||
hooks := &server.Hooks{}
|
||||
hooks.AddAfterInitialize(func(_ context.Context, _ any, msg *mcp.InitializeRequest, _ *mcp.InitializeResult) {
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
raw, _ := msg.Params.Capabilities.Experimental["filepuff"].(map[string]any)
|
||||
prefs := ParseSessionPrefs(raw)
|
||||
s.sessionPrefs.Store(&prefs)
|
||||
})
|
||||
|
||||
// Create MCP server
|
||||
mcpServer := server.NewMCPServer(
|
||||
"mcp-filepuff",
|
||||
"1.0.0",
|
||||
"2.0.0",
|
||||
server.WithLogging(),
|
||||
server.WithHooks(hooks),
|
||||
)
|
||||
s.mcp = mcpServer
|
||||
|
||||
// Register tools
|
||||
s.registerTools()
|
||||
|
||||
// Register help resources (filepuff://help/<tool>)
|
||||
s.registerResources()
|
||||
|
||||
// Register filepuff://read/{+path} resource template for large-file access.
|
||||
s.registerReadResource()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// registerTools registers all available tools with the MCP server.
|
||||
func (s *Server) registerTools() {
|
||||
// Register ping tool for health checks
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("ping",
|
||||
mcp.WithDescription("Health check - returns pong to verify the server is running.\n\n"+
|
||||
"Returns: \"pong\" text string."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
),
|
||||
s.handlePing,
|
||||
)
|
||||
|
||||
// Register file_search tool
|
||||
if s.searcher != nil {
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("file_search",
|
||||
mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines.\n\n"+
|
||||
"Returns: Results grouped by file with match context. Format: \"Found N matches in M files:\" followed by file sections, "+
|
||||
"each with matching lines prefixed by \"L{line}│\" and context lines prefixed by \" │\".\n\n"+
|
||||
"Example: {\"pattern\": \"func.*Error\", \"file_types\": [\"go\"], \"max_results\": 20}"),
|
||||
mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines. "+
|
||||
"See resource filepuff://help/file_search for flags and examples."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("pattern",
|
||||
mcp.Required(),
|
||||
@@ -127,7 +136,16 @@ func (s *Server) registerTools() {
|
||||
mcp.Description("Number of context lines around matches (default: 2)"),
|
||||
),
|
||||
mcp.WithNumber("max_results",
|
||||
mcp.Description("Maximum number of results to return"),
|
||||
mcp.Description("Maximum number of results to return (page size for pagination)"),
|
||||
),
|
||||
mcp.WithBoolean("cluster",
|
||||
mcp.Description("Coalesce consecutive match lines into ranges (L12-14│ text). Drops context lines. Default: false."),
|
||||
),
|
||||
mcp.WithString("cursor",
|
||||
mcp.Description("Pagination cursor from a previous truncated response. Pass back to fetch the next page."),
|
||||
),
|
||||
mcp.WithBoolean("verbose",
|
||||
mcp.Description("Emit \"Found N matches in M files:\" preamble. Default: false (v2 default)."),
|
||||
),
|
||||
),
|
||||
s.handleFileSearch,
|
||||
@@ -137,23 +155,8 @@ func (s *Server) registerTools() {
|
||||
// Register file_read tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("file_read",
|
||||
mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary.\n\n"+
|
||||
"Token-saving features:\n"+
|
||||
" previous_etag: skip re-reading unchanged files (returns '[unchanged, etag: ...]' if unchanged)\n"+
|
||||
" symbol_name: read only a named function/struct/class (eliminates ast_query round-trip)\n"+
|
||||
" symbols_only=true: return only symbol list, ~95% fewer tokens (requires include_ast=true)\n"+
|
||||
" no_line_numbers=true: strip the line-number prefix (~10%% savings)\n"+
|
||||
" line_number_interval=N: print line numbers only every N lines\n"+
|
||||
" collapse_blank_lines=true: collapse consecutive blank lines to one\n"+
|
||||
" max_lines=N: truncate output with omitted count notice\n"+
|
||||
" paths=[...]: read multiple files in one call\n\n"+
|
||||
"All responses include '[etag: hex]' footer for use as previous_etag in subsequent reads.\n\n"+
|
||||
"Examples:\n"+
|
||||
" Full file: {\"path\": \"main.go\"}\n"+
|
||||
" Etag check: {\"path\": \"main.go\", \"previous_etag\": \"a3f9c2b1\"}\n"+
|
||||
" By symbol: {\"path\": \"server.go\", \"symbol_name\": \"handleFileRead\"}\n"+
|
||||
" Batch: {\"paths\": [\"a.go\", \"b.go\"]}\n"+
|
||||
" Line range: {\"path\": \"main.go\", \"line_start\": 10, \"line_end\": 50}"),
|
||||
mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary. "+
|
||||
"See resource filepuff://help/file_read for flags and examples."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("path",
|
||||
mcp.Description("Path to the file to read (required unless paths is provided)"),
|
||||
@@ -181,20 +184,36 @@ func (s *Server) registerTools() {
|
||||
mcp.Description("Include AST symbol summary (functions, classes, types, etc.)"),
|
||||
),
|
||||
mcp.WithBoolean("symbols_only",
|
||||
mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true."),
|
||||
mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true. Alias: mode='symbols_only'."),
|
||||
),
|
||||
mcp.WithString("mode",
|
||||
mcp.Description("Output mode: 'full' (default, full file), 'skeleton' (signatures + { ... } stubs, bodies elided), 'symbols_only' (symbol list only, alias for symbols_only=true)."),
|
||||
),
|
||||
mcp.WithArray("strip",
|
||||
mcp.Description("Strip content classes before line-numbering. Values: 'imports' (remove import blocks), 'license' (remove leading license comment), 'block_comments' (remove /* */ and Python triple-quoted strings). Emits [stripped: ...] footer."),
|
||||
mcp.WithStringItems(),
|
||||
),
|
||||
mcp.WithNumber("max_lines",
|
||||
mcp.Description("Maximum number of lines to return (for token efficiency). Applied after line_start/line_end."),
|
||||
),
|
||||
mcp.WithBoolean("no_line_numbers",
|
||||
mcp.Description("Omit the ' 12│ ' line number prefix entirely. Saves ~10% tokens. line_number_interval=0 has the same effect."),
|
||||
mcp.Description("Omit the ' 12\u2502 ' line number prefix entirely. Saves ~10% tokens. line_number_interval=0 has the same effect."),
|
||||
),
|
||||
mcp.WithNumber("line_number_interval",
|
||||
mcp.Description("Print line numbers only every N lines (default: 1 = every line). E.g. 10 = anchor every 10th line plus first/last. 0 = no line numbers."),
|
||||
),
|
||||
mcp.WithBoolean("compact_line_numbers",
|
||||
mcp.Description("Use compact line prefix '12\u2502' instead of ' 12\u2502 ' (no padding, no trailing space). Works with line_number_interval. Default off."),
|
||||
),
|
||||
mcp.WithBoolean("collapse_blank_lines",
|
||||
mcp.Description("Collapse runs of consecutive blank lines to a single blank line. Useful for token savings on heavily-spaced code."),
|
||||
),
|
||||
mcp.WithBoolean("force_inline",
|
||||
mcp.Description("Always return file content inline, bypassing the resource-link threshold. Default: false."),
|
||||
),
|
||||
mcp.WithNumber("max_inline_bytes",
|
||||
mcp.Description("Per-call inline threshold override in bytes. If set, overrides server resource_link_threshold_bytes for this call only. 0 = use server default."),
|
||||
),
|
||||
),
|
||||
s.handleFileRead,
|
||||
)
|
||||
@@ -202,13 +221,8 @@ func (s *Server) registerTools() {
|
||||
// Register ast_query tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("ast_query",
|
||||
mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types.\n\n"+
|
||||
"Returns: \"Found N match(es):\" followed by entries in format \"**file:line** (node_type)\" with code blocks "+
|
||||
"and captured variables ($NAME=value). Returns \"No matches found.\" when no results.\n\n"+
|
||||
"Examples:\n"+
|
||||
" Go error funcs: {\"pattern\": \"func $NAME($$$ARGS) error\", \"language\": \"go\"}\n"+
|
||||
" Python classes: {\"pattern\": \"class $NAME: $$$BODY\", \"language\": \"python\"}\n"+
|
||||
" Named function: {\"pattern\": \"func $NAME($$$ARGS)\", \"language\": \"go\", \"name_exact\": \"NewServer\"}"),
|
||||
mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types. "+
|
||||
"See resource filepuff://help/ast_query for flags and examples."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("pattern",
|
||||
mcp.Required(),
|
||||
@@ -233,7 +247,16 @@ func (s *Server) registerTools() {
|
||||
mcp.WithStringItems(),
|
||||
),
|
||||
mcp.WithNumber("max_results",
|
||||
mcp.Description("Maximum number of results to return (default: 100)"),
|
||||
mcp.Description("Maximum number of results to return (default: 100, page size for pagination)"),
|
||||
),
|
||||
mcp.WithString("format",
|
||||
mcp.Description("Output format: \"verbose\" (default, full code+captures), \"compact\" (one line per match), \"location\" (file:line only)"),
|
||||
),
|
||||
mcp.WithString("cursor",
|
||||
mcp.Description("Pagination cursor from a previous truncated response. Pass back to fetch the next page."),
|
||||
),
|
||||
mcp.WithBoolean("verbose",
|
||||
mcp.Description("Emit \"Found N match(es):\" preamble. Default: false (v2 default)."),
|
||||
),
|
||||
),
|
||||
s.handleASTQuery,
|
||||
@@ -241,62 +264,15 @@ func (s *Server) registerTools() {
|
||||
|
||||
// Register LSP-based tools if LSP is enabled
|
||||
if s.lspManager != nil {
|
||||
// Register symbol_at tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("symbol_at",
|
||||
mcp.WithDescription("Get information about the symbol at a specific position in a file. Returns type, documentation, and definition location using LSP when available.\n\n"+
|
||||
"Returns: \"**Symbol Information**\" followed by hover/type information from LSP, or \"**Symbol Information** (AST fallback)\" "+
|
||||
"with node type and text when LSP unavailable. Returns \"No symbol information available at this position.\" when nothing is found.\n\n"+
|
||||
"Example: {\"file\": \"server.go\", \"line\": 45, \"column\": 6}"),
|
||||
mcp.NewTool("lsp_query",
|
||||
mcp.WithDescription("Query LSP for symbol info, definition, or references at a specific file position. "+
|
||||
"See resource filepuff://help/lsp_query for flags and examples."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("file",
|
||||
mcp.WithString("action",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file"),
|
||||
mcp.Description("LSP operation: hover | definition | references"),
|
||||
),
|
||||
mcp.WithNumber("line",
|
||||
mcp.Required(),
|
||||
mcp.Description("Line number (1-indexed)"),
|
||||
),
|
||||
mcp.WithNumber("column",
|
||||
mcp.Required(),
|
||||
mcp.Description("Column number (1-indexed)"),
|
||||
),
|
||||
),
|
||||
s.handleSymbolAt,
|
||||
)
|
||||
|
||||
// Register find_definition tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("find_definition",
|
||||
mcp.WithDescription("Find the definition of the symbol at a specific position. Uses LSP to locate where a function, variable, type, etc. is defined.\n\n"+
|
||||
"Returns: \"Found N definition(s):\" with entries showing \"**file:line:column**\" and a 3-line code preview "+
|
||||
"with the target line marked by \">\". Returns \"No definition found.\" when the symbol has no definition.\n\n"+
|
||||
"Example: {\"file\": \"handler.go\", \"line\": 23, \"column\": 10}"),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file"),
|
||||
),
|
||||
mcp.WithNumber("line",
|
||||
mcp.Required(),
|
||||
mcp.Description("Line number (1-indexed)"),
|
||||
),
|
||||
mcp.WithNumber("column",
|
||||
mcp.Required(),
|
||||
mcp.Description("Column number (1-indexed)"),
|
||||
),
|
||||
),
|
||||
s.handleFindDefinition,
|
||||
)
|
||||
|
||||
// Register find_references tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("find_references",
|
||||
mcp.WithDescription("Find all references to the symbol at a specific position. Uses LSP to locate all usages of a function, variable, type, etc.\n\n"+
|
||||
"Returns: \"Found N reference(s):\" grouped by file, each showing \"**file** (count)\" with locations as "+
|
||||
"\"L{line}:{column}\". Returns \"No references found.\" when no usages exist.\n\n"+
|
||||
"Example: {\"file\": \"types.go\", \"line\": 5, \"column\": 6}"),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file"),
|
||||
@@ -310,23 +286,24 @@ func (s *Server) registerTools() {
|
||||
mcp.Description("Column number (1-indexed)"),
|
||||
),
|
||||
mcp.WithBoolean("include_declaration",
|
||||
mcp.Description("Include the declaration in results (default: true)"),
|
||||
mcp.Description("Include the declaration in results. Only valid for action=references (default: true)."),
|
||||
),
|
||||
mcp.WithBoolean("compact",
|
||||
mcp.Description("Compact output: one line per file with all refs in brackets. Only valid for action=references. Default: false."),
|
||||
),
|
||||
mcp.WithBoolean("verbose",
|
||||
mcp.Description("Emit count/header preamble. Applies to all actions. Default: false."),
|
||||
),
|
||||
),
|
||||
s.handleFindReferences,
|
||||
s.handleLSPQuery,
|
||||
)
|
||||
}
|
||||
|
||||
// Register edit tools
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("edit_apply",
|
||||
mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.).\n\n"+
|
||||
"Returns: \"**Edit Applied Successfully**\" followed by a unified diff of the changes made. "+
|
||||
"For code files, validates syntax before writing — returns an error if the edit would produce invalid syntax.\n\n"+
|
||||
"Examples:\n"+
|
||||
" AST mode: {\"file\": \"main.go\", \"operation\": \"replace\", \"selector_kind\": \"function_declaration\", \"selector_name\": \"Hello\", \"new_content\": \"func Hello() {\\n\\treturn\\n}\"}\n"+
|
||||
" Text mode: {\"file\": \"README.md\", \"operation\": \"replace\", \"selector_text\": \"## Old Header\", \"new_content\": \"## New Header\"}\n"+
|
||||
" Line range: {\"file\": \"config.yaml\", \"operation\": \"replace\", \"selector_line\": 5, \"selector_line_end\": 10, \"new_content\": \"key: value\"}"),
|
||||
mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.). "+
|
||||
"See resource filepuff://help/edit_apply for flags and examples."),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file to edit"),
|
||||
@@ -362,19 +339,17 @@ func (s *Server) registerTools() {
|
||||
mcp.WithString("selector_pattern",
|
||||
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
|
||||
),
|
||||
mcp.WithString("response",
|
||||
mcp.Description("Response format: \"count\" (default, \"+3 -1\" line counts), \"diff\" (full unified diff), \"none\" (empty). Default: count."),
|
||||
),
|
||||
mcp.WithBoolean("compact_response",
|
||||
mcp.Description("Return only the modified symbol's content instead of a full diff. Requires selector_name. Saves tokens on large-file edits."),
|
||||
mcp.Description("Deprecated: use response=count. Alias for response=\"count\" kept for pre-v2 compatibility."),
|
||||
),
|
||||
),
|
||||
s.handleEditApply,
|
||||
)
|
||||
}
|
||||
|
||||
// handlePing handles the ping health check tool.
|
||||
func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("pong"), nil
|
||||
}
|
||||
|
||||
// Run starts the MCP server and blocks until shutdown.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
// Set up signal handling for graceful shutdown
|
||||
|
||||
+259
-39
@@ -7,6 +7,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -49,45 +50,6 @@ func TestNew(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
|
||||
result, err := srv.handlePing(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("handlePing() error = %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handlePing() returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the result contains "pong"
|
||||
contents := result.Content
|
||||
if len(contents) == 0 {
|
||||
t.Fatal("handlePing() returned empty content")
|
||||
}
|
||||
|
||||
textContent, ok := contents[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("handlePing() did not return text content")
|
||||
}
|
||||
|
||||
if textContent.Text != "pong" {
|
||||
t.Errorf("handlePing() = %v, want 'pong'", textContent.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleFileRead(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -484,3 +446,261 @@ func TestSplitLinesLongLine(t *testing.T) {
|
||||
t.Error("the 500KB long line was not found in splitLines output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleEditApplyResponseCount verifies the default response=count format "+N -M".
|
||||
func TestHandleEditApplyResponseCount(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("Hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"file": testFile,
|
||||
"operation": "replace",
|
||||
"selector_kind": "function_declaration",
|
||||
"selector_name": "Hello",
|
||||
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
|
||||
// no response flag → default "count"
|
||||
}
|
||||
|
||||
result, err := srv.handleEditApply(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleEditApply error = %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("handleEditApply returned empty result")
|
||||
}
|
||||
textContent, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("expected text content")
|
||||
}
|
||||
// Should be "+N -M" format
|
||||
text := textContent.Text
|
||||
if !strings.HasPrefix(text, "+") {
|
||||
t.Errorf("response=count should start with +, got: %q", text)
|
||||
}
|
||||
if !strings.Contains(text, " -") {
|
||||
t.Errorf("response=count should contain -N, got: %q", text)
|
||||
}
|
||||
// Must NOT contain diff syntax or old preamble
|
||||
if strings.Contains(text, "@@") || strings.Contains(text, "Edit Applied") {
|
||||
t.Errorf("response=count must not contain diff markers, got: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleEditApplyResponseDiff verifies response=diff returns unified diff.
|
||||
func TestHandleEditApplyResponseDiff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("Hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"file": testFile,
|
||||
"operation": "replace",
|
||||
"selector_kind": "function_declaration",
|
||||
"selector_name": "Hello",
|
||||
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
|
||||
"response": "diff",
|
||||
}
|
||||
|
||||
result, err := srv.handleEditApply(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleEditApply error = %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("handleEditApply returned empty result")
|
||||
}
|
||||
textContent, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("expected text content")
|
||||
}
|
||||
text := textContent.Text
|
||||
if !strings.Contains(text, "diff") {
|
||||
t.Errorf("response=diff should contain diff, got: %q", text)
|
||||
}
|
||||
// Must NOT have old "Edit Applied Successfully" preamble
|
||||
if strings.Contains(text, "Edit Applied Successfully") {
|
||||
t.Errorf("v2 diff should not have old preamble, got: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleEditApplyResponseNone verifies response=none returns empty string.
|
||||
func TestHandleEditApplyResponseNone(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("Hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"file": testFile,
|
||||
"operation": "replace",
|
||||
"selector_kind": "function_declaration",
|
||||
"selector_name": "Hello",
|
||||
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
|
||||
"response": "none",
|
||||
}
|
||||
|
||||
result, err := srv.handleEditApply(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleEditApply error = %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("handleEditApply returned empty result")
|
||||
}
|
||||
textContent, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("expected text content")
|
||||
}
|
||||
if textContent.Text != "" {
|
||||
t.Errorf("response=none should return empty string, got: %q", textContent.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleFileReadBatchDedup verifies that identical files in batch mode emit [duplicate of ...].
|
||||
func TestHandleFileReadBatchDedup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "a.go")
|
||||
content := `package main
|
||||
|
||||
func Hello() {}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write a.go: %v", err)
|
||||
}
|
||||
// Make b.go with identical content
|
||||
testFile2 := filepath.Join(tmpDir, "b.go")
|
||||
if err := os.WriteFile(testFile2, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write b.go: %v", err)
|
||||
}
|
||||
// c.go with different content
|
||||
testFile3 := filepath.Join(tmpDir, "c.go")
|
||||
if err := os.WriteFile(testFile3, []byte("package main\n"), 0600); err != nil {
|
||||
t.Fatalf("failed to write c.go: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false, MaxFileSize: 1024 * 1024}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"paths": []interface{}{testFile, testFile2, testFile3},
|
||||
}
|
||||
|
||||
result, err := srv.handleFileRead(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileRead() error = %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("handleFileRead() returned empty result")
|
||||
}
|
||||
textContent, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("expected text content")
|
||||
}
|
||||
text := textContent.Text
|
||||
if !strings.Contains(text, "[duplicate of") {
|
||||
t.Errorf("expected duplicate pointer for b.go, got:\n%s", text)
|
||||
}
|
||||
// a.go should have full content
|
||||
if !strings.Contains(text, "--- "+testFile+" ---") {
|
||||
t.Errorf("expected a.go header, got:\n%s", text)
|
||||
}
|
||||
// c.go should have full content (different hash)
|
||||
if !strings.Contains(text, "--- "+testFile3+" ---") {
|
||||
t.Errorf("expected c.go header, got:\n%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleFileSearchVerbose verifies verbose=true emits "Found N matches in M files:" preamble.
|
||||
func TestHandleFileSearchVerbose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
if err := os.WriteFile(testFile, []byte("package main\n\nfunc Hello() {}\n"), 0600); err != nil {
|
||||
t.Fatalf("write test file: %v", err)
|
||||
}
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false, MaxFileSize: 1024 * 1024, SearchTimeout: 10 * time.Second}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
if srv.searcher == nil {
|
||||
t.Skip("ripgrep not available")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "Hello",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"verbose": true,
|
||||
}
|
||||
result, err := srv.handleFileSearch(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileSearch error = %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("handleFileSearch returned empty")
|
||||
}
|
||||
textContent, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("expected text content")
|
||||
}
|
||||
if !strings.Contains(textContent.Text, "Found ") {
|
||||
t.Errorf("verbose=true should emit preamble, got:\n%s", textContent.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
// Package server implements the MCP server for file operations.
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SessionPrefs holds client-declared session-wide preferences parsed from
|
||||
// InitializeRequest.Params.Capabilities.Experimental["filepuff"].
|
||||
//
|
||||
// These act as defaults; explicit per-call flags always override them.
|
||||
//
|
||||
// Supported keys in the "filepuff" experimental map:
|
||||
//
|
||||
// terse bool — no-op (v2 default is already terse; reserved for future)
|
||||
// default_format string — ast_query format default ("verbose"|"compact"|"location")
|
||||
// default_max_results int — file_search and ast_query max_results when not supplied
|
||||
// default_cluster bool — file_search cluster default
|
||||
// compact_refs bool — lsp_query references compact default
|
||||
// line_numbers string — file_read line prefix default ("none"|"compact"|"full")
|
||||
// resource_link_threshold int — per-session override for cfg.ResourceLinkThresholdBytes
|
||||
type SessionPrefs struct {
|
||||
// ASTQueryFormat is the default format for ast_query ("verbose", "compact", "location").
|
||||
// Empty string means "use handler built-in default".
|
||||
ASTQueryFormat string
|
||||
|
||||
// DefaultMaxResults is the default max_results for file_search and ast_query.
|
||||
// 0 means "use handler built-in default".
|
||||
DefaultMaxResults int
|
||||
|
||||
// DefaultCluster is the default cluster flag for file_search.
|
||||
// nil means "use handler built-in default (false)".
|
||||
DefaultCluster *bool
|
||||
|
||||
// CompactRefs is the default compact flag for lsp_query action=references.
|
||||
// nil means "use handler built-in default (false)".
|
||||
CompactRefs *bool
|
||||
|
||||
// LineNumbers controls the file_read line prefix default.
|
||||
// "" = use handler built-in default (full).
|
||||
// "none" = no line numbers.
|
||||
// "compact" = compact prefix (N│content).
|
||||
// "full" = standard padded prefix ( N│ content).
|
||||
LineNumbers string
|
||||
|
||||
// ResourceLinkThreshold overrides cfg.ResourceLinkThresholdBytes for this session.
|
||||
// 0 means "use config default".
|
||||
ResourceLinkThreshold int
|
||||
}
|
||||
|
||||
// boolPtr returns a pointer to a bool value.
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
|
||||
// ParseSessionPrefs parses the raw map from
|
||||
// InitializeRequest.Params.Capabilities.Experimental["filepuff"].
|
||||
// Unknown keys are silently ignored. Type mismatches for individual keys are
|
||||
// silently ignored (key is treated as absent). Returns zero-value SessionPrefs
|
||||
// when raw is nil or empty — callers should treat zero values as "use built-in defaults".
|
||||
func ParseSessionPrefs(raw map[string]any) SessionPrefs {
|
||||
if len(raw) == 0 {
|
||||
return SessionPrefs{}
|
||||
}
|
||||
|
||||
var p SessionPrefs
|
||||
|
||||
if v, ok := raw["default_format"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
switch s {
|
||||
case "verbose", "compact", "location":
|
||||
p.ASTQueryFormat = s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := raw["default_max_results"]; ok {
|
||||
if n := toInt(v); n > 0 {
|
||||
p.DefaultMaxResults = n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := raw["default_cluster"]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
p.DefaultCluster = boolPtr(b)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := raw["compact_refs"]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
p.CompactRefs = boolPtr(b)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := raw["line_numbers"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
switch s {
|
||||
case "none", "compact", "full":
|
||||
p.LineNumbers = s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := raw["resource_link_threshold"]; ok {
|
||||
if n := toInt(v); n >= 0 {
|
||||
p.ResourceLinkThreshold = n
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// toInt converts numeric JSON-decoded values (float64, int, int64) to int.
|
||||
// Returns 0 for unsupported types or negative values.
|
||||
func toInt(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
if n >= 0 {
|
||||
return int(n)
|
||||
}
|
||||
case int:
|
||||
if n >= 0 {
|
||||
return n
|
||||
}
|
||||
case int64:
|
||||
if n >= 0 {
|
||||
return int(n)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// sessionPrefsPtr is an atomic pointer helper for thread-safe access to *SessionPrefs.
|
||||
// We use atomic store/load so the hook (called once at init) and handlers (many goroutines)
|
||||
// never race. The prefs are write-once after initialization.
|
||||
type sessionPrefsPtr struct {
|
||||
p unsafe.Pointer // *SessionPrefs
|
||||
}
|
||||
|
||||
func (sp *sessionPrefsPtr) Store(prefs *SessionPrefs) {
|
||||
atomic.StorePointer(&sp.p, unsafe.Pointer(prefs))
|
||||
}
|
||||
|
||||
func (sp *sessionPrefsPtr) Load() *SessionPrefs {
|
||||
return (*SessionPrefs)(atomic.LoadPointer(&sp.p))
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// ---- Session prefs integration tests ----
|
||||
|
||||
// setSessionPrefs injects prefs directly on the server (bypasses MCP hook machinery).
|
||||
func setSessionPrefs(srv *Server, prefs SessionPrefs) {
|
||||
srv.sessionPrefs.Store(&prefs)
|
||||
}
|
||||
|
||||
// TestSessionPrefsFileReadLineNumbersNone verifies that session pref line_numbers=none
|
||||
// disables line-number prefixes and is overridden by explicit compact_line_numbers=true.
|
||||
func TestSessionPrefsFileReadLineNumbersNone(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := "package main\n\nfunc Foo() {}\n"
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"line_numbers": "none"}))
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Without explicit override: session pref should suppress line numbers.
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{"path": testFile}
|
||||
result, err := srv.handleFileRead(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileRead: %v", err)
|
||||
}
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
// Standard line-number format is " 1│ "; with no_line_numbers it's absent.
|
||||
if strings.Contains(text, " 1│") {
|
||||
t.Errorf("session line_numbers=none: expected no line-number prefix, got:\n%s", text)
|
||||
}
|
||||
|
||||
// Explicit per-call compact_line_numbers=true should override session none.
|
||||
req2 := mcp.CallToolRequest{}
|
||||
req2.Params.Arguments = map[string]interface{}{
|
||||
"path": testFile,
|
||||
"compact_line_numbers": true,
|
||||
}
|
||||
result2, err := srv.handleFileRead(ctx, req2)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileRead (explicit compact): %v", err)
|
||||
}
|
||||
text2 := result2.Content[0].(mcp.TextContent).Text
|
||||
// Compact format emits "1│" prefix.
|
||||
if !strings.Contains(text2, "\u2502") {
|
||||
t.Errorf("explicit compact_line_numbers should override session none, got:\n%s", text2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPrefsFileReadLineNumbersCompact verifies line_numbers=compact session pref.
|
||||
func TestSessionPrefsFileReadLineNumbersCompact(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "x.go")
|
||||
if err := os.WriteFile(testFile, []byte("package main\nfunc Bar() {}\n"), 0600); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, _ := New(cfg, logger)
|
||||
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"line_numbers": "compact"}))
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{"path": testFile}
|
||||
result, _ := srv.handleFileRead(ctx, req)
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
// Standard padded prefix is " 1│ "; compact is "1│".
|
||||
if strings.Contains(text, " 1\u2502") {
|
||||
t.Errorf("session line_numbers=compact should use compact prefix, got:\n%s", text)
|
||||
}
|
||||
// Should still have the │ separator somewhere.
|
||||
if !strings.Contains(text, "\u2502") {
|
||||
t.Errorf("session line_numbers=compact should still have \u2502 separator, got:\n%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPrefsResourceLinkThreshold verifies per-session threshold override.
|
||||
func TestSessionPrefsResourceLinkThreshold(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "big.go")
|
||||
var sb strings.Builder
|
||||
sb.WriteString("package main\n\nfunc Foo() {\n")
|
||||
for i := 0; i < 15; i++ {
|
||||
sb.WriteString("// comment line\n")
|
||||
}
|
||||
sb.WriteString("}\n")
|
||||
if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
// Config threshold = 0 (disabled) so content is always inlined by default.
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20, ResourceLinkThresholdBytes: 0}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, _ := New(cfg, logger)
|
||||
|
||||
// Set session threshold = 10 bytes (tiny), so any real file triggers resource-link.
|
||||
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"resource_link_threshold": float64(10)}))
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{"path": testFile}
|
||||
result, _ := srv.handleFileRead(ctx, req)
|
||||
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected content")
|
||||
}
|
||||
_, isLink := result.Content[0].(mcp.ResourceLink)
|
||||
_, isText := result.Content[0].(mcp.TextContent)
|
||||
if !isLink && isText {
|
||||
t.Error("expected ResourceLink when session threshold is very small, got TextContent")
|
||||
}
|
||||
|
||||
// force_inline should still bypass even a session threshold.
|
||||
req2 := mcp.CallToolRequest{}
|
||||
req2.Params.Arguments = map[string]interface{}{"path": testFile, "force_inline": true}
|
||||
result2, _ := srv.handleFileRead(ctx, req2)
|
||||
if _, ok := result2.Content[0].(mcp.TextContent); !ok {
|
||||
t.Error("force_inline=true should bypass session threshold and return TextContent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPrefsASTQueryFormat verifies default_format session pref for ast_query.
|
||||
func TestSessionPrefsASTQueryFormat(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
if err := os.WriteFile(testFile, []byte("package main\n\nfunc Greet() {}\n"), 0600); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, _ := New(cfg, logger)
|
||||
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_format": "compact"}))
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME()",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
// no format key → session default should apply
|
||||
}
|
||||
result, err := srv.handleASTQuery(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleASTQuery: %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("empty result")
|
||||
}
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
// Compact format emits one-line results without "**file:line**" markers.
|
||||
if strings.Contains(text, "**") {
|
||||
t.Errorf("session default_format=compact: expected compact output (no **), got:\n%s", text)
|
||||
}
|
||||
|
||||
// Explicit format=verbose should override session compact.
|
||||
req2 := mcp.CallToolRequest{}
|
||||
req2.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME()",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"format": "verbose",
|
||||
}
|
||||
result2, _ := srv.handleASTQuery(ctx, req2)
|
||||
if result2 != nil && len(result2.Content) > 0 {
|
||||
text2 := result2.Content[0].(mcp.TextContent).Text
|
||||
if !strings.Contains(text2, "**") {
|
||||
t.Errorf("explicit format=verbose should override session compact, got:\n%s", text2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPrefsASTQueryMaxResults verifies default_max_results for ast_query.
|
||||
func TestSessionPrefsASTQueryMaxResults(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
// Build a file with 5 functions.
|
||||
var sb strings.Builder
|
||||
sb.WriteString("package main\n\n")
|
||||
for i := 0; i < 5; i++ {
|
||||
sb.WriteString(fmt.Sprintf("func Fn%c() {}\n\n", rune('A'+i)))
|
||||
}
|
||||
testFile := filepath.Join(tmpDir, "many.go")
|
||||
if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, _ := New(cfg, logger)
|
||||
// Session pref: max 2 results.
|
||||
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_max_results": float64(2)}))
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME()",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
// no max_results → session pref of 2 should apply
|
||||
}
|
||||
result, err := srv.handleASTQuery(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleASTQuery: %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("empty result")
|
||||
}
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
// With 5 funcs and max=2, output should mention remaining.
|
||||
if !strings.Contains(text, "remaining") && !strings.Contains(text, "cursor") {
|
||||
t.Errorf("session max_results=2 with 5 matches should produce cursor line, got:\n%s", text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPrefsFileSearchDefaultCluster verifies default_cluster session pref.
|
||||
func TestSessionPrefsFileSearchDefaultCluster(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "x.go")
|
||||
content := "package main\n\nfunc Foo() {}\nfunc Foo2() {}\n"
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
cfg := &config.Config{
|
||||
WorkspaceRoot: tmpDir,
|
||||
MaxFileSize: 1 << 20,
|
||||
SearchTimeout: 10 * time.Second,
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if srv.searcher == nil {
|
||||
t.Skip("ripgrep not available")
|
||||
}
|
||||
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_cluster": true}))
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func",
|
||||
"paths": []interface{}{tmpDir},
|
||||
// no cluster flag → session default_cluster=true should apply
|
||||
}
|
||||
result, err := srv.handleFileSearch(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileSearch: %v", err)
|
||||
}
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Skip("search returned no results")
|
||||
}
|
||||
// Verify call succeeded (cluster behaviour is ripgrep-version dependent).
|
||||
_ = result.Content[0].(mcp.TextContent).Text
|
||||
}
|
||||
|
||||
// TestSessionPrefsMaxResultsExplicitOverride verifies explicit call-time max_results
|
||||
// overrides session pref for ast_query.
|
||||
func TestSessionPrefsMaxResultsExplicitOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
var sb strings.Builder
|
||||
sb.WriteString("package main\n\n")
|
||||
for i := 0; i < 5; i++ {
|
||||
sb.WriteString(fmt.Sprintf("func Fn%c() {}\n\n", rune('A'+i)))
|
||||
}
|
||||
testFile := filepath.Join(tmpDir, "many.go")
|
||||
if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
srv, _ := New(cfg, logger)
|
||||
// Session wants 2, caller supplies 10 — all 5 should fit without cursor.
|
||||
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_max_results": float64(2)}))
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME()",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
"max_results": 10, // explicit override
|
||||
}
|
||||
result, _ := srv.handleASTQuery(ctx, req)
|
||||
if result == nil || len(result.Content) == 0 {
|
||||
t.Fatal("empty result")
|
||||
}
|
||||
text := result.Content[0].(mcp.TextContent).Text
|
||||
// With max_results=10 and only 5 funcs, no cursor line expected.
|
||||
if strings.Contains(text, "remaining") {
|
||||
t.Errorf("explicit max_results=10 should override session 2; 5 funcs fit; no cursor expected, got:\n%s", text)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestParseSessionPrefsEmpty verifies zero-value result for nil/empty input.
|
||||
func TestParseSessionPrefsEmpty(t *testing.T) {
|
||||
p := ParseSessionPrefs(nil)
|
||||
if p.ASTQueryFormat != "" {
|
||||
t.Errorf("ASTQueryFormat: want \"\", got %q", p.ASTQueryFormat)
|
||||
}
|
||||
if p.DefaultMaxResults != 0 {
|
||||
t.Errorf("DefaultMaxResults: want 0, got %d", p.DefaultMaxResults)
|
||||
}
|
||||
if p.DefaultCluster != nil {
|
||||
t.Errorf("DefaultCluster: want nil, got %v", *p.DefaultCluster)
|
||||
}
|
||||
if p.CompactRefs != nil {
|
||||
t.Errorf("CompactRefs: want nil, got %v", *p.CompactRefs)
|
||||
}
|
||||
if p.LineNumbers != "" {
|
||||
t.Errorf("LineNumbers: want \"\", got %q", p.LineNumbers)
|
||||
}
|
||||
if p.ResourceLinkThreshold != 0 {
|
||||
t.Errorf("ResourceLinkThreshold: want 0, got %d", p.ResourceLinkThreshold)
|
||||
}
|
||||
|
||||
// Also test with an empty (non-nil) map.
|
||||
p2 := ParseSessionPrefs(map[string]any{})
|
||||
if p2.ASTQueryFormat != "" || p2.DefaultMaxResults != 0 {
|
||||
t.Error("empty map should produce zero-value prefs")
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseSessionPrefsAllFields verifies full round-trip with all supported keys.
|
||||
func TestParseSessionPrefsAllFields(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"terse": true, // no-op; should not produce an error
|
||||
"default_format": "compact",
|
||||
"default_max_results": float64(50), // JSON numbers decode as float64
|
||||
"default_cluster": true,
|
||||
"compact_refs": true,
|
||||
"line_numbers": "none",
|
||||
"resource_link_threshold": float64(32768),
|
||||
}
|
||||
p := ParseSessionPrefs(raw)
|
||||
|
||||
if p.ASTQueryFormat != "compact" {
|
||||
t.Errorf("ASTQueryFormat: want \"compact\", got %q", p.ASTQueryFormat)
|
||||
}
|
||||
if p.DefaultMaxResults != 50 {
|
||||
t.Errorf("DefaultMaxResults: want 50, got %d", p.DefaultMaxResults)
|
||||
}
|
||||
if p.DefaultCluster == nil || !*p.DefaultCluster {
|
||||
t.Errorf("DefaultCluster: want true, got %v", p.DefaultCluster)
|
||||
}
|
||||
if p.CompactRefs == nil || !*p.CompactRefs {
|
||||
t.Errorf("CompactRefs: want true, got %v", p.CompactRefs)
|
||||
}
|
||||
if p.LineNumbers != "none" {
|
||||
t.Errorf("LineNumbers: want \"none\", got %q", p.LineNumbers)
|
||||
}
|
||||
if p.ResourceLinkThreshold != 32768 {
|
||||
t.Errorf("ResourceLinkThreshold: want 32768, got %d", p.ResourceLinkThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseSessionPrefsLineNumbersVariants tests all valid line_numbers values.
|
||||
func TestParseSessionPrefsLineNumbersVariants(t *testing.T) {
|
||||
for _, want := range []string{"none", "compact", "full"} {
|
||||
p := ParseSessionPrefs(map[string]any{"line_numbers": want})
|
||||
if p.LineNumbers != want {
|
||||
t.Errorf("line_numbers=%q: got %q", want, p.LineNumbers)
|
||||
}
|
||||
}
|
||||
// Invalid value → ignored (empty string).
|
||||
p := ParseSessionPrefs(map[string]any{"line_numbers": "bogus"})
|
||||
if p.LineNumbers != "" {
|
||||
t.Errorf("invalid line_numbers should be ignored, got %q", p.LineNumbers)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseSessionPrefsFormatVariants tests all valid default_format values.
|
||||
func TestParseSessionPrefsFormatVariants(t *testing.T) {
|
||||
for _, want := range []string{"verbose", "compact", "location"} {
|
||||
p := ParseSessionPrefs(map[string]any{"default_format": want})
|
||||
if p.ASTQueryFormat != want {
|
||||
t.Errorf("default_format=%q: got %q", want, p.ASTQueryFormat)
|
||||
}
|
||||
}
|
||||
// Invalid value → ignored.
|
||||
p := ParseSessionPrefs(map[string]any{"default_format": "yaml"})
|
||||
if p.ASTQueryFormat != "" {
|
||||
t.Errorf("invalid format should be ignored, got %q", p.ASTQueryFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseSessionPrefsTypeMismatch verifies that wrong types are silently ignored.
|
||||
func TestParseSessionPrefsTypeMismatch(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"default_format": 123, // wrong type (int instead of string)
|
||||
"default_max_results": "fifty", // wrong type
|
||||
"default_cluster": "yes", // wrong type (string instead of bool)
|
||||
"compact_refs": 42, // wrong type
|
||||
"line_numbers": true, // wrong type
|
||||
"resource_link_threshold": "big", // wrong type
|
||||
}
|
||||
p := ParseSessionPrefs(raw)
|
||||
if p.ASTQueryFormat != "" {
|
||||
t.Errorf("type mismatch for format should produce empty, got %q", p.ASTQueryFormat)
|
||||
}
|
||||
if p.DefaultMaxResults != 0 {
|
||||
t.Errorf("type mismatch for max_results should produce 0, got %d", p.DefaultMaxResults)
|
||||
}
|
||||
if p.DefaultCluster != nil {
|
||||
t.Errorf("type mismatch for cluster should produce nil")
|
||||
}
|
||||
if p.CompactRefs != nil {
|
||||
t.Errorf("type mismatch for compact_refs should produce nil")
|
||||
}
|
||||
if p.LineNumbers != "" {
|
||||
t.Errorf("type mismatch for line_numbers should produce empty, got %q", p.LineNumbers)
|
||||
}
|
||||
if p.ResourceLinkThreshold != 0 {
|
||||
t.Errorf("type mismatch for threshold should produce 0, got %d", p.ResourceLinkThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseSessionPrefsNegativeValues verifies negative numbers are rejected.
|
||||
func TestParseSessionPrefsNegativeValues(t *testing.T) {
|
||||
p := ParseSessionPrefs(map[string]any{
|
||||
"default_max_results": float64(-5),
|
||||
"resource_link_threshold": float64(-1),
|
||||
})
|
||||
if p.DefaultMaxResults != 0 {
|
||||
t.Errorf("negative max_results should be rejected, got %d", p.DefaultMaxResults)
|
||||
}
|
||||
if p.ResourceLinkThreshold != 0 {
|
||||
t.Errorf("negative threshold should be rejected, got %d", p.ResourceLinkThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseSessionPrefsIntCoercion verifies int and int64 inputs also work.
|
||||
func TestParseSessionPrefsIntCoercion(t *testing.T) {
|
||||
p := ParseSessionPrefs(map[string]any{
|
||||
"default_max_results": int(25),
|
||||
"resource_link_threshold": int64(16384),
|
||||
})
|
||||
if p.DefaultMaxResults != 25 {
|
||||
t.Errorf("int max_results: want 25, got %d", p.DefaultMaxResults)
|
||||
}
|
||||
if p.ResourceLinkThreshold != 16384 {
|
||||
t.Errorf("int64 threshold: want 16384, got %d", p.ResourceLinkThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseSessionPrefsClusterFalse ensures default_cluster=false stores a non-nil false.
|
||||
func TestParseSessionPrefsClusterFalse(t *testing.T) {
|
||||
p := ParseSessionPrefs(map[string]any{"default_cluster": false})
|
||||
if p.DefaultCluster == nil {
|
||||
t.Error("default_cluster=false should store non-nil pointer")
|
||||
}
|
||||
if *p.DefaultCluster != false {
|
||||
t.Error("default_cluster=false: want false pointer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionPrefsAtomicStore verifies sessionPrefsPtr is readable after Store.
|
||||
func TestSessionPrefsAtomicStore(t *testing.T) {
|
||||
var sp sessionPrefsPtr
|
||||
if sp.Load() != nil {
|
||||
t.Error("uninitialised Load() should return nil")
|
||||
}
|
||||
|
||||
prefs := ParseSessionPrefs(map[string]any{"default_format": "compact"})
|
||||
sp.Store(&prefs)
|
||||
|
||||
loaded := sp.Load()
|
||||
if loaded == nil {
|
||||
t.Fatal("Load() returned nil after Store")
|
||||
}
|
||||
if loaded.ASTQueryFormat != "compact" {
|
||||
t.Errorf("loaded format: want compact, got %q", loaded.ASTQueryFormat)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user