From 5ad975ee7ae9ba88ebe0d87824af36a717c79881 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sun, 19 Apr 2026 19:56:49 +0100 Subject: [PATCH] V2/token optimization (#11) * v2.0: token-optimization overhaul Additive (backward-compatible flags): - file_read: skeleton mode, strip (imports/license/block_comments), compact_line_numbers, 8-char etag with prefix-match compat - ast_query: format=verbose|compact|location, pagination cursor - file_search: cluster mode, pagination cursor - lsp_query (references): compact output Breaking (v2): - Preambles removed; opt-in verbose=true restores - edit_apply: response=count|diff|none, default count - ping tool removed - symbol_at/find_definition/find_references merged into lsp_query - Tool descriptions trimmed -83%, help moved to filepuff://help/ - Batch file_read dedups by etag Protocol: - ResourceLink returned for file_read >64 KiB (force_inline override) - OnAfterInitialize hook reads capabilities.experimental.filepuff for session defaults (default_format, default_max_results, default_cluster, compact_refs, line_numbers, resource_link_threshold) * fix: drop --max-total-count from ripgrep args The flag does not exist in stable ripgrep (confirmed up to 15.1.0 -- "unrecognized flag --max-total-count, similar flags that are available: --max-count"). Every file_search call failed on hosts with stock rg. --max-count is per-file, not a drop-in replacement, so rely on the in-process truncation in parseOutput that was already the documented safety net. --- internal/config/config.go | 59 +- internal/cursor/cursor.go | 65 ++ internal/cursor/cursor_test.go | 57 ++ internal/parser/skeleton.go | 345 ++++++++ internal/parser/strip.go | 299 +++++++ internal/query/format_test.go | 186 ++++ internal/query/query.go | 139 ++- internal/search/format_test.go | 131 +++ internal/search/search.go | 123 ++- internal/search/search_test.go | 4 +- internal/server/features_test.go | 344 ++++++++ internal/server/handlers_ast.go | 303 ++++--- internal/server/handlers_edit.go | 71 +- internal/server/handlers_file.go | 387 ++++++++- internal/server/handlers_file_test.go | 909 ++++++++++++++++++++ internal/server/handlers_lsp.go | 199 +++-- internal/server/handlers_lsp_test.go | 181 ++++ internal/server/help_content.go | 183 ++++ internal/server/integration_test.go | 28 +- internal/server/resources.go | 156 ++++ internal/server/resources_test.go | 90 ++ internal/server/server.go | 215 ++--- internal/server/server_test.go | 298 ++++++- internal/server/session.go | 145 ++++ internal/server/session_integration_test.go | 313 +++++++ internal/server/session_test.go | 186 ++++ 26 files changed, 4909 insertions(+), 507 deletions(-) create mode 100644 internal/cursor/cursor.go create mode 100644 internal/cursor/cursor_test.go create mode 100644 internal/parser/skeleton.go create mode 100644 internal/parser/strip.go create mode 100644 internal/query/format_test.go create mode 100644 internal/search/format_test.go create mode 100644 internal/server/features_test.go create mode 100644 internal/server/handlers_file_test.go create mode 100644 internal/server/handlers_lsp_test.go create mode 100644 internal/server/help_content.go create mode 100644 internal/server/resources.go create mode 100644 internal/server/resources_test.go create mode 100644 internal/server/session.go create mode 100644 internal/server/session_integration_test.go create mode 100644 internal/server/session_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 38f27da..a5963ed 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,43 +13,46 @@ import ( // Config holds all configuration options for the MCP server. type Config struct { - Formatters map[string]string `json:"formatters"` - WorkspaceRoot string `json:"workspace_root"` - LSPTimeout time.Duration `json:"lsp_timeout"` - SearchTimeout time.Duration `json:"search_timeout"` - MaxFileSize int64 `json:"max_file_size"` - MaxParseSize int64 `json:"max_parse_size"` - MaxSearchResults int `json:"max_search_results"` - MaxEditSize int64 `json:"max_edit_size"` - EnableLSP bool `json:"enable_lsp"` - FollowSymlinks bool `json:"follow_symlinks"` - RespectGitignore bool `json:"respect_gitignore"` + Formatters map[string]string `json:"formatters"` + WorkspaceRoot string `json:"workspace_root"` + LSPTimeout time.Duration `json:"lsp_timeout"` + SearchTimeout time.Duration `json:"search_timeout"` + MaxFileSize int64 `json:"max_file_size"` + MaxParseSize int64 `json:"max_parse_size"` + MaxSearchResults int `json:"max_search_results"` + MaxEditSize int64 `json:"max_edit_size"` + EnableLSP bool `json:"enable_lsp"` + FollowSymlinks bool `json:"follow_symlinks"` + RespectGitignore bool `json:"respect_gitignore"` + ResourceLinkThresholdBytes int `json:"resource_link_threshold_bytes"` } // Default values for configuration. const ( - DefaultLSPTimeout = 5 * time.Minute - DefaultSearchTimeout = 30 * time.Second - DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB - DefaultMaxParseSize = 10 * 1024 * 1024 // 10 MB - DefaultMaxSearchResults = 1000 - DefaultMaxEditSize = 100 * 1024 // 100 KB + DefaultLSPTimeout = 5 * time.Minute + DefaultSearchTimeout = 30 * time.Second + DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB + DefaultMaxParseSize = 10 * 1024 * 1024 // 10 MB + DefaultMaxSearchResults = 1000 + DefaultMaxEditSize = 100 * 1024 // 100 KB + DefaultResourceLinkThresholdBytes = 64 * 1024 // 64 KiB ) // Default returns a Config with default values. func Default() *Config { return &Config{ - WorkspaceRoot: ".", - LSPTimeout: DefaultLSPTimeout, - SearchTimeout: DefaultSearchTimeout, - MaxFileSize: DefaultMaxFileSize, - MaxParseSize: DefaultMaxParseSize, - MaxSearchResults: DefaultMaxSearchResults, - MaxEditSize: DefaultMaxEditSize, - EnableLSP: true, - Formatters: make(map[string]string), - FollowSymlinks: true, - RespectGitignore: true, + WorkspaceRoot: ".", + LSPTimeout: DefaultLSPTimeout, + SearchTimeout: DefaultSearchTimeout, + MaxFileSize: DefaultMaxFileSize, + MaxParseSize: DefaultMaxParseSize, + MaxSearchResults: DefaultMaxSearchResults, + MaxEditSize: DefaultMaxEditSize, + EnableLSP: true, + Formatters: make(map[string]string), + FollowSymlinks: true, + RespectGitignore: true, + ResourceLinkThresholdBytes: DefaultResourceLinkThresholdBytes, } } diff --git a/internal/cursor/cursor.go b/internal/cursor/cursor.go new file mode 100644 index 0000000..fddc90b --- /dev/null +++ b/internal/cursor/cursor.go @@ -0,0 +1,65 @@ +// Package cursor implements opaque pagination cursors for MCP tools. +// A cursor encodes an offset into a result stream plus a query hash so stale +// cursors from different queries fail cleanly. +// +// Encoding: base64url(json({"offset":N,"query_hash":"hex"})) +// The query_hash is a hex-encoded sha256 over the deterministic query params. +package cursor + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "sort" + "strings" + + json "github.com/goccy/go-json" +) + +// payload is the JSON structure inside a cursor. +type payload struct { + Offset int `json:"offset"` + QueryHash string `json:"query_hash"` +} + +// Encode creates an opaque cursor string from an offset and query hash. +func Encode(offset int, queryHash string) string { + p := payload{Offset: offset, QueryHash: queryHash} + b, _ := json.Marshal(p) + return base64.RawURLEncoding.EncodeToString(b) +} + +// Decode parses a cursor string. Returns offset, queryHash, error. +func Decode(cursor string) (int, string, error) { + b, err := base64.RawURLEncoding.DecodeString(cursor) + if err != nil { + return 0, "", fmt.Errorf("invalid cursor encoding: %w", err) + } + var p payload + if err := json.Unmarshal(b, &p); err != nil { + return 0, "", fmt.Errorf("invalid cursor payload: %w", err) + } + return p.Offset, p.QueryHash, nil +} + +// HashParams computes a deterministic query hash from a set of key=value params. +// Keys are sorted before hashing so order doesn't matter. +func HashParams(params map[string]string) string { + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + sort.Strings(keys) + + var sb strings.Builder + for _, k := range keys { + sb.WriteString(k) + sb.WriteByte('=') + sb.WriteString(params[k]) + sb.WriteByte('\n') + } + + sum := sha256.Sum256([]byte(sb.String())) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/cursor/cursor_test.go b/internal/cursor/cursor_test.go new file mode 100644 index 0000000..9e85f1d --- /dev/null +++ b/internal/cursor/cursor_test.go @@ -0,0 +1,57 @@ +package cursor + +import ( + "testing" +) + +func TestEncodeDecodRoundTrip(t *testing.T) { + hash := HashParams(map[string]string{"a": "1", "b": "2"}) + + encoded := Encode(42, hash) + if encoded == "" { + t.Fatal("Encode returned empty string") + } + + offset, gotHash, err := Decode(encoded) + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if offset != 42 { + t.Errorf("offset: got %d, want 42", offset) + } + if gotHash != hash { + t.Errorf("hash mismatch: got %s, want %s", gotHash, hash) + } +} + +func TestDecodeInvalid(t *testing.T) { + _, _, err := Decode("!!!notbase64!!!") + if err == nil { + t.Error("expected error for invalid base64, got nil") + } +} + +func TestDecodeCorruptPayload(t *testing.T) { + import64 := "bm90anNvbg" // "notjson" in base64 + _, _, err := Decode(import64) + if err == nil { + t.Error("expected error for corrupt payload, got nil") + } +} + +func TestHashParamsDeterministic(t *testing.T) { + // Same params regardless of insertion order + h1 := HashParams(map[string]string{"z": "last", "a": "first"}) + h2 := HashParams(map[string]string{"a": "first", "z": "last"}) + if h1 != h2 { + t.Errorf("hash not deterministic: %s != %s", h1, h2) + } +} + +func TestHashParamsDifferentForDifferentQueries(t *testing.T) { + h1 := HashParams(map[string]string{"pattern": "foo"}) + h2 := HashParams(map[string]string{"pattern": "bar"}) + if h1 == h2 { + t.Error("different queries should produce different hashes") + } +} diff --git a/internal/parser/skeleton.go b/internal/parser/skeleton.go new file mode 100644 index 0000000..d7342d7 --- /dev/null +++ b/internal/parser/skeleton.go @@ -0,0 +1,345 @@ +// Package parser provides skeleton rendering for source files. +package parser + +import ( + "context" + "fmt" + "strings" + + "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" + sitter "github.com/smacker/go-tree-sitter" +) + +// SkeletonFile returns a skeleton representation of the file: top-level declarations +// with signatures and doc-comments intact, but function/method bodies replaced with +// a language-appropriate placeholder. +// +// Supported: go, typescript, javascript, python, rust. +// Other languages fall back to symbols_only (AST summary text). +// Returns (skeletonText, isFullSkeleton, error). +func SkeletonFile(ctx context.Context, reg *Registry, filename string, content []byte) (string, bool, error) { + result, err := reg.Parse(ctx, filename, content) + if err != nil { + return "", false, err + } + lang := protocol.DetectLanguage(filename) + + switch lang { + case protocol.LangGo: + return skeletonGo(result.Tree, content), true, nil + case protocol.LangTypeScript, protocol.LangJavaScript: + return skeletonTS(result.Tree, content), true, nil + case protocol.LangPython: + return skeletonPython(result.Tree, content), true, nil + case protocol.LangRust: + return skeletonRust(result.Tree, content), true, nil + default: + // TODO: skeleton for c, cpp, elixir, html, vue — fall back to symbols_only + syms := ExtractSymbols(result.Tree, content, lang, filename) + return renderSymbolsOnly(syms, filename, lang, content), false, nil + } +} + +// renderSymbolsOnly renders a simple symbol list (fallback for unsupported languages). +func renderSymbolsOnly(syms []protocol.Symbol, _ string, lang protocol.Language, content []byte) string { + lines := strings.Split(string(content), "\n") + var sb strings.Builder + sb.WriteString(fmt.Sprintf("// skeleton unavailable for %s — symbol list only\n", lang)) + for _, s := range syms { + sb.WriteString(fmt.Sprintf("// %s %s (line %d)\n", s.Kind, s.Name, s.Location.Line)) + if s.Doc != "" { + sb.WriteString(fmt.Sprintf("// doc: %s\n", s.Doc)) + } + if s.Location.Line >= 1 && s.Location.Line <= len(lines) { + sb.WriteString(lines[s.Location.Line-1] + " { ... }\n") + } + } + return sb.String() +} + +// ---- Go skeleton ---- + +// skeletonGoBodyNodes lists Go node types that have a body field to replace. +var skeletonGoBodyNodes = map[string]string{ + "function_declaration": "body", + "method_declaration": "body", +} + +func skeletonGo(tree *sitter.Tree, content []byte) string { + if tree == nil { + return string(content) + } + root := tree.RootNode() + var sb strings.Builder + skeletonGoNode(root, content, &sb) + return sb.String() +} + +func skeletonGoNode(node *sitter.Node, content []byte, sb *strings.Builder) { + if node == nil { + return + } + nodeType := node.Type() + + if bodyField, ok := skeletonGoBodyNodes[nodeType]; ok { + body := node.ChildByFieldName(bodyField) + if body != nil { + sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t") + sb.WriteString(sig) + sb.WriteString("{ ... }\n\n") + return + } + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") + return + } + + switch nodeType { + case "source_file": + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child == nil { + continue + } + skeletonGoNode(child, content, sb) + } + return + case "comment": + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") + return + default: + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") + } +} + +// ---- TypeScript / JavaScript skeleton ---- + +func skeletonTS(tree *sitter.Tree, content []byte) string { + if tree == nil { + return string(content) + } + root := tree.RootNode() + var sb strings.Builder + skeletonTSNode(root, content, &sb) + return sb.String() +} + +func skeletonTSNode(node *sitter.Node, content []byte, sb *strings.Builder) { + if node == nil { + return + } + nodeType := node.Type() + + switch nodeType { + case "program": + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child == nil { + continue + } + skeletonTSNode(child, content, sb) + } + return + case "function_declaration": + body := node.ChildByFieldName("body") + if body != nil { + sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t") + sb.WriteString(sig) + sb.WriteString("{ ... }\n\n") + return + } + case "class_declaration": + body := node.ChildByFieldName("body") + if body != nil { + header := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t") + sb.WriteString(header) + sb.WriteString("{\n") + for i := 0; i < int(body.ChildCount()); i++ { + child := body.Child(i) + if child == nil { + continue + } + if child.Type() == "method_definition" { + methBody := child.ChildByFieldName("body") + if methBody != nil { + methSig := strings.TrimRight(string(content[child.StartByte():methBody.StartByte()]), " \t") + sb.WriteString(" ") + sb.WriteString(methSig) + sb.WriteString("{ ... }\n") + continue + } + } + sb.WriteString(" ") + sb.WriteString(string(content[child.StartByte():child.EndByte()])) + sb.WriteString("\n") + } + sb.WriteString("}\n\n") + return + } + case "comment": + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") + return + } + + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") +} + +// ---- Python skeleton ---- + +func skeletonPython(tree *sitter.Tree, content []byte) string { + if tree == nil { + return string(content) + } + root := tree.RootNode() + var sb strings.Builder + skeletonPythonNode(root, content, &sb, "") + return sb.String() +} + +func skeletonPythonNode(node *sitter.Node, content []byte, sb *strings.Builder, indent string) { + if node == nil { + return + } + nodeType := node.Type() + + switch nodeType { + case "module": + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child == nil { + continue + } + skeletonPythonNode(child, content, sb, indent) + } + return + case "function_definition", "decorated_definition": + nodeText := string(content[node.StartByte():node.EndByte()]) + lines := strings.SplitN(nodeText, "\n", 2) + sb.WriteString(indent) + sb.WriteString(lines[0]) + sb.WriteString("\n") + sb.WriteString(indent) + sb.WriteString(" ...\n\n") + return + case "class_definition": + body := node.ChildByFieldName("body") + if body != nil { + header := string(content[node.StartByte():body.StartByte()]) + firstLine := strings.SplitN(header, "\n", 2)[0] + sb.WriteString(indent) + sb.WriteString(firstLine) + sb.WriteString("\n") + for i := 0; i < int(body.ChildCount()); i++ { + child := body.Child(i) + if child == nil { + continue + } + if child.Type() == "function_definition" || child.Type() == "decorated_definition" { + childText := string(content[child.StartByte():child.EndByte()]) + childLines := strings.SplitN(childText, "\n", 2) + sb.WriteString(indent + " ") + sb.WriteString(childLines[0]) + sb.WriteString("\n") + sb.WriteString(indent + " ...") + sb.WriteString("\n") + continue + } + if child.Type() == "expression_statement" { + sb.WriteString(indent + " ") + sb.WriteString(string(content[child.StartByte():child.EndByte()])) + sb.WriteString("\n") + continue + } + } + sb.WriteString("\n") + return + } + case "comment": + sb.WriteString(indent) + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") + return + } + + sb.WriteString(indent) + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") +} + +// ---- Rust skeleton ---- + +func skeletonRust(tree *sitter.Tree, content []byte) string { + if tree == nil { + return string(content) + } + root := tree.RootNode() + var sb strings.Builder + skeletonRustNode(root, content, &sb) + return sb.String() +} + +func skeletonRustNode(node *sitter.Node, content []byte, sb *strings.Builder) { + if node == nil { + return + } + nodeType := node.Type() + + switch nodeType { + case "source_file": + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child == nil { + continue + } + skeletonRustNode(child, content, sb) + } + return + case "function_item": + body := node.ChildByFieldName("body") + if body != nil { + sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t") + sb.WriteString(sig) + sb.WriteString("{ ... }\n\n") + return + } + case "impl_item": + body := node.ChildByFieldName("body") + if body != nil { + header := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t") + sb.WriteString(header) + sb.WriteString("{\n") + for i := 0; i < int(body.ChildCount()); i++ { + child := body.Child(i) + if child == nil { + continue + } + if child.Type() == "function_item" { + methBody := child.ChildByFieldName("body") + if methBody != nil { + methSig := strings.TrimRight(string(content[child.StartByte():methBody.StartByte()]), " \t") + sb.WriteString(" ") + sb.WriteString(methSig) + sb.WriteString("{ ... }\n") + continue + } + } + sb.WriteString(" ") + sb.WriteString(string(content[child.StartByte():child.EndByte()])) + sb.WriteString("\n") + } + sb.WriteString("}\n\n") + return + } + case "line_comment", "block_comment": + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") + return + } + + sb.WriteString(string(content[node.StartByte():node.EndByte()])) + sb.WriteString("\n") +} diff --git a/internal/parser/strip.go b/internal/parser/strip.go new file mode 100644 index 0000000..c4cc243 --- /dev/null +++ b/internal/parser/strip.go @@ -0,0 +1,299 @@ +package parser + +import ( + "strings" + + "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" +) + +// StripFlag names the categories of content to remove. +type StripFlag string + +const ( + StripImports StripFlag = "imports" + StripLicense StripFlag = "license" + StripBlockComments StripFlag = "block_comments" +) + +// StripResult holds the stripped content and which flags actually removed content. +type StripResult struct { + Content string + Stripped []StripFlag +} + +// StripContent applies requested strip operations to content, in order: +// license → imports → block_comments. +// lang is used to pick language-specific heuristics. +func StripContent(content string, flags []StripFlag, lang protocol.Language) StripResult { + flagSet := make(map[StripFlag]bool, len(flags)) + for _, f := range flags { + flagSet[f] = true + } + + var stripped []StripFlag + + if flagSet[StripLicense] { + next, removed := stripLicense(content) + if removed { + content = next + stripped = append(stripped, StripLicense) + } + } + + if flagSet[StripImports] { + next, removed := stripImports(content, lang) + if removed { + content = next + stripped = append(stripped, StripImports) + } + } + + if flagSet[StripBlockComments] { + next, removed := stripBlockComments(content, lang) + if removed { + content = next + stripped = append(stripped, StripBlockComments) + } + } + + return StripResult{Content: content, Stripped: stripped} +} + +// stripLicense removes a leading block comment that looks like a license header. +// A comment qualifies if it contains "copyright", "license", or "spdx-license-identifier" (case-insensitive). +func stripLicense(content string) (string, bool) { + trimmed := strings.TrimLeft(content, " \t\n\r") + + // C-style block comment at top + if strings.HasPrefix(trimmed, "/*") { + end := strings.Index(trimmed, "*/") + if end >= 0 { + candidate := trimmed[:end+2] + lower := strings.ToLower(candidate) + if strings.Contains(lower, "copyright") || + strings.Contains(lower, "license") || + strings.Contains(lower, "spdx-license-identifier") { + rest := trimmed[end+2:] + // Consume trailing newline(s) + rest = strings.TrimLeft(rest, "\r\n") + return rest, true + } + } + } + + // Python/hash-style leading comment block + if strings.HasPrefix(trimmed, "#") { + lines := strings.Split(trimmed, "\n") + var commentLines []string + var rest []string + inComment := true + for i, l := range lines { + if inComment && (strings.HasPrefix(l, "#") || strings.TrimSpace(l) == "") { + commentLines = append(commentLines, l) + } else { + rest = lines[i:] + break + } + } + block := strings.Join(commentLines, "\n") + lower := strings.ToLower(block) + if strings.Contains(lower, "copyright") || + strings.Contains(lower, "license") || + strings.Contains(lower, "spdx-license-identifier") { + return strings.Join(rest, "\n"), true + } + } + + return content, false +} + +// stripImports removes top-of-file import blocks, language-specific. +func stripImports(content string, lang protocol.Language) (string, bool) { + switch lang { + case protocol.LangGo: + return stripGoImports(content) + case protocol.LangTypeScript, protocol.LangJavaScript: + return stripTSImports(content) + case protocol.LangPython: + return stripPythonImports(content) + case protocol.LangRust: + return stripRustImports(content) + default: + return content, false + } +} + +// stripGoImports removes Go import(...) or single import "..." declarations. +func stripGoImports(content string) (string, bool) { + lines := strings.Split(content, "\n") + var out []string + removed := false + i := 0 + for i < len(lines) { + trimLine := strings.TrimSpace(lines[i]) + if strings.HasPrefix(trimLine, "import (") || trimLine == "import (" { + // multi-line import block + removed = true + i++ // skip "import (" + for i < len(lines) { + if strings.TrimSpace(lines[i]) == ")" { + i++ // skip closing ")" + break + } + i++ + } + // skip one blank line after + if i < len(lines) && strings.TrimSpace(lines[i]) == "" { + i++ + } + continue + } + if strings.HasPrefix(trimLine, `import "`) || strings.HasPrefix(trimLine, "import `") { + removed = true + i++ + continue + } + out = append(out, lines[i]) + i++ + } + if !removed { + return content, false + } + return strings.Join(out, "\n"), true +} + +// stripTSImports removes TypeScript/JavaScript "import ... from ..." and "require(...)" lines. +func stripTSImports(content string) (string, bool) { + lines := strings.Split(content, "\n") + var out []string + removed := false + for _, l := range lines { + trimLine := strings.TrimSpace(l) + if strings.HasPrefix(trimLine, "import ") || strings.HasPrefix(trimLine, "const {") && strings.Contains(trimLine, "require(") { + removed = true + continue + } + out = append(out, l) + } + if !removed { + return content, false + } + return strings.Join(out, "\n"), true +} + +// stripPythonImports removes Python "import ..." and "from ... import ..." lines. +func stripPythonImports(content string) (string, bool) { + lines := strings.Split(content, "\n") + var out []string + removed := false + for _, l := range lines { + trimLine := strings.TrimSpace(l) + if strings.HasPrefix(trimLine, "import ") || strings.HasPrefix(trimLine, "from ") { + removed = true + continue + } + out = append(out, l) + } + if !removed { + return content, false + } + return strings.Join(out, "\n"), true +} + +// stripRustImports removes Rust "use ..." declarations. +func stripRustImports(content string) (string, bool) { + lines := strings.Split(content, "\n") + var out []string + removed := false + inMulti := false + for _, l := range lines { + trimLine := strings.TrimSpace(l) + if inMulti { + // look for semicolon terminating multi-line use + if strings.Contains(trimLine, ";") { + inMulti = false + } + removed = true + continue + } + if strings.HasPrefix(trimLine, "use ") { + removed = true + if !strings.HasSuffix(trimLine, ";") { + inMulti = true + } + continue + } + out = append(out, l) + } + if !removed { + return content, false + } + return strings.Join(out, "\n"), true +} + +// stripBlockComments removes /* ... */ block comments (Go/TS/C/Rust) +// and Python triple-quoted docstrings. +func stripBlockComments(content string, lang protocol.Language) (string, bool) { + if lang == protocol.LangPython { + return stripPythonDocstrings(content) + } + return stripCStyleBlockComments(content) +} + +// stripCStyleBlockComments removes /* ... */ from content. +func stripCStyleBlockComments(content string) (string, bool) { + removed := false + var sb strings.Builder + i := 0 + for i < len(content) { + if i+1 < len(content) && content[i] == '/' && content[i+1] == '*' { + // find closing */ + end := strings.Index(content[i+2:], "*/") + if end >= 0 { + removed = true + // advance past */ + i = i + 2 + end + 2 + // consume trailing newline + if i < len(content) && content[i] == '\n' { + i++ + } + continue + } + } + sb.WriteByte(content[i]) + i++ + } + if !removed { + return content, false + } + return sb.String(), true +} + +// stripPythonDocstrings removes triple-quoted strings (""" and ”'). +func stripPythonDocstrings(content string) (string, bool) { + removed := false + var sb strings.Builder + i := 0 + for i < len(content) { + if i+2 < len(content) { + triple := content[i : i+3] + if triple == `"""` || triple == `'''` { + end := strings.Index(content[i+3:], triple) + if end >= 0 { + removed = true + i = i + 3 + end + 3 + if i < len(content) && content[i] == '\n' { + i++ + } + continue + } + } + } + sb.WriteByte(content[i]) + i++ + } + if !removed { + return content, false + } + return sb.String(), true +} diff --git a/internal/query/format_test.go b/internal/query/format_test.go new file mode 100644 index 0000000..5d3288b --- /dev/null +++ b/internal/query/format_test.go @@ -0,0 +1,186 @@ +package query + +import ( + "strings" + "testing" + + "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" +) + +// makeResults builds N dummy MatchResults. +func makeResults(n int) []MatchResult { + out := make([]MatchResult, n) + for i := range out { + out[i] = MatchResult{ + File: "file.go", + Location: protocol.Location{ + Line: i + 1, + Column: 1, + }, + Text: "func Foo() {}\nline2", + } + } + return out +} + +// TestFormatResultsVerboseDefault verifies verbose format includes code blocks (no preamble by default). +func TestFormatResultsVerboseDefault(t *testing.T) { + results := makeResults(2) + out := FormatResultsWithOptions(results, 0, "verbose", 0) + // v2 default: no preamble + if strings.Contains(out, "Found ") { + t.Errorf("v2 default should NOT emit preamble, got:\n%s", out) + } + if !strings.Contains(out, "```") { + t.Error("verbose mode should include code blocks") + } +} + +// TestFormatResultsVerbosePreamble verifies verbose=true restores the preamble. +func TestFormatResultsVerbosePreamble(t *testing.T) { + results := makeResults(2) + out := FormatResultsWithOptions(results, 0, "verbose", 0, true) + if !strings.Contains(out, "Found 2 match(es):") { + t.Errorf("expected preamble with verbose=true, got:\n%s", out) + } +} + +func TestFormatResultsCompact(t *testing.T) { + results := makeResults(3) + out := FormatResultsWithOptions(results, 0, "compact", 0) + // v2 default: no preamble in compact mode + // Should NOT have code blocks + if strings.Contains(out, "```") { + t.Error("compact mode should not have code blocks") + } + // Should have one line per match (beyond header) + lines := strings.Split(strings.TrimSpace(out), "\n") + // First two lines are "Found..." and blank, then 3 match lines + matchLines := 0 + for _, l := range lines { + if strings.Contains(l, "file.go:") { + matchLines++ + } + } + if matchLines != 3 { + t.Errorf("expected 3 match lines in compact mode, got %d\nOutput:\n%s", matchLines, out) + } +} + +func TestFormatResultsLocation(t *testing.T) { + results := makeResults(3) + out := FormatResultsWithOptions(results, 0, "location", 0) + if strings.Contains(out, "```") { + t.Error("location mode should not have code blocks") + } + // Should be file:line only + for i := 1; i <= 3; i++ { + expected := "file.go:" + itoa(i) + if !strings.Contains(out, expected) { + t.Errorf("location output missing %s", expected) + } + } +} + +func TestFormatResultsMaxResults(t *testing.T) { + results := makeResults(5) + out := FormatResultsWithOptions(results, 3, "verbose", 0) + // v2 default: no preamble — check that exactly 3 code blocks are present + codeBlockCount := strings.Count(out, "```") + if codeBlockCount != 6 { // 3 opening + 3 closing = 6 + t.Errorf("expected 3 matches (6 backtick markers), got %d in:\n%s", codeBlockCount, out) + } + if !strings.Contains(out, "[remaining: 2]") { + t.Errorf("expected [remaining: 2] footer, got:\n%s", out) + } +} + +func TestFormatResultsOffset(t *testing.T) { + results := makeResults(5) + // Skip first 2, show all remaining + out := FormatResultsWithOptions(results, 0, "verbose", 2) + // offset=2 from 5 results → 3 results; check 3 code blocks + codeBlockCount := strings.Count(out, "```") + if codeBlockCount != 6 { + t.Errorf("expected 3 matches (6 backtick markers) after offset=2, got %d in:\n%s", codeBlockCount, out) + } +} + +func TestFormatResultsOffsetBeyondEnd(t *testing.T) { + results := makeResults(3) + out := FormatResultsWithOptions(results, 0, "verbose", 10) + if out != "No matches found." { + t.Errorf("expected 'No matches found.' for offset beyond end, got: %s", out) + } +} + +func TestFormatResultsPaginationCursor(t *testing.T) { + // Offset=2, maxResults=2, 5 total → show items 3&4, remaining=1 + results := makeResults(5) + out := FormatResultsWithOptions(results, 2, "verbose", 2) + // offset=2, maxResults=2 → items 3&4; check 2 code blocks + codeBlockCount := strings.Count(out, "```") + if codeBlockCount != 4 { + t.Errorf("expected 2 matches (4 backtick markers), got %d in:\n%s", codeBlockCount, out) + } + if !strings.Contains(out, "[remaining: 1]") { + t.Errorf("expected [remaining: 1], got:\n%s", out) + } +} + +func TestFormatResultsEmpty(t *testing.T) { + out := FormatResultsWithOptions(nil, 0, "verbose", 0) + if out != "No matches found." { + t.Errorf("expected 'No matches found.', got: %s", out) + } +} + +func TestFormatResultsBackwardCompat(t *testing.T) { + // FormatResults wrapper should produce same output as FormatResultsWithOptions with verbose=false (default). + results := makeResults(2) + a := FormatResults(results, 0) + b := FormatResultsWithOptions(results, 0, "verbose", 0) + if a != b { + t.Error("FormatResults and FormatResultsWithOptions(verbose,0) should be identical") + } + // Both should have no preamble. + if strings.Contains(a, "Found ") { + t.Error("FormatResults should not emit preamble by default") + } +} + +func TestFirstLineOf(t *testing.T) { + cases := []struct { + input string + maxLen int + want string + }{ + {"hello world", 20, "hello world"}, + {"line1\nline2", 20, "line1"}, + {"\n\nfoo", 20, "foo"}, + {"abcdefghij", 5, "abcd…"}, + } + for _, c := range cases { + got := firstLineOf(c.input, c.maxLen) + if got != c.want { + t.Errorf("firstLineOf(%q, %d) = %q, want %q", c.input, c.maxLen, got, c.want) + } + } +} + +func itoa(n int) string { + if n < 10 { + return string(rune('0' + n)) + } + return strings.TrimRight(strings.TrimRight( + func() string { + buf := make([]byte, 20) + pos := 20 + for n > 0 { + pos-- + buf[pos] = byte('0' + n%10) + n /= 10 + } + return string(buf[pos:]) + }(), ""), "") +} diff --git a/internal/query/query.go b/internal/query/query.go index dc931df..15ee659 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -451,62 +451,121 @@ func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool return true } -// FormatResults formats match results for display. +// FormatResults formats match results for display (backward-compat wrapper, verbose mode). func FormatResults(results []MatchResult, maxResults int) string { + return FormatResultsWithOptions(results, maxResults, "verbose", 0) +} + +// FormatResultsWithOptions formats match results with configurable output format. +// format: "verbose" (default) | "compact" | "location" +// offset: skip this many results before rendering (used for cursor pagination). +// verbose: opt-in variadic — pass true to restore "Found N match(es):" preamble (v1 behaviour). +func FormatResultsWithOptions(results []MatchResult, maxResults int, format string, offset int, verbose ...bool) string { if len(results) == 0 { return "No matches found." } - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results))) - - displayCount := len(results) - truncated := false - if maxResults > 0 && displayCount > maxResults { - displayCount = maxResults - truncated = true + // Apply offset (pagination skip). + if offset > 0 { + if offset >= len(results) { + return "No matches found." + } + results = results[offset:] } - for i := 0; i < displayCount; i++ { - r := results[i] - nodeType := "unknown" - if r.Node != nil { - nodeType = r.Node.Type() - } - sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType)) + // Determine how many to render and whether more remain. + renderCount := len(results) + remaining := 0 + if maxResults > 0 && renderCount > maxResults { + remaining = renderCount - maxResults + renderCount = maxResults + } - // Truncate very long text - text := r.Text - if len(text) > 500 { - text = text[:500] + "..." - } - sb.WriteString("```\n") - sb.WriteString(text) - sb.WriteString("\n```\n") + var sb strings.Builder - // Show captures - if len(r.Captures) > 0 { - sb.WriteString("Captures: ") - first := true - for name, cap := range r.Captures { - if !first { - sb.WriteString(", ") + // Emit preamble only when verbose=true is explicitly passed (opt-in, default off). + wantVerbose := len(verbose) > 0 && verbose[0] + if wantVerbose { + sb.WriteString(fmt.Sprintf("Found %d match(es):\n", renderCount)) + } + + switch format { + case "compact": + for i := 0; i < renderCount; i++ { + r := results[i] + nodeType := "unknown" + if r.Node != nil { + nodeType = r.Node.Type() + } + firstLine := firstLineOf(r.Text, 80) + sb.WriteString(fmt.Sprintf("%s:%d (%s) %s\n", r.File, r.Location.Line, nodeType, firstLine)) + } + + case "location": + for i := 0; i < renderCount; i++ { + r := results[i] + sb.WriteString(fmt.Sprintf("%s:%d\n", r.File, r.Location.Line)) + } + + default: // "verbose" + for i := 0; i < renderCount; i++ { + r := results[i] + nodeType := "unknown" + if r.Node != nil { + nodeType = r.Node.Type() + } + sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType)) + + // Truncate very long text + text := r.Text + if len(text) > 500 { + text = text[:500] + "..." + } + sb.WriteString("```\n") + sb.WriteString(text) + sb.WriteString("\n```\n") + + // Show captures + if len(r.Captures) > 0 { + sb.WriteString("Captures: ") + first := true + for name, cap := range r.Captures { + if !first { + sb.WriteString(", ") + } + first = false + capText := cap.Text + if len(capText) > 50 { + capText = capText[:50] + "..." + } + sb.WriteString(fmt.Sprintf("$%s=%s", name, capText)) } - first = false - capText := cap.Text - if len(capText) > 50 { - capText = capText[:50] + "..." - } - sb.WriteString(fmt.Sprintf("$%s=%s", name, capText)) + sb.WriteString("\n") } sb.WriteString("\n") } - sb.WriteString("\n") } - if truncated { - sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults)) + if remaining > 0 { + // Caller must embed the cursor token; we just append the remaining count hint. + // The actual [cursor: ...] line is written by the handler after calling MakeCursor. + sb.WriteString(fmt.Sprintf("[remaining: %d]\n", remaining)) } return sb.String() } + +// firstLineOf returns the first non-empty line of s, trimmed and capped at maxLen chars. +func firstLineOf(s string, maxLen int) string { + for _, line := range strings.Split(s, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if len(line) > maxLen { + return line[:maxLen-1] + "…" + } + return line + } + return "" +} diff --git a/internal/search/format_test.go b/internal/search/format_test.go new file mode 100644 index 0000000..66eaeb0 --- /dev/null +++ b/internal/search/format_test.go @@ -0,0 +1,131 @@ +package search + +import ( + "strings" + "testing" + + "github.com/lukaszraczylo/mcp-filepuff/internal/config" + "log/slog" + "os" +) + +func newTestSearcher(t *testing.T) *Searcher { + t.Helper() + cfg := &config.Config{ + WorkspaceRoot: t.TempDir(), + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + // Create a Searcher directly without requiring rg + return &Searcher{cfg: cfg, logger: logger, rgPath: "rg"} +} + +func makeSearchResults(lines []int) *SearchResults { + results := make([]Result, len(lines)) + for i, l := range lines { + results[i] = Result{ + File: "/workspace/foo.go", + Line: l, + MatchText: "match at line " + itoa2(l), + } + } + return &SearchResults{Results: results} +} + +// itoa2 simple int to string for test +func itoa2(n int) string { + if n == 0 { + return "0" + } + digits := []byte{} + for n > 0 { + digits = append([]byte{byte('0' + n%10)}, digits...) + n /= 10 + } + return string(digits) +} + +// TestFormatResultsVerboseBackwardCompat verifies that Verbose=true restores the v1 preamble. +func TestFormatResultsVerboseBackwardCompat(t *testing.T) { + s := newTestSearcher(t) + sr := makeSearchResults([]int{1, 5, 10}) + // With Verbose=true, the preamble is emitted (v1 behaviour). + out := s.FormatResultsWithOptions(sr, FormatOptions{Verbose: true}) + if !strings.Contains(out, "Found 3 matches in 1 files") { + t.Errorf("expected header with Verbose=true, got:\n%s", out) + } + if !strings.Contains(out, "L1│") { + t.Errorf("expected L1 match line, got:\n%s", out) + } +} + +// TestFormatResultsDefaultNoPreamble verifies the v2 default has no preamble. +func TestFormatResultsDefaultNoPreamble(t *testing.T) { + s := newTestSearcher(t) + sr := makeSearchResults([]int{1, 5, 10}) + out := s.FormatResults(sr) + if strings.Contains(out, "Found ") { + t.Errorf("v2 default should NOT emit preamble, got:\n%s", out) + } + if !strings.Contains(out, "L1│") { + t.Errorf("expected L1 match line, got:\n%s", out) + } +} + +func TestFormatResultsClusterSingleLine(t *testing.T) { + s := newTestSearcher(t) + // Non-consecutive lines → each should appear separately + sr := makeSearchResults([]int{1, 5, 10}) + out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true}) + if !strings.Contains(out, "L1│") { + t.Errorf("expected L1 cluster, got:\n%s", out) + } + if !strings.Contains(out, "L5│") { + t.Errorf("expected L5 cluster, got:\n%s", out) + } + // Should NOT have " │" context lines in cluster mode + if strings.Contains(out, " │") { + t.Errorf("cluster mode should not have context-line decoration, got:\n%s", out) + } +} + +func TestFormatResultsClusterConsecutive(t *testing.T) { + s := newTestSearcher(t) + // Lines 3,4,5 are consecutive → should be clustered as L3-5 + sr := makeSearchResults([]int{3, 4, 5, 10}) + out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true}) + if !strings.Contains(out, "L3-5│") { + t.Errorf("expected L3-5 cluster, got:\n%s", out) + } + if !strings.Contains(out, "L10│") { + t.Errorf("expected L10 separate cluster, got:\n%s", out) + } +} + +func TestFormatResultsClusterAdjacentMerge(t *testing.T) { + s := newTestSearcher(t) + // Lines 7 and 8 differ by 1 → merge + sr := makeSearchResults([]int{7, 8}) + out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true}) + if !strings.Contains(out, "L7-8│") { + t.Errorf("expected L7-8 cluster, got:\n%s", out) + } +} + +func TestFormatResultsCursorLine(t *testing.T) { + s := newTestSearcher(t) + sr := makeSearchResults([]int{1}) + cursorText := "[cursor: abc123, remaining: 5]" + out := s.FormatResultsWithOptions(sr, FormatOptions{CursorLine: cursorText}) + if !strings.Contains(out, cursorText) { + t.Errorf("expected cursor footer in output, got:\n%s", out) + } +} + +func TestFormatResultsNoMatchesVerbose(t *testing.T) { + s := newTestSearcher(t) + sr := &SearchResults{Results: nil} + out := s.FormatResults(sr) + if out != "No matches found." { + t.Errorf("expected 'No matches found.', got: %s", out) + } +} diff --git a/internal/search/search.go b/internal/search/search.go index 903681b..11c5832 100644 --- a/internal/search/search.go +++ b/internal/search/search.go @@ -219,11 +219,9 @@ func (s *Searcher) buildArgs(req *Request) []string { args = append(args, "--no-ignore") } - // Global result cap — --max-total-count stops rg early across all files. - // Requires ripgrep >= 13.0. In-process truncation in parseOutput is kept as a safety net. - if req.MaxResults > 0 { - args = append(args, fmt.Sprintf("--max-total-count=%d", req.MaxResults)) - } + // Result cap enforced in-process by parseOutput. rg has no cross-file + // total-count flag in stable releases, so we don't pass one; --max-count is + // per-file and would miss results unevenly. // Add pattern args = append(args, "--", req.Pattern) @@ -356,8 +354,21 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes return results, nil } -// FormatResults formats search results for display. +// FormatOptions controls how search results are rendered. +type FormatOptions struct { + Cluster bool // coalesce consecutive matches into line-range blocks + CursorLine string // if non-empty, appended as a footer line + Verbose bool // if true, emit "Found N matches in M files:" preamble (opt-in) +} + +// FormatResults formats search results for display (backward-compat wrapper). func (s *Searcher) FormatResults(results *SearchResults) string { + return s.FormatResultsWithOptions(results, FormatOptions{}) +} + +// FormatResultsWithOptions formats search results with configurable output. +// By default the "Found N matches in M files:" preamble is omitted; set opts.Verbose=true to restore it. +func (s *Searcher) FormatResultsWithOptions(results *SearchResults, opts FormatOptions) string { if len(results.Results) == 0 { return "No matches found." } @@ -374,14 +385,18 @@ func (s *Searcher) FormatResults(results *SearchResults) string { fileResults[r.File] = append(fileResults[r.File], r) } - // Write summary - totalMatches := len(results.Results) - fileCount := len(fileResults) - sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount)) - if results.Truncated { - sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total)) + // Write preamble only when Verbose is requested. + if opts.Verbose { + totalMatches := len(results.Results) + fileCount := len(fileResults) + sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount)) + if results.Truncated { + sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total)) + } + sb.WriteString(":\n\n") + } else if results.Truncated { + sb.WriteString(fmt.Sprintf("(truncated, showing subset of %d total matches)\n\n", results.Total)) } - sb.WriteString(":\n\n") // Write results grouped by file for _, file := range fileOrder { @@ -395,26 +410,82 @@ func (s *Searcher) FormatResults(results *SearchResults) string { sb.WriteString(fmt.Sprintf("**%s**\n", relPath)) - for _, r := range fileResults[file] { - // Write context before - for _, ctx := range r.Context.Before { - sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200))) - } - - // Write match line - sb.WriteString(fmt.Sprintf("L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200))) - - // Write context after - for _, ctx := range r.Context.After { - sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200))) - } + if opts.Cluster { + writeClusteredResults(&sb, fileResults[file]) + } else { + writeVerboseResults(&sb, fileResults[file]) } sb.WriteString("\n") } + if opts.CursorLine != "" { + sb.WriteString(opts.CursorLine) + sb.WriteString("\n") + } + return sb.String() } +// writeVerboseResults writes results in the standard verbose format. +func writeVerboseResults(sb *strings.Builder, results []Result) { + for _, r := range results { + // Write context before + for _, ctx := range r.Context.Before { + fmt.Fprintf(sb, " │ %s\n", truncateLine(ctx, 200)) + } + // Write match line + fmt.Fprintf(sb, "L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200)) + // Write context after + for _, ctx := range r.Context.After { + fmt.Fprintf(sb, " │ %s\n", truncateLine(ctx, 200)) + } + } +} + +// writeClusteredResults coalesces consecutive or adjacent match lines into +// a single "L12-14│ " entry. Context lines are dropped +// in cluster mode to maximise information density. +func writeClusteredResults(sb *strings.Builder, results []Result) { + if len(results) == 0 { + return + } + + type clusterEntry struct { + startLine int + endLine int + firstText string + } + + var clusters []clusterEntry + cur := clusterEntry{ + startLine: results[0].Line, + endLine: results[0].Line, + firstText: results[0].MatchText, + } + + for _, r := range results[1:] { + // Merge if adjacent (within 1 line gap) + if r.Line <= cur.endLine+1 { + if r.Line > cur.endLine { + cur.endLine = r.Line + } + } else { + clusters = append(clusters, cur) + cur = clusterEntry{startLine: r.Line, endLine: r.Line, firstText: r.MatchText} + } + } + clusters = append(clusters, cur) + + for _, c := range clusters { + text := truncateLine(c.firstText, 200) + if c.startLine == c.endLine { + fmt.Fprintf(sb, "L%d│ %s\n", c.startLine, text) + } else { + fmt.Fprintf(sb, "L%d-%d│ %s\n", c.startLine, c.endLine, text) + } + } +} + // truncateLine truncates a line if it exceeds maxLen. func truncateLine(s string, maxLen int) string { if len(s) <= maxLen { diff --git a/internal/search/search_test.go b/internal/search/search_test.go index 73f2a06..ee64cb5 100644 --- a/internal/search/search_test.go +++ b/internal/search/search_test.go @@ -94,8 +94,8 @@ func TestBuildArgs(t *testing.T) { MaxResults: 10, Regex: true, }, - expected: []string{"--json", "--max-total-count=10", "--", "test", "."}, - notExpected: []string{"--ignore-case", "--fixed-strings"}, + expected: []string{"--json", "--", "test", "."}, + notExpected: []string{"--ignore-case", "--fixed-strings", "--max-total-count=10", "--max-count=10"}, }, } diff --git a/internal/server/features_test.go b/internal/server/features_test.go new file mode 100644 index 0000000..5641e2b --- /dev/null +++ b/internal/server/features_test.go @@ -0,0 +1,344 @@ +package server + +import ( + "context" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/lukaszraczylo/mcp-filepuff/internal/config" + "github.com/mark3labs/mcp-go/mcp" +) + +// newTestServer creates a server with a temp workspace containing Go files. +func newFeaturesServer(t *testing.T) (*Server, string) { + t.Helper() + tmpDir := t.TempDir() + + // Write a few Go files for AST queries + goFile1 := filepath.Join(tmpDir, "a.go") + if err := os.WriteFile(goFile1, []byte(`package main + +func Alpha() string { return "alpha" } +func Beta() string { return "beta" } +func Gamma() string { return "gamma" } +`), 0o600); err != nil { + t.Fatal(err) + } + + goFile2 := filepath.Join(tmpDir, "b.go") + if err := os.WriteFile(goFile2, []byte(`package main + +func Delta() int { return 1 } +func Epsilon() int { return 2 } +`), 0o600); err != nil { + t.Fatal(err) + } + + cfg := &config.Config{ + WorkspaceRoot: tmpDir, + EnableLSP: false, + MaxFileSize: config.DefaultMaxFileSize, + MaxParseSize: config.DefaultMaxParseSize, + SearchTimeout: 30 * time.Second, + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error: %v", err) + } + return srv, tmpDir +} + +// ---- Feature 1: ast_query format flag ---- + +func TestASTQueryFormatVerboseDefault(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + ctx := context.Background() + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME() string", + "language": "go", + "paths": []interface{}{tmpDir}, + } + res, err := srv.handleASTQuery(ctx, req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res == nil || len(res.Content) == 0 { + t.Fatal("nil/empty result") + } + text := res.Content[0].(mcp.TextContent).Text + // verbose mode has code blocks + if !strings.Contains(text, "```") { + t.Errorf("verbose mode should have code blocks, got:\n%s", text) + } +} + +func TestASTQueryFormatCompact(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + ctx := context.Background() + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME() string", + "language": "go", + "paths": []interface{}{tmpDir}, + "format": "compact", + } + res, err := srv.handleASTQuery(ctx, req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := res.Content[0].(mcp.TextContent).Text + if strings.Contains(text, "```") { + t.Errorf("compact mode should NOT have code blocks, got:\n%s", text) + } + // Each line should contain file:line (kind) text + for _, line := range strings.Split(strings.TrimSpace(text), "\n") { + if line == "" || strings.HasPrefix(line, "Found") { + continue + } + if !strings.Contains(line, ":") { + t.Errorf("compact line missing ':' separator: %q", line) + } + } +} + +func TestASTQueryFormatLocation(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + ctx := context.Background() + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME() string", + "language": "go", + "paths": []interface{}{tmpDir}, + "format": "location", + } + res, err := srv.handleASTQuery(ctx, req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := res.Content[0].(mcp.TextContent).Text + if strings.Contains(text, "```") { + t.Errorf("location mode should NOT have code blocks, got:\n%s", text) + } + // Lines should be file:linenum only (no parentheses with kind) + for _, line := range strings.Split(strings.TrimSpace(text), "\n") { + if line == "" || strings.HasPrefix(line, "Found") { + continue + } + if strings.Contains(line, "(") { + t.Errorf("location mode should not have node type in parens: %q", line) + } + } +} + +// ---- Feature 3 (ast_query): pagination cursor ---- + +func TestASTQueryPaginationCursor(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + ctx := context.Background() + + // Page 1: max_results=2 + req1 := mcp.CallToolRequest{} + req1.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME() $RET", + "language": "go", + "paths": []interface{}{tmpDir}, + "max_results": float64(2), + } + res1, err := srv.handleASTQuery(ctx, req1) + if err != nil { + t.Fatalf("page1 error: %v", err) + } + text1 := res1.Content[0].(mcp.TextContent).Text + + // Should contain cursor footer if there are more results + if !strings.Contains(text1, "[cursor:") { + // Might have fewer than 2 total results — skip cursor test + t.Logf("no cursor footer (fewer than 2 total matches), skipping pagination round-trip") + return + } + + // Extract cursor token + var cursorToken string + for _, line := range strings.Split(text1, "\n") { + if strings.HasPrefix(line, "[cursor:") { + // [cursor: , remaining: N] + parts := strings.Split(line, " ") + if len(parts) >= 2 { + cursorToken = strings.TrimSuffix(parts[1], ",") + } + break + } + } + if cursorToken == "" { + t.Fatal("could not extract cursor token from output") + } + + // Page 2: pass cursor back + req2 := mcp.CallToolRequest{} + req2.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME() $RET", + "language": "go", + "paths": []interface{}{tmpDir}, + "max_results": float64(2), + "cursor": cursorToken, + } + res2, err := srv.handleASTQuery(ctx, req2) + if err != nil { + t.Fatalf("page2 error: %v", err) + } + text2 := res2.Content[0].(mcp.TextContent).Text + if strings.Contains(text2, "cursor is for a different query") { + t.Error("cursor was rejected as mismatched query") + } + // Page 2 should have results + if strings.Contains(text2, "No matches found.") { + t.Error("page2 should have some results") + } +} + +func TestASTQueryCursorStaleMismatch(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + ctx := context.Background() + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME() string", + "language": "go", + "paths": []interface{}{tmpDir}, + "cursor": "eyJvZmZzZXQiOjIsInF1ZXJ5X2hhc2giOiJkZWFkYmVlZiJ9", // offset=2, hash=deadbeef + } + res, err := srv.handleASTQuery(ctx, req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := res.Content[0].(mcp.TextContent).Text + if !strings.Contains(text, "cursor is for a different query") { + t.Errorf("expected stale cursor error, got:\n%s", text) + } +} + +// ---- Feature 2: file_search cluster ---- + +func TestFileSearchClusterFlag(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + if srv.searcher == nil { + t.Skip("rg not available") + } + ctx := context.Background() + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func", + "paths": []interface{}{tmpDir}, + "cluster": true, + } + res, err := srv.handleFileSearch(ctx, req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res == nil || len(res.Content) == 0 { + t.Fatal("nil/empty result") + } + text := res.Content[0].(mcp.TextContent).Text + if text == "No matches found." { + t.Skip("no matches found (unexpected)") + } + // In cluster mode, should NOT have " │" context decorations + if strings.Contains(text, " │") { + t.Errorf("cluster mode should not contain context-line decoration ' │', got:\n%s", text) + } +} + +func TestFileSearchCursorStaleHash(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + if srv.searcher == nil { + t.Skip("rg not available") + } + ctx := context.Background() + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func", + "paths": []interface{}{tmpDir}, + "cursor": "eyJvZmZzZXQiOjEsInF1ZXJ5X2hhc2giOiJiYWRoYXNoIn0", // offset=1, hash=badhash + } + res, err := srv.handleFileSearch(ctx, req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := res.Content[0].(mcp.TextContent).Text + if !strings.Contains(text, "cursor is for a different query") { + t.Errorf("expected stale cursor error, got:\n%s", text) + } +} + +func TestFileSearchPaginationCursor(t *testing.T) { + srv, tmpDir := newFeaturesServer(t) + if srv.searcher == nil { + t.Skip("rg not available") + } + ctx := context.Background() + + // Page 1: get 1 result + req1 := mcp.CallToolRequest{} + req1.Params.Arguments = map[string]interface{}{ + "pattern": "func", + "paths": []interface{}{tmpDir}, + "max_results": float64(1), + "context_lines": float64(0), + } + res1, err := srv.handleFileSearch(ctx, req1) + if err != nil { + t.Fatalf("page1 error: %v", err) + } + text1 := res1.Content[0].(mcp.TextContent).Text + if !strings.Contains(text1, "[cursor:") { + t.Logf("no cursor in page1 (only 1 total result), skipping round-trip:\n%s", text1) + return + } + + // Extract cursor + var cursorToken string + for _, line := range strings.Split(text1, "\n") { + if strings.HasPrefix(line, "[cursor:") { + parts := strings.Split(line, " ") + if len(parts) >= 2 { + cursorToken = strings.TrimSuffix(parts[1], ",") + } + break + } + } + if cursorToken == "" { + t.Fatal("could not extract cursor from page1") + } + + // Page 2 + req2 := mcp.CallToolRequest{} + req2.Params.Arguments = map[string]interface{}{ + "pattern": "func", + "paths": []interface{}{tmpDir}, + "max_results": float64(1), + "context_lines": float64(0), + "cursor": cursorToken, + } + res2, err := srv.handleFileSearch(ctx, req2) + if err != nil { + t.Fatalf("page2 error: %v", err) + } + text2 := res2.Content[0].(mcp.TextContent).Text + if strings.Contains(text2, "cursor is for a different query") { + t.Error("cursor was rejected as mismatched") + } + if strings.Contains(text2, "No matches found.") { + t.Error("page2 should have matches") + } +} diff --git a/internal/server/handlers_ast.go b/internal/server/handlers_ast.go index f9feb03..c630b98 100644 --- a/internal/server/handlers_ast.go +++ b/internal/server/handlers_ast.go @@ -8,12 +8,193 @@ import ( "path/filepath" "strings" + "github.com/lukaszraczylo/mcp-filepuff/internal/cursor" "github.com/lukaszraczylo/mcp-filepuff/internal/parser" "github.com/lukaszraczylo/mcp-filepuff/internal/query" "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" "github.com/mark3labs/mcp-go/mcp" ) +// astQueryParams holds resolved parameters for an AST query invocation. +type astQueryParams struct { + pattern string + language string + nameMatches string + nameExact string + kindIn []string + paths []string + maxResults int + format string + offset int + queryHash string + verbose bool +} + +// resolveASTQueryParams parses the request and resolves session-pref defaults. +// Returns (params, errorResult, error); when errorResult is non-nil, the caller +// should return it directly. +func (s *Server) resolveASTQueryParams(request mcp.CallToolRequest) (*astQueryParams, *mcp.CallToolResult) { + pattern, err := request.RequireString("pattern") + if err != nil { + return nil, mcp.NewToolResultError("pattern is required") + } + language, err := request.RequireString("language") + if err != nil { + return nil, mcp.NewToolResultError("language is required") + } + + p := &astQueryParams{ + pattern: pattern, + language: language, + nameMatches: request.GetString("name_matches", ""), + nameExact: request.GetString("name_exact", ""), + kindIn: request.GetStringSlice("kind_in", nil), + paths: request.GetStringSlice("paths", nil), + verbose: request.GetBool("verbose", false), + } + + sp := s.sessionPrefs.Load() + var prefsMaxResults int + var prefsFormat string + if sp != nil { + prefsMaxResults = sp.DefaultMaxResults + prefsFormat = sp.ASTQueryFormat + } + p.maxResults = effectiveInt(request, "max_results", prefsMaxResults, 100) + + p.format = request.GetString("format", "") + if p.format == "" { + if prefsFormat != "" { + p.format = prefsFormat + } else { + p.format = "verbose" + } + } + + p.queryHash = cursor.HashParams(map[string]string{ + "pattern": p.pattern, + "language": p.language, + "name_matches": p.nameMatches, + "name_exact": p.nameExact, + "kind_in": strings.Join(p.kindIn, ","), + "paths": strings.Join(p.paths, ","), + }) + + if cursorStr := request.GetString("cursor", ""); cursorStr != "" { + off, hash, decErr := cursor.Decode(cursorStr) + if decErr != nil { + return nil, mcp.NewToolResultError(fmt.Sprintf("invalid cursor: %s", decErr)) + } + if hash != p.queryHash { + return nil, mcp.NewToolResultError("cursor is for a different query, re-run without cursor") + } + p.offset = off + } + + return p, nil +} + +// runASTQueryWalk walks the configured paths and collects matches. +func (s *Server) runASTQueryWalk(ctx context.Context, p *astQueryParams, exts []string) []query.MatchResult { + astQuery := &query.ASTQuery{ + Pattern: p.pattern, + Language: p.language, + Filters: query.QueryFilters{ + NameMatches: p.nameMatches, + NameExact: p.nameExact, + KindIn: p.kindIn, + }, + } + + // Collect limit for early-exit: when paginating we need all results first. + collectLimit := p.maxResults + if p.offset > 0 { + collectLimit = 0 + } + + var allResults []query.MatchResult + for _, searchPath := range p.paths { + if !s.cfg.IsPathAllowed(searchPath) { + continue + } + + walkErr := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if err != nil { + return nil + } + if info.IsDir() { + if strings.HasPrefix(info.Name(), ".") { + return filepath.SkipDir + } + return nil + } + if !hasAnySuffix(path, exts) { + return nil + } + + content, err := os.ReadFile(path) + if err != nil { + return nil + } + if int64(len(content)) > s.cfg.MaxFileSize { + return nil + } + + result, err := s.parser.Parse(ctx, path, content) + if err != nil { + return nil + } + matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path) + if err != nil { + return nil + } + allResults = append(allResults, matches...) + + if collectLimit > 0 && len(allResults) >= collectLimit { + return filepath.SkipAll + } + return nil + }) + if walkErr != nil { + s.logger.Warn("error walking path", "path", searchPath, "error", walkErr) + } + } + return allResults +} + +// hasAnySuffix reports whether path ends with any of the given suffixes. +func hasAnySuffix(path string, suffixes []string) bool { + for _, ext := range suffixes { + if strings.HasSuffix(path, ext) { + return true + } + } + return false +} + +// buildASTCursorFooter computes the cursor footer line for truncated results. +func buildASTCursorFooter(total, offset, maxResults int, queryHash string) string { + if offset > 0 && offset < total { + totalAfterOffset := total - offset + if maxResults > 0 && totalAfterOffset > maxResults { + remaining := totalAfterOffset - maxResults + nextOffset := offset + maxResults + nextCursor := cursor.Encode(nextOffset, queryHash) + return fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining) + } + } else if offset == 0 && maxResults > 0 && total > maxResults { + remaining := total - maxResults + nextCursor := cursor.Encode(maxResults, queryHash) + return fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining) + } + return "" +} + // handleASTQuery handles the ast_query tool. func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Acquire semaphore to limit concurrent queries (prevents CPU exhaustion) @@ -24,121 +205,33 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest return mcp.NewToolResultError("request cancelled"), nil } - pattern, err := request.RequireString("pattern") - if err != nil { - return mcp.NewToolResultError("pattern is required"), nil + p, errResult := s.resolveASTQueryParams(request) + if errResult != nil { + return errResult, nil } - language, err := request.RequireString("language") - if err != nil { - return mcp.NewToolResultError("language is required"), nil + if len(p.paths) == 0 { + p.paths = []string{s.cfg.WorkspaceRoot} } - // Build query - astQuery := &query.ASTQuery{ - Pattern: pattern, - Language: language, - Filters: query.QueryFilters{ - NameMatches: request.GetString("name_matches", ""), - NameExact: request.GetString("name_exact", ""), - KindIn: request.GetStringSlice("kind_in", nil), - }, - } - - maxResults := request.GetInt("max_results", 100) - paths := request.GetStringSlice("paths", nil) - - // Default to workspace root if no paths specified - if len(paths) == 0 { - paths = []string{s.cfg.WorkspaceRoot} - } - - // Find files to search based on language - exts := languageToExtensions(language) + exts := languageToExtensions(p.language) if len(exts) == 0 { - return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir, rust)", language)), nil + return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir, rust)", p.language)), nil } - var allResults []query.MatchResult + allResults := s.runASTQueryWalk(ctx, p, exts) + cursorFooter := buildASTCursorFooter(len(allResults), p.offset, p.maxResults, p.queryHash) - // Walk through paths and find matching files - for _, searchPath := range paths { - // Validate path is within workspace - if !s.cfg.IsPathAllowed(searchPath) { - continue - } - - err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error { - // Check for context cancellation - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - if err != nil { - return nil // Skip files with errors - } - - if info.IsDir() { - // Skip hidden directories - if strings.HasPrefix(info.Name(), ".") { - return filepath.SkipDir - } - return nil - } - - // Check file extension matches language - matched := false - for _, ext := range exts { - if strings.HasSuffix(path, ext) { - matched = true - break - } - } - if !matched { - return nil - } - - // Read and parse file - content, err := os.ReadFile(path) - if err != nil { - return nil // Skip unreadable files - } - - // Check file size - if int64(len(content)) > s.cfg.MaxFileSize { - return nil // Skip large files - } - - // Parse file - result, err := s.parser.Parse(ctx, path, content) - if err != nil { - return nil // Skip unparseable files - } - - // Run query - matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path) - if err != nil { - return nil // Skip on error - } - - allResults = append(allResults, matches...) - - // Stop if we have enough results - if maxResults > 0 && len(allResults) >= maxResults { - return filepath.SkipAll - } - - return nil - }) - if err != nil { - s.logger.Warn("error walking path", "path", searchPath, "error", err) + output := query.FormatResultsWithOptions(allResults, p.maxResults, p.format, p.offset, p.verbose) + if cursorFooter != "" { + // Replace the [remaining: N] placeholder emitted by FormatResultsWithOptions + // with the full [cursor: ..., remaining: N] line. + output = strings.ReplaceAll(output, fmt.Sprintf("[remaining: %d]\n", len(allResults)-p.offset-p.maxResults), cursorFooter+"\n") + // Fallback: if placeholder wasn't present (e.g. format=location), append footer. + if !strings.Contains(output, cursorFooter) { + output = strings.TrimRight(output, "\n") + "\n" + cursorFooter + "\n" } } - - // Format and return results - output := query.FormatResults(allResults, maxResults) return mcp.NewToolResultText(output), nil } diff --git a/internal/server/handlers_edit.go b/internal/server/handlers_edit.go index d067915..3148d63 100644 --- a/internal/server/handlers_edit.go +++ b/internal/server/handlers_edit.go @@ -4,12 +4,10 @@ package server import ( "context" "fmt" - "os" "strings" "github.com/lukaszraczylo/mcp-filepuff/internal/edit" "github.com/lukaszraczylo/mcp-filepuff/pkg/errors" - "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" "github.com/mark3labs/mcp-go/mcp" ) @@ -82,27 +80,56 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest) (* return mcp.NewToolResultError(result.Error), nil } - // compact_response: return just the modified symbol instead of the full diff - if request.GetBool("compact_response", false) && selectorName != "" { - if content, readErr := os.ReadFile(file); readErr == nil { - if start, end, found := s.resolveSymbolLines(ctx, file, content, selectorName, protocol.SymbolKind("")); found { - lines := splitLines(string(content)) - var sb strings.Builder - sb.WriteString(fmt.Sprintf("**Edit Applied** — %s (L%d-L%d):\n\n", selectorName, start, end)) - for i := start - 1; i < end && i < len(lines); i++ { - sb.WriteString(fmt.Sprintf("%4d| %s\n", i+1, lines[i])) - } - return mcp.NewToolResultText(sb.String()), nil - } - } - // fall through to diff if symbol lookup fails + // Determine response mode. + // compact_response is a deprecated alias for response="count". + respMode := request.GetString("response", "count") + if request.GetBool("compact_response", false) { + // Deprecated: use response=count + respMode = "count" } - var output strings.Builder - output.WriteString("**Edit Applied Successfully**\n\n") - output.WriteString("Diff:\n```diff\n") - output.WriteString(result.Diff) - output.WriteString("```\n") + switch respMode { + case "none": + return mcp.NewToolResultText(""), nil - return mcp.NewToolResultText(output.String()), nil + case "count": + // Compute +added/-removed line counts from the unified diff. + added, removed := countDiffLines(result.Diff) + return mcp.NewToolResultText(fmt.Sprintf("+%d -%d", added, removed)), nil + + case "diff": + var output strings.Builder + output.WriteString("Diff:\n```diff\n") + output.WriteString(result.Diff) + output.WriteString("```\n") + return mcp.NewToolResultText(output.String()), nil + + default: + // Fallback: treat unknown values as "diff" for safety. + var output strings.Builder + output.WriteString("Diff:\n```diff\n") + output.WriteString(result.Diff) + output.WriteString("```\n") + return mcp.NewToolResultText(output.String()), nil + } +} + +// countDiffLines counts added (+) and removed (-) lines in a unified diff string. +func countDiffLines(diff string) (added, removed int) { + for _, line := range strings.Split(diff, "\n") { + if len(line) == 0 { + continue + } + switch line[0] { + case '+': + if !strings.HasPrefix(line, "+++") { + added++ + } + case '-': + if !strings.HasPrefix(line, "---") { + removed++ + } + } + } + return } diff --git a/internal/server/handlers_file.go b/internal/server/handlers_file.go index a11c8c3..7c19ed8 100644 --- a/internal/server/handlers_file.go +++ b/internal/server/handlers_file.go @@ -6,10 +6,12 @@ import ( "context" "fmt" "os" + "strconv" "strings" "time" xxhash "github.com/cespare/xxhash/v2" + "github.com/lukaszraczylo/mcp-filepuff/internal/cursor" "github.com/lukaszraczylo/mcp-filepuff/internal/parser" "github.com/lukaszraczylo/mcp-filepuff/internal/search" "github.com/lukaszraczylo/mcp-filepuff/pkg/errors" @@ -35,14 +37,63 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque return mcp.NewToolResultError("pattern is required"), nil } + paths := request.GetStringSlice("paths", nil) + fileTypes := request.GetStringSlice("file_types", nil) + ignoreCase := request.GetBool("ignore_case", false) + regex := request.GetBool("regex", true) + contextLines := request.GetInt("context_lines", 2) + + // Consult session prefs for max_results and cluster when not explicitly supplied. + prefs := s.sessionPrefs.Load() + var prefsMaxResults int + var prefsCluster *bool + if prefs != nil { + prefsMaxResults = prefs.DefaultMaxResults + prefsCluster = prefs.DefaultCluster + } + maxResults := effectiveInt(request, "max_results", prefsMaxResults, 0) + cluster := effectiveBool(request, "cluster", prefsCluster, false) + + cursorStr := request.GetString("cursor", "") + + // Compute query hash for cursor validation. + queryHash := cursor.HashParams(map[string]string{ + "pattern": pattern, + "paths": strings.Join(paths, ","), + "file_types": strings.Join(fileTypes, ","), + "ignore_case": strconv.FormatBool(ignoreCase), + "regex": strconv.FormatBool(regex), + "context_lines": strconv.Itoa(contextLines), + }) + + // Resolve cursor offset. + offset := 0 + if cursorStr != "" { + off, hash, decErr := cursor.Decode(cursorStr) + if decErr != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid cursor: %s", decErr)), nil + } + if hash != queryHash { + return mcp.NewToolResultError("cursor is for a different query, re-run without cursor"), nil + } + offset = off + } + + // When paginating with a cursor, fetch all results (no rg-level cap) so we + // can apply the offset in-process. Without a cursor, let rg cap at maxResults. + rgMaxResults := maxResults + if offset > 0 { + rgMaxResults = 0 // fetch all, apply cap after skipping + } + req := &search.Request{ Pattern: pattern, - Paths: request.GetStringSlice("paths", nil), - FileTypes: request.GetStringSlice("file_types", nil), - IgnoreCase: request.GetBool("ignore_case", false), - Regex: request.GetBool("regex", true), - ContextLines: request.GetInt("context_lines", 2), - MaxResults: request.GetInt("max_results", 0), + Paths: paths, + FileTypes: fileTypes, + IgnoreCase: ignoreCase, + Regex: regex, + ContextLines: contextLines, + MaxResults: rgMaxResults, } results, err := s.searcher.Search(ctx, req) @@ -51,16 +102,66 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque return mcp.NewToolResultError(fmt.Sprintf("search error: %s", errors.SanitizeError(err))), nil } + // Apply cursor offset. + if offset > 0 && offset < len(results.Results) { + results.Results = results.Results[offset:] + results.Truncated = false // will re-evaluate below + } else if offset > 0 { + results.Results = nil + results.Truncated = false + } + + // Apply in-process max_results cap and compute cursor footer. + var cursorLine string + if maxResults > 0 && len(results.Results) > maxResults { + remaining := len(results.Results) - maxResults + results.Results = results.Results[:maxResults] + results.Truncated = true + nextOffset := offset + maxResults + nextCursor := cursor.Encode(nextOffset, queryHash) + cursorLine = fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining) + } + s.logger.Info("search completed", "pattern", pattern, "results_count", len(results.Results), "truncated", results.Truncated, ) - output := s.searcher.FormatResults(results) + verbose := request.GetBool("verbose", false) + opts := search.FormatOptions{ + Cluster: cluster, + CursorLine: cursorLine, + Verbose: verbose, + } + output := s.searcher.FormatResultsWithOptions(results, opts) return mcp.NewToolResultText(output), nil } +// effectiveInt returns the per-call value if the key is explicitly present in the +// request arguments, otherwise falls back to sessionDefault (if > 0), then builtIn. +func effectiveInt(request mcp.CallToolRequest, key string, sessionDefault, builtIn int) int { + if _, explicit := request.GetArguments()[key]; explicit { + return request.GetInt(key, builtIn) + } + if sessionDefault > 0 { + return sessionDefault + } + return builtIn +} + +// effectiveBool returns the per-call value if the key is explicitly present in the +// request arguments, otherwise falls back to sessionDefault (if non-nil), then builtIn. +func effectiveBool(request mcp.CallToolRequest, key string, sessionDefault *bool, builtIn bool) bool { + if _, explicit := request.GetArguments()[key]; explicit { + return request.GetBool(key, builtIn) + } + if sessionDefault != nil { + return *sessionDefault + } + return builtIn +} + // handleFileRead handles the file_read tool. func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { select { @@ -70,9 +171,13 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest return mcp.NewToolResultError("request cancelled"), nil } - // Batch mode: paths[] takes precedence over path + // Batch mode: paths[] takes precedence over path. + // NOTE: batch reads are always inlined — mixing dedup + resource_links is + // too complex and the savings are unclear for multi-file calls. if paths := request.GetStringSlice("paths", nil); len(paths) > 0 { var output strings.Builder + // Dedup: track etag -> first path that produced it. + seenEtag := make(map[string]string) // etag -> first path for i, p := range paths { if i > 0 { output.WriteString("\n") @@ -82,6 +187,16 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest output.WriteString(fmt.Sprintf("--- %s ---\n[error: %s]\n", p, errors.SanitizeError(err))) continue } + // Extract etag from result footer for dedup check. + etag := extractEtag(result) + if etag != "" { + if firstPath, seen := seenEtag[etag]; seen { + // Duplicate content: emit pointer instead of full content. + output.WriteString(fmt.Sprintf("--- %s ---\n[duplicate of %s, etag: %s]\n", p, firstPath, etag)) + continue + } + seenEtag[etag] = p + } output.WriteString(fmt.Sprintf("--- %s ---\n%s", p, result)) } return mcp.NewToolResultText(output.String()), nil @@ -96,59 +211,223 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest if err != nil { return mcp.NewToolResultError(errors.SanitizeError(err)), nil } + + // Resource-link threshold check: for single-file reads, if the result + // exceeds the configured threshold, return a ResourceLink instead of + // inlining the content. The client can fetch the resource on demand. + // Bypassed when: + // - force_inline=true + // - max_inline_bytes is set and result fits within it + // - threshold is 0 (disabled) + // - result is already small (skeleton/symbols_only/line-range paths + // produce small output; threshold is on result bytes, not file bytes) + // Determine resource-link threshold: session pref overrides cfg, per-call overrides session. + threshold := s.cfg.ResourceLinkThresholdBytes + if sp := s.sessionPrefs.Load(); sp != nil && sp.ResourceLinkThreshold > 0 { + threshold = sp.ResourceLinkThreshold + } + forceInline := request.GetBool("force_inline", false) + maxInlineBytes := request.GetInt("max_inline_bytes", 0) + if maxInlineBytes > 0 { + threshold = maxInlineBytes + } + + if !forceInline && threshold > 0 && len(result) > threshold { + etag := extractEtag(result) + uri := buildReadResourceURI(path, etag) + lineCount := strings.Count(result, "\n") + desc := fmt.Sprintf("etag=%s, size=%d bytes, lines=%d", etag, len(result), lineCount) + mimeType := detectMIMEType(path) + link := mcp.NewResourceLink(uri, path, desc, mimeType) + return &mcp.CallToolResult{ + Content: []mcp.Content{link}, + }, nil + } + return mcp.NewToolResultText(result), nil } -// readOneFile reads a single file applying all formatting options from the request. -func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, path string) (string, error) { +// buildReadResourceURI constructs the filepuff://read URI for a file + etag pair. +func buildReadResourceURI(path, etag string) string { + if etag == "" { + return "filepuff://read/" + path + } + return "filepuff://read/" + path + "?etag=" + etag +} + +// detectMIMEType returns a best-effort MIME type for the given file path. +func detectMIMEType(path string) string { + ext := strings.ToLower(path) + switch { + case strings.HasSuffix(ext, ".go"): + return "text/x-go" + case strings.HasSuffix(ext, ".ts"), strings.HasSuffix(ext, ".tsx"): + return "text/typescript" + case strings.HasSuffix(ext, ".js"), strings.HasSuffix(ext, ".jsx"): + return "text/javascript" + case strings.HasSuffix(ext, ".py"): + return "text/x-python" + case strings.HasSuffix(ext, ".rs"): + return "text/x-rust" + case strings.HasSuffix(ext, ".md"): + return "text/markdown" + case strings.HasSuffix(ext, ".json"): + return "application/json" + case strings.HasSuffix(ext, ".yaml"), strings.HasSuffix(ext, ".yml"): + return "text/yaml" + case strings.HasSuffix(ext, ".toml"): + return "text/toml" + case strings.HasSuffix(ext, ".html"), strings.HasSuffix(ext, ".htm"): + return "text/html" + case strings.HasSuffix(ext, ".css"): + return "text/css" + case strings.HasSuffix(ext, ".sh"): + return "text/x-sh" + case strings.HasSuffix(ext, ".c"), strings.HasSuffix(ext, ".h"): + return "text/x-c" + case strings.HasSuffix(ext, ".cpp"), strings.HasSuffix(ext, ".cc"), strings.HasSuffix(ext, ".cxx"): + return "text/x-c++" + default: + return "text/plain" + } +} + +// lineNumberOpts holds resolved line-numbering preferences for readOneFile. +type lineNumberOpts struct { + noLineNumbers bool + compactLineNums bool + lineInterval int +} + +// resolveLineNumberOpts resolves per-call vs session-pref line-number options. +func (s *Server) resolveLineNumberOpts(request mcp.CallToolRequest) lineNumberOpts { + opts := lineNumberOpts{ + noLineNumbers: request.GetBool("no_line_numbers", false), + lineInterval: request.GetInt("line_number_interval", 1), + compactLineNums: request.GetBool("compact_line_numbers", false), + } + + // Apply session line_numbers pref when no explicit per-call override was supplied. + if sp := s.sessionPrefs.Load(); sp != nil && sp.LineNumbers != "" { + _, hasNoLN := request.GetArguments()["no_line_numbers"] + _, hasCompact := request.GetArguments()["compact_line_numbers"] + _, hasInterval := request.GetArguments()["line_number_interval"] + if !hasNoLN && !hasCompact && !hasInterval { + switch sp.LineNumbers { + case "none": + opts.noLineNumbers = true + opts.compactLineNums = false + case "compact": + opts.noLineNumbers = false + opts.compactLineNums = true + case "full": + opts.noLineNumbers = false + opts.compactLineNums = false + } + } + } + + if opts.lineInterval == 0 { + opts.noLineNumbers = true + } + return opts +} + +// applyStrip applies strip flags to the selected line range and returns the +// possibly-rewritten lines, new bounds, and a stripped-footer annotation. +func applyStrip(lines []string, lineStart, lineEnd int, stripFlags []parser.StripFlag, path string) (newLines []string, newStart, newEnd int, footer string) { + if len(stripFlags) == 0 { + return lines, lineStart, lineEnd, "" + } + selectedContent := strings.Join(lines[lineStart-1:lineEnd], "\n") + lang := protocol.DetectLanguage(path) + stripped := parser.StripContent(selectedContent, stripFlags, lang) + if len(stripped.Stripped) > 0 { + names := make([]string, len(stripped.Stripped)) + for i, f := range stripped.Stripped { + names[i] = string(f) + } + footer = "[stripped: " + strings.Join(names, ", ") + "]\n" + } + newLines = splitLines(stripped.Content) + return newLines, 1, len(newLines), footer +} + +// loadFileForRead performs workspace, stat, size, and read checks for a path. +func (s *Server) loadFileForRead(path string) ([]byte, error) { if !s.cfg.IsPathAllowed(path) { - return "", fmt.Errorf("path is outside workspace root") + return nil, fmt.Errorf("path is outside workspace root") } info, err := os.Stat(path) if err != nil { if os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", path) + return nil, fmt.Errorf("file not found: %s", path) } if os.IsPermission(err) { - return "", fmt.Errorf("permission denied: %s", path) + return nil, fmt.Errorf("permission denied: %s", path) } s.logger.Warn("file stat error", "path", path, "error", err) - return "", fmt.Errorf("error accessing file") + return nil, fmt.Errorf("error accessing file") } if info.Size() > s.cfg.MaxFileSize { - return "", fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize) + return nil, fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize) } content, err := os.ReadFile(path) if err != nil { if os.IsPermission(err) { - return "", fmt.Errorf("permission denied: %s", path) + return nil, fmt.Errorf("permission denied: %s", path) } s.logger.Warn("file read error", "path", path, "error", err) - return "", fmt.Errorf("error reading file") + return nil, fmt.Errorf("error reading file") + } + return content, nil +} + +// readOneFile reads a single file applying all formatting options from the request. +func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, path string) (string, error) { + content, err := s.loadFileForRead(path) + if err != nil { + return "", err } - // Compute etag from content hash - etag := fmt.Sprintf("%016x", xxhash.Sum64(content)) + // Feature 3: short etag — 8 hex chars (32-bit). + // Accept previous_etag by prefix match so old 16-char etags keep working. + fullHash := fmt.Sprintf("%016x", xxhash.Sum64(content)) + etag := fullHash[:8] - // Short-circuit if caller has the current version - if prev := request.GetString("previous_etag", ""); prev != "" && prev == etag { - return fmt.Sprintf("[unchanged, etag: %s]\n", etag), nil + if prev := request.GetString("previous_etag", ""); prev != "" { + // Match: exact 8-char match, old client sent full 16-char etag, or new client sent 8-char prefix of old. + if prev == etag || strings.HasPrefix(fullHash, prev) || strings.HasPrefix(prev, etag) { + return fmt.Sprintf("[unchanged, etag: %s]\n", etag), nil + } } // Parse request options includeAST := request.GetBool("include_ast", false) symbolsOnly := request.GetBool("symbols_only", false) symbolName := request.GetString("symbol_name", "") - noLineNumbers := request.GetBool("no_line_numbers", false) - lineInterval := request.GetInt("line_number_interval", 1) collapseBlank := request.GetBool("collapse_blank_lines", false) maxLines := request.GetInt("max_lines", 0) - if lineInterval == 0 { - noLineNumbers = true + lnOpts := s.resolveLineNumberOpts(request) + + // Feature 1: mode flag — "full" (default) | "skeleton" | "symbols_only". + // symbols_only mode is an alias for include_ast+symbols_only. + mode := request.GetString("mode", "full") + if mode == "symbols_only" { + symbolsOnly = true + includeAST = true } + + // Feature 2: strip — remove selected content classes before line-numbering. + stripRaw := request.GetStringSlice("strip", nil) + var stripFlags []parser.StripFlag + for _, sf := range stripRaw { + stripFlags = append(stripFlags, parser.StripFlag(sf)) + } + if symbolsOnly && !includeAST { return "", fmt.Errorf("symbols_only requires include_ast=true") } @@ -157,7 +436,7 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p lineStart := request.GetInt("line_start", 1) lineEnd := request.GetInt("line_end", len(lines)) - // Symbol-based line range: find the symbol and use its exact bounds + // Symbol-based line range: find the symbol and use its exact bounds. if symbolName != "" { symbolKind := protocol.SymbolKind(request.GetString("symbol_kind", "")) start, end, found := s.resolveSymbolLines(ctx, path, content, symbolName, symbolKind) @@ -168,7 +447,7 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p lineEnd = end } - // Clamp to valid range + // Clamp to valid range. if lineStart < 1 { lineStart = 1 } @@ -181,6 +460,19 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p var output strings.Builder + // Feature 1: skeleton mode — replace function bodies with { ... }. + if mode == "skeleton" { + skText, _, skErr := parser.SkeletonFile(ctx, s.parser, path, content) + if skErr != nil { + // Fall back to full mode on parse error. + s.logger.Warn("skeleton mode failed, falling back to full", "path", path, "error", skErr) + } else { + output.WriteString(skText) + fmt.Fprintf(&output, "[etag: %s]\n", etag) + return output.String(), nil + } + } + if includeAST { if summary := s.generateASTSummary(ctx, path, content); summary != "" { output.WriteString(summary) @@ -191,13 +483,19 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p } if symbolsOnly { - output.WriteString(fmt.Sprintf("[etag: %s]\n", etag)) + fmt.Fprintf(&output, "[etag: %s]\n", etag) return output.String(), nil } - writeLines(&output, lines, lineStart, lineEnd, maxLines, noLineNumbers, lineInterval, collapseBlank) + // Feature 2: apply strip AFTER line-range selection, BEFORE line numbering. + lines, lineStart, lineEnd, strippedFooter := applyStrip(lines, lineStart, lineEnd, stripFlags, path) - output.WriteString(fmt.Sprintf("[etag: %s]\n", etag)) + writeLines(&output, lines, lineStart, lineEnd, maxLines, lnOpts.noLineNumbers, lnOpts.lineInterval, collapseBlank, lnOpts.compactLineNums) + + if strippedFooter != "" { + output.WriteString(strippedFooter) + } + fmt.Fprintf(&output, "[etag: %s]\n", etag) return output.String(), nil } @@ -212,7 +510,8 @@ func (s *Server) resolveSymbolLines(ctx context.Context, path string, content [] } // writeLines writes the selected line range into output, applying all formatting options. -func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, maxLines int, noLineNumbers bool, lineInterval int, collapseBlank bool) { +// compactLineNums=true emits "12│" instead of " 12│ " (no padding, no trailing space). +func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, maxLines int, noLineNumbers bool, lineInterval int, collapseBlank bool, compactLineNums bool) { effectiveEnd := lineEnd truncatedCount := 0 if maxLines > 0 && (lineEnd-lineStart+1) > maxLines { @@ -233,6 +532,13 @@ func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, max switch { case noLineNumbers: output.WriteString(line + "\n") + case compactLineNums: + // Feature 4: compact prefix — "12│content" + if lineInterval <= 1 || lineNum%lineInterval == 0 || i == lineStart-1 || i == effectiveEnd-1 { + fmt.Fprintf(output, "%d│%s\n", lineNum, line) + } else { + fmt.Fprintf(output, "│%s\n", line) + } case lineInterval <= 1 || lineNum%lineInterval == 0 || i == lineStart-1 || i == effectiveEnd-1: fmt.Fprintf(output, "%4d│ %s\n", lineNum, line) default: @@ -245,6 +551,23 @@ func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, max } } +// extractEtag extracts the etag value from a readOneFile result string. +// Returns empty string if not found. +func extractEtag(result string) string { + // Look for "[etag: XXXXXXXX]" at end of result. + const prefix = "[etag: " + idx := strings.LastIndex(result, prefix) + if idx < 0 { + return "" + } + rest := result[idx+len(prefix):] + close := strings.Index(rest, "]") + if close < 0 { + return "" + } + return rest[:close] +} + // splitLines splits a string into lines. // For large files (> 1MB), uses bufio.Scanner which is more memory efficient. // For smaller files, uses simple string split which is faster. diff --git a/internal/server/handlers_file_test.go b/internal/server/handlers_file_test.go new file mode 100644 index 0000000..da8c583 --- /dev/null +++ b/internal/server/handlers_file_test.go @@ -0,0 +1,909 @@ +package server + +import ( + "context" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/lukaszraczylo/mcp-filepuff/internal/config" + "github.com/mark3labs/mcp-go/mcp" +) + +// newTestServer creates a minimal server pointing at tmpDir. +func newTestServer(t *testing.T, tmpDir string) *Server { + t.Helper() + cfg := &config.Config{ + WorkspaceRoot: tmpDir, + EnableLSP: false, + MaxFileSize: 10 * 1024 * 1024, // 10 MB — required for file reads to succeed + MaxParseSize: 10 * 1024 * 1024, + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + return srv +} + +// callRead calls handleFileRead with the given args map and returns the text content. +func callRead(t *testing.T, srv *Server, args map[string]interface{}) string { + t.Helper() + req := mcp.CallToolRequest{} + req.Params.Arguments = args + result, err := srv.handleFileRead(context.Background(), req) + if err != nil { + t.Fatalf("handleFileRead error: %v", err) + } + if result == nil { + t.Fatal("handleFileRead returned nil") + } + if len(result.Content) == 0 { + t.Fatal("handleFileRead returned empty content") + } + tc, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("handleFileRead did not return TextContent") + } + return tc.Text +} + +// callReadResult calls handleFileRead and returns the raw CallToolResult (not just text). +func callReadResult(t *testing.T, srv *Server, args map[string]interface{}) *mcp.CallToolResult { + t.Helper() + req := mcp.CallToolRequest{} + req.Params.Arguments = args + result, err := srv.handleFileRead(context.Background(), req) + if err != nil { + t.Fatalf("handleFileRead error: %v", err) + } + if result == nil { + t.Fatal("handleFileRead returned nil") + } + return result +} + +// newTestServerWithThreshold creates a server with a custom ResourceLinkThresholdBytes. +func newTestServerWithThreshold(t *testing.T, tmpDir string, thresholdBytes int) *Server { + t.Helper() + cfg := &config.Config{ + WorkspaceRoot: tmpDir, + EnableLSP: false, + MaxFileSize: 10 * 1024 * 1024, + MaxParseSize: 10 * 1024 * 1024, + ResourceLinkThresholdBytes: thresholdBytes, + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + return srv +} + +// writeFile writes content to a file in tmpDir and returns its absolute path. +func writeFile(t *testing.T, dir, name, content string) string { + t.Helper() + p := filepath.Join(dir, name) + if err := os.WriteFile(p, []byte(content), 0600); err != nil { + t.Fatalf("WriteFile(%s): %v", name, err) + } + return p +} + +// ---- Feature 1: skeleton mode ---- + +func TestSkeletonModeGo(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := `package main + +// Hello says hello +func Hello() { + println("Hello, World!") + println("more body") +} + +func Add(a, b int) int { + return a + b +} +` + f := writeFile(t, tmpDir, "test.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "mode": "skeleton", + }) + + // Should contain function signatures + if !strings.Contains(out, "func Hello()") { + t.Errorf("skeleton output missing Hello signature, got:\n%s", out) + } + if !strings.Contains(out, "func Add(") { + t.Errorf("skeleton output missing Add signature, got:\n%s", out) + } + // Should NOT contain body contents + if strings.Contains(out, `println("more body")`) { + t.Errorf("skeleton output should not contain body contents, got:\n%s", out) + } + // Should contain placeholder + if !strings.Contains(out, "{ ... }") { + t.Errorf("skeleton output missing { ... } placeholder, got:\n%s", out) + } + // Should contain etag footer + if !strings.Contains(out, "[etag:") { + t.Errorf("skeleton output missing etag footer, got:\n%s", out) + } +} + +func TestSkeletonModeFullFlagAlias(t *testing.T) { + // mode="full" should behave identically to not specifying mode + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := "package main\nfunc F() { println(1) }\n" + f := writeFile(t, tmpDir, "test.go", goSrc) + + outFull := callRead(t, srv, map[string]interface{}{"path": f, "mode": "full"}) + outDefault := callRead(t, srv, map[string]interface{}{"path": f}) + + // Both should have same content (etag will be same, line content same) + if outFull != outDefault { + t.Errorf("mode=full differs from default\nfull: %q\ndefault: %q", outFull, outDefault) + } +} + +func TestSkeletonModeSymbolsOnlyAlias(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := "package main\nfunc F() { println(1) }\n" + f := writeFile(t, tmpDir, "test.go", goSrc) + + // mode=symbols_only should return symbols summary (needs include_ast implicitly) + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "mode": "symbols_only", + }) + // Should contain etag but NOT the function body + if !strings.Contains(out, "[etag:") { + t.Errorf("symbols_only output missing etag, got:\n%s", out) + } + if strings.Contains(out, "println") { + t.Errorf("symbols_only should not contain body, got:\n%s", out) + } +} + +func TestSkeletonModeTypeScript(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + tsSrc := `// a function +function greet(name: string): string { + return "Hello " + name; +} + +class Greeter { + greet(name: string) { + return "hi " + name; + } +} +` + f := writeFile(t, tmpDir, "test.ts", tsSrc) + + out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"}) + + if !strings.Contains(out, "function greet") { + t.Errorf("TS skeleton missing function signature, got:\n%s", out) + } + if !strings.Contains(out, "{ ... }") { + t.Errorf("TS skeleton missing placeholder, got:\n%s", out) + } + if strings.Contains(out, `"Hello " + name`) { + t.Errorf("TS skeleton should not contain body, got:\n%s", out) + } +} + +func TestSkeletonModePython(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + pySrc := `def greet(name): + print("Hello " + name) + print("extra line") + +class Foo: + def bar(self): + return 42 +` + f := writeFile(t, tmpDir, "test.py", pySrc) + + out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"}) + + if !strings.Contains(out, "def greet") { + t.Errorf("Python skeleton missing greet signature, got:\n%s", out) + } + if strings.Contains(out, "extra line") { + t.Errorf("Python skeleton should not contain body, got:\n%s", out) + } + // Python uses "..." as placeholder + if !strings.Contains(out, "...") { + t.Errorf("Python skeleton missing ... placeholder, got:\n%s", out) + } +} + +func TestSkeletonModeRust(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + rsSrc := `fn add(a: i32, b: i32) -> i32 { + let result = a + b; + result +} + +struct Foo { + x: i32, +} +` + f := writeFile(t, tmpDir, "test.rs", rsSrc) + + out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"}) + + if !strings.Contains(out, "fn add(") { + t.Errorf("Rust skeleton missing fn signature, got:\n%s", out) + } + if !strings.Contains(out, "{ ... }") { + t.Errorf("Rust skeleton missing placeholder, got:\n%s", out) + } + if strings.Contains(out, "let result") { + t.Errorf("Rust skeleton should not contain body, got:\n%s", out) + } +} + +// ---- Feature 2: strip flag ---- + +func TestStripImportsGo(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := `package main + +import ( + "fmt" + "os" +) + +func main() { + fmt.Println("hello") +} +` + f := writeFile(t, tmpDir, "test.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "strip": []interface{}{"imports"}, + }) + + if strings.Contains(out, `"fmt"`) { + t.Errorf("strip=imports should remove import block, got:\n%s", out) + } + if !strings.Contains(out, "func main") { + t.Errorf("strip=imports should keep function, got:\n%s", out) + } + if !strings.Contains(out, "[stripped: imports]") { + t.Errorf("strip footer missing, got:\n%s", out) + } +} + +func TestStripLicenseGoBlockComment(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := `/* Copyright 2024 Acme Corp. All rights reserved. + License: MIT +*/ +package main + +func main() {} +` + f := writeFile(t, tmpDir, "main.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "strip": []interface{}{"license"}, + }) + + if strings.Contains(out, "Copyright") { + t.Errorf("strip=license should remove license comment, got:\n%s", out) + } + if !strings.Contains(out, "func main") { + t.Errorf("strip=license should keep code, got:\n%s", out) + } + if !strings.Contains(out, "[stripped: license]") { + t.Errorf("license strip footer missing, got:\n%s", out) + } +} + +func TestStripBlockCommentsGo(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := `package main + +/* This is a block comment + spanning multiple lines */ +func main() { + /* inline block */ + println("hi") +} +` + f := writeFile(t, tmpDir, "test.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "strip": []interface{}{"block_comments"}, + }) + + if strings.Contains(out, "This is a block comment") { + t.Errorf("strip=block_comments should remove block comments, got:\n%s", out) + } + if strings.Contains(out, "inline block") { + t.Errorf("strip=block_comments should remove inline block comment, got:\n%s", out) + } + if !strings.Contains(out, "func main") { + t.Errorf("strip=block_comments should keep code, got:\n%s", out) + } + if !strings.Contains(out, "[stripped: block_comments]") { + t.Errorf("block_comments strip footer missing, got:\n%s", out) + } +} + +func TestStripMultipleFlags(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := `/* Copyright 2024. License: MIT */ +package main + +import "fmt" + +func main() { + fmt.Println("hello") +} +` + f := writeFile(t, tmpDir, "main.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "strip": []interface{}{"license", "imports"}, + }) + + if strings.Contains(out, "Copyright") { + t.Errorf("license not stripped, got:\n%s", out) + } + if strings.Contains(out, `"fmt"`) { + t.Errorf("imports not stripped, got:\n%s", out) + } + if !strings.Contains(out, "[stripped:") { + t.Errorf("strip footer missing, got:\n%s", out) + } +} + +func TestStripNoRemovalProducesNoFooter(t *testing.T) { + // A file with no imports: strip=imports should not add footer + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := "package main\nfunc main() {}\n" + f := writeFile(t, tmpDir, "test.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "strip": []interface{}{"imports"}, + }) + + if strings.Contains(out, "[stripped:") { + t.Errorf("should not have stripped footer when nothing removed, got:\n%s", out) + } +} + +func TestStripImportsPython(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + pySrc := `import os +from sys import argv + +def main(): + print(argv[0]) +` + f := writeFile(t, tmpDir, "test.py", pySrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "strip": []interface{}{"imports"}, + }) + + if strings.Contains(out, "import os") { + t.Errorf("Python imports not stripped, got:\n%s", out) + } + if strings.Contains(out, "from sys") { + t.Errorf("Python from-import not stripped, got:\n%s", out) + } + if !strings.Contains(out, "def main") { + t.Errorf("Python function missing after strip, got:\n%s", out) + } +} + +// ---- Feature 3: short etag (8 hex chars) ---- + +func TestEtagIs8Chars(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + f := writeFile(t, tmpDir, "test.go", "package main\n") + + out := callRead(t, srv, map[string]interface{}{"path": f}) + + // Find "[etag: XXXXXXXX]" + idx := strings.Index(out, "[etag: ") + if idx < 0 { + t.Fatalf("no etag in output: %q", out) + } + rest := out[idx+7:] + end := strings.Index(rest, "]") + if end < 0 { + t.Fatalf("malformed etag in output: %q", out) + } + etagVal := rest[:end] + if len(etagVal) != 8 { + t.Errorf("etag should be 8 hex chars, got %d chars: %q", len(etagVal), etagVal) + } + // Validate hex + for _, c := range etagVal { + isDigit := c >= '0' && c <= '9' + isHexLower := c >= 'a' && c <= 'f' + if !isDigit && !isHexLower { + t.Errorf("etag contains non-hex char %q in %q", c, etagVal) + } + } +} + +func TestEtagPreviousEtagShortCircuit(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + f := writeFile(t, tmpDir, "test.go", "package main\n") + + // First read: get the etag + out1 := callRead(t, srv, map[string]interface{}{"path": f}) + idx := strings.Index(out1, "[etag: ") + etag := out1[idx+7 : idx+7+8] + + // Second read with same etag: should short-circuit + out2 := callRead(t, srv, map[string]interface{}{ + "path": f, + "previous_etag": etag, + }) + if !strings.Contains(out2, "[unchanged, etag:") { + t.Errorf("expected [unchanged, etag:] for same etag, got: %q", out2) + } +} + +func TestEtagOldLongEtagStillWorks(t *testing.T) { + // Simulate old client sending 16-char etag: should still short-circuit via prefix match. + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + f := writeFile(t, tmpDir, "test.go", "package main\n") + + // Get the 8-char etag + out1 := callRead(t, srv, map[string]interface{}{"path": f}) + idx := strings.Index(out1, "[etag: ") + shortEtag := out1[idx+7 : idx+7+8] + + // Construct a fake 16-char etag that starts with the short one + fakeOldEtag := shortEtag + "00000000" + + out2 := callRead(t, srv, map[string]interface{}{ + "path": f, + "previous_etag": fakeOldEtag, + }) + if !strings.Contains(out2, "[unchanged, etag:") { + t.Errorf("old 16-char etag should still short-circuit, got: %q", out2) + } +} + +func TestEtagDifferentFileChanges(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + f := writeFile(t, tmpDir, "test.go", "package main\n") + + out1 := callRead(t, srv, map[string]interface{}{"path": f}) + idx := strings.Index(out1, "[etag: ") + etag1 := out1[idx+7 : idx+7+8] + + // Modify the file + if err := os.WriteFile(f, []byte("package main\n// changed\n"), 0600); err != nil { + t.Fatal(err) + } + + // Read again with old etag: should NOT short-circuit + out2 := callRead(t, srv, map[string]interface{}{ + "path": f, + "previous_etag": etag1, + }) + if strings.Contains(out2, "[unchanged") { + t.Errorf("modified file should not return unchanged, got: %q", out2) + } +} + +// ---- Feature 4: compact_line_numbers ---- + +func TestCompactLineNumbers(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + content := "line one\nline two\nline three\n" + f := writeFile(t, tmpDir, "test.txt", content) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "compact_line_numbers": true, + }) + + // Should have "1│" not " 1│ " + if !strings.Contains(out, "1│line one") { + t.Errorf("compact prefix not found, got:\n%s", out) + } + // Should NOT have padded format + if strings.Contains(out, " 1│ line one") { + t.Errorf("compact should not have padded prefix, got:\n%s", out) + } +} + +func TestCompactLineNumbersOffByDefault(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + content := "line one\nline two\n" + f := writeFile(t, tmpDir, "test.txt", content) + + // Default: no compact_line_numbers + out := callRead(t, srv, map[string]interface{}{"path": f}) + + // Should have padded format + if !strings.Contains(out, " 1│ line one") { + t.Errorf("default should have padded format, got:\n%s", out) + } +} + +func TestCompactLineNumbersWithInterval(t *testing.T) { + // compact_line_numbers + line_number_interval should work together + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + var sb strings.Builder + for i := 1; i <= 10; i++ { + sb.WriteString("line\n") + } + f := writeFile(t, tmpDir, "test.txt", sb.String()) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "compact_line_numbers": true, + "line_number_interval": 5, + }) + + // Line 5 should have number prefix + if !strings.Contains(out, "5│line") { + t.Errorf("compact+interval: line 5 should have number, got:\n%s", out) + } + // Line 1 (first) should have number + if !strings.Contains(out, "1│line") { + t.Errorf("compact+interval: line 1 should have number, got:\n%s", out) + } + // Line 10 (last) should have number + if !strings.Contains(out, "10│line") { + t.Errorf("compact+interval: line 10 should have number, got:\n%s", out) + } + // Non-interval line should have bare │ prefix (no number) + if !strings.Contains(out, "│line") { + t.Errorf("compact+interval: non-interval lines should have bare │, got:\n%s", out) + } +} + +func TestCompactLineNumbersWithNoLineNumbers(t *testing.T) { + // compact_line_numbers + no_line_numbers: no_line_numbers wins + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + f := writeFile(t, tmpDir, "test.txt", "line one\n") + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "compact_line_numbers": true, + "no_line_numbers": true, + }) + + // Should have NO prefix at all + if strings.Contains(out, "│") { + t.Errorf("no_line_numbers should suppress all prefixes, got:\n%s", out) + } +} + +// ---- Backward compatibility: existing behavior unchanged ---- + +func TestDefaultBehaviorUnchanged(t *testing.T) { + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := `package main + +func Hello() { + println("hello") +} +` + f := writeFile(t, tmpDir, "test.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{"path": f}) + + // All lines present + if !strings.Contains(out, `println("hello")`) { + t.Errorf("full mode missing body, got:\n%s", out) + } + // Padded line numbers + if !strings.Contains(out, " 1│ package main") { + t.Errorf("default should have padded line numbers, got:\n%s", out) + } + // etag present + if !strings.Contains(out, "[etag:") { + t.Errorf("missing etag footer, got:\n%s", out) + } +} + +func TestSymbolsOnlyFlagStillWorks(t *testing.T) { + // Old symbols_only=true + include_ast=true should still work + tmpDir := t.TempDir() + srv := newTestServer(t, tmpDir) + + goSrc := "package main\nfunc Hello() { println(1) }\n" + f := writeFile(t, tmpDir, "test.go", goSrc) + + out := callRead(t, srv, map[string]interface{}{ + "path": f, + "include_ast": true, + "symbols_only": true, + }) + + if strings.Contains(out, "println") { + t.Errorf("symbols_only should suppress body, got:\n%s", out) + } + if !strings.Contains(out, "[etag:") { + t.Errorf("symbols_only missing etag, got:\n%s", out) + } +} + +// ---- Feature: resource_link for big reads ---- + +func TestResourceLinkThresholdTrip(t *testing.T) { + // When result bytes > threshold, handleFileRead returns a ResourceLink content block. + tmpDir := t.TempDir() + // Low threshold (10 bytes) guarantees even a tiny file trips it. + srv := newTestServerWithThreshold(t, tmpDir, 10) + + f := writeFile(t, tmpDir, "big.txt", strings.Repeat("x", 200)) + + result := callReadResult(t, srv, map[string]interface{}{"path": f}) + + if len(result.Content) == 0 { + t.Fatal("expected content, got none") + } + link, ok := result.Content[0].(mcp.ResourceLink) + if !ok { + t.Fatalf("expected ResourceLink content, got %T", result.Content[0]) + } + if !strings.HasPrefix(link.URI, "filepuff://read/") { + t.Errorf("ResourceLink URI should start with filepuff://read/, got: %q", link.URI) + } + if link.Name != f { + t.Errorf("ResourceLink Name should be file path %q, got %q", f, link.Name) + } + if !strings.Contains(link.Description, "etag=") { + t.Errorf("ResourceLink Description should contain etag=, got %q", link.Description) + } + if !strings.Contains(link.Description, "size=") { + t.Errorf("ResourceLink Description should contain size=, got %q", link.Description) + } + if !strings.Contains(link.Description, "lines=") { + t.Errorf("ResourceLink Description should contain lines=, got %q", link.Description) + } +} + +func TestResourceLinkForceInlineBypass(t *testing.T) { + // force_inline=true must always return TextContent regardless of threshold. + tmpDir := t.TempDir() + srv := newTestServerWithThreshold(t, tmpDir, 1) // threshold = 1 byte, always trips + + f := writeFile(t, tmpDir, "test.txt", "hello world") + + result := callReadResult(t, srv, map[string]interface{}{ + "path": f, + "force_inline": true, + }) + + if len(result.Content) == 0 { + t.Fatal("expected content, got none") + } + tc, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("force_inline=true should return TextContent, got %T", result.Content[0]) + } + if !strings.Contains(tc.Text, "hello world") { + t.Errorf("force_inline result should contain file content, got: %q", tc.Text) + } +} + +func TestResourceLinkMaxInlineBytesOverride(t *testing.T) { + // max_inline_bytes overrides server threshold per-call. + tmpDir := t.TempDir() + // Server threshold = 1 byte (always trips), but max_inline_bytes allows bigger. + srv := newTestServerWithThreshold(t, tmpDir, 1) + + content := strings.Repeat("a", 50) // 50 bytes + f := writeFile(t, tmpDir, "test.txt", content) + + // max_inline_bytes=100 > 50 bytes result → should inline + result := callReadResult(t, srv, map[string]interface{}{ + "path": f, + "max_inline_bytes": 100, + }) + + if len(result.Content) == 0 { + t.Fatal("expected content, got none") + } + _, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("max_inline_bytes=100 with 50-byte file should return TextContent, got %T", result.Content[0]) + } +} + +func TestResourceLinkStaleEtagRejection(t *testing.T) { + // handleReadResource should reject fetch when file has changed since link was issued. + tmpDir := t.TempDir() + srv := newTestServerWithThreshold(t, tmpDir, 1) // always trips + + f := writeFile(t, tmpDir, "stale.txt", "original content") + + // Get a ResourceLink — captures etag of "original content" + result := callReadResult(t, srv, map[string]interface{}{"path": f}) + link, ok := result.Content[0].(mcp.ResourceLink) + if !ok { + t.Fatalf("expected ResourceLink, got %T", result.Content[0]) + } + // URI contains ?etag= + if !strings.Contains(link.URI, "?etag=") { + t.Fatalf("expected etag in URI, got %q", link.URI) + } + + // Overwrite the file with new content. + if err := os.WriteFile(f, []byte("modified content — different"), 0600); err != nil { + t.Fatal(err) + } + + // Fetch the resource using the stale URI — should error. + req := mcp.ReadResourceRequest{} + req.Params.URI = link.URI + _, err := srv.handleReadResource(req) + if err == nil { + t.Fatal("expected error for stale etag, got nil") + } + if !strings.Contains(err.Error(), "file changed") { + t.Errorf("error should mention 'file changed', got: %v", err) + } +} + +func TestResourceLinkBelowThresholdInlines(t *testing.T) { + // When result is small (below threshold), always inline regardless of threshold setting. + tmpDir := t.TempDir() + // Large threshold — small file should be inlined. + srv := newTestServerWithThreshold(t, tmpDir, 64*1024) + + f := writeFile(t, tmpDir, "small.txt", "tiny") + + result := callReadResult(t, srv, map[string]interface{}{"path": f}) + + if len(result.Content) == 0 { + t.Fatal("expected content, got none") + } + _, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("small file should return TextContent, got %T", result.Content[0]) + } +} + +func TestResourceLinkThresholdZeroDisabled(t *testing.T) { + // threshold=0 disables the feature entirely — always inline. + tmpDir := t.TempDir() + srv := newTestServerWithThreshold(t, tmpDir, 0) + + f := writeFile(t, tmpDir, "test.txt", strings.Repeat("z", 10000)) + + result := callReadResult(t, srv, map[string]interface{}{"path": f}) + + if len(result.Content) == 0 { + t.Fatal("expected content") + } + _, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("threshold=0 should always inline, got %T", result.Content[0]) + } +} + +func TestHandleReadResource_ValidFetch(t *testing.T) { + // handleReadResource fetches file content when etag matches. + tmpDir := t.TempDir() + srv := newTestServerWithThreshold(t, tmpDir, 1) + + f := writeFile(t, tmpDir, "fetch.txt", "fetch me please") + + // Trigger a ResourceLink to get a valid URI with correct etag. + toolResult := callReadResult(t, srv, map[string]interface{}{"path": f}) + link, ok := toolResult.Content[0].(mcp.ResourceLink) + if !ok { + t.Fatalf("expected ResourceLink, got %T", toolResult.Content[0]) + } + + req := mcp.ReadResourceRequest{} + req.Params.URI = link.URI + contents, err := srv.handleReadResource(req) + if err != nil { + t.Fatalf("handleReadResource error: %v", err) + } + if len(contents) == 0 { + t.Fatal("expected resource contents, got none") + } + tc, ok := contents[0].(mcp.TextResourceContents) + if !ok { + t.Fatalf("expected TextResourceContents, got %T", contents[0]) + } + if !strings.Contains(tc.Text, "fetch me please") { + t.Errorf("resource contents should include file content, got: %q", tc.Text) + } +} + +func TestResourceLinkMIMEType(t *testing.T) { + // Verify MIME types for common extensions. + tmpDir := t.TempDir() + srv := newTestServerWithThreshold(t, tmpDir, 1) + + cases := []struct { + name string + content string + wantMIME string + }{ + {"test.go", "package main\n", "text/x-go"}, + {"test.py", "# py\n", "text/x-python"}, + {"test.ts", "// ts\n", "text/typescript"}, + {"test.json", "{}\n", "application/json"}, + {"test.md", "# hi\n", "text/markdown"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + f := writeFile(t, tmpDir, c.name, c.content) + result := callReadResult(t, srv, map[string]interface{}{"path": f}) + link, ok := result.Content[0].(mcp.ResourceLink) + if !ok { + t.Fatalf("expected ResourceLink for %s, got %T", c.name, result.Content[0]) + } + if link.MIMEType != c.wantMIME { + t.Errorf("%s: MIMEType = %q, want %q", c.name, link.MIMEType, c.wantMIME) + } + }) + } +} diff --git a/internal/server/handlers_lsp.go b/internal/server/handlers_lsp.go index 4067d08..e9c5861 100644 --- a/internal/server/handlers_lsp.go +++ b/internal/server/handlers_lsp.go @@ -13,8 +13,14 @@ import ( "github.com/mark3labs/mcp-go/mcp" ) -// handleSymbolAt handles the symbol_at tool. -func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +// handleLSPQuery is the unified dispatcher for all LSP operations. +// action must be one of: "hover", "definition", "references". +func (s *Server) handleLSPQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + action, err := request.RequireString("action") + if err != nil { + return mcp.NewToolResultError("action is required (hover | definition | references)"), nil + } + file, err := request.RequireString("file") if err != nil { return mcp.NewToolResultError("file is required"), nil @@ -30,16 +36,51 @@ func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest return mcp.NewToolResultError("column must be positive"), nil } - // Validate path if !s.cfg.IsPathAllowed(file) { return mcp.NewToolResultError("file is outside workspace root"), nil } - // Try LSP hover + verbose := request.GetBool("verbose", false) + + switch action { + case "hover": + if _, ok := request.GetArguments()["include_declaration"]; ok { + return mcp.NewToolResultError("include_declaration is only valid for action=references"), nil + } + if _, ok := request.GetArguments()["compact"]; ok { + return mcp.NewToolResultError("compact is only valid for action=references"), nil + } + return s.lspHover(ctx, file, line, col, verbose) + + case "definition": + if _, ok := request.GetArguments()["include_declaration"]; ok { + return mcp.NewToolResultError("include_declaration is only valid for action=references"), nil + } + if _, ok := request.GetArguments()["compact"]; ok { + return mcp.NewToolResultError("compact is only valid for action=references"), nil + } + return s.lspDefinition(ctx, file, line, col, verbose) + + case "references": + includeDecl := request.GetBool("include_declaration", true) + // compact: explicit call-time > session compact_refs pref > false + var prefsCompact *bool + if sp := s.sessionPrefs.Load(); sp != nil { + prefsCompact = sp.CompactRefs + } + compact := effectiveBool(request, "compact", prefsCompact, false) + return s.lspReferences(ctx, file, line, col, includeDecl, compact, verbose) + + default: + return mcp.NewToolResultError(fmt.Sprintf("unknown action %q: must be hover | definition | references", action)), nil + } +} + +// lspHover performs hover (symbol info) for the given position. +func (s *Server) lspHover(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) { hover, err := s.lspManager.Hover(ctx, file, line, col) if err != nil { - // Fall back to AST-based info - return s.handleSymbolAtFallback(ctx, file, line, col) + return s.handleSymbolAtFallback(ctx, file, line, col, verbose) } if hover == nil { @@ -47,14 +88,16 @@ func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest } var output strings.Builder - output.WriteString("**Symbol Information**\n\n") + if verbose { + output.WriteString("**Symbol Information**\n\n") + } output.WriteString(hover.Contents.Value) return mcp.NewToolResultText(output.String()), nil } // handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable. -func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) { +func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) { content, err := os.ReadFile(file) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %s", errors.SanitizeError(err))), nil @@ -71,35 +114,17 @@ func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, } var output strings.Builder - output.WriteString("**Symbol Information** (AST fallback)\n\n") + if verbose { + output.WriteString("**Symbol Information** (AST fallback)\n\n") + } output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type())) output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content))) return mcp.NewToolResultText(output.String()), nil } -// handleFindDefinition handles the find_definition tool. -func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - file, err := request.RequireString("file") - if err != nil { - return mcp.NewToolResultError("file is required"), nil - } - - line := request.GetInt("line", 0) - if line <= 0 { - return mcp.NewToolResultError("line must be positive"), nil - } - - col := request.GetInt("column", 0) - if col <= 0 { - return mcp.NewToolResultError("column must be positive"), nil - } - - // Validate path - if !s.cfg.IsPathAllowed(file) { - return mcp.NewToolResultError("file is outside workspace root"), nil - } - +// lspDefinition finds the definition of the symbol at the given position. +func (s *Server) lspDefinition(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) { locations, err := s.lspManager.Definition(ctx, file, line, col) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %s", errors.SanitizeError(err))), nil @@ -110,13 +135,14 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR } var output strings.Builder - output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations))) + if verbose { + output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations))) + } for _, loc := range locations { filePath := lsp.URIToFile(loc.URI) output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1)) - // Try to read a preview snippet preview := s.readFilePreview(filePath, loc.Range.Start.Line+1, 3) if preview != "" { output.WriteString("```\n") @@ -129,30 +155,8 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR return mcp.NewToolResultText(output.String()), nil } -// handleFindReferences handles the find_references tool. -func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - file, err := request.RequireString("file") - if err != nil { - return mcp.NewToolResultError("file is required"), nil - } - - line := request.GetInt("line", 0) - if line <= 0 { - return mcp.NewToolResultError("line must be positive"), nil - } - - col := request.GetInt("column", 0) - if col <= 0 { - return mcp.NewToolResultError("column must be positive"), nil - } - - includeDecl := request.GetBool("include_declaration", true) - - // Validate path - if !s.cfg.IsPathAllowed(file) { - return mcp.NewToolResultError("file is outside workspace root"), nil - } - +// lspReferences finds all references to the symbol at the given position. +func (s *Server) lspReferences(ctx context.Context, file string, line, col int, includeDecl, compact, verbose bool) (*mcp.CallToolResult, error) { locations, err := s.lspManager.References(ctx, file, line, col, includeDecl) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %s", errors.SanitizeError(err))), nil @@ -162,25 +166,84 @@ func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolR return mcp.NewToolResultText("No references found."), nil } - var output strings.Builder - output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations))) - - // Group by file + // Group by file, preserving encounter order. fileGroups := make(map[string][]lsp.Location) + fileOrder := make([]string, 0) for _, loc := range locations { filePath := lsp.URIToFile(loc.URI) + if _, seen := fileGroups[filePath]; !seen { + fileOrder = append(fileOrder, filePath) + } fileGroups[filePath] = append(fileGroups[filePath], loc) } - for filePath, locs := range fileGroups { - output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs))) - for _, loc := range locs { - output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1)) - } - output.WriteString("\n") + return mcp.NewToolResultText(formatReferences(fileGroups, fileOrder, len(locations), compact, verbose)), nil +} + +// formatReferences formats grouped reference locations as a string. +// compact=false: verbose multi-line format with L{line}:{col} per entry. +// compact=true: one line per file — file:[line:col, ...] (N), with same-line +// columns collapsed to line:{col1,col2,...}. +func formatReferences(fileGroups map[string][]lsp.Location, fileOrder []string, total int, compact bool, verbose bool) string { + var output strings.Builder + if verbose { + output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", total)) } - return mcp.NewToolResultText(output.String()), nil + for _, filePath := range fileOrder { + locs := fileGroups[filePath] + if compact { + output.WriteString(formatReferencesCompact(filePath, locs)) + } else { + output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs))) + for _, loc := range locs { + output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1)) + } + output.WriteString("\n") + } + } + + return output.String() +} + +// formatReferencesCompact formats one file's references as a single compact line. +// Same-line references are collapsed: 12:{5,8} instead of 12:5, 12:8. +func formatReferencesCompact(filePath string, locs []lsp.Location) string { + // Build ordered line->col map preserving encounter order per line. + type lineEntry struct { + lineNum int + cols []int + } + lineMap := make(map[int]*lineEntry) + lineOrder := make([]int, 0, len(locs)) + + for _, loc := range locs { + ln := loc.Range.Start.Line + 1 + col := loc.Range.Start.Character + 1 + if e, ok := lineMap[ln]; ok { + e.cols = append(e.cols, col) + } else { + lineMap[ln] = &lineEntry{lineNum: ln, cols: []int{col}} + lineOrder = append(lineOrder, ln) + } + } + + // Build the bracket contents. + parts := make([]string, 0, len(lineOrder)) + for _, ln := range lineOrder { + e := lineMap[ln] + if len(e.cols) == 1 { + parts = append(parts, fmt.Sprintf("%d:%d", ln, e.cols[0])) + } else { + colStrs := make([]string, len(e.cols)) + for i, c := range e.cols { + colStrs[i] = fmt.Sprintf("%d", c) + } + parts = append(parts, fmt.Sprintf("%d:{%s}", ln, strings.Join(colStrs, ","))) + } + } + + return fmt.Sprintf("%s:[%s] (%d)\n", filePath, strings.Join(parts, ", "), len(locs)) } // readFilePreview reads a few lines from a file around the given line. diff --git a/internal/server/handlers_lsp_test.go b/internal/server/handlers_lsp_test.go new file mode 100644 index 0000000..57ff408 --- /dev/null +++ b/internal/server/handlers_lsp_test.go @@ -0,0 +1,181 @@ +package server + +import ( + "strings" + "testing" + + "github.com/lukaszraczylo/mcp-filepuff/internal/lsp" +) + +// makeLocation builds an lsp.Location with a file:// URI. +func makeLocation(file string, line, col int) lsp.Location { + return lsp.Location{ + URI: "file://" + file, + Range: lsp.Range{ + Start: lsp.Position{Line: line - 1, Character: col - 1}, + End: lsp.Position{Line: line - 1, Character: col - 1}, + }, + } +} + +// groupLocations is a helper that mirrors the grouping logic in lspReferences. +func groupLocations(locations []lsp.Location) (map[string][]lsp.Location, []string) { + fileGroups := make(map[string][]lsp.Location) + fileOrder := make([]string, 0) + for _, loc := range locations { + filePath := lsp.URIToFile(loc.URI) + if _, seen := fileGroups[filePath]; !seen { + fileOrder = append(fileOrder, filePath) + } + fileGroups[filePath] = append(fileGroups[filePath], loc) + } + return fileGroups, fileOrder +} + +// TestFormatReferencesVerbose verifies the default (non-compact) format is unchanged. +func TestFormatReferencesVerbose(t *testing.T) { + locs := []lsp.Location{ + makeLocation("/a/foo.go", 12, 5), + makeLocation("/a/foo.go", 13, 8), + makeLocation("/a/bar.go", 15, 1), + } + groups, order := groupLocations(locs) + out := formatReferences(groups, order, len(locs), false, true) + + if !strings.Contains(out, "Found 3 reference(s):") { + t.Errorf("missing header, got:\n%s", out) + } + if !strings.Contains(out, "**") { + t.Errorf("verbose format should use **file** markers, got:\n%s", out) + } + if !strings.Contains(out, "L12:5") { + t.Errorf("missing L12:5 in verbose output, got:\n%s", out) + } + if !strings.Contains(out, "L13:8") { + t.Errorf("missing L13:8 in verbose output, got:\n%s", out) + } + if !strings.Contains(out, "L15:1") { + t.Errorf("missing L15:1 in verbose output, got:\n%s", out) + } +} + +// TestFormatReferencesCompactBasic verifies compact output for distinct lines. +func TestFormatReferencesCompactBasic(t *testing.T) { + locs := []lsp.Location{ + makeLocation("/a/foo.go", 12, 5), + makeLocation("/a/foo.go", 13, 8), + makeLocation("/a/foo.go", 15, 12), + } + groups, order := groupLocations(locs) + out := formatReferences(groups, order, len(locs), true, true) + + if !strings.Contains(out, "Found 3 reference(s):") { + t.Errorf("missing header, got:\n%s", out) + } + // Should contain "foo.go:[12:5, 13:8, 15:12] (3)" + if !strings.Contains(out, "12:5") { + t.Errorf("missing 12:5 in compact output, got:\n%s", out) + } + if !strings.Contains(out, "13:8") { + t.Errorf("missing 13:8 in compact output, got:\n%s", out) + } + if !strings.Contains(out, "15:12") { + t.Errorf("missing 15:12 in compact output, got:\n%s", out) + } + if !strings.Contains(out, "(3)") { + t.Errorf("missing (3) count, got:\n%s", out) + } + // Compact format must NOT use ** markers + if strings.Contains(out, "**") { + t.Errorf("compact format must not use ** markers, got:\n%s", out) + } + // Must not have L prefix + if strings.Contains(out, "L12") { + t.Errorf("compact format must not have L prefix, got:\n%s", out) + } +} + +// TestFormatReferencesCompactSameLineCollapse verifies same-line columns collapse. +func TestFormatReferencesCompactSameLineCollapse(t *testing.T) { + locs := []lsp.Location{ + makeLocation("/a/foo.go", 12, 5), + makeLocation("/a/foo.go", 12, 8), + makeLocation("/a/foo.go", 15, 3), + } + groups, order := groupLocations(locs) + out := formatReferences(groups, order, len(locs), true, false) + + // Line 12 has two refs → should be collapsed: 12:{5,8} + if !strings.Contains(out, "12:{5,8}") { + t.Errorf("same-line refs should collapse to 12:{5,8}, got:\n%s", out) + } + // Line 15 is single → 15:3 + if !strings.Contains(out, "15:3") { + t.Errorf("single ref on line 15 should be 15:3, got:\n%s", out) + } +} + +// TestFormatReferencesCompactMultiFile verifies compact output across multiple files. +func TestFormatReferencesCompactMultiFile(t *testing.T) { + locs := []lsp.Location{ + makeLocation("/a/foo.go", 5, 1), + makeLocation("/b/bar.go", 10, 3), + makeLocation("/b/bar.go", 10, 7), + } + groups, order := groupLocations(locs) + out := formatReferences(groups, order, len(locs), true, true) + + if !strings.Contains(out, "Found 3 reference(s):") { + t.Errorf("missing header, got:\n%s", out) + } + // foo.go: one ref + if !strings.Contains(out, "5:1") { + t.Errorf("missing 5:1 for foo.go, got:\n%s", out) + } + // bar.go: two refs on same line → collapsed + if !strings.Contains(out, "10:{3,7}") { + t.Errorf("missing 10:{3,7} collapse for bar.go, got:\n%s", out) + } +} + +// TestFormatReferencesCompactSingleRef verifies single-reference compact output. +func TestFormatReferencesCompactSingleRef(t *testing.T) { + locs := []lsp.Location{ + makeLocation("/a/only.go", 7, 2), + } + groups, order := groupLocations(locs) + out := formatReferences(groups, order, len(locs), true, false) + + if !strings.Contains(out, "7:2") { + t.Errorf("missing 7:2, got:\n%s", out) + } + if !strings.Contains(out, "(1)") { + t.Errorf("missing (1), got:\n%s", out) + } +} + +// TestFormatReferencesCompactNoLPrefix verifies the L prefix is absent in compact mode. +func TestFormatReferencesCompactNoLPrefix(t *testing.T) { + locs := []lsp.Location{ + makeLocation("/a/x.go", 3, 4), + } + groups, order := groupLocations(locs) + out := formatReferences(groups, order, len(locs), true, false) + + if strings.Contains(out, "L3") { + t.Errorf("compact output must not contain L prefix, got:\n%s", out) + } +} + +// TestFormatReferencesVerboseNoChange verifies compact=false preserves old L-prefix format. +func TestFormatReferencesVerbosePreservesLPrefix(t *testing.T) { + locs := []lsp.Location{ + makeLocation("/a/x.go", 3, 4), + } + groups, order := groupLocations(locs) + out := formatReferences(groups, order, len(locs), false, true) + + if !strings.Contains(out, "L3:4") { + t.Errorf("verbose output must contain L3:4, got:\n%s", out) + } +} diff --git a/internal/server/help_content.go b/internal/server/help_content.go new file mode 100644 index 0000000..f17f2fb --- /dev/null +++ b/internal/server/help_content.go @@ -0,0 +1,183 @@ +// Package server implements the MCP server for file operations. +package server + +// helpFileRead is the full flag documentation and examples for the file_read tool, +// served at filepuff://help/file_read. +const helpFileRead = "# file_read — flags and examples\n\n" + + "## Token-saving features\n\n" + + "| Flag | Effect |\n" + + "|------|--------|\n" + + "| `previous_etag` | Skip re-reading unchanged files. Returns `[unchanged, etag: ...]` if file is unchanged. |\n" + + "| `symbol_name` | Read only a named function/struct/class — eliminates an ast_query round-trip. |\n" + + "| `symbols_only=true` | Return only symbol list (~95% fewer tokens). Requires `include_ast=true`. Alias: `mode='symbols_only'`. |\n" + + "| `mode` | `full` (default) \\| `skeleton` (signatures + `{ ... }` stubs, bodies elided) \\| `symbols_only` |\n" + + "| `strip` | Remove content classes before line-numbering: `imports`, `license`, `block_comments`. Emits `[stripped: ...]` footer. |\n" + + "| `no_line_numbers=true` | Omit the ` 12│ ` line-number prefix (~10% savings). `line_number_interval=0` has the same effect. |\n" + + "| `line_number_interval=N` | Print line numbers only every N lines. |\n" + + "| `compact_line_numbers=true` | Use compact `12│` prefix instead of ` 12│ ` (no padding, no trailing space). |\n" + + "| `collapse_blank_lines=true` | Collapse consecutive blank lines to one. |\n" + + "| `max_lines=N` | Truncate output with omitted count notice. Applied after `line_start`/`line_end`. |\n" + + "| `paths=[...]` | Read multiple files in one call. Each file gets a `--- path ---` header. |\n\n" + + "All responses include `[etag: hex]` footer (8 hex chars) for use as `previous_etag` in subsequent reads.\n\n" + + "## Examples\n\n" + + "```json\n" + + "// Full file\n" + + `{"path": "main.go"}` + "\n\n" + + "// Etag check — returns unchanged notice if file hasn't changed\n" + + `{"path": "main.go", "previous_etag": "a3f9c2b1"}` + "\n\n" + + "// Read only one named symbol\n" + + `{"path": "server.go", "symbol_name": "handleFileRead"}` + "\n\n" + + "// Skeleton mode — signatures only, bodies elided\n" + + `{"path": "server.go", "mode": "skeleton"}` + "\n\n" + + "// Strip imports and license header\n" + + `{"path": "main.go", "strip": ["imports", "license"]}` + "\n\n" + + "// Batch read multiple files\n" + + `{"paths": ["a.go", "b.go"]}` + "\n\n" + + "// Specific line range\n" + + `{"path": "main.go", "line_start": 10, "line_end": 50}` + "\n" + + "```\n" + +// helpFileSearch is the full flag documentation and examples for the file_search tool, +// served at filepuff://help/file_search. +const helpFileSearch = "# file_search — flags and examples\n\n" + + "## Output format\n\n" + + "Matches grouped by file. Each file section has matching lines prefixed by `L{line}│` and context lines prefixed by ` │`. Zero matches: `No matches found.`\n\n" + + "## Flags\n\n" + + "| Flag | Effect |\n" + + "|------|--------|\n" + + "| `verbose=true` | Emit `Found N matches in M files:` preamble (v1 behaviour). Default: false. |\n" + + "| `cluster=true` | Coalesce consecutive match lines into ranges (`L12-14│ text`). Drops context lines for density. |\n" + + "| `cursor` | Opaque pagination token from a previous truncated response — fetches next page. |\n" + + "| `max_results` | Page size for pagination. Re-run with `cursor` to get next page. |\n" + + "| `context_lines` | Number of context lines around matches (default: 2). |\n" + + "| `ignore_case` | Case-insensitive search. |\n" + + "| `regex` | Treat pattern as regex (default: true). |\n" + + "| `file_types` | Restrict to file extensions, e.g. `[\"go\", \"ts\"]`. |\n" + + "| `paths` | Paths to search in (defaults to workspace root). |\n\n" + + "## Examples\n\n" + + "```json\n" + + "// Search for error-returning functions in Go files\n" + + `{"pattern": "func.*Error", "file_types": ["go"], "max_results": 20}` + "\n\n" + + "// Case-insensitive literal search\n" + + `{"pattern": "TODO", "ignore_case": true}` + "\n\n" + + "// Paginated search — fetch next page\n" + + `{"pattern": "import", "max_results": 50, "cursor": ""}` + "\n\n" + + "// Clustered — dense view of many matches\n" + + `{"pattern": "return err", "file_types": ["go"], "cluster": true}` + "\n" + + "```\n" + +// helpASTQuery is the full flag documentation and examples for the ast_query tool, +// served at filepuff://help/ast_query. +const helpASTQuery = "# ast_query — flags and examples\n\n" + + "## Output format\n\n" + + "Entries in format `**file:line** (node_type)` with code blocks and captured variables (`$NAME=value`). Zero matches: `No matches found.`\n\n" + + "## Flags\n\n" + + "| Flag | Effect |\n" + + "|------|--------|\n" + + "| `verbose=true` | Emit `Found N match(es):` preamble (v1 behaviour). Default: false. |\n" + + "| `format` | `verbose` (default, full code+captures) \\| `compact` (one line per match) \\| `location` (file:line only) |\n" + + "| `cursor` | Opaque pagination token from a previous truncated response — fetches next page. |\n" + + "| `max_results` | Page size (default: 100). |\n" + + "| `name_exact` | Exact symbol name to match. |\n" + + "| `name_matches` | Regex pattern to filter by name. |\n" + + "| `kind_in` | Node types to match (e.g. `function_declaration`, `class_declaration`). |\n" + + "| `paths` | Paths to search in (defaults to workspace root). |\n\n" + + "## Pattern placeholders\n\n" + + "| Placeholder | Meaning |\n" + + "|-------------|----------|\n" + + "| `$NAME` | Matches a single node, captures as `$NAME` |\n" + + "| `$$$ARGS` | Matches zero or more nodes (variadic capture) |\n" + + "| `$_` | Wildcard — matches any single node, no capture |\n\n" + + "## Examples\n\n" + + "```json\n" + + "// All Go functions returning error\n" + + `{"pattern": "func $NAME($$$ARGS) error", "language": "go"}` + "\n\n" + + "// Python classes\n" + + `{"pattern": "class $NAME: $$$BODY", "language": "python"}` + "\n\n" + + "// Specific named function\n" + + `{"pattern": "func $NAME($$$ARGS)", "language": "go", "name_exact": "NewServer"}` + "\n\n" + + "// Compact output — one line per match\n" + + `{"pattern": "func $NAME($$$ARGS) error", "language": "go", "format": "compact"}` + "\n" + + "```\n" + +// helpLSPQuery is the full flag documentation and examples for the lsp_query tool, +// served at filepuff://help/lsp_query. +const helpLSPQuery = "# lsp_query — flags and examples\n\n" + + "## Actions\n\n" + + "### hover\n" + + "Returns type/doc from LSP, falls back to AST node info. `verbose=true` adds `**Symbol Information**` header.\n\n" + + "### definition\n" + + "Returns `file:line:col` + 3-line code preview for each definition. `verbose=true` adds `Found N definition(s):` header.\n\n" + + "### references\n" + + "Returns references grouped by file. Flags:\n" + + "- `include_declaration` (default true) — include the declaration itself\n" + + "- `compact=true` — collapse to one line per file\n" + + "- `verbose=true` — add `Found N reference(s):` header\n\n" + + "Note: `include_declaration` and `compact` are errors when used with actions other than `references`.\n\n" + + "## Examples\n\n" + + "```json\n" + + "// Hover — type/doc at position\n" + + `{"action": "hover", "file": "server.go", "line": 45, "column": 6}` + "\n\n" + + "// Definition — where is this symbol defined?\n" + + `{"action": "definition", "file": "handler.go", "line": 23, "column": 10}` + "\n\n" + + "// References — all usages\n" + + `{"action": "references", "file": "types.go", "line": 5, "column": 6}` + "\n\n" + + "// References — compact (one line per file)\n" + + `{"action": "references", "file": "types.go", "line": 5, "column": 6, "compact": true}` + "\n" + + "```\n" + +// helpEditApply is the full flag documentation and examples for the edit_apply tool, +// served at filepuff://help/edit_apply. +const helpEditApply = "# edit_apply — flags and examples\n\n" + + "## Response format (`response` flag)\n\n" + + "| Value | Output |\n" + + "|-------|--------|\n" + + "| `count` (default) | `+3 -1` added/removed line counts only |\n" + + "| `diff` | Full unified diff of changes made |\n" + + "| `none` | Empty response (silent success) |\n\n" + + "`compact_response=true` is a deprecated alias for `response=\"count\"` kept for pre-v2 compatibility.\n\n" + + "For code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) syntax is validated before writing — the edit is rejected if it would produce invalid syntax.\n\n" + + "## Selector types\n\n" + + "### AST-mode selectors (code files)\n" + + "- `selector_kind` — AST node type (e.g. `function_declaration`, `class_declaration`)\n" + + "- `selector_name` — symbol name to match\n\n" + + "### Text-mode selectors (all files)\n" + + "- `selector_text` — exact text to match (must be unique, or use `selector_index`)\n" + + "- `selector_pattern` — regex pattern to match\n" + + "- `selector_line` / `selector_line_end` — line range\n\n" + + "### Shared\n" + + "- `selector_index` — index of match when multiple exist (default: 0)\n\n" + + "## Examples\n\n" + + "```json\n" + + "// AST mode — replace a named function\n" + + "{\n" + + ` "file": "main.go",` + "\n" + + ` "operation": "replace",` + "\n" + + ` "selector_kind": "function_declaration",` + "\n" + + ` "selector_name": "Hello",` + "\n" + + ` "new_content": "func Hello() {\n\treturn\n}"` + "\n" + + "}\n\n" + + "// Text mode — replace a markdown header\n" + + "{\n" + + ` "file": "README.md",` + "\n" + + ` "operation": "replace",` + "\n" + + ` "selector_text": "## Old Header",` + "\n" + + ` "new_content": "## New Header"` + "\n" + + "}\n\n" + + "// Line range replacement\n" + + "{\n" + + ` "file": "config.yaml",` + "\n" + + ` "operation": "replace",` + "\n" + + ` "selector_line": 5,` + "\n" + + ` "selector_line_end": 10,` + "\n" + + ` "new_content": "key: value"` + "\n" + + "}\n\n" + + "// Request full diff in response\n" + + "{\n" + + ` "file": "main.go",` + "\n" + + ` "operation": "replace",` + "\n" + + ` "selector_name": "Hello",` + "\n" + + ` "new_content": "func Hello() {}",` + "\n" + + ` "response": "diff"` + "\n" + + "}\n" + + "```\n" diff --git a/internal/server/integration_test.go b/internal/server/integration_test.go index 9ca7156..8e58a15 100644 --- a/internal/server/integration_test.go +++ b/internal/server/integration_test.go @@ -43,23 +43,7 @@ func Hello() string { ctx := context.Background() - // Test 1: Ping tool (health check) - t.Run("ping", func(t *testing.T) { - req := mcp.CallToolRequest{} - result, err := srv.handlePing(ctx, req) - if err != nil { - t.Errorf("handlePing() error = %v", err) - } - if result == nil { - t.Fatal("handlePing() returned nil") - return - } - if len(result.Content) == 0 { - t.Fatal("handlePing() returned empty content") - } - }) - - // Test 2: File read + // Test: File read (ping removed — Change 3) t.Run("file_read", func(t *testing.T) { req := mcp.CallToolRequest{} req.Params.Arguments = map[string]interface{}{ @@ -144,14 +128,8 @@ func TestMCPToolDiscovery(t *testing.T) { t.Fatal("MCP server not initialized") } - // Verify each expected tool works - ctx := context.Background() - - // Test ping tool - pingReq := mcp.CallToolRequest{} - if _, err := srv.handlePing(ctx, pingReq); err != nil { - t.Errorf("ping tool failed: %v", err) - } + // Ping tool removed (Change 3 — MCP protocol has own liveness check). + // Tools verified via integration tests in TestIntegrationFileOperations. } // TestMCPErrorResponses tests error handling following MCP spec. diff --git a/internal/server/resources.go b/internal/server/resources.go new file mode 100644 index 0000000..e11ce69 --- /dev/null +++ b/internal/server/resources.go @@ -0,0 +1,156 @@ +package server + +import ( + "context" + "fmt" + "net/url" + "os" + "strings" + + xxhash "github.com/cespare/xxhash/v2" + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" +) + +// helpResources maps a tool name to its help content constant. +var helpResources = map[string]string{ + "file_read": helpFileRead, + "file_search": helpFileSearch, + "ast_query": helpASTQuery, + "lsp_query": helpLSPQuery, + "edit_apply": helpEditApply, +} + +// registerResources registers one filepuff://help/ resource per tool. +// Each resource returns Markdown-formatted flag docs and examples. +func (s *Server) registerResources() { + for toolName, content := range helpResources { + uri := "filepuff://help/" + toolName + name := "help/" + toolName + description := "Flag documentation and examples for the " + toolName + " tool." + captured := content // capture for closure + + s.mcp.AddResource( + mcp.NewResource(uri, name, + mcp.WithResourceDescription(description), + mcp.WithMIMEType("text/markdown"), + ), + func(_ context.Context, _ mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: uri, + MIMEType: "text/markdown", + Text: captured, + }, + }, nil + }, + ) + } +} + +// readHelpResource is a convenience handler that can be used directly when a +// single resource handler is needed. It is kept exported for testability. +func readHelpResource(uri string) mcpserver.ResourceHandlerFunc { + return func(_ context.Context, _ mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + // Extract tool name from filepuff://help/ + const prefix = "filepuff://help/" + if len(uri) <= len(prefix) { + return nil, fmt.Errorf("invalid help URI: %s", uri) + } + toolName := uri[len(prefix):] + content, ok := helpResources[toolName] + if !ok { + return nil, fmt.Errorf("no help content for tool: %s", toolName) + } + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: uri, + MIMEType: "text/markdown", + Text: content, + }, + }, nil + } +} + +// registerReadResource registers the filepuff://read/{+path} resource template. +// The handler re-reads the file, validates the etag query param if provided, +// and returns the raw file content (no line-number formatting). +// +// URI format: filepuff://read/?etag= +// The etag param is optional. If supplied and the file has changed, the handler +// returns an error so the caller re-runs file_read to get a fresh ResourceLink. +func (s *Server) registerReadResource() { + const uriTemplate = "filepuff://read/{+path}" + + s.mcp.AddResourceTemplate( + mcp.NewResourceTemplate(uriTemplate, "file-read", + mcp.WithTemplateDescription("Raw content of a file previously read via file_read. "+ + "Fetch when file_read returns a ResourceLink instead of inlining content. "+ + "URI: filepuff://read/?etag="), + ), + func(_ context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return s.handleReadResource(req) + }, + ) +} + +// handleReadResource is the resource handler for filepuff://read/{+path} URIs. +func (s *Server) handleReadResource(req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + rawURI := req.Params.URI + + // Parse path and etag from the URI. + // URI shape: filepuff://read/[?etag=] + const scheme = "filepuff://read/" + if !strings.HasPrefix(rawURI, scheme) { + return nil, fmt.Errorf("invalid read resource URI: %s", rawURI) + } + rest := rawURI[len(scheme):] + + // Split off query string to get the path. + filePath := rest + var expectedEtag string + if qIdx := strings.IndexByte(rest, '?'); qIdx >= 0 { + filePath = rest[:qIdx] + qs, err := url.ParseQuery(rest[qIdx+1:]) + if err == nil { + expectedEtag = qs.Get("etag") + } + } + + if filePath == "" { + return nil, fmt.Errorf("read resource URI missing path") + } + + if !s.cfg.IsPathAllowed(filePath) { + return nil, fmt.Errorf("path is outside workspace root") + } + + content, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("file not found: %s", filePath) + } + if os.IsPermission(err) { + return nil, fmt.Errorf("permission denied: %s", filePath) + } + return nil, fmt.Errorf("error reading file: %s", filePath) + } + + // Validate etag if provided — detect stale references. + if expectedEtag != "" { + fullHash := fmt.Sprintf("%016x", xxhash.Sum64(content)) + currentEtag := fullHash[:8] + if expectedEtag != currentEtag && !strings.HasPrefix(fullHash, expectedEtag) && !strings.HasPrefix(expectedEtag, currentEtag) { + return nil, fmt.Errorf("file changed since ResourceLink was issued (expected etag %s, got %s); re-run file_read to get fresh content", expectedEtag, currentEtag) + } + } + + mimeType := detectMIMEType(filePath) + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: rawURI, + MIMEType: mimeType, + Text: string(content), + }, + }, nil +} diff --git a/internal/server/resources_test.go b/internal/server/resources_test.go new file mode 100644 index 0000000..2c52341 --- /dev/null +++ b/internal/server/resources_test.go @@ -0,0 +1,90 @@ +package server + +import ( + "context" + "strings" + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestRegisterResources_AllToolsHaveResource(t *testing.T) { + // Verify that registerResources wires up without panicking. + _ = newTestServer(t, t.TempDir()) + + expectedURIs := []string{ + "filepuff://help/file_read", + "filepuff://help/file_search", + "filepuff://help/ast_query", + "filepuff://help/lsp_query", + "filepuff://help/edit_apply", + } + + for _, uri := range expectedURIs { + t.Run(uri, func(t *testing.T) { + handler := readHelpResource(uri) + contents, err := handler(context.Background(), mcp.ReadResourceRequest{}) + if err != nil { + t.Fatalf("readHelpResource(%q) error = %v", uri, err) + } + if len(contents) == 0 { + t.Fatalf("readHelpResource(%q) returned empty contents", uri) + } + tc, ok := contents[0].(mcp.TextResourceContents) + if !ok { + t.Fatalf("readHelpResource(%q) contents[0] is not TextResourceContents", uri) + } + if tc.MIMEType != "text/markdown" { + t.Errorf("MIMEType = %q, want %q", tc.MIMEType, "text/markdown") + } + if len(tc.Text) == 0 { + t.Errorf("Text is empty for %q", uri) + } + if !strings.HasPrefix(tc.Text, "#") { + t.Errorf("expected markdown (# heading) for %q, got: %q", uri, tc.Text[:min(50, len(tc.Text))]) + } + if tc.URI != uri { + t.Errorf("URI = %q, want %q", tc.URI, uri) + } + }) + } +} + +func TestReadHelpResource_UnknownTool(t *testing.T) { + handler := readHelpResource("filepuff://help/nonexistent") + _, err := handler(context.Background(), mcp.ReadResourceRequest{}) + if err == nil { + t.Fatal("expected error for unknown tool, got nil") + } +} + +func TestReadHelpResource_InvalidURI(t *testing.T) { + handler := readHelpResource("filepuff://help/") + _, err := handler(context.Background(), mcp.ReadResourceRequest{}) + if err == nil { + t.Fatal("expected error for empty tool name, got nil") + } +} + +func TestHelpContent_NotEmpty(t *testing.T) { + cases := map[string]string{ + "file_read": helpFileRead, + "file_search": helpFileSearch, + "ast_query": helpASTQuery, + "lsp_query": helpLSPQuery, + "edit_apply": helpEditApply, + } + for name, content := range cases { + t.Run(name, func(t *testing.T) { + if len(content) == 0 { + t.Errorf("help content for %q is empty", name) + } + if !strings.Contains(content, "##") { + t.Errorf("expected markdown sections (##) in help content for %q", name) + } + if !strings.Contains(content, "```") { + t.Errorf("expected code fences in help content for %q", name) + } + }) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index de4d913..0552f23 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -33,16 +33,17 @@ const PreviewLineMaxLength = 100 // Server represents the MCP file operations server. type Server struct { - cfg *config.Config - logger *slog.Logger - mcp *server.MCPServer - searcher *search.Searcher - parser *parser.Registry - matcher *query.Matcher - lspManager *lsp.Manager - editor *edit.Engine - readSem chan struct{} // Semaphore for limiting concurrent file reads - querySem chan struct{} // Semaphore for limiting concurrent AST queries + cfg *config.Config + logger *slog.Logger + mcp *server.MCPServer + searcher *search.Searcher + parser *parser.Registry + matcher *query.Matcher + lspManager *lsp.Manager + editor *edit.Engine + readSem chan struct{} // Semaphore for limiting concurrent file reads + querySem chan struct{} // Semaphore for limiting concurrent AST queries + sessionPrefs sessionPrefsPtr // Atomic pointer; populated by OnAfterInitialize hook } // New creates a new MCP server instance. @@ -70,40 +71,48 @@ func New(cfg *config.Config, logger *slog.Logger) (*Server, error) { s.lspManager = lsp.NewManager(cfg.WorkspaceRoot, logger) } + // Build OnAfterInitialize hook that parses client capability prefs. + // Signature (from mcp-go v0.48.0 hooks.go): + // func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) + hooks := &server.Hooks{} + hooks.AddAfterInitialize(func(_ context.Context, _ any, msg *mcp.InitializeRequest, _ *mcp.InitializeResult) { + if msg == nil { + return + } + raw, _ := msg.Params.Capabilities.Experimental["filepuff"].(map[string]any) + prefs := ParseSessionPrefs(raw) + s.sessionPrefs.Store(&prefs) + }) + // Create MCP server mcpServer := server.NewMCPServer( "mcp-filepuff", - "1.0.0", + "2.0.0", server.WithLogging(), + server.WithHooks(hooks), ) s.mcp = mcpServer // Register tools s.registerTools() + // Register help resources (filepuff://help/) + s.registerResources() + + // Register filepuff://read/{+path} resource template for large-file access. + s.registerReadResource() + return s, nil } // registerTools registers all available tools with the MCP server. func (s *Server) registerTools() { - // Register ping tool for health checks - s.mcp.AddTool( - mcp.NewTool("ping", - mcp.WithDescription("Health check - returns pong to verify the server is running.\n\n"+ - "Returns: \"pong\" text string."), - mcp.WithReadOnlyHintAnnotation(true), - ), - s.handlePing, - ) - // Register file_search tool if s.searcher != nil { s.mcp.AddTool( mcp.NewTool("file_search", - mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines.\n\n"+ - "Returns: Results grouped by file with match context. Format: \"Found N matches in M files:\" followed by file sections, "+ - "each with matching lines prefixed by \"L{line}│\" and context lines prefixed by \" │\".\n\n"+ - "Example: {\"pattern\": \"func.*Error\", \"file_types\": [\"go\"], \"max_results\": 20}"), + mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines. "+ + "See resource filepuff://help/file_search for flags and examples."), mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("pattern", mcp.Required(), @@ -127,7 +136,16 @@ func (s *Server) registerTools() { mcp.Description("Number of context lines around matches (default: 2)"), ), mcp.WithNumber("max_results", - mcp.Description("Maximum number of results to return"), + mcp.Description("Maximum number of results to return (page size for pagination)"), + ), + mcp.WithBoolean("cluster", + mcp.Description("Coalesce consecutive match lines into ranges (L12-14│ text). Drops context lines. Default: false."), + ), + mcp.WithString("cursor", + mcp.Description("Pagination cursor from a previous truncated response. Pass back to fetch the next page."), + ), + mcp.WithBoolean("verbose", + mcp.Description("Emit \"Found N matches in M files:\" preamble. Default: false (v2 default)."), ), ), s.handleFileSearch, @@ -137,23 +155,8 @@ func (s *Server) registerTools() { // Register file_read tool s.mcp.AddTool( mcp.NewTool("file_read", - mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary.\n\n"+ - "Token-saving features:\n"+ - " previous_etag: skip re-reading unchanged files (returns '[unchanged, etag: ...]' if unchanged)\n"+ - " symbol_name: read only a named function/struct/class (eliminates ast_query round-trip)\n"+ - " symbols_only=true: return only symbol list, ~95% fewer tokens (requires include_ast=true)\n"+ - " no_line_numbers=true: strip the line-number prefix (~10%% savings)\n"+ - " line_number_interval=N: print line numbers only every N lines\n"+ - " collapse_blank_lines=true: collapse consecutive blank lines to one\n"+ - " max_lines=N: truncate output with omitted count notice\n"+ - " paths=[...]: read multiple files in one call\n\n"+ - "All responses include '[etag: hex]' footer for use as previous_etag in subsequent reads.\n\n"+ - "Examples:\n"+ - " Full file: {\"path\": \"main.go\"}\n"+ - " Etag check: {\"path\": \"main.go\", \"previous_etag\": \"a3f9c2b1\"}\n"+ - " By symbol: {\"path\": \"server.go\", \"symbol_name\": \"handleFileRead\"}\n"+ - " Batch: {\"paths\": [\"a.go\", \"b.go\"]}\n"+ - " Line range: {\"path\": \"main.go\", \"line_start\": 10, \"line_end\": 50}"), + mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary. "+ + "See resource filepuff://help/file_read for flags and examples."), mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("path", mcp.Description("Path to the file to read (required unless paths is provided)"), @@ -181,20 +184,36 @@ func (s *Server) registerTools() { mcp.Description("Include AST symbol summary (functions, classes, types, etc.)"), ), mcp.WithBoolean("symbols_only", - mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true."), + mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true. Alias: mode='symbols_only'."), + ), + mcp.WithString("mode", + mcp.Description("Output mode: 'full' (default, full file), 'skeleton' (signatures + { ... } stubs, bodies elided), 'symbols_only' (symbol list only, alias for symbols_only=true)."), + ), + mcp.WithArray("strip", + mcp.Description("Strip content classes before line-numbering. Values: 'imports' (remove import blocks), 'license' (remove leading license comment), 'block_comments' (remove /* */ and Python triple-quoted strings). Emits [stripped: ...] footer."), + mcp.WithStringItems(), ), mcp.WithNumber("max_lines", mcp.Description("Maximum number of lines to return (for token efficiency). Applied after line_start/line_end."), ), mcp.WithBoolean("no_line_numbers", - mcp.Description("Omit the ' 12│ ' line number prefix entirely. Saves ~10% tokens. line_number_interval=0 has the same effect."), + mcp.Description("Omit the ' 12\u2502 ' line number prefix entirely. Saves ~10% tokens. line_number_interval=0 has the same effect."), ), mcp.WithNumber("line_number_interval", mcp.Description("Print line numbers only every N lines (default: 1 = every line). E.g. 10 = anchor every 10th line plus first/last. 0 = no line numbers."), ), + mcp.WithBoolean("compact_line_numbers", + mcp.Description("Use compact line prefix '12\u2502' instead of ' 12\u2502 ' (no padding, no trailing space). Works with line_number_interval. Default off."), + ), mcp.WithBoolean("collapse_blank_lines", mcp.Description("Collapse runs of consecutive blank lines to a single blank line. Useful for token savings on heavily-spaced code."), ), + mcp.WithBoolean("force_inline", + mcp.Description("Always return file content inline, bypassing the resource-link threshold. Default: false."), + ), + mcp.WithNumber("max_inline_bytes", + mcp.Description("Per-call inline threshold override in bytes. If set, overrides server resource_link_threshold_bytes for this call only. 0 = use server default."), + ), ), s.handleFileRead, ) @@ -202,13 +221,8 @@ func (s *Server) registerTools() { // Register ast_query tool s.mcp.AddTool( mcp.NewTool("ast_query", - mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types.\n\n"+ - "Returns: \"Found N match(es):\" followed by entries in format \"**file:line** (node_type)\" with code blocks "+ - "and captured variables ($NAME=value). Returns \"No matches found.\" when no results.\n\n"+ - "Examples:\n"+ - " Go error funcs: {\"pattern\": \"func $NAME($$$ARGS) error\", \"language\": \"go\"}\n"+ - " Python classes: {\"pattern\": \"class $NAME: $$$BODY\", \"language\": \"python\"}\n"+ - " Named function: {\"pattern\": \"func $NAME($$$ARGS)\", \"language\": \"go\", \"name_exact\": \"NewServer\"}"), + mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types. "+ + "See resource filepuff://help/ast_query for flags and examples."), mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("pattern", mcp.Required(), @@ -233,7 +247,16 @@ func (s *Server) registerTools() { mcp.WithStringItems(), ), mcp.WithNumber("max_results", - mcp.Description("Maximum number of results to return (default: 100)"), + mcp.Description("Maximum number of results to return (default: 100, page size for pagination)"), + ), + mcp.WithString("format", + mcp.Description("Output format: \"verbose\" (default, full code+captures), \"compact\" (one line per match), \"location\" (file:line only)"), + ), + mcp.WithString("cursor", + mcp.Description("Pagination cursor from a previous truncated response. Pass back to fetch the next page."), + ), + mcp.WithBoolean("verbose", + mcp.Description("Emit \"Found N match(es):\" preamble. Default: false (v2 default)."), ), ), s.handleASTQuery, @@ -241,62 +264,15 @@ func (s *Server) registerTools() { // Register LSP-based tools if LSP is enabled if s.lspManager != nil { - // Register symbol_at tool s.mcp.AddTool( - mcp.NewTool("symbol_at", - mcp.WithDescription("Get information about the symbol at a specific position in a file. Returns type, documentation, and definition location using LSP when available.\n\n"+ - "Returns: \"**Symbol Information**\" followed by hover/type information from LSP, or \"**Symbol Information** (AST fallback)\" "+ - "with node type and text when LSP unavailable. Returns \"No symbol information available at this position.\" when nothing is found.\n\n"+ - "Example: {\"file\": \"server.go\", \"line\": 45, \"column\": 6}"), + mcp.NewTool("lsp_query", + mcp.WithDescription("Query LSP for symbol info, definition, or references at a specific file position. "+ + "See resource filepuff://help/lsp_query for flags and examples."), mcp.WithReadOnlyHintAnnotation(true), - mcp.WithString("file", + mcp.WithString("action", mcp.Required(), - mcp.Description("Path to the file"), + mcp.Description("LSP operation: hover | definition | references"), ), - mcp.WithNumber("line", - mcp.Required(), - mcp.Description("Line number (1-indexed)"), - ), - mcp.WithNumber("column", - mcp.Required(), - mcp.Description("Column number (1-indexed)"), - ), - ), - s.handleSymbolAt, - ) - - // Register find_definition tool - s.mcp.AddTool( - mcp.NewTool("find_definition", - mcp.WithDescription("Find the definition of the symbol at a specific position. Uses LSP to locate where a function, variable, type, etc. is defined.\n\n"+ - "Returns: \"Found N definition(s):\" with entries showing \"**file:line:column**\" and a 3-line code preview "+ - "with the target line marked by \">\". Returns \"No definition found.\" when the symbol has no definition.\n\n"+ - "Example: {\"file\": \"handler.go\", \"line\": 23, \"column\": 10}"), - mcp.WithReadOnlyHintAnnotation(true), - mcp.WithString("file", - mcp.Required(), - mcp.Description("Path to the file"), - ), - mcp.WithNumber("line", - mcp.Required(), - mcp.Description("Line number (1-indexed)"), - ), - mcp.WithNumber("column", - mcp.Required(), - mcp.Description("Column number (1-indexed)"), - ), - ), - s.handleFindDefinition, - ) - - // Register find_references tool - s.mcp.AddTool( - mcp.NewTool("find_references", - mcp.WithDescription("Find all references to the symbol at a specific position. Uses LSP to locate all usages of a function, variable, type, etc.\n\n"+ - "Returns: \"Found N reference(s):\" grouped by file, each showing \"**file** (count)\" with locations as "+ - "\"L{line}:{column}\". Returns \"No references found.\" when no usages exist.\n\n"+ - "Example: {\"file\": \"types.go\", \"line\": 5, \"column\": 6}"), - mcp.WithReadOnlyHintAnnotation(true), mcp.WithString("file", mcp.Required(), mcp.Description("Path to the file"), @@ -310,23 +286,24 @@ func (s *Server) registerTools() { mcp.Description("Column number (1-indexed)"), ), mcp.WithBoolean("include_declaration", - mcp.Description("Include the declaration in results (default: true)"), + mcp.Description("Include the declaration in results. Only valid for action=references (default: true)."), + ), + mcp.WithBoolean("compact", + mcp.Description("Compact output: one line per file with all refs in brackets. Only valid for action=references. Default: false."), + ), + mcp.WithBoolean("verbose", + mcp.Description("Emit count/header preamble. Applies to all actions. Default: false."), ), ), - s.handleFindReferences, + s.handleLSPQuery, ) } // Register edit tools s.mcp.AddTool( mcp.NewTool("edit_apply", - mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.).\n\n"+ - "Returns: \"**Edit Applied Successfully**\" followed by a unified diff of the changes made. "+ - "For code files, validates syntax before writing — returns an error if the edit would produce invalid syntax.\n\n"+ - "Examples:\n"+ - " AST mode: {\"file\": \"main.go\", \"operation\": \"replace\", \"selector_kind\": \"function_declaration\", \"selector_name\": \"Hello\", \"new_content\": \"func Hello() {\\n\\treturn\\n}\"}\n"+ - " Text mode: {\"file\": \"README.md\", \"operation\": \"replace\", \"selector_text\": \"## Old Header\", \"new_content\": \"## New Header\"}\n"+ - " Line range: {\"file\": \"config.yaml\", \"operation\": \"replace\", \"selector_line\": 5, \"selector_line_end\": 10, \"new_content\": \"key: value\"}"), + mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.). "+ + "See resource filepuff://help/edit_apply for flags and examples."), mcp.WithString("file", mcp.Required(), mcp.Description("Path to the file to edit"), @@ -362,19 +339,17 @@ func (s *Server) registerTools() { mcp.WithString("selector_pattern", mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."), ), + mcp.WithString("response", + mcp.Description("Response format: \"count\" (default, \"+3 -1\" line counts), \"diff\" (full unified diff), \"none\" (empty). Default: count."), + ), mcp.WithBoolean("compact_response", - mcp.Description("Return only the modified symbol's content instead of a full diff. Requires selector_name. Saves tokens on large-file edits."), + mcp.Description("Deprecated: use response=count. Alias for response=\"count\" kept for pre-v2 compatibility."), ), ), s.handleEditApply, ) } -// handlePing handles the ping health check tool. -func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return mcp.NewToolResultText("pong"), nil -} - // Run starts the MCP server and blocks until shutdown. func (s *Server) Run(ctx context.Context) error { // Set up signal handling for graceful shutdown diff --git a/internal/server/server_test.go b/internal/server/server_test.go index caea7f4..fa4a94c 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/lukaszraczylo/mcp-filepuff/internal/config" "github.com/mark3labs/mcp-go/mcp" @@ -49,45 +50,6 @@ func TestNew(t *testing.T) { } } -func TestHandlePing(t *testing.T) { - tmpDir := t.TempDir() - cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false} - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) - - srv, err := New(cfg, logger) - if err != nil { - t.Fatalf("New() error = %v", err) - } - - ctx := context.Background() - req := mcp.CallToolRequest{} - - result, err := srv.handlePing(ctx, req) - if err != nil { - t.Errorf("handlePing() error = %v", err) - } - - if result == nil { - t.Fatal("handlePing() returned nil result") - return - } - - // Check that the result contains "pong" - contents := result.Content - if len(contents) == 0 { - t.Fatal("handlePing() returned empty content") - } - - textContent, ok := contents[0].(mcp.TextContent) - if !ok { - t.Fatal("handlePing() did not return text content") - } - - if textContent.Text != "pong" { - t.Errorf("handlePing() = %v, want 'pong'", textContent.Text) - } -} - func TestHandleFileRead(t *testing.T) { tmpDir := t.TempDir() @@ -484,3 +446,261 @@ func TestSplitLinesLongLine(t *testing.T) { t.Error("the 500KB long line was not found in splitLines output") } } + +// TestHandleEditApplyResponseCount verifies the default response=count format "+N -M". +func TestHandleEditApplyResponseCount(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + content := `package main + +func Hello() { + println("Hello") +} +` + if err := os.WriteFile(testFile, []byte(content), 0600); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "file": testFile, + "operation": "replace", + "selector_kind": "function_declaration", + "selector_name": "Hello", + "new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}", + // no response flag → default "count" + } + + result, err := srv.handleEditApply(ctx, req) + if err != nil { + t.Fatalf("handleEditApply error = %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("handleEditApply returned empty result") + } + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("expected text content") + } + // Should be "+N -M" format + text := textContent.Text + if !strings.HasPrefix(text, "+") { + t.Errorf("response=count should start with +, got: %q", text) + } + if !strings.Contains(text, " -") { + t.Errorf("response=count should contain -N, got: %q", text) + } + // Must NOT contain diff syntax or old preamble + if strings.Contains(text, "@@") || strings.Contains(text, "Edit Applied") { + t.Errorf("response=count must not contain diff markers, got: %q", text) + } +} + +// TestHandleEditApplyResponseDiff verifies response=diff returns unified diff. +func TestHandleEditApplyResponseDiff(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + content := `package main + +func Hello() { + println("Hello") +} +` + if err := os.WriteFile(testFile, []byte(content), 0600); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "file": testFile, + "operation": "replace", + "selector_kind": "function_declaration", + "selector_name": "Hello", + "new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}", + "response": "diff", + } + + result, err := srv.handleEditApply(ctx, req) + if err != nil { + t.Fatalf("handleEditApply error = %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("handleEditApply returned empty result") + } + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("expected text content") + } + text := textContent.Text + if !strings.Contains(text, "diff") { + t.Errorf("response=diff should contain diff, got: %q", text) + } + // Must NOT have old "Edit Applied Successfully" preamble + if strings.Contains(text, "Edit Applied Successfully") { + t.Errorf("v2 diff should not have old preamble, got: %q", text) + } +} + +// TestHandleEditApplyResponseNone verifies response=none returns empty string. +func TestHandleEditApplyResponseNone(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + content := `package main + +func Hello() { + println("Hello") +} +` + if err := os.WriteFile(testFile, []byte(content), 0600); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "file": testFile, + "operation": "replace", + "selector_kind": "function_declaration", + "selector_name": "Hello", + "new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}", + "response": "none", + } + + result, err := srv.handleEditApply(ctx, req) + if err != nil { + t.Fatalf("handleEditApply error = %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("handleEditApply returned empty result") + } + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("expected text content") + } + if textContent.Text != "" { + t.Errorf("response=none should return empty string, got: %q", textContent.Text) + } +} + +// TestHandleFileReadBatchDedup verifies that identical files in batch mode emit [duplicate of ...]. +func TestHandleFileReadBatchDedup(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "a.go") + content := `package main + +func Hello() {} +` + if err := os.WriteFile(testFile, []byte(content), 0600); err != nil { + t.Fatalf("failed to write a.go: %v", err) + } + // Make b.go with identical content + testFile2 := filepath.Join(tmpDir, "b.go") + if err := os.WriteFile(testFile2, []byte(content), 0600); err != nil { + t.Fatalf("failed to write b.go: %v", err) + } + // c.go with different content + testFile3 := filepath.Join(tmpDir, "c.go") + if err := os.WriteFile(testFile3, []byte("package main\n"), 0600); err != nil { + t.Fatalf("failed to write c.go: %v", err) + } + + cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false, MaxFileSize: 1024 * 1024} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "paths": []interface{}{testFile, testFile2, testFile3}, + } + + result, err := srv.handleFileRead(ctx, req) + if err != nil { + t.Fatalf("handleFileRead() error = %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("handleFileRead() returned empty result") + } + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("expected text content") + } + text := textContent.Text + if !strings.Contains(text, "[duplicate of") { + t.Errorf("expected duplicate pointer for b.go, got:\n%s", text) + } + // a.go should have full content + if !strings.Contains(text, "--- "+testFile+" ---") { + t.Errorf("expected a.go header, got:\n%s", text) + } + // c.go should have full content (different hash) + if !strings.Contains(text, "--- "+testFile3+" ---") { + t.Errorf("expected c.go header, got:\n%s", text) + } +} + +// TestHandleFileSearchVerbose verifies verbose=true emits "Found N matches in M files:" preamble. +func TestHandleFileSearchVerbose(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + if err := os.WriteFile(testFile, []byte("package main\n\nfunc Hello() {}\n"), 0600); err != nil { + t.Fatalf("write test file: %v", err) + } + cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false, MaxFileSize: 1024 * 1024, SearchTimeout: 10 * time.Second} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + if srv.searcher == nil { + t.Skip("ripgrep not available") + } + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "Hello", + "paths": []interface{}{tmpDir}, + "verbose": true, + } + result, err := srv.handleFileSearch(ctx, req) + if err != nil { + t.Fatalf("handleFileSearch error = %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("handleFileSearch returned empty") + } + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("expected text content") + } + if !strings.Contains(textContent.Text, "Found ") { + t.Errorf("verbose=true should emit preamble, got:\n%s", textContent.Text) + } +} diff --git a/internal/server/session.go b/internal/server/session.go new file mode 100644 index 0000000..d851a81 --- /dev/null +++ b/internal/server/session.go @@ -0,0 +1,145 @@ +// Package server implements the MCP server for file operations. +package server + +import ( + "sync/atomic" + "unsafe" +) + +// SessionPrefs holds client-declared session-wide preferences parsed from +// InitializeRequest.Params.Capabilities.Experimental["filepuff"]. +// +// These act as defaults; explicit per-call flags always override them. +// +// Supported keys in the "filepuff" experimental map: +// +// terse bool — no-op (v2 default is already terse; reserved for future) +// default_format string — ast_query format default ("verbose"|"compact"|"location") +// default_max_results int — file_search and ast_query max_results when not supplied +// default_cluster bool — file_search cluster default +// compact_refs bool — lsp_query references compact default +// line_numbers string — file_read line prefix default ("none"|"compact"|"full") +// resource_link_threshold int — per-session override for cfg.ResourceLinkThresholdBytes +type SessionPrefs struct { + // ASTQueryFormat is the default format for ast_query ("verbose", "compact", "location"). + // Empty string means "use handler built-in default". + ASTQueryFormat string + + // DefaultMaxResults is the default max_results for file_search and ast_query. + // 0 means "use handler built-in default". + DefaultMaxResults int + + // DefaultCluster is the default cluster flag for file_search. + // nil means "use handler built-in default (false)". + DefaultCluster *bool + + // CompactRefs is the default compact flag for lsp_query action=references. + // nil means "use handler built-in default (false)". + CompactRefs *bool + + // LineNumbers controls the file_read line prefix default. + // "" = use handler built-in default (full). + // "none" = no line numbers. + // "compact" = compact prefix (N│content). + // "full" = standard padded prefix ( N│ content). + LineNumbers string + + // ResourceLinkThreshold overrides cfg.ResourceLinkThresholdBytes for this session. + // 0 means "use config default". + ResourceLinkThreshold int +} + +// boolPtr returns a pointer to a bool value. +func boolPtr(b bool) *bool { return &b } + +// ParseSessionPrefs parses the raw map from +// InitializeRequest.Params.Capabilities.Experimental["filepuff"]. +// Unknown keys are silently ignored. Type mismatches for individual keys are +// silently ignored (key is treated as absent). Returns zero-value SessionPrefs +// when raw is nil or empty — callers should treat zero values as "use built-in defaults". +func ParseSessionPrefs(raw map[string]any) SessionPrefs { + if len(raw) == 0 { + return SessionPrefs{} + } + + var p SessionPrefs + + if v, ok := raw["default_format"]; ok { + if s, ok := v.(string); ok { + switch s { + case "verbose", "compact", "location": + p.ASTQueryFormat = s + } + } + } + + if v, ok := raw["default_max_results"]; ok { + if n := toInt(v); n > 0 { + p.DefaultMaxResults = n + } + } + + if v, ok := raw["default_cluster"]; ok { + if b, ok := v.(bool); ok { + p.DefaultCluster = boolPtr(b) + } + } + + if v, ok := raw["compact_refs"]; ok { + if b, ok := v.(bool); ok { + p.CompactRefs = boolPtr(b) + } + } + + if v, ok := raw["line_numbers"]; ok { + if s, ok := v.(string); ok { + switch s { + case "none", "compact", "full": + p.LineNumbers = s + } + } + } + + if v, ok := raw["resource_link_threshold"]; ok { + if n := toInt(v); n >= 0 { + p.ResourceLinkThreshold = n + } + } + + return p +} + +// toInt converts numeric JSON-decoded values (float64, int, int64) to int. +// Returns 0 for unsupported types or negative values. +func toInt(v any) int { + switch n := v.(type) { + case float64: + if n >= 0 { + return int(n) + } + case int: + if n >= 0 { + return n + } + case int64: + if n >= 0 { + return int(n) + } + } + return 0 +} + +// sessionPrefsPtr is an atomic pointer helper for thread-safe access to *SessionPrefs. +// We use atomic store/load so the hook (called once at init) and handlers (many goroutines) +// never race. The prefs are write-once after initialization. +type sessionPrefsPtr struct { + p unsafe.Pointer // *SessionPrefs +} + +func (sp *sessionPrefsPtr) Store(prefs *SessionPrefs) { + atomic.StorePointer(&sp.p, unsafe.Pointer(prefs)) +} + +func (sp *sessionPrefsPtr) Load() *SessionPrefs { + return (*SessionPrefs)(atomic.LoadPointer(&sp.p)) +} diff --git a/internal/server/session_integration_test.go b/internal/server/session_integration_test.go new file mode 100644 index 0000000..919bfd9 --- /dev/null +++ b/internal/server/session_integration_test.go @@ -0,0 +1,313 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/lukaszraczylo/mcp-filepuff/internal/config" + "github.com/mark3labs/mcp-go/mcp" +) + +// ---- Session prefs integration tests ---- + +// setSessionPrefs injects prefs directly on the server (bypasses MCP hook machinery). +func setSessionPrefs(srv *Server, prefs SessionPrefs) { + srv.sessionPrefs.Store(&prefs) +} + +// TestSessionPrefsFileReadLineNumbersNone verifies that session pref line_numbers=none +// disables line-number prefixes and is overridden by explicit compact_line_numbers=true. +func TestSessionPrefsFileReadLineNumbersNone(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + content := "package main\n\nfunc Foo() {}\n" + if err := os.WriteFile(testFile, []byte(content), 0600); err != nil { + t.Fatalf("write file: %v", err) + } + cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New: %v", err) + } + setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"line_numbers": "none"})) + + ctx := context.Background() + + // Without explicit override: session pref should suppress line numbers. + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"path": testFile} + result, err := srv.handleFileRead(ctx, req) + if err != nil { + t.Fatalf("handleFileRead: %v", err) + } + text := result.Content[0].(mcp.TextContent).Text + // Standard line-number format is " 1│ "; with no_line_numbers it's absent. + if strings.Contains(text, " 1│") { + t.Errorf("session line_numbers=none: expected no line-number prefix, got:\n%s", text) + } + + // Explicit per-call compact_line_numbers=true should override session none. + req2 := mcp.CallToolRequest{} + req2.Params.Arguments = map[string]interface{}{ + "path": testFile, + "compact_line_numbers": true, + } + result2, err := srv.handleFileRead(ctx, req2) + if err != nil { + t.Fatalf("handleFileRead (explicit compact): %v", err) + } + text2 := result2.Content[0].(mcp.TextContent).Text + // Compact format emits "1│" prefix. + if !strings.Contains(text2, "\u2502") { + t.Errorf("explicit compact_line_numbers should override session none, got:\n%s", text2) + } +} + +// TestSessionPrefsFileReadLineNumbersCompact verifies line_numbers=compact session pref. +func TestSessionPrefsFileReadLineNumbersCompact(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "x.go") + if err := os.WriteFile(testFile, []byte("package main\nfunc Bar() {}\n"), 0600); err != nil { + t.Fatalf("write: %v", err) + } + cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, _ := New(cfg, logger) + setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"line_numbers": "compact"})) + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"path": testFile} + result, _ := srv.handleFileRead(ctx, req) + text := result.Content[0].(mcp.TextContent).Text + // Standard padded prefix is " 1│ "; compact is "1│". + if strings.Contains(text, " 1\u2502") { + t.Errorf("session line_numbers=compact should use compact prefix, got:\n%s", text) + } + // Should still have the │ separator somewhere. + if !strings.Contains(text, "\u2502") { + t.Errorf("session line_numbers=compact should still have \u2502 separator, got:\n%s", text) + } +} + +// TestSessionPrefsResourceLinkThreshold verifies per-session threshold override. +func TestSessionPrefsResourceLinkThreshold(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "big.go") + var sb strings.Builder + sb.WriteString("package main\n\nfunc Foo() {\n") + for i := 0; i < 15; i++ { + sb.WriteString("// comment line\n") + } + sb.WriteString("}\n") + if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil { + t.Fatalf("write: %v", err) + } + + // Config threshold = 0 (disabled) so content is always inlined by default. + cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20, ResourceLinkThresholdBytes: 0} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, _ := New(cfg, logger) + + // Set session threshold = 10 bytes (tiny), so any real file triggers resource-link. + setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"resource_link_threshold": float64(10)})) + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"path": testFile} + result, _ := srv.handleFileRead(ctx, req) + + if len(result.Content) == 0 { + t.Fatal("expected content") + } + _, isLink := result.Content[0].(mcp.ResourceLink) + _, isText := result.Content[0].(mcp.TextContent) + if !isLink && isText { + t.Error("expected ResourceLink when session threshold is very small, got TextContent") + } + + // force_inline should still bypass even a session threshold. + req2 := mcp.CallToolRequest{} + req2.Params.Arguments = map[string]interface{}{"path": testFile, "force_inline": true} + result2, _ := srv.handleFileRead(ctx, req2) + if _, ok := result2.Content[0].(mcp.TextContent); !ok { + t.Error("force_inline=true should bypass session threshold and return TextContent") + } +} + +// TestSessionPrefsASTQueryFormat verifies default_format session pref for ast_query. +func TestSessionPrefsASTQueryFormat(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + if err := os.WriteFile(testFile, []byte("package main\n\nfunc Greet() {}\n"), 0600); err != nil { + t.Fatalf("write: %v", err) + } + cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, _ := New(cfg, logger) + setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_format": "compact"})) + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME()", + "language": "go", + "paths": []interface{}{tmpDir}, + // no format key → session default should apply + } + result, err := srv.handleASTQuery(ctx, req) + if err != nil { + t.Fatalf("handleASTQuery: %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("empty result") + } + text := result.Content[0].(mcp.TextContent).Text + // Compact format emits one-line results without "**file:line**" markers. + if strings.Contains(text, "**") { + t.Errorf("session default_format=compact: expected compact output (no **), got:\n%s", text) + } + + // Explicit format=verbose should override session compact. + req2 := mcp.CallToolRequest{} + req2.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME()", + "language": "go", + "paths": []interface{}{tmpDir}, + "format": "verbose", + } + result2, _ := srv.handleASTQuery(ctx, req2) + if result2 != nil && len(result2.Content) > 0 { + text2 := result2.Content[0].(mcp.TextContent).Text + if !strings.Contains(text2, "**") { + t.Errorf("explicit format=verbose should override session compact, got:\n%s", text2) + } + } +} + +// TestSessionPrefsASTQueryMaxResults verifies default_max_results for ast_query. +func TestSessionPrefsASTQueryMaxResults(t *testing.T) { + tmpDir := t.TempDir() + // Build a file with 5 functions. + var sb strings.Builder + sb.WriteString("package main\n\n") + for i := 0; i < 5; i++ { + sb.WriteString(fmt.Sprintf("func Fn%c() {}\n\n", rune('A'+i))) + } + testFile := filepath.Join(tmpDir, "many.go") + if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil { + t.Fatalf("write: %v", err) + } + cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, _ := New(cfg, logger) + // Session pref: max 2 results. + setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_max_results": float64(2)})) + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME()", + "language": "go", + "paths": []interface{}{tmpDir}, + // no max_results → session pref of 2 should apply + } + result, err := srv.handleASTQuery(ctx, req) + if err != nil { + t.Fatalf("handleASTQuery: %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("empty result") + } + text := result.Content[0].(mcp.TextContent).Text + // With 5 funcs and max=2, output should mention remaining. + if !strings.Contains(text, "remaining") && !strings.Contains(text, "cursor") { + t.Errorf("session max_results=2 with 5 matches should produce cursor line, got:\n%s", text) + } +} + +// TestSessionPrefsFileSearchDefaultCluster verifies default_cluster session pref. +func TestSessionPrefsFileSearchDefaultCluster(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "x.go") + content := "package main\n\nfunc Foo() {}\nfunc Foo2() {}\n" + if err := os.WriteFile(testFile, []byte(content), 0600); err != nil { + t.Fatalf("write: %v", err) + } + cfg := &config.Config{ + WorkspaceRoot: tmpDir, + MaxFileSize: 1 << 20, + SearchTimeout: 10 * time.Second, + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New: %v", err) + } + if srv.searcher == nil { + t.Skip("ripgrep not available") + } + setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_cluster": true})) + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func", + "paths": []interface{}{tmpDir}, + // no cluster flag → session default_cluster=true should apply + } + result, err := srv.handleFileSearch(ctx, req) + if err != nil { + t.Fatalf("handleFileSearch: %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Skip("search returned no results") + } + // Verify call succeeded (cluster behaviour is ripgrep-version dependent). + _ = result.Content[0].(mcp.TextContent).Text +} + +// TestSessionPrefsMaxResultsExplicitOverride verifies explicit call-time max_results +// overrides session pref for ast_query. +func TestSessionPrefsMaxResultsExplicitOverride(t *testing.T) { + tmpDir := t.TempDir() + var sb strings.Builder + sb.WriteString("package main\n\n") + for i := 0; i < 5; i++ { + sb.WriteString(fmt.Sprintf("func Fn%c() {}\n\n", rune('A'+i))) + } + testFile := filepath.Join(tmpDir, "many.go") + if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil { + t.Fatalf("write: %v", err) + } + cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20} + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + srv, _ := New(cfg, logger) + // Session wants 2, caller supplies 10 — all 5 should fit without cursor. + setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_max_results": float64(2)})) + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "pattern": "func $NAME()", + "language": "go", + "paths": []interface{}{tmpDir}, + "max_results": 10, // explicit override + } + result, _ := srv.handleASTQuery(ctx, req) + if result == nil || len(result.Content) == 0 { + t.Fatal("empty result") + } + text := result.Content[0].(mcp.TextContent).Text + // With max_results=10 and only 5 funcs, no cursor line expected. + if strings.Contains(text, "remaining") { + t.Errorf("explicit max_results=10 should override session 2; 5 funcs fit; no cursor expected, got:\n%s", text) + } +} diff --git a/internal/server/session_test.go b/internal/server/session_test.go new file mode 100644 index 0000000..f2fcfa9 --- /dev/null +++ b/internal/server/session_test.go @@ -0,0 +1,186 @@ +package server + +import ( + "testing" +) + +// TestParseSessionPrefsEmpty verifies zero-value result for nil/empty input. +func TestParseSessionPrefsEmpty(t *testing.T) { + p := ParseSessionPrefs(nil) + if p.ASTQueryFormat != "" { + t.Errorf("ASTQueryFormat: want \"\", got %q", p.ASTQueryFormat) + } + if p.DefaultMaxResults != 0 { + t.Errorf("DefaultMaxResults: want 0, got %d", p.DefaultMaxResults) + } + if p.DefaultCluster != nil { + t.Errorf("DefaultCluster: want nil, got %v", *p.DefaultCluster) + } + if p.CompactRefs != nil { + t.Errorf("CompactRefs: want nil, got %v", *p.CompactRefs) + } + if p.LineNumbers != "" { + t.Errorf("LineNumbers: want \"\", got %q", p.LineNumbers) + } + if p.ResourceLinkThreshold != 0 { + t.Errorf("ResourceLinkThreshold: want 0, got %d", p.ResourceLinkThreshold) + } + + // Also test with an empty (non-nil) map. + p2 := ParseSessionPrefs(map[string]any{}) + if p2.ASTQueryFormat != "" || p2.DefaultMaxResults != 0 { + t.Error("empty map should produce zero-value prefs") + } +} + +// TestParseSessionPrefsAllFields verifies full round-trip with all supported keys. +func TestParseSessionPrefsAllFields(t *testing.T) { + raw := map[string]any{ + "terse": true, // no-op; should not produce an error + "default_format": "compact", + "default_max_results": float64(50), // JSON numbers decode as float64 + "default_cluster": true, + "compact_refs": true, + "line_numbers": "none", + "resource_link_threshold": float64(32768), + } + p := ParseSessionPrefs(raw) + + if p.ASTQueryFormat != "compact" { + t.Errorf("ASTQueryFormat: want \"compact\", got %q", p.ASTQueryFormat) + } + if p.DefaultMaxResults != 50 { + t.Errorf("DefaultMaxResults: want 50, got %d", p.DefaultMaxResults) + } + if p.DefaultCluster == nil || !*p.DefaultCluster { + t.Errorf("DefaultCluster: want true, got %v", p.DefaultCluster) + } + if p.CompactRefs == nil || !*p.CompactRefs { + t.Errorf("CompactRefs: want true, got %v", p.CompactRefs) + } + if p.LineNumbers != "none" { + t.Errorf("LineNumbers: want \"none\", got %q", p.LineNumbers) + } + if p.ResourceLinkThreshold != 32768 { + t.Errorf("ResourceLinkThreshold: want 32768, got %d", p.ResourceLinkThreshold) + } +} + +// TestParseSessionPrefsLineNumbersVariants tests all valid line_numbers values. +func TestParseSessionPrefsLineNumbersVariants(t *testing.T) { + for _, want := range []string{"none", "compact", "full"} { + p := ParseSessionPrefs(map[string]any{"line_numbers": want}) + if p.LineNumbers != want { + t.Errorf("line_numbers=%q: got %q", want, p.LineNumbers) + } + } + // Invalid value → ignored (empty string). + p := ParseSessionPrefs(map[string]any{"line_numbers": "bogus"}) + if p.LineNumbers != "" { + t.Errorf("invalid line_numbers should be ignored, got %q", p.LineNumbers) + } +} + +// TestParseSessionPrefsFormatVariants tests all valid default_format values. +func TestParseSessionPrefsFormatVariants(t *testing.T) { + for _, want := range []string{"verbose", "compact", "location"} { + p := ParseSessionPrefs(map[string]any{"default_format": want}) + if p.ASTQueryFormat != want { + t.Errorf("default_format=%q: got %q", want, p.ASTQueryFormat) + } + } + // Invalid value → ignored. + p := ParseSessionPrefs(map[string]any{"default_format": "yaml"}) + if p.ASTQueryFormat != "" { + t.Errorf("invalid format should be ignored, got %q", p.ASTQueryFormat) + } +} + +// TestParseSessionPrefsTypeMismatch verifies that wrong types are silently ignored. +func TestParseSessionPrefsTypeMismatch(t *testing.T) { + raw := map[string]any{ + "default_format": 123, // wrong type (int instead of string) + "default_max_results": "fifty", // wrong type + "default_cluster": "yes", // wrong type (string instead of bool) + "compact_refs": 42, // wrong type + "line_numbers": true, // wrong type + "resource_link_threshold": "big", // wrong type + } + p := ParseSessionPrefs(raw) + if p.ASTQueryFormat != "" { + t.Errorf("type mismatch for format should produce empty, got %q", p.ASTQueryFormat) + } + if p.DefaultMaxResults != 0 { + t.Errorf("type mismatch for max_results should produce 0, got %d", p.DefaultMaxResults) + } + if p.DefaultCluster != nil { + t.Errorf("type mismatch for cluster should produce nil") + } + if p.CompactRefs != nil { + t.Errorf("type mismatch for compact_refs should produce nil") + } + if p.LineNumbers != "" { + t.Errorf("type mismatch for line_numbers should produce empty, got %q", p.LineNumbers) + } + if p.ResourceLinkThreshold != 0 { + t.Errorf("type mismatch for threshold should produce 0, got %d", p.ResourceLinkThreshold) + } +} + +// TestParseSessionPrefsNegativeValues verifies negative numbers are rejected. +func TestParseSessionPrefsNegativeValues(t *testing.T) { + p := ParseSessionPrefs(map[string]any{ + "default_max_results": float64(-5), + "resource_link_threshold": float64(-1), + }) + if p.DefaultMaxResults != 0 { + t.Errorf("negative max_results should be rejected, got %d", p.DefaultMaxResults) + } + if p.ResourceLinkThreshold != 0 { + t.Errorf("negative threshold should be rejected, got %d", p.ResourceLinkThreshold) + } +} + +// TestParseSessionPrefsIntCoercion verifies int and int64 inputs also work. +func TestParseSessionPrefsIntCoercion(t *testing.T) { + p := ParseSessionPrefs(map[string]any{ + "default_max_results": int(25), + "resource_link_threshold": int64(16384), + }) + if p.DefaultMaxResults != 25 { + t.Errorf("int max_results: want 25, got %d", p.DefaultMaxResults) + } + if p.ResourceLinkThreshold != 16384 { + t.Errorf("int64 threshold: want 16384, got %d", p.ResourceLinkThreshold) + } +} + +// TestParseSessionPrefsClusterFalse ensures default_cluster=false stores a non-nil false. +func TestParseSessionPrefsClusterFalse(t *testing.T) { + p := ParseSessionPrefs(map[string]any{"default_cluster": false}) + if p.DefaultCluster == nil { + t.Error("default_cluster=false should store non-nil pointer") + } + if *p.DefaultCluster != false { + t.Error("default_cluster=false: want false pointer") + } +} + +// TestSessionPrefsAtomicStore verifies sessionPrefsPtr is readable after Store. +func TestSessionPrefsAtomicStore(t *testing.T) { + var sp sessionPrefsPtr + if sp.Load() != nil { + t.Error("uninitialised Load() should return nil") + } + + prefs := ParseSessionPrefs(map[string]any{"default_format": "compact"}) + sp.Store(&prefs) + + loaded := sp.Load() + if loaded == nil { + t.Fatal("Load() returned nil after Store") + } + if loaded.ASTQueryFormat != "compact" { + t.Errorf("loaded format: want compact, got %q", loaded.ASTQueryFormat) + } +}