V2/token optimization (#11)

* v2.0: token-optimization overhaul

Additive (backward-compatible flags):
- file_read: skeleton mode, strip (imports/license/block_comments),
  compact_line_numbers, 8-char etag with prefix-match compat
- ast_query: format=verbose|compact|location, pagination cursor
- file_search: cluster mode, pagination cursor
- lsp_query (references): compact output

Breaking (v2):
- Preambles removed; opt-in verbose=true restores
- edit_apply: response=count|diff|none, default count
- ping tool removed
- symbol_at/find_definition/find_references merged into lsp_query
- Tool descriptions trimmed -83%, help moved to filepuff://help/<tool>
- Batch file_read dedups by etag

Protocol:
- ResourceLink returned for file_read >64 KiB (force_inline override)
- OnAfterInitialize hook reads capabilities.experimental.filepuff
  for session defaults (default_format, default_max_results,
  default_cluster, compact_refs, line_numbers,
  resource_link_threshold)

* fix: drop --max-total-count from ripgrep args

The flag does not exist in stable ripgrep (confirmed up to 15.1.0 --
"unrecognized flag --max-total-count, similar flags that are
available: --max-count"). Every file_search call failed on hosts with
stock rg. --max-count is per-file, not a drop-in replacement, so rely
on the in-process truncation in parseOutput that was already the
documented safety net.
This commit is contained in:
2026-04-19 19:56:49 +01:00
committed by GitHub
parent b131c1edd3
commit 5ad975ee7a
26 changed files with 4909 additions and 507 deletions
+31 -28
View File
@@ -13,43 +13,46 @@ import (
// Config holds all configuration options for the MCP server.
type Config struct {
Formatters map[string]string `json:"formatters"`
WorkspaceRoot string `json:"workspace_root"`
LSPTimeout time.Duration `json:"lsp_timeout"`
SearchTimeout time.Duration `json:"search_timeout"`
MaxFileSize int64 `json:"max_file_size"`
MaxParseSize int64 `json:"max_parse_size"`
MaxSearchResults int `json:"max_search_results"`
MaxEditSize int64 `json:"max_edit_size"`
EnableLSP bool `json:"enable_lsp"`
FollowSymlinks bool `json:"follow_symlinks"`
RespectGitignore bool `json:"respect_gitignore"`
Formatters map[string]string `json:"formatters"`
WorkspaceRoot string `json:"workspace_root"`
LSPTimeout time.Duration `json:"lsp_timeout"`
SearchTimeout time.Duration `json:"search_timeout"`
MaxFileSize int64 `json:"max_file_size"`
MaxParseSize int64 `json:"max_parse_size"`
MaxSearchResults int `json:"max_search_results"`
MaxEditSize int64 `json:"max_edit_size"`
EnableLSP bool `json:"enable_lsp"`
FollowSymlinks bool `json:"follow_symlinks"`
RespectGitignore bool `json:"respect_gitignore"`
ResourceLinkThresholdBytes int `json:"resource_link_threshold_bytes"`
}
// Default values for configuration.
const (
DefaultLSPTimeout = 5 * time.Minute
DefaultSearchTimeout = 30 * time.Second
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
DefaultMaxParseSize = 10 * 1024 * 1024 // 10 MB
DefaultMaxSearchResults = 1000
DefaultMaxEditSize = 100 * 1024 // 100 KB
DefaultLSPTimeout = 5 * time.Minute
DefaultSearchTimeout = 30 * time.Second
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
DefaultMaxParseSize = 10 * 1024 * 1024 // 10 MB
DefaultMaxSearchResults = 1000
DefaultMaxEditSize = 100 * 1024 // 100 KB
DefaultResourceLinkThresholdBytes = 64 * 1024 // 64 KiB
)
// Default returns a Config with default values.
func Default() *Config {
return &Config{
WorkspaceRoot: ".",
LSPTimeout: DefaultLSPTimeout,
SearchTimeout: DefaultSearchTimeout,
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: DefaultMaxParseSize,
MaxSearchResults: DefaultMaxSearchResults,
MaxEditSize: DefaultMaxEditSize,
EnableLSP: true,
Formatters: make(map[string]string),
FollowSymlinks: true,
RespectGitignore: true,
WorkspaceRoot: ".",
LSPTimeout: DefaultLSPTimeout,
SearchTimeout: DefaultSearchTimeout,
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: DefaultMaxParseSize,
MaxSearchResults: DefaultMaxSearchResults,
MaxEditSize: DefaultMaxEditSize,
EnableLSP: true,
Formatters: make(map[string]string),
FollowSymlinks: true,
RespectGitignore: true,
ResourceLinkThresholdBytes: DefaultResourceLinkThresholdBytes,
}
}
+65
View File
@@ -0,0 +1,65 @@
// Package cursor implements opaque pagination cursors for MCP tools.
// A cursor encodes an offset into a result stream plus a query hash so stale
// cursors from different queries fail cleanly.
//
// Encoding: base64url(json({"offset":N,"query_hash":"hex"}))
// The query_hash is a hex-encoded sha256 over the deterministic query params.
package cursor
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"sort"
"strings"
json "github.com/goccy/go-json"
)
// payload is the JSON structure inside a cursor.
type payload struct {
Offset int `json:"offset"`
QueryHash string `json:"query_hash"`
}
// Encode creates an opaque cursor string from an offset and query hash.
func Encode(offset int, queryHash string) string {
p := payload{Offset: offset, QueryHash: queryHash}
b, _ := json.Marshal(p)
return base64.RawURLEncoding.EncodeToString(b)
}
// Decode parses a cursor string. Returns offset, queryHash, error.
func Decode(cursor string) (int, string, error) {
b, err := base64.RawURLEncoding.DecodeString(cursor)
if err != nil {
return 0, "", fmt.Errorf("invalid cursor encoding: %w", err)
}
var p payload
if err := json.Unmarshal(b, &p); err != nil {
return 0, "", fmt.Errorf("invalid cursor payload: %w", err)
}
return p.Offset, p.QueryHash, nil
}
// HashParams computes a deterministic query hash from a set of key=value params.
// Keys are sorted before hashing so order doesn't matter.
func HashParams(params map[string]string) string {
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder
for _, k := range keys {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(params[k])
sb.WriteByte('\n')
}
sum := sha256.Sum256([]byte(sb.String()))
return hex.EncodeToString(sum[:])
}
+57
View File
@@ -0,0 +1,57 @@
package cursor
import (
"testing"
)
func TestEncodeDecodRoundTrip(t *testing.T) {
hash := HashParams(map[string]string{"a": "1", "b": "2"})
encoded := Encode(42, hash)
if encoded == "" {
t.Fatal("Encode returned empty string")
}
offset, gotHash, err := Decode(encoded)
if err != nil {
t.Fatalf("Decode error: %v", err)
}
if offset != 42 {
t.Errorf("offset: got %d, want 42", offset)
}
if gotHash != hash {
t.Errorf("hash mismatch: got %s, want %s", gotHash, hash)
}
}
func TestDecodeInvalid(t *testing.T) {
_, _, err := Decode("!!!notbase64!!!")
if err == nil {
t.Error("expected error for invalid base64, got nil")
}
}
func TestDecodeCorruptPayload(t *testing.T) {
import64 := "bm90anNvbg" // "notjson" in base64
_, _, err := Decode(import64)
if err == nil {
t.Error("expected error for corrupt payload, got nil")
}
}
func TestHashParamsDeterministic(t *testing.T) {
// Same params regardless of insertion order
h1 := HashParams(map[string]string{"z": "last", "a": "first"})
h2 := HashParams(map[string]string{"a": "first", "z": "last"})
if h1 != h2 {
t.Errorf("hash not deterministic: %s != %s", h1, h2)
}
}
func TestHashParamsDifferentForDifferentQueries(t *testing.T) {
h1 := HashParams(map[string]string{"pattern": "foo"})
h2 := HashParams(map[string]string{"pattern": "bar"})
if h1 == h2 {
t.Error("different queries should produce different hashes")
}
}
+345
View File
@@ -0,0 +1,345 @@
// Package parser provides skeleton rendering for source files.
package parser
import (
"context"
"fmt"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// SkeletonFile returns a skeleton representation of the file: top-level declarations
// with signatures and doc-comments intact, but function/method bodies replaced with
// a language-appropriate placeholder.
//
// Supported: go, typescript, javascript, python, rust.
// Other languages fall back to symbols_only (AST summary text).
// Returns (skeletonText, isFullSkeleton, error).
func SkeletonFile(ctx context.Context, reg *Registry, filename string, content []byte) (string, bool, error) {
result, err := reg.Parse(ctx, filename, content)
if err != nil {
return "", false, err
}
lang := protocol.DetectLanguage(filename)
switch lang {
case protocol.LangGo:
return skeletonGo(result.Tree, content), true, nil
case protocol.LangTypeScript, protocol.LangJavaScript:
return skeletonTS(result.Tree, content), true, nil
case protocol.LangPython:
return skeletonPython(result.Tree, content), true, nil
case protocol.LangRust:
return skeletonRust(result.Tree, content), true, nil
default:
// TODO: skeleton for c, cpp, elixir, html, vue — fall back to symbols_only
syms := ExtractSymbols(result.Tree, content, lang, filename)
return renderSymbolsOnly(syms, filename, lang, content), false, nil
}
}
// renderSymbolsOnly renders a simple symbol list (fallback for unsupported languages).
func renderSymbolsOnly(syms []protocol.Symbol, _ string, lang protocol.Language, content []byte) string {
lines := strings.Split(string(content), "\n")
var sb strings.Builder
sb.WriteString(fmt.Sprintf("// skeleton unavailable for %s — symbol list only\n", lang))
for _, s := range syms {
sb.WriteString(fmt.Sprintf("// %s %s (line %d)\n", s.Kind, s.Name, s.Location.Line))
if s.Doc != "" {
sb.WriteString(fmt.Sprintf("// doc: %s\n", s.Doc))
}
if s.Location.Line >= 1 && s.Location.Line <= len(lines) {
sb.WriteString(lines[s.Location.Line-1] + " { ... }\n")
}
}
return sb.String()
}
// ---- Go skeleton ----
// skeletonGoBodyNodes lists Go node types that have a body field to replace.
var skeletonGoBodyNodes = map[string]string{
"function_declaration": "body",
"method_declaration": "body",
}
func skeletonGo(tree *sitter.Tree, content []byte) string {
if tree == nil {
return string(content)
}
root := tree.RootNode()
var sb strings.Builder
skeletonGoNode(root, content, &sb)
return sb.String()
}
func skeletonGoNode(node *sitter.Node, content []byte, sb *strings.Builder) {
if node == nil {
return
}
nodeType := node.Type()
if bodyField, ok := skeletonGoBodyNodes[nodeType]; ok {
body := node.ChildByFieldName(bodyField)
if body != nil {
sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
sb.WriteString(sig)
sb.WriteString("{ ... }\n\n")
return
}
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
return
}
switch nodeType {
case "source_file":
for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
if child == nil {
continue
}
skeletonGoNode(child, content, sb)
}
return
case "comment":
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
return
default:
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
}
}
// ---- TypeScript / JavaScript skeleton ----
func skeletonTS(tree *sitter.Tree, content []byte) string {
if tree == nil {
return string(content)
}
root := tree.RootNode()
var sb strings.Builder
skeletonTSNode(root, content, &sb)
return sb.String()
}
func skeletonTSNode(node *sitter.Node, content []byte, sb *strings.Builder) {
if node == nil {
return
}
nodeType := node.Type()
switch nodeType {
case "program":
for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
if child == nil {
continue
}
skeletonTSNode(child, content, sb)
}
return
case "function_declaration":
body := node.ChildByFieldName("body")
if body != nil {
sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
sb.WriteString(sig)
sb.WriteString("{ ... }\n\n")
return
}
case "class_declaration":
body := node.ChildByFieldName("body")
if body != nil {
header := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
sb.WriteString(header)
sb.WriteString("{\n")
for i := 0; i < int(body.ChildCount()); i++ {
child := body.Child(i)
if child == nil {
continue
}
if child.Type() == "method_definition" {
methBody := child.ChildByFieldName("body")
if methBody != nil {
methSig := strings.TrimRight(string(content[child.StartByte():methBody.StartByte()]), " \t")
sb.WriteString(" ")
sb.WriteString(methSig)
sb.WriteString("{ ... }\n")
continue
}
}
sb.WriteString(" ")
sb.WriteString(string(content[child.StartByte():child.EndByte()]))
sb.WriteString("\n")
}
sb.WriteString("}\n\n")
return
}
case "comment":
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
return
}
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
}
// ---- Python skeleton ----
func skeletonPython(tree *sitter.Tree, content []byte) string {
if tree == nil {
return string(content)
}
root := tree.RootNode()
var sb strings.Builder
skeletonPythonNode(root, content, &sb, "")
return sb.String()
}
func skeletonPythonNode(node *sitter.Node, content []byte, sb *strings.Builder, indent string) {
if node == nil {
return
}
nodeType := node.Type()
switch nodeType {
case "module":
for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
if child == nil {
continue
}
skeletonPythonNode(child, content, sb, indent)
}
return
case "function_definition", "decorated_definition":
nodeText := string(content[node.StartByte():node.EndByte()])
lines := strings.SplitN(nodeText, "\n", 2)
sb.WriteString(indent)
sb.WriteString(lines[0])
sb.WriteString("\n")
sb.WriteString(indent)
sb.WriteString(" ...\n\n")
return
case "class_definition":
body := node.ChildByFieldName("body")
if body != nil {
header := string(content[node.StartByte():body.StartByte()])
firstLine := strings.SplitN(header, "\n", 2)[0]
sb.WriteString(indent)
sb.WriteString(firstLine)
sb.WriteString("\n")
for i := 0; i < int(body.ChildCount()); i++ {
child := body.Child(i)
if child == nil {
continue
}
if child.Type() == "function_definition" || child.Type() == "decorated_definition" {
childText := string(content[child.StartByte():child.EndByte()])
childLines := strings.SplitN(childText, "\n", 2)
sb.WriteString(indent + " ")
sb.WriteString(childLines[0])
sb.WriteString("\n")
sb.WriteString(indent + " ...")
sb.WriteString("\n")
continue
}
if child.Type() == "expression_statement" {
sb.WriteString(indent + " ")
sb.WriteString(string(content[child.StartByte():child.EndByte()]))
sb.WriteString("\n")
continue
}
}
sb.WriteString("\n")
return
}
case "comment":
sb.WriteString(indent)
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
return
}
sb.WriteString(indent)
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
}
// ---- Rust skeleton ----
func skeletonRust(tree *sitter.Tree, content []byte) string {
if tree == nil {
return string(content)
}
root := tree.RootNode()
var sb strings.Builder
skeletonRustNode(root, content, &sb)
return sb.String()
}
func skeletonRustNode(node *sitter.Node, content []byte, sb *strings.Builder) {
if node == nil {
return
}
nodeType := node.Type()
switch nodeType {
case "source_file":
for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
if child == nil {
continue
}
skeletonRustNode(child, content, sb)
}
return
case "function_item":
body := node.ChildByFieldName("body")
if body != nil {
sig := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
sb.WriteString(sig)
sb.WriteString("{ ... }\n\n")
return
}
case "impl_item":
body := node.ChildByFieldName("body")
if body != nil {
header := strings.TrimRight(string(content[node.StartByte():body.StartByte()]), " \t")
sb.WriteString(header)
sb.WriteString("{\n")
for i := 0; i < int(body.ChildCount()); i++ {
child := body.Child(i)
if child == nil {
continue
}
if child.Type() == "function_item" {
methBody := child.ChildByFieldName("body")
if methBody != nil {
methSig := strings.TrimRight(string(content[child.StartByte():methBody.StartByte()]), " \t")
sb.WriteString(" ")
sb.WriteString(methSig)
sb.WriteString("{ ... }\n")
continue
}
}
sb.WriteString(" ")
sb.WriteString(string(content[child.StartByte():child.EndByte()]))
sb.WriteString("\n")
}
sb.WriteString("}\n\n")
return
}
case "line_comment", "block_comment":
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
return
}
sb.WriteString(string(content[node.StartByte():node.EndByte()]))
sb.WriteString("\n")
}
+299
View File
@@ -0,0 +1,299 @@
package parser
import (
"strings"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// StripFlag names the categories of content to remove.
type StripFlag string
const (
StripImports StripFlag = "imports"
StripLicense StripFlag = "license"
StripBlockComments StripFlag = "block_comments"
)
// StripResult holds the stripped content and which flags actually removed content.
type StripResult struct {
Content string
Stripped []StripFlag
}
// StripContent applies requested strip operations to content, in order:
// license → imports → block_comments.
// lang is used to pick language-specific heuristics.
func StripContent(content string, flags []StripFlag, lang protocol.Language) StripResult {
flagSet := make(map[StripFlag]bool, len(flags))
for _, f := range flags {
flagSet[f] = true
}
var stripped []StripFlag
if flagSet[StripLicense] {
next, removed := stripLicense(content)
if removed {
content = next
stripped = append(stripped, StripLicense)
}
}
if flagSet[StripImports] {
next, removed := stripImports(content, lang)
if removed {
content = next
stripped = append(stripped, StripImports)
}
}
if flagSet[StripBlockComments] {
next, removed := stripBlockComments(content, lang)
if removed {
content = next
stripped = append(stripped, StripBlockComments)
}
}
return StripResult{Content: content, Stripped: stripped}
}
// stripLicense removes a leading block comment that looks like a license header.
// A comment qualifies if it contains "copyright", "license", or "spdx-license-identifier" (case-insensitive).
func stripLicense(content string) (string, bool) {
trimmed := strings.TrimLeft(content, " \t\n\r")
// C-style block comment at top
if strings.HasPrefix(trimmed, "/*") {
end := strings.Index(trimmed, "*/")
if end >= 0 {
candidate := trimmed[:end+2]
lower := strings.ToLower(candidate)
if strings.Contains(lower, "copyright") ||
strings.Contains(lower, "license") ||
strings.Contains(lower, "spdx-license-identifier") {
rest := trimmed[end+2:]
// Consume trailing newline(s)
rest = strings.TrimLeft(rest, "\r\n")
return rest, true
}
}
}
// Python/hash-style leading comment block
if strings.HasPrefix(trimmed, "#") {
lines := strings.Split(trimmed, "\n")
var commentLines []string
var rest []string
inComment := true
for i, l := range lines {
if inComment && (strings.HasPrefix(l, "#") || strings.TrimSpace(l) == "") {
commentLines = append(commentLines, l)
} else {
rest = lines[i:]
break
}
}
block := strings.Join(commentLines, "\n")
lower := strings.ToLower(block)
if strings.Contains(lower, "copyright") ||
strings.Contains(lower, "license") ||
strings.Contains(lower, "spdx-license-identifier") {
return strings.Join(rest, "\n"), true
}
}
return content, false
}
// stripImports removes top-of-file import blocks, language-specific.
func stripImports(content string, lang protocol.Language) (string, bool) {
switch lang {
case protocol.LangGo:
return stripGoImports(content)
case protocol.LangTypeScript, protocol.LangJavaScript:
return stripTSImports(content)
case protocol.LangPython:
return stripPythonImports(content)
case protocol.LangRust:
return stripRustImports(content)
default:
return content, false
}
}
// stripGoImports removes Go import(...) or single import "..." declarations.
func stripGoImports(content string) (string, bool) {
lines := strings.Split(content, "\n")
var out []string
removed := false
i := 0
for i < len(lines) {
trimLine := strings.TrimSpace(lines[i])
if strings.HasPrefix(trimLine, "import (") || trimLine == "import (" {
// multi-line import block
removed = true
i++ // skip "import ("
for i < len(lines) {
if strings.TrimSpace(lines[i]) == ")" {
i++ // skip closing ")"
break
}
i++
}
// skip one blank line after
if i < len(lines) && strings.TrimSpace(lines[i]) == "" {
i++
}
continue
}
if strings.HasPrefix(trimLine, `import "`) || strings.HasPrefix(trimLine, "import `") {
removed = true
i++
continue
}
out = append(out, lines[i])
i++
}
if !removed {
return content, false
}
return strings.Join(out, "\n"), true
}
// stripTSImports removes TypeScript/JavaScript "import ... from ..." and "require(...)" lines.
func stripTSImports(content string) (string, bool) {
lines := strings.Split(content, "\n")
var out []string
removed := false
for _, l := range lines {
trimLine := strings.TrimSpace(l)
if strings.HasPrefix(trimLine, "import ") || strings.HasPrefix(trimLine, "const {") && strings.Contains(trimLine, "require(") {
removed = true
continue
}
out = append(out, l)
}
if !removed {
return content, false
}
return strings.Join(out, "\n"), true
}
// stripPythonImports removes Python "import ..." and "from ... import ..." lines.
func stripPythonImports(content string) (string, bool) {
lines := strings.Split(content, "\n")
var out []string
removed := false
for _, l := range lines {
trimLine := strings.TrimSpace(l)
if strings.HasPrefix(trimLine, "import ") || strings.HasPrefix(trimLine, "from ") {
removed = true
continue
}
out = append(out, l)
}
if !removed {
return content, false
}
return strings.Join(out, "\n"), true
}
// stripRustImports removes Rust "use ..." declarations.
func stripRustImports(content string) (string, bool) {
lines := strings.Split(content, "\n")
var out []string
removed := false
inMulti := false
for _, l := range lines {
trimLine := strings.TrimSpace(l)
if inMulti {
// look for semicolon terminating multi-line use
if strings.Contains(trimLine, ";") {
inMulti = false
}
removed = true
continue
}
if strings.HasPrefix(trimLine, "use ") {
removed = true
if !strings.HasSuffix(trimLine, ";") {
inMulti = true
}
continue
}
out = append(out, l)
}
if !removed {
return content, false
}
return strings.Join(out, "\n"), true
}
// stripBlockComments removes /* ... */ block comments (Go/TS/C/Rust)
// and Python triple-quoted docstrings.
func stripBlockComments(content string, lang protocol.Language) (string, bool) {
if lang == protocol.LangPython {
return stripPythonDocstrings(content)
}
return stripCStyleBlockComments(content)
}
// stripCStyleBlockComments removes /* ... */ from content.
func stripCStyleBlockComments(content string) (string, bool) {
removed := false
var sb strings.Builder
i := 0
for i < len(content) {
if i+1 < len(content) && content[i] == '/' && content[i+1] == '*' {
// find closing */
end := strings.Index(content[i+2:], "*/")
if end >= 0 {
removed = true
// advance past */
i = i + 2 + end + 2
// consume trailing newline
if i < len(content) && content[i] == '\n' {
i++
}
continue
}
}
sb.WriteByte(content[i])
i++
}
if !removed {
return content, false
}
return sb.String(), true
}
// stripPythonDocstrings removes triple-quoted strings (""" and ”').
func stripPythonDocstrings(content string) (string, bool) {
removed := false
var sb strings.Builder
i := 0
for i < len(content) {
if i+2 < len(content) {
triple := content[i : i+3]
if triple == `"""` || triple == `'''` {
end := strings.Index(content[i+3:], triple)
if end >= 0 {
removed = true
i = i + 3 + end + 3
if i < len(content) && content[i] == '\n' {
i++
}
continue
}
}
}
sb.WriteByte(content[i])
i++
}
if !removed {
return content, false
}
return sb.String(), true
}
+186
View File
@@ -0,0 +1,186 @@
package query
import (
"strings"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// makeResults builds N dummy MatchResults.
func makeResults(n int) []MatchResult {
out := make([]MatchResult, n)
for i := range out {
out[i] = MatchResult{
File: "file.go",
Location: protocol.Location{
Line: i + 1,
Column: 1,
},
Text: "func Foo() {}\nline2",
}
}
return out
}
// TestFormatResultsVerboseDefault verifies verbose format includes code blocks (no preamble by default).
func TestFormatResultsVerboseDefault(t *testing.T) {
results := makeResults(2)
out := FormatResultsWithOptions(results, 0, "verbose", 0)
// v2 default: no preamble
if strings.Contains(out, "Found ") {
t.Errorf("v2 default should NOT emit preamble, got:\n%s", out)
}
if !strings.Contains(out, "```") {
t.Error("verbose mode should include code blocks")
}
}
// TestFormatResultsVerbosePreamble verifies verbose=true restores the preamble.
func TestFormatResultsVerbosePreamble(t *testing.T) {
results := makeResults(2)
out := FormatResultsWithOptions(results, 0, "verbose", 0, true)
if !strings.Contains(out, "Found 2 match(es):") {
t.Errorf("expected preamble with verbose=true, got:\n%s", out)
}
}
func TestFormatResultsCompact(t *testing.T) {
results := makeResults(3)
out := FormatResultsWithOptions(results, 0, "compact", 0)
// v2 default: no preamble in compact mode
// Should NOT have code blocks
if strings.Contains(out, "```") {
t.Error("compact mode should not have code blocks")
}
// Should have one line per match (beyond header)
lines := strings.Split(strings.TrimSpace(out), "\n")
// First two lines are "Found..." and blank, then 3 match lines
matchLines := 0
for _, l := range lines {
if strings.Contains(l, "file.go:") {
matchLines++
}
}
if matchLines != 3 {
t.Errorf("expected 3 match lines in compact mode, got %d\nOutput:\n%s", matchLines, out)
}
}
func TestFormatResultsLocation(t *testing.T) {
results := makeResults(3)
out := FormatResultsWithOptions(results, 0, "location", 0)
if strings.Contains(out, "```") {
t.Error("location mode should not have code blocks")
}
// Should be file:line only
for i := 1; i <= 3; i++ {
expected := "file.go:" + itoa(i)
if !strings.Contains(out, expected) {
t.Errorf("location output missing %s", expected)
}
}
}
func TestFormatResultsMaxResults(t *testing.T) {
results := makeResults(5)
out := FormatResultsWithOptions(results, 3, "verbose", 0)
// v2 default: no preamble — check that exactly 3 code blocks are present
codeBlockCount := strings.Count(out, "```")
if codeBlockCount != 6 { // 3 opening + 3 closing = 6
t.Errorf("expected 3 matches (6 backtick markers), got %d in:\n%s", codeBlockCount, out)
}
if !strings.Contains(out, "[remaining: 2]") {
t.Errorf("expected [remaining: 2] footer, got:\n%s", out)
}
}
func TestFormatResultsOffset(t *testing.T) {
results := makeResults(5)
// Skip first 2, show all remaining
out := FormatResultsWithOptions(results, 0, "verbose", 2)
// offset=2 from 5 results → 3 results; check 3 code blocks
codeBlockCount := strings.Count(out, "```")
if codeBlockCount != 6 {
t.Errorf("expected 3 matches (6 backtick markers) after offset=2, got %d in:\n%s", codeBlockCount, out)
}
}
func TestFormatResultsOffsetBeyondEnd(t *testing.T) {
results := makeResults(3)
out := FormatResultsWithOptions(results, 0, "verbose", 10)
if out != "No matches found." {
t.Errorf("expected 'No matches found.' for offset beyond end, got: %s", out)
}
}
func TestFormatResultsPaginationCursor(t *testing.T) {
// Offset=2, maxResults=2, 5 total → show items 3&4, remaining=1
results := makeResults(5)
out := FormatResultsWithOptions(results, 2, "verbose", 2)
// offset=2, maxResults=2 → items 3&4; check 2 code blocks
codeBlockCount := strings.Count(out, "```")
if codeBlockCount != 4 {
t.Errorf("expected 2 matches (4 backtick markers), got %d in:\n%s", codeBlockCount, out)
}
if !strings.Contains(out, "[remaining: 1]") {
t.Errorf("expected [remaining: 1], got:\n%s", out)
}
}
func TestFormatResultsEmpty(t *testing.T) {
out := FormatResultsWithOptions(nil, 0, "verbose", 0)
if out != "No matches found." {
t.Errorf("expected 'No matches found.', got: %s", out)
}
}
func TestFormatResultsBackwardCompat(t *testing.T) {
// FormatResults wrapper should produce same output as FormatResultsWithOptions with verbose=false (default).
results := makeResults(2)
a := FormatResults(results, 0)
b := FormatResultsWithOptions(results, 0, "verbose", 0)
if a != b {
t.Error("FormatResults and FormatResultsWithOptions(verbose,0) should be identical")
}
// Both should have no preamble.
if strings.Contains(a, "Found ") {
t.Error("FormatResults should not emit preamble by default")
}
}
func TestFirstLineOf(t *testing.T) {
cases := []struct {
input string
maxLen int
want string
}{
{"hello world", 20, "hello world"},
{"line1\nline2", 20, "line1"},
{"\n\nfoo", 20, "foo"},
{"abcdefghij", 5, "abcd…"},
}
for _, c := range cases {
got := firstLineOf(c.input, c.maxLen)
if got != c.want {
t.Errorf("firstLineOf(%q, %d) = %q, want %q", c.input, c.maxLen, got, c.want)
}
}
}
func itoa(n int) string {
if n < 10 {
return string(rune('0' + n))
}
return strings.TrimRight(strings.TrimRight(
func() string {
buf := make([]byte, 20)
pos := 20
for n > 0 {
pos--
buf[pos] = byte('0' + n%10)
n /= 10
}
return string(buf[pos:])
}(), ""), "")
}
+99 -40
View File
@@ -451,62 +451,121 @@ func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool
return true
}
// FormatResults formats match results for display.
// FormatResults formats match results for display (backward-compat wrapper, verbose mode).
func FormatResults(results []MatchResult, maxResults int) string {
return FormatResultsWithOptions(results, maxResults, "verbose", 0)
}
// FormatResultsWithOptions formats match results with configurable output format.
// format: "verbose" (default) | "compact" | "location"
// offset: skip this many results before rendering (used for cursor pagination).
// verbose: opt-in variadic — pass true to restore "Found N match(es):" preamble (v1 behaviour).
func FormatResultsWithOptions(results []MatchResult, maxResults int, format string, offset int, verbose ...bool) string {
if len(results) == 0 {
return "No matches found."
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results)))
displayCount := len(results)
truncated := false
if maxResults > 0 && displayCount > maxResults {
displayCount = maxResults
truncated = true
// Apply offset (pagination skip).
if offset > 0 {
if offset >= len(results) {
return "No matches found."
}
results = results[offset:]
}
for i := 0; i < displayCount; i++ {
r := results[i]
nodeType := "unknown"
if r.Node != nil {
nodeType = r.Node.Type()
}
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
// Determine how many to render and whether more remain.
renderCount := len(results)
remaining := 0
if maxResults > 0 && renderCount > maxResults {
remaining = renderCount - maxResults
renderCount = maxResults
}
// Truncate very long text
text := r.Text
if len(text) > 500 {
text = text[:500] + "..."
}
sb.WriteString("```\n")
sb.WriteString(text)
sb.WriteString("\n```\n")
var sb strings.Builder
// Show captures
if len(r.Captures) > 0 {
sb.WriteString("Captures: ")
first := true
for name, cap := range r.Captures {
if !first {
sb.WriteString(", ")
// Emit preamble only when verbose=true is explicitly passed (opt-in, default off).
wantVerbose := len(verbose) > 0 && verbose[0]
if wantVerbose {
sb.WriteString(fmt.Sprintf("Found %d match(es):\n", renderCount))
}
switch format {
case "compact":
for i := 0; i < renderCount; i++ {
r := results[i]
nodeType := "unknown"
if r.Node != nil {
nodeType = r.Node.Type()
}
firstLine := firstLineOf(r.Text, 80)
sb.WriteString(fmt.Sprintf("%s:%d (%s) %s\n", r.File, r.Location.Line, nodeType, firstLine))
}
case "location":
for i := 0; i < renderCount; i++ {
r := results[i]
sb.WriteString(fmt.Sprintf("%s:%d\n", r.File, r.Location.Line))
}
default: // "verbose"
for i := 0; i < renderCount; i++ {
r := results[i]
nodeType := "unknown"
if r.Node != nil {
nodeType = r.Node.Type()
}
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
// Truncate very long text
text := r.Text
if len(text) > 500 {
text = text[:500] + "..."
}
sb.WriteString("```\n")
sb.WriteString(text)
sb.WriteString("\n```\n")
// Show captures
if len(r.Captures) > 0 {
sb.WriteString("Captures: ")
first := true
for name, cap := range r.Captures {
if !first {
sb.WriteString(", ")
}
first = false
capText := cap.Text
if len(capText) > 50 {
capText = capText[:50] + "..."
}
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
}
first = false
capText := cap.Text
if len(capText) > 50 {
capText = capText[:50] + "..."
}
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
sb.WriteString("\n")
}
sb.WriteString("\n")
}
sb.WriteString("\n")
}
if truncated {
sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults))
if remaining > 0 {
// Caller must embed the cursor token; we just append the remaining count hint.
// The actual [cursor: ...] line is written by the handler after calling MakeCursor.
sb.WriteString(fmt.Sprintf("[remaining: %d]\n", remaining))
}
return sb.String()
}
// firstLineOf returns the first non-empty line of s, trimmed and capped at maxLen chars.
func firstLineOf(s string, maxLen int) string {
for _, line := range strings.Split(s, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
if len(line) > maxLen {
return line[:maxLen-1] + "…"
}
return line
}
return ""
}
+131
View File
@@ -0,0 +1,131 @@
package search
import (
"strings"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"log/slog"
"os"
)
func newTestSearcher(t *testing.T) *Searcher {
t.Helper()
cfg := &config.Config{
WorkspaceRoot: t.TempDir(),
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
// Create a Searcher directly without requiring rg
return &Searcher{cfg: cfg, logger: logger, rgPath: "rg"}
}
func makeSearchResults(lines []int) *SearchResults {
results := make([]Result, len(lines))
for i, l := range lines {
results[i] = Result{
File: "/workspace/foo.go",
Line: l,
MatchText: "match at line " + itoa2(l),
}
}
return &SearchResults{Results: results}
}
// itoa2 simple int to string for test
func itoa2(n int) string {
if n == 0 {
return "0"
}
digits := []byte{}
for n > 0 {
digits = append([]byte{byte('0' + n%10)}, digits...)
n /= 10
}
return string(digits)
}
// TestFormatResultsVerboseBackwardCompat verifies that Verbose=true restores the v1 preamble.
func TestFormatResultsVerboseBackwardCompat(t *testing.T) {
s := newTestSearcher(t)
sr := makeSearchResults([]int{1, 5, 10})
// With Verbose=true, the preamble is emitted (v1 behaviour).
out := s.FormatResultsWithOptions(sr, FormatOptions{Verbose: true})
if !strings.Contains(out, "Found 3 matches in 1 files") {
t.Errorf("expected header with Verbose=true, got:\n%s", out)
}
if !strings.Contains(out, "L1│") {
t.Errorf("expected L1 match line, got:\n%s", out)
}
}
// TestFormatResultsDefaultNoPreamble verifies the v2 default has no preamble.
func TestFormatResultsDefaultNoPreamble(t *testing.T) {
s := newTestSearcher(t)
sr := makeSearchResults([]int{1, 5, 10})
out := s.FormatResults(sr)
if strings.Contains(out, "Found ") {
t.Errorf("v2 default should NOT emit preamble, got:\n%s", out)
}
if !strings.Contains(out, "L1│") {
t.Errorf("expected L1 match line, got:\n%s", out)
}
}
func TestFormatResultsClusterSingleLine(t *testing.T) {
s := newTestSearcher(t)
// Non-consecutive lines → each should appear separately
sr := makeSearchResults([]int{1, 5, 10})
out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true})
if !strings.Contains(out, "L1│") {
t.Errorf("expected L1 cluster, got:\n%s", out)
}
if !strings.Contains(out, "L5│") {
t.Errorf("expected L5 cluster, got:\n%s", out)
}
// Should NOT have " │" context lines in cluster mode
if strings.Contains(out, " │") {
t.Errorf("cluster mode should not have context-line decoration, got:\n%s", out)
}
}
func TestFormatResultsClusterConsecutive(t *testing.T) {
s := newTestSearcher(t)
// Lines 3,4,5 are consecutive → should be clustered as L3-5
sr := makeSearchResults([]int{3, 4, 5, 10})
out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true})
if !strings.Contains(out, "L3-5│") {
t.Errorf("expected L3-5 cluster, got:\n%s", out)
}
if !strings.Contains(out, "L10│") {
t.Errorf("expected L10 separate cluster, got:\n%s", out)
}
}
func TestFormatResultsClusterAdjacentMerge(t *testing.T) {
s := newTestSearcher(t)
// Lines 7 and 8 differ by 1 → merge
sr := makeSearchResults([]int{7, 8})
out := s.FormatResultsWithOptions(sr, FormatOptions{Cluster: true})
if !strings.Contains(out, "L7-8│") {
t.Errorf("expected L7-8 cluster, got:\n%s", out)
}
}
func TestFormatResultsCursorLine(t *testing.T) {
s := newTestSearcher(t)
sr := makeSearchResults([]int{1})
cursorText := "[cursor: abc123, remaining: 5]"
out := s.FormatResultsWithOptions(sr, FormatOptions{CursorLine: cursorText})
if !strings.Contains(out, cursorText) {
t.Errorf("expected cursor footer in output, got:\n%s", out)
}
}
func TestFormatResultsNoMatchesVerbose(t *testing.T) {
s := newTestSearcher(t)
sr := &SearchResults{Results: nil}
out := s.FormatResults(sr)
if out != "No matches found." {
t.Errorf("expected 'No matches found.', got: %s", out)
}
}
+97 -26
View File
@@ -219,11 +219,9 @@ func (s *Searcher) buildArgs(req *Request) []string {
args = append(args, "--no-ignore")
}
// Global result cap — --max-total-count stops rg early across all files.
// Requires ripgrep >= 13.0. In-process truncation in parseOutput is kept as a safety net.
if req.MaxResults > 0 {
args = append(args, fmt.Sprintf("--max-total-count=%d", req.MaxResults))
}
// Result cap enforced in-process by parseOutput. rg has no cross-file
// total-count flag in stable releases, so we don't pass one; --max-count is
// per-file and would miss results unevenly.
// Add pattern
args = append(args, "--", req.Pattern)
@@ -356,8 +354,21 @@ func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchRes
return results, nil
}
// FormatResults formats search results for display.
// FormatOptions controls how search results are rendered.
type FormatOptions struct {
Cluster bool // coalesce consecutive matches into line-range blocks
CursorLine string // if non-empty, appended as a footer line
Verbose bool // if true, emit "Found N matches in M files:" preamble (opt-in)
}
// FormatResults formats search results for display (backward-compat wrapper).
func (s *Searcher) FormatResults(results *SearchResults) string {
return s.FormatResultsWithOptions(results, FormatOptions{})
}
// FormatResultsWithOptions formats search results with configurable output.
// By default the "Found N matches in M files:" preamble is omitted; set opts.Verbose=true to restore it.
func (s *Searcher) FormatResultsWithOptions(results *SearchResults, opts FormatOptions) string {
if len(results.Results) == 0 {
return "No matches found."
}
@@ -374,14 +385,18 @@ func (s *Searcher) FormatResults(results *SearchResults) string {
fileResults[r.File] = append(fileResults[r.File], r)
}
// Write summary
totalMatches := len(results.Results)
fileCount := len(fileResults)
sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount))
if results.Truncated {
sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total))
// Write preamble only when Verbose is requested.
if opts.Verbose {
totalMatches := len(results.Results)
fileCount := len(fileResults)
sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount))
if results.Truncated {
sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total))
}
sb.WriteString(":\n\n")
} else if results.Truncated {
sb.WriteString(fmt.Sprintf("(truncated, showing subset of %d total matches)\n\n", results.Total))
}
sb.WriteString(":\n\n")
// Write results grouped by file
for _, file := range fileOrder {
@@ -395,26 +410,82 @@ func (s *Searcher) FormatResults(results *SearchResults) string {
sb.WriteString(fmt.Sprintf("**%s**\n", relPath))
for _, r := range fileResults[file] {
// Write context before
for _, ctx := range r.Context.Before {
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
}
// Write match line
sb.WriteString(fmt.Sprintf("L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200)))
// Write context after
for _, ctx := range r.Context.After {
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
}
if opts.Cluster {
writeClusteredResults(&sb, fileResults[file])
} else {
writeVerboseResults(&sb, fileResults[file])
}
sb.WriteString("\n")
}
if opts.CursorLine != "" {
sb.WriteString(opts.CursorLine)
sb.WriteString("\n")
}
return sb.String()
}
// writeVerboseResults writes results in the standard verbose format.
func writeVerboseResults(sb *strings.Builder, results []Result) {
for _, r := range results {
// Write context before
for _, ctx := range r.Context.Before {
fmt.Fprintf(sb, " │ %s\n", truncateLine(ctx, 200))
}
// Write match line
fmt.Fprintf(sb, "L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200))
// Write context after
for _, ctx := range r.Context.After {
fmt.Fprintf(sb, " │ %s\n", truncateLine(ctx, 200))
}
}
}
// writeClusteredResults coalesces consecutive or adjacent match lines into
// a single "L12-14│ <first-match-text>" entry. Context lines are dropped
// in cluster mode to maximise information density.
func writeClusteredResults(sb *strings.Builder, results []Result) {
if len(results) == 0 {
return
}
type clusterEntry struct {
startLine int
endLine int
firstText string
}
var clusters []clusterEntry
cur := clusterEntry{
startLine: results[0].Line,
endLine: results[0].Line,
firstText: results[0].MatchText,
}
for _, r := range results[1:] {
// Merge if adjacent (within 1 line gap)
if r.Line <= cur.endLine+1 {
if r.Line > cur.endLine {
cur.endLine = r.Line
}
} else {
clusters = append(clusters, cur)
cur = clusterEntry{startLine: r.Line, endLine: r.Line, firstText: r.MatchText}
}
}
clusters = append(clusters, cur)
for _, c := range clusters {
text := truncateLine(c.firstText, 200)
if c.startLine == c.endLine {
fmt.Fprintf(sb, "L%d│ %s\n", c.startLine, text)
} else {
fmt.Fprintf(sb, "L%d-%d│ %s\n", c.startLine, c.endLine, text)
}
}
}
// truncateLine truncates a line if it exceeds maxLen.
func truncateLine(s string, maxLen int) string {
if len(s) <= maxLen {
+2 -2
View File
@@ -94,8 +94,8 @@ func TestBuildArgs(t *testing.T) {
MaxResults: 10,
Regex: true,
},
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
notExpected: []string{"--ignore-case", "--fixed-strings"},
expected: []string{"--json", "--", "test", "."},
notExpected: []string{"--ignore-case", "--fixed-strings", "--max-total-count=10", "--max-count=10"},
},
}
+344
View File
@@ -0,0 +1,344 @@
package server
import (
"context"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/mark3labs/mcp-go/mcp"
)
// newTestServer creates a server with a temp workspace containing Go files.
func newFeaturesServer(t *testing.T) (*Server, string) {
t.Helper()
tmpDir := t.TempDir()
// Write a few Go files for AST queries
goFile1 := filepath.Join(tmpDir, "a.go")
if err := os.WriteFile(goFile1, []byte(`package main
func Alpha() string { return "alpha" }
func Beta() string { return "beta" }
func Gamma() string { return "gamma" }
`), 0o600); err != nil {
t.Fatal(err)
}
goFile2 := filepath.Join(tmpDir, "b.go")
if err := os.WriteFile(goFile2, []byte(`package main
func Delta() int { return 1 }
func Epsilon() int { return 2 }
`), 0o600); err != nil {
t.Fatal(err)
}
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: config.DefaultMaxFileSize,
MaxParseSize: config.DefaultMaxParseSize,
SearchTimeout: 30 * time.Second,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error: %v", err)
}
return srv, tmpDir
}
// ---- Feature 1: ast_query format flag ----
func TestASTQueryFormatVerboseDefault(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() string",
"language": "go",
"paths": []interface{}{tmpDir},
}
res, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res == nil || len(res.Content) == 0 {
t.Fatal("nil/empty result")
}
text := res.Content[0].(mcp.TextContent).Text
// verbose mode has code blocks
if !strings.Contains(text, "```") {
t.Errorf("verbose mode should have code blocks, got:\n%s", text)
}
}
func TestASTQueryFormatCompact(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() string",
"language": "go",
"paths": []interface{}{tmpDir},
"format": "compact",
}
res, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
text := res.Content[0].(mcp.TextContent).Text
if strings.Contains(text, "```") {
t.Errorf("compact mode should NOT have code blocks, got:\n%s", text)
}
// Each line should contain file:line (kind) text
for _, line := range strings.Split(strings.TrimSpace(text), "\n") {
if line == "" || strings.HasPrefix(line, "Found") {
continue
}
if !strings.Contains(line, ":") {
t.Errorf("compact line missing ':' separator: %q", line)
}
}
}
func TestASTQueryFormatLocation(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() string",
"language": "go",
"paths": []interface{}{tmpDir},
"format": "location",
}
res, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
text := res.Content[0].(mcp.TextContent).Text
if strings.Contains(text, "```") {
t.Errorf("location mode should NOT have code blocks, got:\n%s", text)
}
// Lines should be file:linenum only (no parentheses with kind)
for _, line := range strings.Split(strings.TrimSpace(text), "\n") {
if line == "" || strings.HasPrefix(line, "Found") {
continue
}
if strings.Contains(line, "(") {
t.Errorf("location mode should not have node type in parens: %q", line)
}
}
}
// ---- Feature 3 (ast_query): pagination cursor ----
func TestASTQueryPaginationCursor(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
ctx := context.Background()
// Page 1: max_results=2
req1 := mcp.CallToolRequest{}
req1.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() $RET",
"language": "go",
"paths": []interface{}{tmpDir},
"max_results": float64(2),
}
res1, err := srv.handleASTQuery(ctx, req1)
if err != nil {
t.Fatalf("page1 error: %v", err)
}
text1 := res1.Content[0].(mcp.TextContent).Text
// Should contain cursor footer if there are more results
if !strings.Contains(text1, "[cursor:") {
// Might have fewer than 2 total results — skip cursor test
t.Logf("no cursor footer (fewer than 2 total matches), skipping pagination round-trip")
return
}
// Extract cursor token
var cursorToken string
for _, line := range strings.Split(text1, "\n") {
if strings.HasPrefix(line, "[cursor:") {
// [cursor: <token>, remaining: N]
parts := strings.Split(line, " ")
if len(parts) >= 2 {
cursorToken = strings.TrimSuffix(parts[1], ",")
}
break
}
}
if cursorToken == "" {
t.Fatal("could not extract cursor token from output")
}
// Page 2: pass cursor back
req2 := mcp.CallToolRequest{}
req2.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() $RET",
"language": "go",
"paths": []interface{}{tmpDir},
"max_results": float64(2),
"cursor": cursorToken,
}
res2, err := srv.handleASTQuery(ctx, req2)
if err != nil {
t.Fatalf("page2 error: %v", err)
}
text2 := res2.Content[0].(mcp.TextContent).Text
if strings.Contains(text2, "cursor is for a different query") {
t.Error("cursor was rejected as mismatched query")
}
// Page 2 should have results
if strings.Contains(text2, "No matches found.") {
t.Error("page2 should have some results")
}
}
func TestASTQueryCursorStaleMismatch(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() string",
"language": "go",
"paths": []interface{}{tmpDir},
"cursor": "eyJvZmZzZXQiOjIsInF1ZXJ5X2hhc2giOiJkZWFkYmVlZiJ9", // offset=2, hash=deadbeef
}
res, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
text := res.Content[0].(mcp.TextContent).Text
if !strings.Contains(text, "cursor is for a different query") {
t.Errorf("expected stale cursor error, got:\n%s", text)
}
}
// ---- Feature 2: file_search cluster ----
func TestFileSearchClusterFlag(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
if srv.searcher == nil {
t.Skip("rg not available")
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func",
"paths": []interface{}{tmpDir},
"cluster": true,
}
res, err := srv.handleFileSearch(ctx, req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res == nil || len(res.Content) == 0 {
t.Fatal("nil/empty result")
}
text := res.Content[0].(mcp.TextContent).Text
if text == "No matches found." {
t.Skip("no matches found (unexpected)")
}
// In cluster mode, should NOT have " │" context decorations
if strings.Contains(text, " │") {
t.Errorf("cluster mode should not contain context-line decoration ' │', got:\n%s", text)
}
}
func TestFileSearchCursorStaleHash(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
if srv.searcher == nil {
t.Skip("rg not available")
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func",
"paths": []interface{}{tmpDir},
"cursor": "eyJvZmZzZXQiOjEsInF1ZXJ5X2hhc2giOiJiYWRoYXNoIn0", // offset=1, hash=badhash
}
res, err := srv.handleFileSearch(ctx, req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
text := res.Content[0].(mcp.TextContent).Text
if !strings.Contains(text, "cursor is for a different query") {
t.Errorf("expected stale cursor error, got:\n%s", text)
}
}
func TestFileSearchPaginationCursor(t *testing.T) {
srv, tmpDir := newFeaturesServer(t)
if srv.searcher == nil {
t.Skip("rg not available")
}
ctx := context.Background()
// Page 1: get 1 result
req1 := mcp.CallToolRequest{}
req1.Params.Arguments = map[string]interface{}{
"pattern": "func",
"paths": []interface{}{tmpDir},
"max_results": float64(1),
"context_lines": float64(0),
}
res1, err := srv.handleFileSearch(ctx, req1)
if err != nil {
t.Fatalf("page1 error: %v", err)
}
text1 := res1.Content[0].(mcp.TextContent).Text
if !strings.Contains(text1, "[cursor:") {
t.Logf("no cursor in page1 (only 1 total result), skipping round-trip:\n%s", text1)
return
}
// Extract cursor
var cursorToken string
for _, line := range strings.Split(text1, "\n") {
if strings.HasPrefix(line, "[cursor:") {
parts := strings.Split(line, " ")
if len(parts) >= 2 {
cursorToken = strings.TrimSuffix(parts[1], ",")
}
break
}
}
if cursorToken == "" {
t.Fatal("could not extract cursor from page1")
}
// Page 2
req2 := mcp.CallToolRequest{}
req2.Params.Arguments = map[string]interface{}{
"pattern": "func",
"paths": []interface{}{tmpDir},
"max_results": float64(1),
"context_lines": float64(0),
"cursor": cursorToken,
}
res2, err := srv.handleFileSearch(ctx, req2)
if err != nil {
t.Fatalf("page2 error: %v", err)
}
text2 := res2.Content[0].(mcp.TextContent).Text
if strings.Contains(text2, "cursor is for a different query") {
t.Error("cursor was rejected as mismatched")
}
if strings.Contains(text2, "No matches found.") {
t.Error("page2 should have matches")
}
}
+198 -105
View File
@@ -8,12 +8,193 @@ import (
"path/filepath"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/internal/cursor"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/mark3labs/mcp-go/mcp"
)
// astQueryParams holds resolved parameters for an AST query invocation.
type astQueryParams struct {
pattern string
language string
nameMatches string
nameExact string
kindIn []string
paths []string
maxResults int
format string
offset int
queryHash string
verbose bool
}
// resolveASTQueryParams parses the request and resolves session-pref defaults.
// Returns (params, errorResult, error); when errorResult is non-nil, the caller
// should return it directly.
func (s *Server) resolveASTQueryParams(request mcp.CallToolRequest) (*astQueryParams, *mcp.CallToolResult) {
pattern, err := request.RequireString("pattern")
if err != nil {
return nil, mcp.NewToolResultError("pattern is required")
}
language, err := request.RequireString("language")
if err != nil {
return nil, mcp.NewToolResultError("language is required")
}
p := &astQueryParams{
pattern: pattern,
language: language,
nameMatches: request.GetString("name_matches", ""),
nameExact: request.GetString("name_exact", ""),
kindIn: request.GetStringSlice("kind_in", nil),
paths: request.GetStringSlice("paths", nil),
verbose: request.GetBool("verbose", false),
}
sp := s.sessionPrefs.Load()
var prefsMaxResults int
var prefsFormat string
if sp != nil {
prefsMaxResults = sp.DefaultMaxResults
prefsFormat = sp.ASTQueryFormat
}
p.maxResults = effectiveInt(request, "max_results", prefsMaxResults, 100)
p.format = request.GetString("format", "")
if p.format == "" {
if prefsFormat != "" {
p.format = prefsFormat
} else {
p.format = "verbose"
}
}
p.queryHash = cursor.HashParams(map[string]string{
"pattern": p.pattern,
"language": p.language,
"name_matches": p.nameMatches,
"name_exact": p.nameExact,
"kind_in": strings.Join(p.kindIn, ","),
"paths": strings.Join(p.paths, ","),
})
if cursorStr := request.GetString("cursor", ""); cursorStr != "" {
off, hash, decErr := cursor.Decode(cursorStr)
if decErr != nil {
return nil, mcp.NewToolResultError(fmt.Sprintf("invalid cursor: %s", decErr))
}
if hash != p.queryHash {
return nil, mcp.NewToolResultError("cursor is for a different query, re-run without cursor")
}
p.offset = off
}
return p, nil
}
// runASTQueryWalk walks the configured paths and collects matches.
func (s *Server) runASTQueryWalk(ctx context.Context, p *astQueryParams, exts []string) []query.MatchResult {
astQuery := &query.ASTQuery{
Pattern: p.pattern,
Language: p.language,
Filters: query.QueryFilters{
NameMatches: p.nameMatches,
NameExact: p.nameExact,
KindIn: p.kindIn,
},
}
// Collect limit for early-exit: when paginating we need all results first.
collectLimit := p.maxResults
if p.offset > 0 {
collectLimit = 0
}
var allResults []query.MatchResult
for _, searchPath := range p.paths {
if !s.cfg.IsPathAllowed(searchPath) {
continue
}
walkErr := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err != nil {
return nil
}
if info.IsDir() {
if strings.HasPrefix(info.Name(), ".") {
return filepath.SkipDir
}
return nil
}
if !hasAnySuffix(path, exts) {
return nil
}
content, err := os.ReadFile(path)
if err != nil {
return nil
}
if int64(len(content)) > s.cfg.MaxFileSize {
return nil
}
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return nil
}
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
if err != nil {
return nil
}
allResults = append(allResults, matches...)
if collectLimit > 0 && len(allResults) >= collectLimit {
return filepath.SkipAll
}
return nil
})
if walkErr != nil {
s.logger.Warn("error walking path", "path", searchPath, "error", walkErr)
}
}
return allResults
}
// hasAnySuffix reports whether path ends with any of the given suffixes.
func hasAnySuffix(path string, suffixes []string) bool {
for _, ext := range suffixes {
if strings.HasSuffix(path, ext) {
return true
}
}
return false
}
// buildASTCursorFooter computes the cursor footer line for truncated results.
func buildASTCursorFooter(total, offset, maxResults int, queryHash string) string {
if offset > 0 && offset < total {
totalAfterOffset := total - offset
if maxResults > 0 && totalAfterOffset > maxResults {
remaining := totalAfterOffset - maxResults
nextOffset := offset + maxResults
nextCursor := cursor.Encode(nextOffset, queryHash)
return fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining)
}
} else if offset == 0 && maxResults > 0 && total > maxResults {
remaining := total - maxResults
nextCursor := cursor.Encode(maxResults, queryHash)
return fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining)
}
return ""
}
// handleASTQuery handles the ast_query tool.
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Acquire semaphore to limit concurrent queries (prevents CPU exhaustion)
@@ -24,121 +205,33 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest
return mcp.NewToolResultError("request cancelled"), nil
}
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
p, errResult := s.resolveASTQueryParams(request)
if errResult != nil {
return errResult, nil
}
language, err := request.RequireString("language")
if err != nil {
return mcp.NewToolResultError("language is required"), nil
if len(p.paths) == 0 {
p.paths = []string{s.cfg.WorkspaceRoot}
}
// Build query
astQuery := &query.ASTQuery{
Pattern: pattern,
Language: language,
Filters: query.QueryFilters{
NameMatches: request.GetString("name_matches", ""),
NameExact: request.GetString("name_exact", ""),
KindIn: request.GetStringSlice("kind_in", nil),
},
}
maxResults := request.GetInt("max_results", 100)
paths := request.GetStringSlice("paths", nil)
// Default to workspace root if no paths specified
if len(paths) == 0 {
paths = []string{s.cfg.WorkspaceRoot}
}
// Find files to search based on language
exts := languageToExtensions(language)
exts := languageToExtensions(p.language)
if len(exts) == 0 {
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir, rust)", language)), nil
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir, rust)", p.language)), nil
}
var allResults []query.MatchResult
allResults := s.runASTQueryWalk(ctx, p, exts)
cursorFooter := buildASTCursorFooter(len(allResults), p.offset, p.maxResults, p.queryHash)
// Walk through paths and find matching files
for _, searchPath := range paths {
// Validate path is within workspace
if !s.cfg.IsPathAllowed(searchPath) {
continue
}
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
// Check for context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err != nil {
return nil // Skip files with errors
}
if info.IsDir() {
// Skip hidden directories
if strings.HasPrefix(info.Name(), ".") {
return filepath.SkipDir
}
return nil
}
// Check file extension matches language
matched := false
for _, ext := range exts {
if strings.HasSuffix(path, ext) {
matched = true
break
}
}
if !matched {
return nil
}
// Read and parse file
content, err := os.ReadFile(path)
if err != nil {
return nil // Skip unreadable files
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return nil // Skip large files
}
// Parse file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return nil // Skip unparseable files
}
// Run query
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
if err != nil {
return nil // Skip on error
}
allResults = append(allResults, matches...)
// Stop if we have enough results
if maxResults > 0 && len(allResults) >= maxResults {
return filepath.SkipAll
}
return nil
})
if err != nil {
s.logger.Warn("error walking path", "path", searchPath, "error", err)
output := query.FormatResultsWithOptions(allResults, p.maxResults, p.format, p.offset, p.verbose)
if cursorFooter != "" {
// Replace the [remaining: N] placeholder emitted by FormatResultsWithOptions
// with the full [cursor: ..., remaining: N] line.
output = strings.ReplaceAll(output, fmt.Sprintf("[remaining: %d]\n", len(allResults)-p.offset-p.maxResults), cursorFooter+"\n")
// Fallback: if placeholder wasn't present (e.g. format=location), append footer.
if !strings.Contains(output, cursorFooter) {
output = strings.TrimRight(output, "\n") + "\n" + cursorFooter + "\n"
}
}
// Format and return results
output := query.FormatResults(allResults, maxResults)
return mcp.NewToolResultText(output), nil
}
+49 -22
View File
@@ -4,12 +4,10 @@ package server
import (
"context"
"fmt"
"os"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/internal/edit"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/mark3labs/mcp-go/mcp"
)
@@ -82,27 +80,56 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest) (*
return mcp.NewToolResultError(result.Error), nil
}
// compact_response: return just the modified symbol instead of the full diff
if request.GetBool("compact_response", false) && selectorName != "" {
if content, readErr := os.ReadFile(file); readErr == nil {
if start, end, found := s.resolveSymbolLines(ctx, file, content, selectorName, protocol.SymbolKind("")); found {
lines := splitLines(string(content))
var sb strings.Builder
sb.WriteString(fmt.Sprintf("**Edit Applied** — %s (L%d-L%d):\n\n", selectorName, start, end))
for i := start - 1; i < end && i < len(lines); i++ {
sb.WriteString(fmt.Sprintf("%4d| %s\n", i+1, lines[i]))
}
return mcp.NewToolResultText(sb.String()), nil
}
}
// fall through to diff if symbol lookup fails
// Determine response mode.
// compact_response is a deprecated alias for response="count".
respMode := request.GetString("response", "count")
if request.GetBool("compact_response", false) {
// Deprecated: use response=count
respMode = "count"
}
var output strings.Builder
output.WriteString("**Edit Applied Successfully**\n\n")
output.WriteString("Diff:\n```diff\n")
output.WriteString(result.Diff)
output.WriteString("```\n")
switch respMode {
case "none":
return mcp.NewToolResultText(""), nil
return mcp.NewToolResultText(output.String()), nil
case "count":
// Compute +added/-removed line counts from the unified diff.
added, removed := countDiffLines(result.Diff)
return mcp.NewToolResultText(fmt.Sprintf("+%d -%d", added, removed)), nil
case "diff":
var output strings.Builder
output.WriteString("Diff:\n```diff\n")
output.WriteString(result.Diff)
output.WriteString("```\n")
return mcp.NewToolResultText(output.String()), nil
default:
// Fallback: treat unknown values as "diff" for safety.
var output strings.Builder
output.WriteString("Diff:\n```diff\n")
output.WriteString(result.Diff)
output.WriteString("```\n")
return mcp.NewToolResultText(output.String()), nil
}
}
// countDiffLines counts added (+) and removed (-) lines in a unified diff string.
func countDiffLines(diff string) (added, removed int) {
for _, line := range strings.Split(diff, "\n") {
if len(line) == 0 {
continue
}
switch line[0] {
case '+':
if !strings.HasPrefix(line, "+++") {
added++
}
case '-':
if !strings.HasPrefix(line, "---") {
removed++
}
}
}
return
}
+355 -32
View File
@@ -6,10 +6,12 @@ import (
"context"
"fmt"
"os"
"strconv"
"strings"
"time"
xxhash "github.com/cespare/xxhash/v2"
"github.com/lukaszraczylo/mcp-filepuff/internal/cursor"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
@@ -35,14 +37,63 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque
return mcp.NewToolResultError("pattern is required"), nil
}
paths := request.GetStringSlice("paths", nil)
fileTypes := request.GetStringSlice("file_types", nil)
ignoreCase := request.GetBool("ignore_case", false)
regex := request.GetBool("regex", true)
contextLines := request.GetInt("context_lines", 2)
// Consult session prefs for max_results and cluster when not explicitly supplied.
prefs := s.sessionPrefs.Load()
var prefsMaxResults int
var prefsCluster *bool
if prefs != nil {
prefsMaxResults = prefs.DefaultMaxResults
prefsCluster = prefs.DefaultCluster
}
maxResults := effectiveInt(request, "max_results", prefsMaxResults, 0)
cluster := effectiveBool(request, "cluster", prefsCluster, false)
cursorStr := request.GetString("cursor", "")
// Compute query hash for cursor validation.
queryHash := cursor.HashParams(map[string]string{
"pattern": pattern,
"paths": strings.Join(paths, ","),
"file_types": strings.Join(fileTypes, ","),
"ignore_case": strconv.FormatBool(ignoreCase),
"regex": strconv.FormatBool(regex),
"context_lines": strconv.Itoa(contextLines),
})
// Resolve cursor offset.
offset := 0
if cursorStr != "" {
off, hash, decErr := cursor.Decode(cursorStr)
if decErr != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid cursor: %s", decErr)), nil
}
if hash != queryHash {
return mcp.NewToolResultError("cursor is for a different query, re-run without cursor"), nil
}
offset = off
}
// When paginating with a cursor, fetch all results (no rg-level cap) so we
// can apply the offset in-process. Without a cursor, let rg cap at maxResults.
rgMaxResults := maxResults
if offset > 0 {
rgMaxResults = 0 // fetch all, apply cap after skipping
}
req := &search.Request{
Pattern: pattern,
Paths: request.GetStringSlice("paths", nil),
FileTypes: request.GetStringSlice("file_types", nil),
IgnoreCase: request.GetBool("ignore_case", false),
Regex: request.GetBool("regex", true),
ContextLines: request.GetInt("context_lines", 2),
MaxResults: request.GetInt("max_results", 0),
Paths: paths,
FileTypes: fileTypes,
IgnoreCase: ignoreCase,
Regex: regex,
ContextLines: contextLines,
MaxResults: rgMaxResults,
}
results, err := s.searcher.Search(ctx, req)
@@ -51,16 +102,66 @@ func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolReque
return mcp.NewToolResultError(fmt.Sprintf("search error: %s", errors.SanitizeError(err))), nil
}
// Apply cursor offset.
if offset > 0 && offset < len(results.Results) {
results.Results = results.Results[offset:]
results.Truncated = false // will re-evaluate below
} else if offset > 0 {
results.Results = nil
results.Truncated = false
}
// Apply in-process max_results cap and compute cursor footer.
var cursorLine string
if maxResults > 0 && len(results.Results) > maxResults {
remaining := len(results.Results) - maxResults
results.Results = results.Results[:maxResults]
results.Truncated = true
nextOffset := offset + maxResults
nextCursor := cursor.Encode(nextOffset, queryHash)
cursorLine = fmt.Sprintf("[cursor: %s, remaining: %d]", nextCursor, remaining)
}
s.logger.Info("search completed",
"pattern", pattern,
"results_count", len(results.Results),
"truncated", results.Truncated,
)
output := s.searcher.FormatResults(results)
verbose := request.GetBool("verbose", false)
opts := search.FormatOptions{
Cluster: cluster,
CursorLine: cursorLine,
Verbose: verbose,
}
output := s.searcher.FormatResultsWithOptions(results, opts)
return mcp.NewToolResultText(output), nil
}
// effectiveInt returns the per-call value if the key is explicitly present in the
// request arguments, otherwise falls back to sessionDefault (if > 0), then builtIn.
func effectiveInt(request mcp.CallToolRequest, key string, sessionDefault, builtIn int) int {
if _, explicit := request.GetArguments()[key]; explicit {
return request.GetInt(key, builtIn)
}
if sessionDefault > 0 {
return sessionDefault
}
return builtIn
}
// effectiveBool returns the per-call value if the key is explicitly present in the
// request arguments, otherwise falls back to sessionDefault (if non-nil), then builtIn.
func effectiveBool(request mcp.CallToolRequest, key string, sessionDefault *bool, builtIn bool) bool {
if _, explicit := request.GetArguments()[key]; explicit {
return request.GetBool(key, builtIn)
}
if sessionDefault != nil {
return *sessionDefault
}
return builtIn
}
// handleFileRead handles the file_read tool.
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
select {
@@ -70,9 +171,13 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
return mcp.NewToolResultError("request cancelled"), nil
}
// Batch mode: paths[] takes precedence over path
// Batch mode: paths[] takes precedence over path.
// NOTE: batch reads are always inlined — mixing dedup + resource_links is
// too complex and the savings are unclear for multi-file calls.
if paths := request.GetStringSlice("paths", nil); len(paths) > 0 {
var output strings.Builder
// Dedup: track etag -> first path that produced it.
seenEtag := make(map[string]string) // etag -> first path
for i, p := range paths {
if i > 0 {
output.WriteString("\n")
@@ -82,6 +187,16 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
output.WriteString(fmt.Sprintf("--- %s ---\n[error: %s]\n", p, errors.SanitizeError(err)))
continue
}
// Extract etag from result footer for dedup check.
etag := extractEtag(result)
if etag != "" {
if firstPath, seen := seenEtag[etag]; seen {
// Duplicate content: emit pointer instead of full content.
output.WriteString(fmt.Sprintf("--- %s ---\n[duplicate of %s, etag: %s]\n", p, firstPath, etag))
continue
}
seenEtag[etag] = p
}
output.WriteString(fmt.Sprintf("--- %s ---\n%s", p, result))
}
return mcp.NewToolResultText(output.String()), nil
@@ -96,59 +211,223 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
if err != nil {
return mcp.NewToolResultError(errors.SanitizeError(err)), nil
}
// Resource-link threshold check: for single-file reads, if the result
// exceeds the configured threshold, return a ResourceLink instead of
// inlining the content. The client can fetch the resource on demand.
// Bypassed when:
// - force_inline=true
// - max_inline_bytes is set and result fits within it
// - threshold is 0 (disabled)
// - result is already small (skeleton/symbols_only/line-range paths
// produce small output; threshold is on result bytes, not file bytes)
// Determine resource-link threshold: session pref overrides cfg, per-call overrides session.
threshold := s.cfg.ResourceLinkThresholdBytes
if sp := s.sessionPrefs.Load(); sp != nil && sp.ResourceLinkThreshold > 0 {
threshold = sp.ResourceLinkThreshold
}
forceInline := request.GetBool("force_inline", false)
maxInlineBytes := request.GetInt("max_inline_bytes", 0)
if maxInlineBytes > 0 {
threshold = maxInlineBytes
}
if !forceInline && threshold > 0 && len(result) > threshold {
etag := extractEtag(result)
uri := buildReadResourceURI(path, etag)
lineCount := strings.Count(result, "\n")
desc := fmt.Sprintf("etag=%s, size=%d bytes, lines=%d", etag, len(result), lineCount)
mimeType := detectMIMEType(path)
link := mcp.NewResourceLink(uri, path, desc, mimeType)
return &mcp.CallToolResult{
Content: []mcp.Content{link},
}, nil
}
return mcp.NewToolResultText(result), nil
}
// readOneFile reads a single file applying all formatting options from the request.
func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, path string) (string, error) {
// buildReadResourceURI constructs the filepuff://read URI for a file + etag pair.
func buildReadResourceURI(path, etag string) string {
if etag == "" {
return "filepuff://read/" + path
}
return "filepuff://read/" + path + "?etag=" + etag
}
// detectMIMEType returns a best-effort MIME type for the given file path.
func detectMIMEType(path string) string {
ext := strings.ToLower(path)
switch {
case strings.HasSuffix(ext, ".go"):
return "text/x-go"
case strings.HasSuffix(ext, ".ts"), strings.HasSuffix(ext, ".tsx"):
return "text/typescript"
case strings.HasSuffix(ext, ".js"), strings.HasSuffix(ext, ".jsx"):
return "text/javascript"
case strings.HasSuffix(ext, ".py"):
return "text/x-python"
case strings.HasSuffix(ext, ".rs"):
return "text/x-rust"
case strings.HasSuffix(ext, ".md"):
return "text/markdown"
case strings.HasSuffix(ext, ".json"):
return "application/json"
case strings.HasSuffix(ext, ".yaml"), strings.HasSuffix(ext, ".yml"):
return "text/yaml"
case strings.HasSuffix(ext, ".toml"):
return "text/toml"
case strings.HasSuffix(ext, ".html"), strings.HasSuffix(ext, ".htm"):
return "text/html"
case strings.HasSuffix(ext, ".css"):
return "text/css"
case strings.HasSuffix(ext, ".sh"):
return "text/x-sh"
case strings.HasSuffix(ext, ".c"), strings.HasSuffix(ext, ".h"):
return "text/x-c"
case strings.HasSuffix(ext, ".cpp"), strings.HasSuffix(ext, ".cc"), strings.HasSuffix(ext, ".cxx"):
return "text/x-c++"
default:
return "text/plain"
}
}
// lineNumberOpts holds resolved line-numbering preferences for readOneFile.
type lineNumberOpts struct {
noLineNumbers bool
compactLineNums bool
lineInterval int
}
// resolveLineNumberOpts resolves per-call vs session-pref line-number options.
func (s *Server) resolveLineNumberOpts(request mcp.CallToolRequest) lineNumberOpts {
opts := lineNumberOpts{
noLineNumbers: request.GetBool("no_line_numbers", false),
lineInterval: request.GetInt("line_number_interval", 1),
compactLineNums: request.GetBool("compact_line_numbers", false),
}
// Apply session line_numbers pref when no explicit per-call override was supplied.
if sp := s.sessionPrefs.Load(); sp != nil && sp.LineNumbers != "" {
_, hasNoLN := request.GetArguments()["no_line_numbers"]
_, hasCompact := request.GetArguments()["compact_line_numbers"]
_, hasInterval := request.GetArguments()["line_number_interval"]
if !hasNoLN && !hasCompact && !hasInterval {
switch sp.LineNumbers {
case "none":
opts.noLineNumbers = true
opts.compactLineNums = false
case "compact":
opts.noLineNumbers = false
opts.compactLineNums = true
case "full":
opts.noLineNumbers = false
opts.compactLineNums = false
}
}
}
if opts.lineInterval == 0 {
opts.noLineNumbers = true
}
return opts
}
// applyStrip applies strip flags to the selected line range and returns the
// possibly-rewritten lines, new bounds, and a stripped-footer annotation.
func applyStrip(lines []string, lineStart, lineEnd int, stripFlags []parser.StripFlag, path string) (newLines []string, newStart, newEnd int, footer string) {
if len(stripFlags) == 0 {
return lines, lineStart, lineEnd, ""
}
selectedContent := strings.Join(lines[lineStart-1:lineEnd], "\n")
lang := protocol.DetectLanguage(path)
stripped := parser.StripContent(selectedContent, stripFlags, lang)
if len(stripped.Stripped) > 0 {
names := make([]string, len(stripped.Stripped))
for i, f := range stripped.Stripped {
names[i] = string(f)
}
footer = "[stripped: " + strings.Join(names, ", ") + "]\n"
}
newLines = splitLines(stripped.Content)
return newLines, 1, len(newLines), footer
}
// loadFileForRead performs workspace, stat, size, and read checks for a path.
func (s *Server) loadFileForRead(path string) ([]byte, error) {
if !s.cfg.IsPathAllowed(path) {
return "", fmt.Errorf("path is outside workspace root")
return nil, fmt.Errorf("path is outside workspace root")
}
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return "", fmt.Errorf("file not found: %s", path)
return nil, fmt.Errorf("file not found: %s", path)
}
if os.IsPermission(err) {
return "", fmt.Errorf("permission denied: %s", path)
return nil, fmt.Errorf("permission denied: %s", path)
}
s.logger.Warn("file stat error", "path", path, "error", err)
return "", fmt.Errorf("error accessing file")
return nil, fmt.Errorf("error accessing file")
}
if info.Size() > s.cfg.MaxFileSize {
return "", fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)
return nil, fmt.Errorf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)
}
content, err := os.ReadFile(path)
if err != nil {
if os.IsPermission(err) {
return "", fmt.Errorf("permission denied: %s", path)
return nil, fmt.Errorf("permission denied: %s", path)
}
s.logger.Warn("file read error", "path", path, "error", err)
return "", fmt.Errorf("error reading file")
return nil, fmt.Errorf("error reading file")
}
return content, nil
}
// readOneFile reads a single file applying all formatting options from the request.
func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, path string) (string, error) {
content, err := s.loadFileForRead(path)
if err != nil {
return "", err
}
// Compute etag from content hash
etag := fmt.Sprintf("%016x", xxhash.Sum64(content))
// Feature 3: short etag — 8 hex chars (32-bit).
// Accept previous_etag by prefix match so old 16-char etags keep working.
fullHash := fmt.Sprintf("%016x", xxhash.Sum64(content))
etag := fullHash[:8]
// Short-circuit if caller has the current version
if prev := request.GetString("previous_etag", ""); prev != "" && prev == etag {
return fmt.Sprintf("[unchanged, etag: %s]\n", etag), nil
if prev := request.GetString("previous_etag", ""); prev != "" {
// Match: exact 8-char match, old client sent full 16-char etag, or new client sent 8-char prefix of old.
if prev == etag || strings.HasPrefix(fullHash, prev) || strings.HasPrefix(prev, etag) {
return fmt.Sprintf("[unchanged, etag: %s]\n", etag), nil
}
}
// Parse request options
includeAST := request.GetBool("include_ast", false)
symbolsOnly := request.GetBool("symbols_only", false)
symbolName := request.GetString("symbol_name", "")
noLineNumbers := request.GetBool("no_line_numbers", false)
lineInterval := request.GetInt("line_number_interval", 1)
collapseBlank := request.GetBool("collapse_blank_lines", false)
maxLines := request.GetInt("max_lines", 0)
if lineInterval == 0 {
noLineNumbers = true
lnOpts := s.resolveLineNumberOpts(request)
// Feature 1: mode flag — "full" (default) | "skeleton" | "symbols_only".
// symbols_only mode is an alias for include_ast+symbols_only.
mode := request.GetString("mode", "full")
if mode == "symbols_only" {
symbolsOnly = true
includeAST = true
}
// Feature 2: strip — remove selected content classes before line-numbering.
stripRaw := request.GetStringSlice("strip", nil)
var stripFlags []parser.StripFlag
for _, sf := range stripRaw {
stripFlags = append(stripFlags, parser.StripFlag(sf))
}
if symbolsOnly && !includeAST {
return "", fmt.Errorf("symbols_only requires include_ast=true")
}
@@ -157,7 +436,7 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
lineStart := request.GetInt("line_start", 1)
lineEnd := request.GetInt("line_end", len(lines))
// Symbol-based line range: find the symbol and use its exact bounds
// Symbol-based line range: find the symbol and use its exact bounds.
if symbolName != "" {
symbolKind := protocol.SymbolKind(request.GetString("symbol_kind", ""))
start, end, found := s.resolveSymbolLines(ctx, path, content, symbolName, symbolKind)
@@ -168,7 +447,7 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
lineEnd = end
}
// Clamp to valid range
// Clamp to valid range.
if lineStart < 1 {
lineStart = 1
}
@@ -181,6 +460,19 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
var output strings.Builder
// Feature 1: skeleton mode — replace function bodies with { ... }.
if mode == "skeleton" {
skText, _, skErr := parser.SkeletonFile(ctx, s.parser, path, content)
if skErr != nil {
// Fall back to full mode on parse error.
s.logger.Warn("skeleton mode failed, falling back to full", "path", path, "error", skErr)
} else {
output.WriteString(skText)
fmt.Fprintf(&output, "[etag: %s]\n", etag)
return output.String(), nil
}
}
if includeAST {
if summary := s.generateASTSummary(ctx, path, content); summary != "" {
output.WriteString(summary)
@@ -191,13 +483,19 @@ func (s *Server) readOneFile(ctx context.Context, request mcp.CallToolRequest, p
}
if symbolsOnly {
output.WriteString(fmt.Sprintf("[etag: %s]\n", etag))
fmt.Fprintf(&output, "[etag: %s]\n", etag)
return output.String(), nil
}
writeLines(&output, lines, lineStart, lineEnd, maxLines, noLineNumbers, lineInterval, collapseBlank)
// Feature 2: apply strip AFTER line-range selection, BEFORE line numbering.
lines, lineStart, lineEnd, strippedFooter := applyStrip(lines, lineStart, lineEnd, stripFlags, path)
output.WriteString(fmt.Sprintf("[etag: %s]\n", etag))
writeLines(&output, lines, lineStart, lineEnd, maxLines, lnOpts.noLineNumbers, lnOpts.lineInterval, collapseBlank, lnOpts.compactLineNums)
if strippedFooter != "" {
output.WriteString(strippedFooter)
}
fmt.Fprintf(&output, "[etag: %s]\n", etag)
return output.String(), nil
}
@@ -212,7 +510,8 @@ func (s *Server) resolveSymbolLines(ctx context.Context, path string, content []
}
// writeLines writes the selected line range into output, applying all formatting options.
func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, maxLines int, noLineNumbers bool, lineInterval int, collapseBlank bool) {
// compactLineNums=true emits "12│" instead of " 12│ " (no padding, no trailing space).
func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, maxLines int, noLineNumbers bool, lineInterval int, collapseBlank bool, compactLineNums bool) {
effectiveEnd := lineEnd
truncatedCount := 0
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
@@ -233,6 +532,13 @@ func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, max
switch {
case noLineNumbers:
output.WriteString(line + "\n")
case compactLineNums:
// Feature 4: compact prefix — "12│content"
if lineInterval <= 1 || lineNum%lineInterval == 0 || i == lineStart-1 || i == effectiveEnd-1 {
fmt.Fprintf(output, "%d│%s\n", lineNum, line)
} else {
fmt.Fprintf(output, "│%s\n", line)
}
case lineInterval <= 1 || lineNum%lineInterval == 0 || i == lineStart-1 || i == effectiveEnd-1:
fmt.Fprintf(output, "%4d│ %s\n", lineNum, line)
default:
@@ -245,6 +551,23 @@ func writeLines(output *strings.Builder, lines []string, lineStart, lineEnd, max
}
}
// extractEtag extracts the etag value from a readOneFile result string.
// Returns empty string if not found.
func extractEtag(result string) string {
// Look for "[etag: XXXXXXXX]" at end of result.
const prefix = "[etag: "
idx := strings.LastIndex(result, prefix)
if idx < 0 {
return ""
}
rest := result[idx+len(prefix):]
close := strings.Index(rest, "]")
if close < 0 {
return ""
}
return rest[:close]
}
// splitLines splits a string into lines.
// For large files (> 1MB), uses bufio.Scanner which is more memory efficient.
// For smaller files, uses simple string split which is faster.
+909
View File
@@ -0,0 +1,909 @@
package server
import (
"context"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/mark3labs/mcp-go/mcp"
)
// newTestServer creates a minimal server pointing at tmpDir.
func newTestServer(t *testing.T, tmpDir string) *Server {
t.Helper()
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: 10 * 1024 * 1024, // 10 MB — required for file reads to succeed
MaxParseSize: 10 * 1024 * 1024,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
return srv
}
// callRead calls handleFileRead with the given args map and returns the text content.
func callRead(t *testing.T, srv *Server, args map[string]interface{}) string {
t.Helper()
req := mcp.CallToolRequest{}
req.Params.Arguments = args
result, err := srv.handleFileRead(context.Background(), req)
if err != nil {
t.Fatalf("handleFileRead error: %v", err)
}
if result == nil {
t.Fatal("handleFileRead returned nil")
}
if len(result.Content) == 0 {
t.Fatal("handleFileRead returned empty content")
}
tc, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatal("handleFileRead did not return TextContent")
}
return tc.Text
}
// callReadResult calls handleFileRead and returns the raw CallToolResult (not just text).
func callReadResult(t *testing.T, srv *Server, args map[string]interface{}) *mcp.CallToolResult {
t.Helper()
req := mcp.CallToolRequest{}
req.Params.Arguments = args
result, err := srv.handleFileRead(context.Background(), req)
if err != nil {
t.Fatalf("handleFileRead error: %v", err)
}
if result == nil {
t.Fatal("handleFileRead returned nil")
}
return result
}
// newTestServerWithThreshold creates a server with a custom ResourceLinkThresholdBytes.
func newTestServerWithThreshold(t *testing.T, tmpDir string, thresholdBytes int) *Server {
t.Helper()
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: 10 * 1024 * 1024,
MaxParseSize: 10 * 1024 * 1024,
ResourceLinkThresholdBytes: thresholdBytes,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
return srv
}
// writeFile writes content to a file in tmpDir and returns its absolute path.
func writeFile(t *testing.T, dir, name, content string) string {
t.Helper()
p := filepath.Join(dir, name)
if err := os.WriteFile(p, []byte(content), 0600); err != nil {
t.Fatalf("WriteFile(%s): %v", name, err)
}
return p
}
// ---- Feature 1: skeleton mode ----
func TestSkeletonModeGo(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := `package main
// Hello says hello
func Hello() {
println("Hello, World!")
println("more body")
}
func Add(a, b int) int {
return a + b
}
`
f := writeFile(t, tmpDir, "test.go", goSrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"mode": "skeleton",
})
// Should contain function signatures
if !strings.Contains(out, "func Hello()") {
t.Errorf("skeleton output missing Hello signature, got:\n%s", out)
}
if !strings.Contains(out, "func Add(") {
t.Errorf("skeleton output missing Add signature, got:\n%s", out)
}
// Should NOT contain body contents
if strings.Contains(out, `println("more body")`) {
t.Errorf("skeleton output should not contain body contents, got:\n%s", out)
}
// Should contain placeholder
if !strings.Contains(out, "{ ... }") {
t.Errorf("skeleton output missing { ... } placeholder, got:\n%s", out)
}
// Should contain etag footer
if !strings.Contains(out, "[etag:") {
t.Errorf("skeleton output missing etag footer, got:\n%s", out)
}
}
func TestSkeletonModeFullFlagAlias(t *testing.T) {
// mode="full" should behave identically to not specifying mode
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := "package main\nfunc F() { println(1) }\n"
f := writeFile(t, tmpDir, "test.go", goSrc)
outFull := callRead(t, srv, map[string]interface{}{"path": f, "mode": "full"})
outDefault := callRead(t, srv, map[string]interface{}{"path": f})
// Both should have same content (etag will be same, line content same)
if outFull != outDefault {
t.Errorf("mode=full differs from default\nfull: %q\ndefault: %q", outFull, outDefault)
}
}
func TestSkeletonModeSymbolsOnlyAlias(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := "package main\nfunc F() { println(1) }\n"
f := writeFile(t, tmpDir, "test.go", goSrc)
// mode=symbols_only should return symbols summary (needs include_ast implicitly)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"mode": "symbols_only",
})
// Should contain etag but NOT the function body
if !strings.Contains(out, "[etag:") {
t.Errorf("symbols_only output missing etag, got:\n%s", out)
}
if strings.Contains(out, "println") {
t.Errorf("symbols_only should not contain body, got:\n%s", out)
}
}
func TestSkeletonModeTypeScript(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
tsSrc := `// a function
function greet(name: string): string {
return "Hello " + name;
}
class Greeter {
greet(name: string) {
return "hi " + name;
}
}
`
f := writeFile(t, tmpDir, "test.ts", tsSrc)
out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"})
if !strings.Contains(out, "function greet") {
t.Errorf("TS skeleton missing function signature, got:\n%s", out)
}
if !strings.Contains(out, "{ ... }") {
t.Errorf("TS skeleton missing placeholder, got:\n%s", out)
}
if strings.Contains(out, `"Hello " + name`) {
t.Errorf("TS skeleton should not contain body, got:\n%s", out)
}
}
func TestSkeletonModePython(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
pySrc := `def greet(name):
print("Hello " + name)
print("extra line")
class Foo:
def bar(self):
return 42
`
f := writeFile(t, tmpDir, "test.py", pySrc)
out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"})
if !strings.Contains(out, "def greet") {
t.Errorf("Python skeleton missing greet signature, got:\n%s", out)
}
if strings.Contains(out, "extra line") {
t.Errorf("Python skeleton should not contain body, got:\n%s", out)
}
// Python uses "..." as placeholder
if !strings.Contains(out, "...") {
t.Errorf("Python skeleton missing ... placeholder, got:\n%s", out)
}
}
func TestSkeletonModeRust(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
rsSrc := `fn add(a: i32, b: i32) -> i32 {
let result = a + b;
result
}
struct Foo {
x: i32,
}
`
f := writeFile(t, tmpDir, "test.rs", rsSrc)
out := callRead(t, srv, map[string]interface{}{"path": f, "mode": "skeleton"})
if !strings.Contains(out, "fn add(") {
t.Errorf("Rust skeleton missing fn signature, got:\n%s", out)
}
if !strings.Contains(out, "{ ... }") {
t.Errorf("Rust skeleton missing placeholder, got:\n%s", out)
}
if strings.Contains(out, "let result") {
t.Errorf("Rust skeleton should not contain body, got:\n%s", out)
}
}
// ---- Feature 2: strip flag ----
func TestStripImportsGo(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := `package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("hello")
}
`
f := writeFile(t, tmpDir, "test.go", goSrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"strip": []interface{}{"imports"},
})
if strings.Contains(out, `"fmt"`) {
t.Errorf("strip=imports should remove import block, got:\n%s", out)
}
if !strings.Contains(out, "func main") {
t.Errorf("strip=imports should keep function, got:\n%s", out)
}
if !strings.Contains(out, "[stripped: imports]") {
t.Errorf("strip footer missing, got:\n%s", out)
}
}
func TestStripLicenseGoBlockComment(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := `/* Copyright 2024 Acme Corp. All rights reserved.
License: MIT
*/
package main
func main() {}
`
f := writeFile(t, tmpDir, "main.go", goSrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"strip": []interface{}{"license"},
})
if strings.Contains(out, "Copyright") {
t.Errorf("strip=license should remove license comment, got:\n%s", out)
}
if !strings.Contains(out, "func main") {
t.Errorf("strip=license should keep code, got:\n%s", out)
}
if !strings.Contains(out, "[stripped: license]") {
t.Errorf("license strip footer missing, got:\n%s", out)
}
}
func TestStripBlockCommentsGo(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := `package main
/* This is a block comment
spanning multiple lines */
func main() {
/* inline block */
println("hi")
}
`
f := writeFile(t, tmpDir, "test.go", goSrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"strip": []interface{}{"block_comments"},
})
if strings.Contains(out, "This is a block comment") {
t.Errorf("strip=block_comments should remove block comments, got:\n%s", out)
}
if strings.Contains(out, "inline block") {
t.Errorf("strip=block_comments should remove inline block comment, got:\n%s", out)
}
if !strings.Contains(out, "func main") {
t.Errorf("strip=block_comments should keep code, got:\n%s", out)
}
if !strings.Contains(out, "[stripped: block_comments]") {
t.Errorf("block_comments strip footer missing, got:\n%s", out)
}
}
func TestStripMultipleFlags(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := `/* Copyright 2024. License: MIT */
package main
import "fmt"
func main() {
fmt.Println("hello")
}
`
f := writeFile(t, tmpDir, "main.go", goSrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"strip": []interface{}{"license", "imports"},
})
if strings.Contains(out, "Copyright") {
t.Errorf("license not stripped, got:\n%s", out)
}
if strings.Contains(out, `"fmt"`) {
t.Errorf("imports not stripped, got:\n%s", out)
}
if !strings.Contains(out, "[stripped:") {
t.Errorf("strip footer missing, got:\n%s", out)
}
}
func TestStripNoRemovalProducesNoFooter(t *testing.T) {
// A file with no imports: strip=imports should not add footer
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := "package main\nfunc main() {}\n"
f := writeFile(t, tmpDir, "test.go", goSrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"strip": []interface{}{"imports"},
})
if strings.Contains(out, "[stripped:") {
t.Errorf("should not have stripped footer when nothing removed, got:\n%s", out)
}
}
func TestStripImportsPython(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
pySrc := `import os
from sys import argv
def main():
print(argv[0])
`
f := writeFile(t, tmpDir, "test.py", pySrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"strip": []interface{}{"imports"},
})
if strings.Contains(out, "import os") {
t.Errorf("Python imports not stripped, got:\n%s", out)
}
if strings.Contains(out, "from sys") {
t.Errorf("Python from-import not stripped, got:\n%s", out)
}
if !strings.Contains(out, "def main") {
t.Errorf("Python function missing after strip, got:\n%s", out)
}
}
// ---- Feature 3: short etag (8 hex chars) ----
func TestEtagIs8Chars(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
f := writeFile(t, tmpDir, "test.go", "package main\n")
out := callRead(t, srv, map[string]interface{}{"path": f})
// Find "[etag: XXXXXXXX]"
idx := strings.Index(out, "[etag: ")
if idx < 0 {
t.Fatalf("no etag in output: %q", out)
}
rest := out[idx+7:]
end := strings.Index(rest, "]")
if end < 0 {
t.Fatalf("malformed etag in output: %q", out)
}
etagVal := rest[:end]
if len(etagVal) != 8 {
t.Errorf("etag should be 8 hex chars, got %d chars: %q", len(etagVal), etagVal)
}
// Validate hex
for _, c := range etagVal {
isDigit := c >= '0' && c <= '9'
isHexLower := c >= 'a' && c <= 'f'
if !isDigit && !isHexLower {
t.Errorf("etag contains non-hex char %q in %q", c, etagVal)
}
}
}
func TestEtagPreviousEtagShortCircuit(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
f := writeFile(t, tmpDir, "test.go", "package main\n")
// First read: get the etag
out1 := callRead(t, srv, map[string]interface{}{"path": f})
idx := strings.Index(out1, "[etag: ")
etag := out1[idx+7 : idx+7+8]
// Second read with same etag: should short-circuit
out2 := callRead(t, srv, map[string]interface{}{
"path": f,
"previous_etag": etag,
})
if !strings.Contains(out2, "[unchanged, etag:") {
t.Errorf("expected [unchanged, etag:] for same etag, got: %q", out2)
}
}
func TestEtagOldLongEtagStillWorks(t *testing.T) {
// Simulate old client sending 16-char etag: should still short-circuit via prefix match.
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
f := writeFile(t, tmpDir, "test.go", "package main\n")
// Get the 8-char etag
out1 := callRead(t, srv, map[string]interface{}{"path": f})
idx := strings.Index(out1, "[etag: ")
shortEtag := out1[idx+7 : idx+7+8]
// Construct a fake 16-char etag that starts with the short one
fakeOldEtag := shortEtag + "00000000"
out2 := callRead(t, srv, map[string]interface{}{
"path": f,
"previous_etag": fakeOldEtag,
})
if !strings.Contains(out2, "[unchanged, etag:") {
t.Errorf("old 16-char etag should still short-circuit, got: %q", out2)
}
}
func TestEtagDifferentFileChanges(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
f := writeFile(t, tmpDir, "test.go", "package main\n")
out1 := callRead(t, srv, map[string]interface{}{"path": f})
idx := strings.Index(out1, "[etag: ")
etag1 := out1[idx+7 : idx+7+8]
// Modify the file
if err := os.WriteFile(f, []byte("package main\n// changed\n"), 0600); err != nil {
t.Fatal(err)
}
// Read again with old etag: should NOT short-circuit
out2 := callRead(t, srv, map[string]interface{}{
"path": f,
"previous_etag": etag1,
})
if strings.Contains(out2, "[unchanged") {
t.Errorf("modified file should not return unchanged, got: %q", out2)
}
}
// ---- Feature 4: compact_line_numbers ----
func TestCompactLineNumbers(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
content := "line one\nline two\nline three\n"
f := writeFile(t, tmpDir, "test.txt", content)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"compact_line_numbers": true,
})
// Should have "1│" not " 1│ "
if !strings.Contains(out, "1│line one") {
t.Errorf("compact prefix not found, got:\n%s", out)
}
// Should NOT have padded format
if strings.Contains(out, " 1│ line one") {
t.Errorf("compact should not have padded prefix, got:\n%s", out)
}
}
func TestCompactLineNumbersOffByDefault(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
content := "line one\nline two\n"
f := writeFile(t, tmpDir, "test.txt", content)
// Default: no compact_line_numbers
out := callRead(t, srv, map[string]interface{}{"path": f})
// Should have padded format
if !strings.Contains(out, " 1│ line one") {
t.Errorf("default should have padded format, got:\n%s", out)
}
}
func TestCompactLineNumbersWithInterval(t *testing.T) {
// compact_line_numbers + line_number_interval should work together
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
var sb strings.Builder
for i := 1; i <= 10; i++ {
sb.WriteString("line\n")
}
f := writeFile(t, tmpDir, "test.txt", sb.String())
out := callRead(t, srv, map[string]interface{}{
"path": f,
"compact_line_numbers": true,
"line_number_interval": 5,
})
// Line 5 should have number prefix
if !strings.Contains(out, "5│line") {
t.Errorf("compact+interval: line 5 should have number, got:\n%s", out)
}
// Line 1 (first) should have number
if !strings.Contains(out, "1│line") {
t.Errorf("compact+interval: line 1 should have number, got:\n%s", out)
}
// Line 10 (last) should have number
if !strings.Contains(out, "10│line") {
t.Errorf("compact+interval: line 10 should have number, got:\n%s", out)
}
// Non-interval line should have bare │ prefix (no number)
if !strings.Contains(out, "│line") {
t.Errorf("compact+interval: non-interval lines should have bare │, got:\n%s", out)
}
}
func TestCompactLineNumbersWithNoLineNumbers(t *testing.T) {
// compact_line_numbers + no_line_numbers: no_line_numbers wins
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
f := writeFile(t, tmpDir, "test.txt", "line one\n")
out := callRead(t, srv, map[string]interface{}{
"path": f,
"compact_line_numbers": true,
"no_line_numbers": true,
})
// Should have NO prefix at all
if strings.Contains(out, "│") {
t.Errorf("no_line_numbers should suppress all prefixes, got:\n%s", out)
}
}
// ---- Backward compatibility: existing behavior unchanged ----
func TestDefaultBehaviorUnchanged(t *testing.T) {
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := `package main
func Hello() {
println("hello")
}
`
f := writeFile(t, tmpDir, "test.go", goSrc)
out := callRead(t, srv, map[string]interface{}{"path": f})
// All lines present
if !strings.Contains(out, `println("hello")`) {
t.Errorf("full mode missing body, got:\n%s", out)
}
// Padded line numbers
if !strings.Contains(out, " 1│ package main") {
t.Errorf("default should have padded line numbers, got:\n%s", out)
}
// etag present
if !strings.Contains(out, "[etag:") {
t.Errorf("missing etag footer, got:\n%s", out)
}
}
func TestSymbolsOnlyFlagStillWorks(t *testing.T) {
// Old symbols_only=true + include_ast=true should still work
tmpDir := t.TempDir()
srv := newTestServer(t, tmpDir)
goSrc := "package main\nfunc Hello() { println(1) }\n"
f := writeFile(t, tmpDir, "test.go", goSrc)
out := callRead(t, srv, map[string]interface{}{
"path": f,
"include_ast": true,
"symbols_only": true,
})
if strings.Contains(out, "println") {
t.Errorf("symbols_only should suppress body, got:\n%s", out)
}
if !strings.Contains(out, "[etag:") {
t.Errorf("symbols_only missing etag, got:\n%s", out)
}
}
// ---- Feature: resource_link for big reads ----
func TestResourceLinkThresholdTrip(t *testing.T) {
// When result bytes > threshold, handleFileRead returns a ResourceLink content block.
tmpDir := t.TempDir()
// Low threshold (10 bytes) guarantees even a tiny file trips it.
srv := newTestServerWithThreshold(t, tmpDir, 10)
f := writeFile(t, tmpDir, "big.txt", strings.Repeat("x", 200))
result := callReadResult(t, srv, map[string]interface{}{"path": f})
if len(result.Content) == 0 {
t.Fatal("expected content, got none")
}
link, ok := result.Content[0].(mcp.ResourceLink)
if !ok {
t.Fatalf("expected ResourceLink content, got %T", result.Content[0])
}
if !strings.HasPrefix(link.URI, "filepuff://read/") {
t.Errorf("ResourceLink URI should start with filepuff://read/, got: %q", link.URI)
}
if link.Name != f {
t.Errorf("ResourceLink Name should be file path %q, got %q", f, link.Name)
}
if !strings.Contains(link.Description, "etag=") {
t.Errorf("ResourceLink Description should contain etag=, got %q", link.Description)
}
if !strings.Contains(link.Description, "size=") {
t.Errorf("ResourceLink Description should contain size=, got %q", link.Description)
}
if !strings.Contains(link.Description, "lines=") {
t.Errorf("ResourceLink Description should contain lines=, got %q", link.Description)
}
}
func TestResourceLinkForceInlineBypass(t *testing.T) {
// force_inline=true must always return TextContent regardless of threshold.
tmpDir := t.TempDir()
srv := newTestServerWithThreshold(t, tmpDir, 1) // threshold = 1 byte, always trips
f := writeFile(t, tmpDir, "test.txt", "hello world")
result := callReadResult(t, srv, map[string]interface{}{
"path": f,
"force_inline": true,
})
if len(result.Content) == 0 {
t.Fatal("expected content, got none")
}
tc, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatalf("force_inline=true should return TextContent, got %T", result.Content[0])
}
if !strings.Contains(tc.Text, "hello world") {
t.Errorf("force_inline result should contain file content, got: %q", tc.Text)
}
}
func TestResourceLinkMaxInlineBytesOverride(t *testing.T) {
// max_inline_bytes overrides server threshold per-call.
tmpDir := t.TempDir()
// Server threshold = 1 byte (always trips), but max_inline_bytes allows bigger.
srv := newTestServerWithThreshold(t, tmpDir, 1)
content := strings.Repeat("a", 50) // 50 bytes
f := writeFile(t, tmpDir, "test.txt", content)
// max_inline_bytes=100 > 50 bytes result → should inline
result := callReadResult(t, srv, map[string]interface{}{
"path": f,
"max_inline_bytes": 100,
})
if len(result.Content) == 0 {
t.Fatal("expected content, got none")
}
_, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatalf("max_inline_bytes=100 with 50-byte file should return TextContent, got %T", result.Content[0])
}
}
func TestResourceLinkStaleEtagRejection(t *testing.T) {
// handleReadResource should reject fetch when file has changed since link was issued.
tmpDir := t.TempDir()
srv := newTestServerWithThreshold(t, tmpDir, 1) // always trips
f := writeFile(t, tmpDir, "stale.txt", "original content")
// Get a ResourceLink — captures etag of "original content"
result := callReadResult(t, srv, map[string]interface{}{"path": f})
link, ok := result.Content[0].(mcp.ResourceLink)
if !ok {
t.Fatalf("expected ResourceLink, got %T", result.Content[0])
}
// URI contains ?etag=<hash-of-original>
if !strings.Contains(link.URI, "?etag=") {
t.Fatalf("expected etag in URI, got %q", link.URI)
}
// Overwrite the file with new content.
if err := os.WriteFile(f, []byte("modified content — different"), 0600); err != nil {
t.Fatal(err)
}
// Fetch the resource using the stale URI — should error.
req := mcp.ReadResourceRequest{}
req.Params.URI = link.URI
_, err := srv.handleReadResource(req)
if err == nil {
t.Fatal("expected error for stale etag, got nil")
}
if !strings.Contains(err.Error(), "file changed") {
t.Errorf("error should mention 'file changed', got: %v", err)
}
}
func TestResourceLinkBelowThresholdInlines(t *testing.T) {
// When result is small (below threshold), always inline regardless of threshold setting.
tmpDir := t.TempDir()
// Large threshold — small file should be inlined.
srv := newTestServerWithThreshold(t, tmpDir, 64*1024)
f := writeFile(t, tmpDir, "small.txt", "tiny")
result := callReadResult(t, srv, map[string]interface{}{"path": f})
if len(result.Content) == 0 {
t.Fatal("expected content, got none")
}
_, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatalf("small file should return TextContent, got %T", result.Content[0])
}
}
func TestResourceLinkThresholdZeroDisabled(t *testing.T) {
// threshold=0 disables the feature entirely — always inline.
tmpDir := t.TempDir()
srv := newTestServerWithThreshold(t, tmpDir, 0)
f := writeFile(t, tmpDir, "test.txt", strings.Repeat("z", 10000))
result := callReadResult(t, srv, map[string]interface{}{"path": f})
if len(result.Content) == 0 {
t.Fatal("expected content")
}
_, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatalf("threshold=0 should always inline, got %T", result.Content[0])
}
}
func TestHandleReadResource_ValidFetch(t *testing.T) {
// handleReadResource fetches file content when etag matches.
tmpDir := t.TempDir()
srv := newTestServerWithThreshold(t, tmpDir, 1)
f := writeFile(t, tmpDir, "fetch.txt", "fetch me please")
// Trigger a ResourceLink to get a valid URI with correct etag.
toolResult := callReadResult(t, srv, map[string]interface{}{"path": f})
link, ok := toolResult.Content[0].(mcp.ResourceLink)
if !ok {
t.Fatalf("expected ResourceLink, got %T", toolResult.Content[0])
}
req := mcp.ReadResourceRequest{}
req.Params.URI = link.URI
contents, err := srv.handleReadResource(req)
if err != nil {
t.Fatalf("handleReadResource error: %v", err)
}
if len(contents) == 0 {
t.Fatal("expected resource contents, got none")
}
tc, ok := contents[0].(mcp.TextResourceContents)
if !ok {
t.Fatalf("expected TextResourceContents, got %T", contents[0])
}
if !strings.Contains(tc.Text, "fetch me please") {
t.Errorf("resource contents should include file content, got: %q", tc.Text)
}
}
func TestResourceLinkMIMEType(t *testing.T) {
// Verify MIME types for common extensions.
tmpDir := t.TempDir()
srv := newTestServerWithThreshold(t, tmpDir, 1)
cases := []struct {
name string
content string
wantMIME string
}{
{"test.go", "package main\n", "text/x-go"},
{"test.py", "# py\n", "text/x-python"},
{"test.ts", "// ts\n", "text/typescript"},
{"test.json", "{}\n", "application/json"},
{"test.md", "# hi\n", "text/markdown"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
f := writeFile(t, tmpDir, c.name, c.content)
result := callReadResult(t, srv, map[string]interface{}{"path": f})
link, ok := result.Content[0].(mcp.ResourceLink)
if !ok {
t.Fatalf("expected ResourceLink for %s, got %T", c.name, result.Content[0])
}
if link.MIMEType != c.wantMIME {
t.Errorf("%s: MIMEType = %q, want %q", c.name, link.MIMEType, c.wantMIME)
}
})
}
}
+131 -68
View File
@@ -13,8 +13,14 @@ import (
"github.com/mark3labs/mcp-go/mcp"
)
// handleSymbolAt handles the symbol_at tool.
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// handleLSPQuery is the unified dispatcher for all LSP operations.
// action must be one of: "hover", "definition", "references".
func (s *Server) handleLSPQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
action, err := request.RequireString("action")
if err != nil {
return mcp.NewToolResultError("action is required (hover | definition | references)"), nil
}
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
@@ -30,16 +36,51 @@ func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// Try LSP hover
verbose := request.GetBool("verbose", false)
switch action {
case "hover":
if _, ok := request.GetArguments()["include_declaration"]; ok {
return mcp.NewToolResultError("include_declaration is only valid for action=references"), nil
}
if _, ok := request.GetArguments()["compact"]; ok {
return mcp.NewToolResultError("compact is only valid for action=references"), nil
}
return s.lspHover(ctx, file, line, col, verbose)
case "definition":
if _, ok := request.GetArguments()["include_declaration"]; ok {
return mcp.NewToolResultError("include_declaration is only valid for action=references"), nil
}
if _, ok := request.GetArguments()["compact"]; ok {
return mcp.NewToolResultError("compact is only valid for action=references"), nil
}
return s.lspDefinition(ctx, file, line, col, verbose)
case "references":
includeDecl := request.GetBool("include_declaration", true)
// compact: explicit call-time > session compact_refs pref > false
var prefsCompact *bool
if sp := s.sessionPrefs.Load(); sp != nil {
prefsCompact = sp.CompactRefs
}
compact := effectiveBool(request, "compact", prefsCompact, false)
return s.lspReferences(ctx, file, line, col, includeDecl, compact, verbose)
default:
return mcp.NewToolResultError(fmt.Sprintf("unknown action %q: must be hover | definition | references", action)), nil
}
}
// lspHover performs hover (symbol info) for the given position.
func (s *Server) lspHover(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) {
hover, err := s.lspManager.Hover(ctx, file, line, col)
if err != nil {
// Fall back to AST-based info
return s.handleSymbolAtFallback(ctx, file, line, col)
return s.handleSymbolAtFallback(ctx, file, line, col, verbose)
}
if hover == nil {
@@ -47,14 +88,16 @@ func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest
}
var output strings.Builder
output.WriteString("**Symbol Information**\n\n")
if verbose {
output.WriteString("**Symbol Information**\n\n")
}
output.WriteString(hover.Contents.Value)
return mcp.NewToolResultText(output.String()), nil
}
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) {
content, err := os.ReadFile(file)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %s", errors.SanitizeError(err))), nil
@@ -71,35 +114,17 @@ func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line,
}
var output strings.Builder
output.WriteString("**Symbol Information** (AST fallback)\n\n")
if verbose {
output.WriteString("**Symbol Information** (AST fallback)\n\n")
}
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
return mcp.NewToolResultText(output.String()), nil
}
// handleFindDefinition handles the find_definition tool.
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// lspDefinition finds the definition of the symbol at the given position.
func (s *Server) lspDefinition(ctx context.Context, file string, line, col int, verbose bool) (*mcp.CallToolResult, error) {
locations, err := s.lspManager.Definition(ctx, file, line, col)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %s", errors.SanitizeError(err))), nil
@@ -110,13 +135,14 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
if verbose {
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
}
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
// Try to read a preview snippet
preview := s.readFilePreview(filePath, loc.Range.Start.Line+1, 3)
if preview != "" {
output.WriteString("```\n")
@@ -129,30 +155,8 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR
return mcp.NewToolResultText(output.String()), nil
}
// handleFindReferences handles the find_references tool.
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
includeDecl := request.GetBool("include_declaration", true)
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// lspReferences finds all references to the symbol at the given position.
func (s *Server) lspReferences(ctx context.Context, file string, line, col int, includeDecl, compact, verbose bool) (*mcp.CallToolResult, error) {
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %s", errors.SanitizeError(err))), nil
@@ -162,25 +166,84 @@ func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolR
return mcp.NewToolResultText("No references found."), nil
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
// Group by file
// Group by file, preserving encounter order.
fileGroups := make(map[string][]lsp.Location)
fileOrder := make([]string, 0)
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
if _, seen := fileGroups[filePath]; !seen {
fileOrder = append(fileOrder, filePath)
}
fileGroups[filePath] = append(fileGroups[filePath], loc)
}
for filePath, locs := range fileGroups {
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
for _, loc := range locs {
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
}
output.WriteString("\n")
return mcp.NewToolResultText(formatReferences(fileGroups, fileOrder, len(locations), compact, verbose)), nil
}
// formatReferences formats grouped reference locations as a string.
// compact=false: verbose multi-line format with L{line}:{col} per entry.
// compact=true: one line per file — file:[line:col, ...] (N), with same-line
// columns collapsed to line:{col1,col2,...}.
func formatReferences(fileGroups map[string][]lsp.Location, fileOrder []string, total int, compact bool, verbose bool) string {
var output strings.Builder
if verbose {
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", total))
}
return mcp.NewToolResultText(output.String()), nil
for _, filePath := range fileOrder {
locs := fileGroups[filePath]
if compact {
output.WriteString(formatReferencesCompact(filePath, locs))
} else {
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
for _, loc := range locs {
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
}
output.WriteString("\n")
}
}
return output.String()
}
// formatReferencesCompact formats one file's references as a single compact line.
// Same-line references are collapsed: 12:{5,8} instead of 12:5, 12:8.
func formatReferencesCompact(filePath string, locs []lsp.Location) string {
// Build ordered line->col map preserving encounter order per line.
type lineEntry struct {
lineNum int
cols []int
}
lineMap := make(map[int]*lineEntry)
lineOrder := make([]int, 0, len(locs))
for _, loc := range locs {
ln := loc.Range.Start.Line + 1
col := loc.Range.Start.Character + 1
if e, ok := lineMap[ln]; ok {
e.cols = append(e.cols, col)
} else {
lineMap[ln] = &lineEntry{lineNum: ln, cols: []int{col}}
lineOrder = append(lineOrder, ln)
}
}
// Build the bracket contents.
parts := make([]string, 0, len(lineOrder))
for _, ln := range lineOrder {
e := lineMap[ln]
if len(e.cols) == 1 {
parts = append(parts, fmt.Sprintf("%d:%d", ln, e.cols[0]))
} else {
colStrs := make([]string, len(e.cols))
for i, c := range e.cols {
colStrs[i] = fmt.Sprintf("%d", c)
}
parts = append(parts, fmt.Sprintf("%d:{%s}", ln, strings.Join(colStrs, ",")))
}
}
return fmt.Sprintf("%s:[%s] (%d)\n", filePath, strings.Join(parts, ", "), len(locs))
}
// readFilePreview reads a few lines from a file around the given line.
+181
View File
@@ -0,0 +1,181 @@
package server
import (
"strings"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/lsp"
)
// makeLocation builds an lsp.Location with a file:// URI.
func makeLocation(file string, line, col int) lsp.Location {
return lsp.Location{
URI: "file://" + file,
Range: lsp.Range{
Start: lsp.Position{Line: line - 1, Character: col - 1},
End: lsp.Position{Line: line - 1, Character: col - 1},
},
}
}
// groupLocations is a helper that mirrors the grouping logic in lspReferences.
func groupLocations(locations []lsp.Location) (map[string][]lsp.Location, []string) {
fileGroups := make(map[string][]lsp.Location)
fileOrder := make([]string, 0)
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
if _, seen := fileGroups[filePath]; !seen {
fileOrder = append(fileOrder, filePath)
}
fileGroups[filePath] = append(fileGroups[filePath], loc)
}
return fileGroups, fileOrder
}
// TestFormatReferencesVerbose verifies the default (non-compact) format is unchanged.
func TestFormatReferencesVerbose(t *testing.T) {
locs := []lsp.Location{
makeLocation("/a/foo.go", 12, 5),
makeLocation("/a/foo.go", 13, 8),
makeLocation("/a/bar.go", 15, 1),
}
groups, order := groupLocations(locs)
out := formatReferences(groups, order, len(locs), false, true)
if !strings.Contains(out, "Found 3 reference(s):") {
t.Errorf("missing header, got:\n%s", out)
}
if !strings.Contains(out, "**") {
t.Errorf("verbose format should use **file** markers, got:\n%s", out)
}
if !strings.Contains(out, "L12:5") {
t.Errorf("missing L12:5 in verbose output, got:\n%s", out)
}
if !strings.Contains(out, "L13:8") {
t.Errorf("missing L13:8 in verbose output, got:\n%s", out)
}
if !strings.Contains(out, "L15:1") {
t.Errorf("missing L15:1 in verbose output, got:\n%s", out)
}
}
// TestFormatReferencesCompactBasic verifies compact output for distinct lines.
func TestFormatReferencesCompactBasic(t *testing.T) {
locs := []lsp.Location{
makeLocation("/a/foo.go", 12, 5),
makeLocation("/a/foo.go", 13, 8),
makeLocation("/a/foo.go", 15, 12),
}
groups, order := groupLocations(locs)
out := formatReferences(groups, order, len(locs), true, true)
if !strings.Contains(out, "Found 3 reference(s):") {
t.Errorf("missing header, got:\n%s", out)
}
// Should contain "foo.go:[12:5, 13:8, 15:12] (3)"
if !strings.Contains(out, "12:5") {
t.Errorf("missing 12:5 in compact output, got:\n%s", out)
}
if !strings.Contains(out, "13:8") {
t.Errorf("missing 13:8 in compact output, got:\n%s", out)
}
if !strings.Contains(out, "15:12") {
t.Errorf("missing 15:12 in compact output, got:\n%s", out)
}
if !strings.Contains(out, "(3)") {
t.Errorf("missing (3) count, got:\n%s", out)
}
// Compact format must NOT use ** markers
if strings.Contains(out, "**") {
t.Errorf("compact format must not use ** markers, got:\n%s", out)
}
// Must not have L prefix
if strings.Contains(out, "L12") {
t.Errorf("compact format must not have L prefix, got:\n%s", out)
}
}
// TestFormatReferencesCompactSameLineCollapse verifies same-line columns collapse.
func TestFormatReferencesCompactSameLineCollapse(t *testing.T) {
locs := []lsp.Location{
makeLocation("/a/foo.go", 12, 5),
makeLocation("/a/foo.go", 12, 8),
makeLocation("/a/foo.go", 15, 3),
}
groups, order := groupLocations(locs)
out := formatReferences(groups, order, len(locs), true, false)
// Line 12 has two refs → should be collapsed: 12:{5,8}
if !strings.Contains(out, "12:{5,8}") {
t.Errorf("same-line refs should collapse to 12:{5,8}, got:\n%s", out)
}
// Line 15 is single → 15:3
if !strings.Contains(out, "15:3") {
t.Errorf("single ref on line 15 should be 15:3, got:\n%s", out)
}
}
// TestFormatReferencesCompactMultiFile verifies compact output across multiple files.
func TestFormatReferencesCompactMultiFile(t *testing.T) {
locs := []lsp.Location{
makeLocation("/a/foo.go", 5, 1),
makeLocation("/b/bar.go", 10, 3),
makeLocation("/b/bar.go", 10, 7),
}
groups, order := groupLocations(locs)
out := formatReferences(groups, order, len(locs), true, true)
if !strings.Contains(out, "Found 3 reference(s):") {
t.Errorf("missing header, got:\n%s", out)
}
// foo.go: one ref
if !strings.Contains(out, "5:1") {
t.Errorf("missing 5:1 for foo.go, got:\n%s", out)
}
// bar.go: two refs on same line → collapsed
if !strings.Contains(out, "10:{3,7}") {
t.Errorf("missing 10:{3,7} collapse for bar.go, got:\n%s", out)
}
}
// TestFormatReferencesCompactSingleRef verifies single-reference compact output.
func TestFormatReferencesCompactSingleRef(t *testing.T) {
locs := []lsp.Location{
makeLocation("/a/only.go", 7, 2),
}
groups, order := groupLocations(locs)
out := formatReferences(groups, order, len(locs), true, false)
if !strings.Contains(out, "7:2") {
t.Errorf("missing 7:2, got:\n%s", out)
}
if !strings.Contains(out, "(1)") {
t.Errorf("missing (1), got:\n%s", out)
}
}
// TestFormatReferencesCompactNoLPrefix verifies the L prefix is absent in compact mode.
func TestFormatReferencesCompactNoLPrefix(t *testing.T) {
locs := []lsp.Location{
makeLocation("/a/x.go", 3, 4),
}
groups, order := groupLocations(locs)
out := formatReferences(groups, order, len(locs), true, false)
if strings.Contains(out, "L3") {
t.Errorf("compact output must not contain L prefix, got:\n%s", out)
}
}
// TestFormatReferencesVerboseNoChange verifies compact=false preserves old L-prefix format.
func TestFormatReferencesVerbosePreservesLPrefix(t *testing.T) {
locs := []lsp.Location{
makeLocation("/a/x.go", 3, 4),
}
groups, order := groupLocations(locs)
out := formatReferences(groups, order, len(locs), false, true)
if !strings.Contains(out, "L3:4") {
t.Errorf("verbose output must contain L3:4, got:\n%s", out)
}
}
+183
View File
@@ -0,0 +1,183 @@
// Package server implements the MCP server for file operations.
package server
// helpFileRead is the full flag documentation and examples for the file_read tool,
// served at filepuff://help/file_read.
const helpFileRead = "# file_read — flags and examples\n\n" +
"## Token-saving features\n\n" +
"| Flag | Effect |\n" +
"|------|--------|\n" +
"| `previous_etag` | Skip re-reading unchanged files. Returns `[unchanged, etag: ...]` if file is unchanged. |\n" +
"| `symbol_name` | Read only a named function/struct/class — eliminates an ast_query round-trip. |\n" +
"| `symbols_only=true` | Return only symbol list (~95% fewer tokens). Requires `include_ast=true`. Alias: `mode='symbols_only'`. |\n" +
"| `mode` | `full` (default) \\| `skeleton` (signatures + `{ ... }` stubs, bodies elided) \\| `symbols_only` |\n" +
"| `strip` | Remove content classes before line-numbering: `imports`, `license`, `block_comments`. Emits `[stripped: ...]` footer. |\n" +
"| `no_line_numbers=true` | Omit the ` 12│ ` line-number prefix (~10% savings). `line_number_interval=0` has the same effect. |\n" +
"| `line_number_interval=N` | Print line numbers only every N lines. |\n" +
"| `compact_line_numbers=true` | Use compact `12│` prefix instead of ` 12│ ` (no padding, no trailing space). |\n" +
"| `collapse_blank_lines=true` | Collapse consecutive blank lines to one. |\n" +
"| `max_lines=N` | Truncate output with omitted count notice. Applied after `line_start`/`line_end`. |\n" +
"| `paths=[...]` | Read multiple files in one call. Each file gets a `--- path ---` header. |\n\n" +
"All responses include `[etag: hex]` footer (8 hex chars) for use as `previous_etag` in subsequent reads.\n\n" +
"## Examples\n\n" +
"```json\n" +
"// Full file\n" +
`{"path": "main.go"}` + "\n\n" +
"// Etag check — returns unchanged notice if file hasn't changed\n" +
`{"path": "main.go", "previous_etag": "a3f9c2b1"}` + "\n\n" +
"// Read only one named symbol\n" +
`{"path": "server.go", "symbol_name": "handleFileRead"}` + "\n\n" +
"// Skeleton mode — signatures only, bodies elided\n" +
`{"path": "server.go", "mode": "skeleton"}` + "\n\n" +
"// Strip imports and license header\n" +
`{"path": "main.go", "strip": ["imports", "license"]}` + "\n\n" +
"// Batch read multiple files\n" +
`{"paths": ["a.go", "b.go"]}` + "\n\n" +
"// Specific line range\n" +
`{"path": "main.go", "line_start": 10, "line_end": 50}` + "\n" +
"```\n"
// helpFileSearch is the full flag documentation and examples for the file_search tool,
// served at filepuff://help/file_search.
const helpFileSearch = "# file_search — flags and examples\n\n" +
"## Output format\n\n" +
"Matches grouped by file. Each file section has matching lines prefixed by `L{line}│` and context lines prefixed by ` │`. Zero matches: `No matches found.`\n\n" +
"## Flags\n\n" +
"| Flag | Effect |\n" +
"|------|--------|\n" +
"| `verbose=true` | Emit `Found N matches in M files:` preamble (v1 behaviour). Default: false. |\n" +
"| `cluster=true` | Coalesce consecutive match lines into ranges (`L12-14│ text`). Drops context lines for density. |\n" +
"| `cursor` | Opaque pagination token from a previous truncated response — fetches next page. |\n" +
"| `max_results` | Page size for pagination. Re-run with `cursor` to get next page. |\n" +
"| `context_lines` | Number of context lines around matches (default: 2). |\n" +
"| `ignore_case` | Case-insensitive search. |\n" +
"| `regex` | Treat pattern as regex (default: true). |\n" +
"| `file_types` | Restrict to file extensions, e.g. `[\"go\", \"ts\"]`. |\n" +
"| `paths` | Paths to search in (defaults to workspace root). |\n\n" +
"## Examples\n\n" +
"```json\n" +
"// Search for error-returning functions in Go files\n" +
`{"pattern": "func.*Error", "file_types": ["go"], "max_results": 20}` + "\n\n" +
"// Case-insensitive literal search\n" +
`{"pattern": "TODO", "ignore_case": true}` + "\n\n" +
"// Paginated search — fetch next page\n" +
`{"pattern": "import", "max_results": 50, "cursor": "<token from previous response>"}` + "\n\n" +
"// Clustered — dense view of many matches\n" +
`{"pattern": "return err", "file_types": ["go"], "cluster": true}` + "\n" +
"```\n"
// helpASTQuery is the full flag documentation and examples for the ast_query tool,
// served at filepuff://help/ast_query.
const helpASTQuery = "# ast_query — flags and examples\n\n" +
"## Output format\n\n" +
"Entries in format `**file:line** (node_type)` with code blocks and captured variables (`$NAME=value`). Zero matches: `No matches found.`\n\n" +
"## Flags\n\n" +
"| Flag | Effect |\n" +
"|------|--------|\n" +
"| `verbose=true` | Emit `Found N match(es):` preamble (v1 behaviour). Default: false. |\n" +
"| `format` | `verbose` (default, full code+captures) \\| `compact` (one line per match) \\| `location` (file:line only) |\n" +
"| `cursor` | Opaque pagination token from a previous truncated response — fetches next page. |\n" +
"| `max_results` | Page size (default: 100). |\n" +
"| `name_exact` | Exact symbol name to match. |\n" +
"| `name_matches` | Regex pattern to filter by name. |\n" +
"| `kind_in` | Node types to match (e.g. `function_declaration`, `class_declaration`). |\n" +
"| `paths` | Paths to search in (defaults to workspace root). |\n\n" +
"## Pattern placeholders\n\n" +
"| Placeholder | Meaning |\n" +
"|-------------|----------|\n" +
"| `$NAME` | Matches a single node, captures as `$NAME` |\n" +
"| `$$$ARGS` | Matches zero or more nodes (variadic capture) |\n" +
"| `$_` | Wildcard — matches any single node, no capture |\n\n" +
"## Examples\n\n" +
"```json\n" +
"// All Go functions returning error\n" +
`{"pattern": "func $NAME($$$ARGS) error", "language": "go"}` + "\n\n" +
"// Python classes\n" +
`{"pattern": "class $NAME: $$$BODY", "language": "python"}` + "\n\n" +
"// Specific named function\n" +
`{"pattern": "func $NAME($$$ARGS)", "language": "go", "name_exact": "NewServer"}` + "\n\n" +
"// Compact output — one line per match\n" +
`{"pattern": "func $NAME($$$ARGS) error", "language": "go", "format": "compact"}` + "\n" +
"```\n"
// helpLSPQuery is the full flag documentation and examples for the lsp_query tool,
// served at filepuff://help/lsp_query.
const helpLSPQuery = "# lsp_query — flags and examples\n\n" +
"## Actions\n\n" +
"### hover\n" +
"Returns type/doc from LSP, falls back to AST node info. `verbose=true` adds `**Symbol Information**` header.\n\n" +
"### definition\n" +
"Returns `file:line:col` + 3-line code preview for each definition. `verbose=true` adds `Found N definition(s):` header.\n\n" +
"### references\n" +
"Returns references grouped by file. Flags:\n" +
"- `include_declaration` (default true) — include the declaration itself\n" +
"- `compact=true` — collapse to one line per file\n" +
"- `verbose=true` — add `Found N reference(s):` header\n\n" +
"Note: `include_declaration` and `compact` are errors when used with actions other than `references`.\n\n" +
"## Examples\n\n" +
"```json\n" +
"// Hover — type/doc at position\n" +
`{"action": "hover", "file": "server.go", "line": 45, "column": 6}` + "\n\n" +
"// Definition — where is this symbol defined?\n" +
`{"action": "definition", "file": "handler.go", "line": 23, "column": 10}` + "\n\n" +
"// References — all usages\n" +
`{"action": "references", "file": "types.go", "line": 5, "column": 6}` + "\n\n" +
"// References — compact (one line per file)\n" +
`{"action": "references", "file": "types.go", "line": 5, "column": 6, "compact": true}` + "\n" +
"```\n"
// helpEditApply is the full flag documentation and examples for the edit_apply tool,
// served at filepuff://help/edit_apply.
const helpEditApply = "# edit_apply — flags and examples\n\n" +
"## Response format (`response` flag)\n\n" +
"| Value | Output |\n" +
"|-------|--------|\n" +
"| `count` (default) | `+3 -1` added/removed line counts only |\n" +
"| `diff` | Full unified diff of changes made |\n" +
"| `none` | Empty response (silent success) |\n\n" +
"`compact_response=true` is a deprecated alias for `response=\"count\"` kept for pre-v2 compatibility.\n\n" +
"For code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) syntax is validated before writing — the edit is rejected if it would produce invalid syntax.\n\n" +
"## Selector types\n\n" +
"### AST-mode selectors (code files)\n" +
"- `selector_kind` — AST node type (e.g. `function_declaration`, `class_declaration`)\n" +
"- `selector_name` — symbol name to match\n\n" +
"### Text-mode selectors (all files)\n" +
"- `selector_text` — exact text to match (must be unique, or use `selector_index`)\n" +
"- `selector_pattern` — regex pattern to match\n" +
"- `selector_line` / `selector_line_end` — line range\n\n" +
"### Shared\n" +
"- `selector_index` — index of match when multiple exist (default: 0)\n\n" +
"## Examples\n\n" +
"```json\n" +
"// AST mode — replace a named function\n" +
"{\n" +
` "file": "main.go",` + "\n" +
` "operation": "replace",` + "\n" +
` "selector_kind": "function_declaration",` + "\n" +
` "selector_name": "Hello",` + "\n" +
` "new_content": "func Hello() {\n\treturn\n}"` + "\n" +
"}\n\n" +
"// Text mode — replace a markdown header\n" +
"{\n" +
` "file": "README.md",` + "\n" +
` "operation": "replace",` + "\n" +
` "selector_text": "## Old Header",` + "\n" +
` "new_content": "## New Header"` + "\n" +
"}\n\n" +
"// Line range replacement\n" +
"{\n" +
` "file": "config.yaml",` + "\n" +
` "operation": "replace",` + "\n" +
` "selector_line": 5,` + "\n" +
` "selector_line_end": 10,` + "\n" +
` "new_content": "key: value"` + "\n" +
"}\n\n" +
"// Request full diff in response\n" +
"{\n" +
` "file": "main.go",` + "\n" +
` "operation": "replace",` + "\n" +
` "selector_name": "Hello",` + "\n" +
` "new_content": "func Hello() {}",` + "\n" +
` "response": "diff"` + "\n" +
"}\n" +
"```\n"
+3 -25
View File
@@ -43,23 +43,7 @@ func Hello() string {
ctx := context.Background()
// Test 1: Ping tool (health check)
t.Run("ping", func(t *testing.T) {
req := mcp.CallToolRequest{}
result, err := srv.handlePing(ctx, req)
if err != nil {
t.Errorf("handlePing() error = %v", err)
}
if result == nil {
t.Fatal("handlePing() returned nil")
return
}
if len(result.Content) == 0 {
t.Fatal("handlePing() returned empty content")
}
})
// Test 2: File read
// Test: File read (ping removed — Change 3)
t.Run("file_read", func(t *testing.T) {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
@@ -144,14 +128,8 @@ func TestMCPToolDiscovery(t *testing.T) {
t.Fatal("MCP server not initialized")
}
// Verify each expected tool works
ctx := context.Background()
// Test ping tool
pingReq := mcp.CallToolRequest{}
if _, err := srv.handlePing(ctx, pingReq); err != nil {
t.Errorf("ping tool failed: %v", err)
}
// Ping tool removed (Change 3 — MCP protocol has own liveness check).
// Tools verified via integration tests in TestIntegrationFileOperations.
}
// TestMCPErrorResponses tests error handling following MCP spec.
+156
View File
@@ -0,0 +1,156 @@
package server
import (
"context"
"fmt"
"net/url"
"os"
"strings"
xxhash "github.com/cespare/xxhash/v2"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
)
// helpResources maps a tool name to its help content constant.
var helpResources = map[string]string{
"file_read": helpFileRead,
"file_search": helpFileSearch,
"ast_query": helpASTQuery,
"lsp_query": helpLSPQuery,
"edit_apply": helpEditApply,
}
// registerResources registers one filepuff://help/<tool> resource per tool.
// Each resource returns Markdown-formatted flag docs and examples.
func (s *Server) registerResources() {
for toolName, content := range helpResources {
uri := "filepuff://help/" + toolName
name := "help/" + toolName
description := "Flag documentation and examples for the " + toolName + " tool."
captured := content // capture for closure
s.mcp.AddResource(
mcp.NewResource(uri, name,
mcp.WithResourceDescription(description),
mcp.WithMIMEType("text/markdown"),
),
func(_ context.Context, _ mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: uri,
MIMEType: "text/markdown",
Text: captured,
},
}, nil
},
)
}
}
// readHelpResource is a convenience handler that can be used directly when a
// single resource handler is needed. It is kept exported for testability.
func readHelpResource(uri string) mcpserver.ResourceHandlerFunc {
return func(_ context.Context, _ mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
// Extract tool name from filepuff://help/<tool>
const prefix = "filepuff://help/"
if len(uri) <= len(prefix) {
return nil, fmt.Errorf("invalid help URI: %s", uri)
}
toolName := uri[len(prefix):]
content, ok := helpResources[toolName]
if !ok {
return nil, fmt.Errorf("no help content for tool: %s", toolName)
}
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: uri,
MIMEType: "text/markdown",
Text: content,
},
}, nil
}
}
// registerReadResource registers the filepuff://read/{+path} resource template.
// The handler re-reads the file, validates the etag query param if provided,
// and returns the raw file content (no line-number formatting).
//
// URI format: filepuff://read/<absolute-path>?etag=<etag>
// The etag param is optional. If supplied and the file has changed, the handler
// returns an error so the caller re-runs file_read to get a fresh ResourceLink.
func (s *Server) registerReadResource() {
const uriTemplate = "filepuff://read/{+path}"
s.mcp.AddResourceTemplate(
mcp.NewResourceTemplate(uriTemplate, "file-read",
mcp.WithTemplateDescription("Raw content of a file previously read via file_read. "+
"Fetch when file_read returns a ResourceLink instead of inlining content. "+
"URI: filepuff://read/<path>?etag=<etag>"),
),
func(_ context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
return s.handleReadResource(req)
},
)
}
// handleReadResource is the resource handler for filepuff://read/{+path} URIs.
func (s *Server) handleReadResource(req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
rawURI := req.Params.URI
// Parse path and etag from the URI.
// URI shape: filepuff://read/<path>[?etag=<hex>]
const scheme = "filepuff://read/"
if !strings.HasPrefix(rawURI, scheme) {
return nil, fmt.Errorf("invalid read resource URI: %s", rawURI)
}
rest := rawURI[len(scheme):]
// Split off query string to get the path.
filePath := rest
var expectedEtag string
if qIdx := strings.IndexByte(rest, '?'); qIdx >= 0 {
filePath = rest[:qIdx]
qs, err := url.ParseQuery(rest[qIdx+1:])
if err == nil {
expectedEtag = qs.Get("etag")
}
}
if filePath == "" {
return nil, fmt.Errorf("read resource URI missing path")
}
if !s.cfg.IsPathAllowed(filePath) {
return nil, fmt.Errorf("path is outside workspace root")
}
content, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("file not found: %s", filePath)
}
if os.IsPermission(err) {
return nil, fmt.Errorf("permission denied: %s", filePath)
}
return nil, fmt.Errorf("error reading file: %s", filePath)
}
// Validate etag if provided — detect stale references.
if expectedEtag != "" {
fullHash := fmt.Sprintf("%016x", xxhash.Sum64(content))
currentEtag := fullHash[:8]
if expectedEtag != currentEtag && !strings.HasPrefix(fullHash, expectedEtag) && !strings.HasPrefix(expectedEtag, currentEtag) {
return nil, fmt.Errorf("file changed since ResourceLink was issued (expected etag %s, got %s); re-run file_read to get fresh content", expectedEtag, currentEtag)
}
}
mimeType := detectMIMEType(filePath)
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: rawURI,
MIMEType: mimeType,
Text: string(content),
},
}, nil
}
+90
View File
@@ -0,0 +1,90 @@
package server
import (
"context"
"strings"
"testing"
"github.com/mark3labs/mcp-go/mcp"
)
func TestRegisterResources_AllToolsHaveResource(t *testing.T) {
// Verify that registerResources wires up without panicking.
_ = newTestServer(t, t.TempDir())
expectedURIs := []string{
"filepuff://help/file_read",
"filepuff://help/file_search",
"filepuff://help/ast_query",
"filepuff://help/lsp_query",
"filepuff://help/edit_apply",
}
for _, uri := range expectedURIs {
t.Run(uri, func(t *testing.T) {
handler := readHelpResource(uri)
contents, err := handler(context.Background(), mcp.ReadResourceRequest{})
if err != nil {
t.Fatalf("readHelpResource(%q) error = %v", uri, err)
}
if len(contents) == 0 {
t.Fatalf("readHelpResource(%q) returned empty contents", uri)
}
tc, ok := contents[0].(mcp.TextResourceContents)
if !ok {
t.Fatalf("readHelpResource(%q) contents[0] is not TextResourceContents", uri)
}
if tc.MIMEType != "text/markdown" {
t.Errorf("MIMEType = %q, want %q", tc.MIMEType, "text/markdown")
}
if len(tc.Text) == 0 {
t.Errorf("Text is empty for %q", uri)
}
if !strings.HasPrefix(tc.Text, "#") {
t.Errorf("expected markdown (# heading) for %q, got: %q", uri, tc.Text[:min(50, len(tc.Text))])
}
if tc.URI != uri {
t.Errorf("URI = %q, want %q", tc.URI, uri)
}
})
}
}
func TestReadHelpResource_UnknownTool(t *testing.T) {
handler := readHelpResource("filepuff://help/nonexistent")
_, err := handler(context.Background(), mcp.ReadResourceRequest{})
if err == nil {
t.Fatal("expected error for unknown tool, got nil")
}
}
func TestReadHelpResource_InvalidURI(t *testing.T) {
handler := readHelpResource("filepuff://help/")
_, err := handler(context.Background(), mcp.ReadResourceRequest{})
if err == nil {
t.Fatal("expected error for empty tool name, got nil")
}
}
func TestHelpContent_NotEmpty(t *testing.T) {
cases := map[string]string{
"file_read": helpFileRead,
"file_search": helpFileSearch,
"ast_query": helpASTQuery,
"lsp_query": helpLSPQuery,
"edit_apply": helpEditApply,
}
for name, content := range cases {
t.Run(name, func(t *testing.T) {
if len(content) == 0 {
t.Errorf("help content for %q is empty", name)
}
if !strings.Contains(content, "##") {
t.Errorf("expected markdown sections (##) in help content for %q", name)
}
if !strings.Contains(content, "```") {
t.Errorf("expected code fences in help content for %q", name)
}
})
}
}
+95 -120
View File
@@ -33,16 +33,17 @@ const PreviewLineMaxLength = 100
// Server represents the MCP file operations server.
type Server struct {
cfg *config.Config
logger *slog.Logger
mcp *server.MCPServer
searcher *search.Searcher
parser *parser.Registry
matcher *query.Matcher
lspManager *lsp.Manager
editor *edit.Engine
readSem chan struct{} // Semaphore for limiting concurrent file reads
querySem chan struct{} // Semaphore for limiting concurrent AST queries
cfg *config.Config
logger *slog.Logger
mcp *server.MCPServer
searcher *search.Searcher
parser *parser.Registry
matcher *query.Matcher
lspManager *lsp.Manager
editor *edit.Engine
readSem chan struct{} // Semaphore for limiting concurrent file reads
querySem chan struct{} // Semaphore for limiting concurrent AST queries
sessionPrefs sessionPrefsPtr // Atomic pointer; populated by OnAfterInitialize hook
}
// New creates a new MCP server instance.
@@ -70,40 +71,48 @@ func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
s.lspManager = lsp.NewManager(cfg.WorkspaceRoot, logger)
}
// Build OnAfterInitialize hook that parses client capability prefs.
// Signature (from mcp-go v0.48.0 hooks.go):
// func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult)
hooks := &server.Hooks{}
hooks.AddAfterInitialize(func(_ context.Context, _ any, msg *mcp.InitializeRequest, _ *mcp.InitializeResult) {
if msg == nil {
return
}
raw, _ := msg.Params.Capabilities.Experimental["filepuff"].(map[string]any)
prefs := ParseSessionPrefs(raw)
s.sessionPrefs.Store(&prefs)
})
// Create MCP server
mcpServer := server.NewMCPServer(
"mcp-filepuff",
"1.0.0",
"2.0.0",
server.WithLogging(),
server.WithHooks(hooks),
)
s.mcp = mcpServer
// Register tools
s.registerTools()
// Register help resources (filepuff://help/<tool>)
s.registerResources()
// Register filepuff://read/{+path} resource template for large-file access.
s.registerReadResource()
return s, nil
}
// registerTools registers all available tools with the MCP server.
func (s *Server) registerTools() {
// Register ping tool for health checks
s.mcp.AddTool(
mcp.NewTool("ping",
mcp.WithDescription("Health check - returns pong to verify the server is running.\n\n"+
"Returns: \"pong\" text string."),
mcp.WithReadOnlyHintAnnotation(true),
),
s.handlePing,
)
// Register file_search tool
if s.searcher != nil {
s.mcp.AddTool(
mcp.NewTool("file_search",
mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines.\n\n"+
"Returns: Results grouped by file with match context. Format: \"Found N matches in M files:\" followed by file sections, "+
"each with matching lines prefixed by \"L{line}│\" and context lines prefixed by \" │\".\n\n"+
"Example: {\"pattern\": \"func.*Error\", \"file_types\": [\"go\"], \"max_results\": 20}"),
mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines. "+
"See resource filepuff://help/file_search for flags and examples."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("pattern",
mcp.Required(),
@@ -127,7 +136,16 @@ func (s *Server) registerTools() {
mcp.Description("Number of context lines around matches (default: 2)"),
),
mcp.WithNumber("max_results",
mcp.Description("Maximum number of results to return"),
mcp.Description("Maximum number of results to return (page size for pagination)"),
),
mcp.WithBoolean("cluster",
mcp.Description("Coalesce consecutive match lines into ranges (L12-14│ text). Drops context lines. Default: false."),
),
mcp.WithString("cursor",
mcp.Description("Pagination cursor from a previous truncated response. Pass back to fetch the next page."),
),
mcp.WithBoolean("verbose",
mcp.Description("Emit \"Found N matches in M files:\" preamble. Default: false (v2 default)."),
),
),
s.handleFileSearch,
@@ -137,23 +155,8 @@ func (s *Server) registerTools() {
// Register file_read tool
s.mcp.AddTool(
mcp.NewTool("file_read",
mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary.\n\n"+
"Token-saving features:\n"+
" previous_etag: skip re-reading unchanged files (returns '[unchanged, etag: ...]' if unchanged)\n"+
" symbol_name: read only a named function/struct/class (eliminates ast_query round-trip)\n"+
" symbols_only=true: return only symbol list, ~95% fewer tokens (requires include_ast=true)\n"+
" no_line_numbers=true: strip the line-number prefix (~10%% savings)\n"+
" line_number_interval=N: print line numbers only every N lines\n"+
" collapse_blank_lines=true: collapse consecutive blank lines to one\n"+
" max_lines=N: truncate output with omitted count notice\n"+
" paths=[...]: read multiple files in one call\n\n"+
"All responses include '[etag: hex]' footer for use as previous_etag in subsequent reads.\n\n"+
"Examples:\n"+
" Full file: {\"path\": \"main.go\"}\n"+
" Etag check: {\"path\": \"main.go\", \"previous_etag\": \"a3f9c2b1\"}\n"+
" By symbol: {\"path\": \"server.go\", \"symbol_name\": \"handleFileRead\"}\n"+
" Batch: {\"paths\": [\"a.go\", \"b.go\"]}\n"+
" Line range: {\"path\": \"main.go\", \"line_start\": 10, \"line_end\": 50}"),
mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary. "+
"See resource filepuff://help/file_read for flags and examples."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("path",
mcp.Description("Path to the file to read (required unless paths is provided)"),
@@ -181,20 +184,36 @@ func (s *Server) registerTools() {
mcp.Description("Include AST symbol summary (functions, classes, types, etc.)"),
),
mcp.WithBoolean("symbols_only",
mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true."),
mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true. Alias: mode='symbols_only'."),
),
mcp.WithString("mode",
mcp.Description("Output mode: 'full' (default, full file), 'skeleton' (signatures + { ... } stubs, bodies elided), 'symbols_only' (symbol list only, alias for symbols_only=true)."),
),
mcp.WithArray("strip",
mcp.Description("Strip content classes before line-numbering. Values: 'imports' (remove import blocks), 'license' (remove leading license comment), 'block_comments' (remove /* */ and Python triple-quoted strings). Emits [stripped: ...] footer."),
mcp.WithStringItems(),
),
mcp.WithNumber("max_lines",
mcp.Description("Maximum number of lines to return (for token efficiency). Applied after line_start/line_end."),
),
mcp.WithBoolean("no_line_numbers",
mcp.Description("Omit the ' 12 ' line number prefix entirely. Saves ~10% tokens. line_number_interval=0 has the same effect."),
mcp.Description("Omit the ' 12\u2502 ' line number prefix entirely. Saves ~10% tokens. line_number_interval=0 has the same effect."),
),
mcp.WithNumber("line_number_interval",
mcp.Description("Print line numbers only every N lines (default: 1 = every line). E.g. 10 = anchor every 10th line plus first/last. 0 = no line numbers."),
),
mcp.WithBoolean("compact_line_numbers",
mcp.Description("Use compact line prefix '12\u2502' instead of ' 12\u2502 ' (no padding, no trailing space). Works with line_number_interval. Default off."),
),
mcp.WithBoolean("collapse_blank_lines",
mcp.Description("Collapse runs of consecutive blank lines to a single blank line. Useful for token savings on heavily-spaced code."),
),
mcp.WithBoolean("force_inline",
mcp.Description("Always return file content inline, bypassing the resource-link threshold. Default: false."),
),
mcp.WithNumber("max_inline_bytes",
mcp.Description("Per-call inline threshold override in bytes. If set, overrides server resource_link_threshold_bytes for this call only. 0 = use server default."),
),
),
s.handleFileRead,
)
@@ -202,13 +221,8 @@ func (s *Server) registerTools() {
// Register ast_query tool
s.mcp.AddTool(
mcp.NewTool("ast_query",
mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types.\n\n"+
"Returns: \"Found N match(es):\" followed by entries in format \"**file:line** (node_type)\" with code blocks "+
"and captured variables ($NAME=value). Returns \"No matches found.\" when no results.\n\n"+
"Examples:\n"+
" Go error funcs: {\"pattern\": \"func $NAME($$$ARGS) error\", \"language\": \"go\"}\n"+
" Python classes: {\"pattern\": \"class $NAME: $$$BODY\", \"language\": \"python\"}\n"+
" Named function: {\"pattern\": \"func $NAME($$$ARGS)\", \"language\": \"go\", \"name_exact\": \"NewServer\"}"),
mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types. "+
"See resource filepuff://help/ast_query for flags and examples."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("pattern",
mcp.Required(),
@@ -233,7 +247,16 @@ func (s *Server) registerTools() {
mcp.WithStringItems(),
),
mcp.WithNumber("max_results",
mcp.Description("Maximum number of results to return (default: 100)"),
mcp.Description("Maximum number of results to return (default: 100, page size for pagination)"),
),
mcp.WithString("format",
mcp.Description("Output format: \"verbose\" (default, full code+captures), \"compact\" (one line per match), \"location\" (file:line only)"),
),
mcp.WithString("cursor",
mcp.Description("Pagination cursor from a previous truncated response. Pass back to fetch the next page."),
),
mcp.WithBoolean("verbose",
mcp.Description("Emit \"Found N match(es):\" preamble. Default: false (v2 default)."),
),
),
s.handleASTQuery,
@@ -241,62 +264,15 @@ func (s *Server) registerTools() {
// Register LSP-based tools if LSP is enabled
if s.lspManager != nil {
// Register symbol_at tool
s.mcp.AddTool(
mcp.NewTool("symbol_at",
mcp.WithDescription("Get information about the symbol at a specific position in a file. Returns type, documentation, and definition location using LSP when available.\n\n"+
"Returns: \"**Symbol Information**\" followed by hover/type information from LSP, or \"**Symbol Information** (AST fallback)\" "+
"with node type and text when LSP unavailable. Returns \"No symbol information available at this position.\" when nothing is found.\n\n"+
"Example: {\"file\": \"server.go\", \"line\": 45, \"column\": 6}"),
mcp.NewTool("lsp_query",
mcp.WithDescription("Query LSP for symbol info, definition, or references at a specific file position. "+
"See resource filepuff://help/lsp_query for flags and examples."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("file",
mcp.WithString("action",
mcp.Required(),
mcp.Description("Path to the file"),
mcp.Description("LSP operation: hover | definition | references"),
),
mcp.WithNumber("line",
mcp.Required(),
mcp.Description("Line number (1-indexed)"),
),
mcp.WithNumber("column",
mcp.Required(),
mcp.Description("Column number (1-indexed)"),
),
),
s.handleSymbolAt,
)
// Register find_definition tool
s.mcp.AddTool(
mcp.NewTool("find_definition",
mcp.WithDescription("Find the definition of the symbol at a specific position. Uses LSP to locate where a function, variable, type, etc. is defined.\n\n"+
"Returns: \"Found N definition(s):\" with entries showing \"**file:line:column**\" and a 3-line code preview "+
"with the target line marked by \">\". Returns \"No definition found.\" when the symbol has no definition.\n\n"+
"Example: {\"file\": \"handler.go\", \"line\": 23, \"column\": 10}"),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file"),
),
mcp.WithNumber("line",
mcp.Required(),
mcp.Description("Line number (1-indexed)"),
),
mcp.WithNumber("column",
mcp.Required(),
mcp.Description("Column number (1-indexed)"),
),
),
s.handleFindDefinition,
)
// Register find_references tool
s.mcp.AddTool(
mcp.NewTool("find_references",
mcp.WithDescription("Find all references to the symbol at a specific position. Uses LSP to locate all usages of a function, variable, type, etc.\n\n"+
"Returns: \"Found N reference(s):\" grouped by file, each showing \"**file** (count)\" with locations as "+
"\"L{line}:{column}\". Returns \"No references found.\" when no usages exist.\n\n"+
"Example: {\"file\": \"types.go\", \"line\": 5, \"column\": 6}"),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file"),
@@ -310,23 +286,24 @@ func (s *Server) registerTools() {
mcp.Description("Column number (1-indexed)"),
),
mcp.WithBoolean("include_declaration",
mcp.Description("Include the declaration in results (default: true)"),
mcp.Description("Include the declaration in results. Only valid for action=references (default: true)."),
),
mcp.WithBoolean("compact",
mcp.Description("Compact output: one line per file with all refs in brackets. Only valid for action=references. Default: false."),
),
mcp.WithBoolean("verbose",
mcp.Description("Emit count/header preamble. Applies to all actions. Default: false."),
),
),
s.handleFindReferences,
s.handleLSPQuery,
)
}
// Register edit tools
s.mcp.AddTool(
mcp.NewTool("edit_apply",
mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.).\n\n"+
"Returns: \"**Edit Applied Successfully**\" followed by a unified diff of the changes made. "+
"For code files, validates syntax before writing — returns an error if the edit would produce invalid syntax.\n\n"+
"Examples:\n"+
" AST mode: {\"file\": \"main.go\", \"operation\": \"replace\", \"selector_kind\": \"function_declaration\", \"selector_name\": \"Hello\", \"new_content\": \"func Hello() {\\n\\treturn\\n}\"}\n"+
" Text mode: {\"file\": \"README.md\", \"operation\": \"replace\", \"selector_text\": \"## Old Header\", \"new_content\": \"## New Header\"}\n"+
" Line range: {\"file\": \"config.yaml\", \"operation\": \"replace\", \"selector_line\": 5, \"selector_line_end\": 10, \"new_content\": \"key: value\"}"),
mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++, Rust) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.). "+
"See resource filepuff://help/edit_apply for flags and examples."),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file to edit"),
@@ -362,19 +339,17 @@ func (s *Server) registerTools() {
mcp.WithString("selector_pattern",
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
),
mcp.WithString("response",
mcp.Description("Response format: \"count\" (default, \"+3 -1\" line counts), \"diff\" (full unified diff), \"none\" (empty). Default: count."),
),
mcp.WithBoolean("compact_response",
mcp.Description("Return only the modified symbol's content instead of a full diff. Requires selector_name. Saves tokens on large-file edits."),
mcp.Description("Deprecated: use response=count. Alias for response=\"count\" kept for pre-v2 compatibility."),
),
),
s.handleEditApply,
)
}
// handlePing handles the ping health check tool.
func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("pong"), nil
}
// Run starts the MCP server and blocks until shutdown.
func (s *Server) Run(ctx context.Context) error {
// Set up signal handling for graceful shutdown
+259 -39
View File
@@ -7,6 +7,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/mark3labs/mcp-go/mcp"
@@ -49,45 +50,6 @@ func TestNew(t *testing.T) {
}
}
func TestHandlePing(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
result, err := srv.handlePing(ctx, req)
if err != nil {
t.Errorf("handlePing() error = %v", err)
}
if result == nil {
t.Fatal("handlePing() returned nil result")
return
}
// Check that the result contains "pong"
contents := result.Content
if len(contents) == 0 {
t.Fatal("handlePing() returned empty content")
}
textContent, ok := contents[0].(mcp.TextContent)
if !ok {
t.Fatal("handlePing() did not return text content")
}
if textContent.Text != "pong" {
t.Errorf("handlePing() = %v, want 'pong'", textContent.Text)
}
}
func TestHandleFileRead(t *testing.T) {
tmpDir := t.TempDir()
@@ -484,3 +446,261 @@ func TestSplitLinesLongLine(t *testing.T) {
t.Error("the 500KB long line was not found in splitLines output")
}
}
// TestHandleEditApplyResponseCount verifies the default response=count format "+N -M".
func TestHandleEditApplyResponseCount(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() {
println("Hello")
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"file": testFile,
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Hello",
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
// no response flag → default "count"
}
result, err := srv.handleEditApply(ctx, req)
if err != nil {
t.Fatalf("handleEditApply error = %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("handleEditApply returned empty result")
}
textContent, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatal("expected text content")
}
// Should be "+N -M" format
text := textContent.Text
if !strings.HasPrefix(text, "+") {
t.Errorf("response=count should start with +, got: %q", text)
}
if !strings.Contains(text, " -") {
t.Errorf("response=count should contain -N, got: %q", text)
}
// Must NOT contain diff syntax or old preamble
if strings.Contains(text, "@@") || strings.Contains(text, "Edit Applied") {
t.Errorf("response=count must not contain diff markers, got: %q", text)
}
}
// TestHandleEditApplyResponseDiff verifies response=diff returns unified diff.
func TestHandleEditApplyResponseDiff(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() {
println("Hello")
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"file": testFile,
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Hello",
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
"response": "diff",
}
result, err := srv.handleEditApply(ctx, req)
if err != nil {
t.Fatalf("handleEditApply error = %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("handleEditApply returned empty result")
}
textContent, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatal("expected text content")
}
text := textContent.Text
if !strings.Contains(text, "diff") {
t.Errorf("response=diff should contain diff, got: %q", text)
}
// Must NOT have old "Edit Applied Successfully" preamble
if strings.Contains(text, "Edit Applied Successfully") {
t.Errorf("v2 diff should not have old preamble, got: %q", text)
}
}
// TestHandleEditApplyResponseNone verifies response=none returns empty string.
func TestHandleEditApplyResponseNone(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() {
println("Hello")
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"file": testFile,
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Hello",
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
"response": "none",
}
result, err := srv.handleEditApply(ctx, req)
if err != nil {
t.Fatalf("handleEditApply error = %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("handleEditApply returned empty result")
}
textContent, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatal("expected text content")
}
if textContent.Text != "" {
t.Errorf("response=none should return empty string, got: %q", textContent.Text)
}
}
// TestHandleFileReadBatchDedup verifies that identical files in batch mode emit [duplicate of ...].
func TestHandleFileReadBatchDedup(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "a.go")
content := `package main
func Hello() {}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write a.go: %v", err)
}
// Make b.go with identical content
testFile2 := filepath.Join(tmpDir, "b.go")
if err := os.WriteFile(testFile2, []byte(content), 0600); err != nil {
t.Fatalf("failed to write b.go: %v", err)
}
// c.go with different content
testFile3 := filepath.Join(tmpDir, "c.go")
if err := os.WriteFile(testFile3, []byte("package main\n"), 0600); err != nil {
t.Fatalf("failed to write c.go: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false, MaxFileSize: 1024 * 1024}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"paths": []interface{}{testFile, testFile2, testFile3},
}
result, err := srv.handleFileRead(ctx, req)
if err != nil {
t.Fatalf("handleFileRead() error = %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("handleFileRead() returned empty result")
}
textContent, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatal("expected text content")
}
text := textContent.Text
if !strings.Contains(text, "[duplicate of") {
t.Errorf("expected duplicate pointer for b.go, got:\n%s", text)
}
// a.go should have full content
if !strings.Contains(text, "--- "+testFile+" ---") {
t.Errorf("expected a.go header, got:\n%s", text)
}
// c.go should have full content (different hash)
if !strings.Contains(text, "--- "+testFile3+" ---") {
t.Errorf("expected c.go header, got:\n%s", text)
}
}
// TestHandleFileSearchVerbose verifies verbose=true emits "Found N matches in M files:" preamble.
func TestHandleFileSearchVerbose(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
if err := os.WriteFile(testFile, []byte("package main\n\nfunc Hello() {}\n"), 0600); err != nil {
t.Fatalf("write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false, MaxFileSize: 1024 * 1024, SearchTimeout: 10 * time.Second}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
if srv.searcher == nil {
t.Skip("ripgrep not available")
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "Hello",
"paths": []interface{}{tmpDir},
"verbose": true,
}
result, err := srv.handleFileSearch(ctx, req)
if err != nil {
t.Fatalf("handleFileSearch error = %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("handleFileSearch returned empty")
}
textContent, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatal("expected text content")
}
if !strings.Contains(textContent.Text, "Found ") {
t.Errorf("verbose=true should emit preamble, got:\n%s", textContent.Text)
}
}
+145
View File
@@ -0,0 +1,145 @@
// Package server implements the MCP server for file operations.
package server
import (
"sync/atomic"
"unsafe"
)
// SessionPrefs holds client-declared session-wide preferences parsed from
// InitializeRequest.Params.Capabilities.Experimental["filepuff"].
//
// These act as defaults; explicit per-call flags always override them.
//
// Supported keys in the "filepuff" experimental map:
//
// terse bool — no-op (v2 default is already terse; reserved for future)
// default_format string — ast_query format default ("verbose"|"compact"|"location")
// default_max_results int — file_search and ast_query max_results when not supplied
// default_cluster bool — file_search cluster default
// compact_refs bool — lsp_query references compact default
// line_numbers string — file_read line prefix default ("none"|"compact"|"full")
// resource_link_threshold int — per-session override for cfg.ResourceLinkThresholdBytes
type SessionPrefs struct {
// ASTQueryFormat is the default format for ast_query ("verbose", "compact", "location").
// Empty string means "use handler built-in default".
ASTQueryFormat string
// DefaultMaxResults is the default max_results for file_search and ast_query.
// 0 means "use handler built-in default".
DefaultMaxResults int
// DefaultCluster is the default cluster flag for file_search.
// nil means "use handler built-in default (false)".
DefaultCluster *bool
// CompactRefs is the default compact flag for lsp_query action=references.
// nil means "use handler built-in default (false)".
CompactRefs *bool
// LineNumbers controls the file_read line prefix default.
// "" = use handler built-in default (full).
// "none" = no line numbers.
// "compact" = compact prefix (N│content).
// "full" = standard padded prefix ( N│ content).
LineNumbers string
// ResourceLinkThreshold overrides cfg.ResourceLinkThresholdBytes for this session.
// 0 means "use config default".
ResourceLinkThreshold int
}
// boolPtr returns a pointer to a bool value.
func boolPtr(b bool) *bool { return &b }
// ParseSessionPrefs parses the raw map from
// InitializeRequest.Params.Capabilities.Experimental["filepuff"].
// Unknown keys are silently ignored. Type mismatches for individual keys are
// silently ignored (key is treated as absent). Returns zero-value SessionPrefs
// when raw is nil or empty — callers should treat zero values as "use built-in defaults".
func ParseSessionPrefs(raw map[string]any) SessionPrefs {
if len(raw) == 0 {
return SessionPrefs{}
}
var p SessionPrefs
if v, ok := raw["default_format"]; ok {
if s, ok := v.(string); ok {
switch s {
case "verbose", "compact", "location":
p.ASTQueryFormat = s
}
}
}
if v, ok := raw["default_max_results"]; ok {
if n := toInt(v); n > 0 {
p.DefaultMaxResults = n
}
}
if v, ok := raw["default_cluster"]; ok {
if b, ok := v.(bool); ok {
p.DefaultCluster = boolPtr(b)
}
}
if v, ok := raw["compact_refs"]; ok {
if b, ok := v.(bool); ok {
p.CompactRefs = boolPtr(b)
}
}
if v, ok := raw["line_numbers"]; ok {
if s, ok := v.(string); ok {
switch s {
case "none", "compact", "full":
p.LineNumbers = s
}
}
}
if v, ok := raw["resource_link_threshold"]; ok {
if n := toInt(v); n >= 0 {
p.ResourceLinkThreshold = n
}
}
return p
}
// toInt converts numeric JSON-decoded values (float64, int, int64) to int.
// Returns 0 for unsupported types or negative values.
func toInt(v any) int {
switch n := v.(type) {
case float64:
if n >= 0 {
return int(n)
}
case int:
if n >= 0 {
return n
}
case int64:
if n >= 0 {
return int(n)
}
}
return 0
}
// sessionPrefsPtr is an atomic pointer helper for thread-safe access to *SessionPrefs.
// We use atomic store/load so the hook (called once at init) and handlers (many goroutines)
// never race. The prefs are write-once after initialization.
type sessionPrefsPtr struct {
p unsafe.Pointer // *SessionPrefs
}
func (sp *sessionPrefsPtr) Store(prefs *SessionPrefs) {
atomic.StorePointer(&sp.p, unsafe.Pointer(prefs))
}
func (sp *sessionPrefsPtr) Load() *SessionPrefs {
return (*SessionPrefs)(atomic.LoadPointer(&sp.p))
}
+313
View File
@@ -0,0 +1,313 @@
package server
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/mark3labs/mcp-go/mcp"
)
// ---- Session prefs integration tests ----
// setSessionPrefs injects prefs directly on the server (bypasses MCP hook machinery).
func setSessionPrefs(srv *Server, prefs SessionPrefs) {
srv.sessionPrefs.Store(&prefs)
}
// TestSessionPrefsFileReadLineNumbersNone verifies that session pref line_numbers=none
// disables line-number prefixes and is overridden by explicit compact_line_numbers=true.
func TestSessionPrefsFileReadLineNumbersNone(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
content := "package main\n\nfunc Foo() {}\n"
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("write file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New: %v", err)
}
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"line_numbers": "none"}))
ctx := context.Background()
// Without explicit override: session pref should suppress line numbers.
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{"path": testFile}
result, err := srv.handleFileRead(ctx, req)
if err != nil {
t.Fatalf("handleFileRead: %v", err)
}
text := result.Content[0].(mcp.TextContent).Text
// Standard line-number format is " 1│ "; with no_line_numbers it's absent.
if strings.Contains(text, " 1│") {
t.Errorf("session line_numbers=none: expected no line-number prefix, got:\n%s", text)
}
// Explicit per-call compact_line_numbers=true should override session none.
req2 := mcp.CallToolRequest{}
req2.Params.Arguments = map[string]interface{}{
"path": testFile,
"compact_line_numbers": true,
}
result2, err := srv.handleFileRead(ctx, req2)
if err != nil {
t.Fatalf("handleFileRead (explicit compact): %v", err)
}
text2 := result2.Content[0].(mcp.TextContent).Text
// Compact format emits "1│" prefix.
if !strings.Contains(text2, "\u2502") {
t.Errorf("explicit compact_line_numbers should override session none, got:\n%s", text2)
}
}
// TestSessionPrefsFileReadLineNumbersCompact verifies line_numbers=compact session pref.
func TestSessionPrefsFileReadLineNumbersCompact(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "x.go")
if err := os.WriteFile(testFile, []byte("package main\nfunc Bar() {}\n"), 0600); err != nil {
t.Fatalf("write: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, _ := New(cfg, logger)
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"line_numbers": "compact"}))
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{"path": testFile}
result, _ := srv.handleFileRead(ctx, req)
text := result.Content[0].(mcp.TextContent).Text
// Standard padded prefix is " 1│ "; compact is "1│".
if strings.Contains(text, " 1\u2502") {
t.Errorf("session line_numbers=compact should use compact prefix, got:\n%s", text)
}
// Should still have the │ separator somewhere.
if !strings.Contains(text, "\u2502") {
t.Errorf("session line_numbers=compact should still have \u2502 separator, got:\n%s", text)
}
}
// TestSessionPrefsResourceLinkThreshold verifies per-session threshold override.
func TestSessionPrefsResourceLinkThreshold(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "big.go")
var sb strings.Builder
sb.WriteString("package main\n\nfunc Foo() {\n")
for i := 0; i < 15; i++ {
sb.WriteString("// comment line\n")
}
sb.WriteString("}\n")
if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil {
t.Fatalf("write: %v", err)
}
// Config threshold = 0 (disabled) so content is always inlined by default.
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20, ResourceLinkThresholdBytes: 0}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, _ := New(cfg, logger)
// Set session threshold = 10 bytes (tiny), so any real file triggers resource-link.
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"resource_link_threshold": float64(10)}))
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{"path": testFile}
result, _ := srv.handleFileRead(ctx, req)
if len(result.Content) == 0 {
t.Fatal("expected content")
}
_, isLink := result.Content[0].(mcp.ResourceLink)
_, isText := result.Content[0].(mcp.TextContent)
if !isLink && isText {
t.Error("expected ResourceLink when session threshold is very small, got TextContent")
}
// force_inline should still bypass even a session threshold.
req2 := mcp.CallToolRequest{}
req2.Params.Arguments = map[string]interface{}{"path": testFile, "force_inline": true}
result2, _ := srv.handleFileRead(ctx, req2)
if _, ok := result2.Content[0].(mcp.TextContent); !ok {
t.Error("force_inline=true should bypass session threshold and return TextContent")
}
}
// TestSessionPrefsASTQueryFormat verifies default_format session pref for ast_query.
func TestSessionPrefsASTQueryFormat(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
if err := os.WriteFile(testFile, []byte("package main\n\nfunc Greet() {}\n"), 0600); err != nil {
t.Fatalf("write: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, _ := New(cfg, logger)
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_format": "compact"}))
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
"language": "go",
"paths": []interface{}{tmpDir},
// no format key → session default should apply
}
result, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Fatalf("handleASTQuery: %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("empty result")
}
text := result.Content[0].(mcp.TextContent).Text
// Compact format emits one-line results without "**file:line**" markers.
if strings.Contains(text, "**") {
t.Errorf("session default_format=compact: expected compact output (no **), got:\n%s", text)
}
// Explicit format=verbose should override session compact.
req2 := mcp.CallToolRequest{}
req2.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
"language": "go",
"paths": []interface{}{tmpDir},
"format": "verbose",
}
result2, _ := srv.handleASTQuery(ctx, req2)
if result2 != nil && len(result2.Content) > 0 {
text2 := result2.Content[0].(mcp.TextContent).Text
if !strings.Contains(text2, "**") {
t.Errorf("explicit format=verbose should override session compact, got:\n%s", text2)
}
}
}
// TestSessionPrefsASTQueryMaxResults verifies default_max_results for ast_query.
func TestSessionPrefsASTQueryMaxResults(t *testing.T) {
tmpDir := t.TempDir()
// Build a file with 5 functions.
var sb strings.Builder
sb.WriteString("package main\n\n")
for i := 0; i < 5; i++ {
sb.WriteString(fmt.Sprintf("func Fn%c() {}\n\n", rune('A'+i)))
}
testFile := filepath.Join(tmpDir, "many.go")
if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil {
t.Fatalf("write: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, _ := New(cfg, logger)
// Session pref: max 2 results.
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_max_results": float64(2)}))
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
"language": "go",
"paths": []interface{}{tmpDir},
// no max_results → session pref of 2 should apply
}
result, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Fatalf("handleASTQuery: %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("empty result")
}
text := result.Content[0].(mcp.TextContent).Text
// With 5 funcs and max=2, output should mention remaining.
if !strings.Contains(text, "remaining") && !strings.Contains(text, "cursor") {
t.Errorf("session max_results=2 with 5 matches should produce cursor line, got:\n%s", text)
}
}
// TestSessionPrefsFileSearchDefaultCluster verifies default_cluster session pref.
func TestSessionPrefsFileSearchDefaultCluster(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "x.go")
content := "package main\n\nfunc Foo() {}\nfunc Foo2() {}\n"
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("write: %v", err)
}
cfg := &config.Config{
WorkspaceRoot: tmpDir,
MaxFileSize: 1 << 20,
SearchTimeout: 10 * time.Second,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New: %v", err)
}
if srv.searcher == nil {
t.Skip("ripgrep not available")
}
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_cluster": true}))
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func",
"paths": []interface{}{tmpDir},
// no cluster flag → session default_cluster=true should apply
}
result, err := srv.handleFileSearch(ctx, req)
if err != nil {
t.Fatalf("handleFileSearch: %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Skip("search returned no results")
}
// Verify call succeeded (cluster behaviour is ripgrep-version dependent).
_ = result.Content[0].(mcp.TextContent).Text
}
// TestSessionPrefsMaxResultsExplicitOverride verifies explicit call-time max_results
// overrides session pref for ast_query.
func TestSessionPrefsMaxResultsExplicitOverride(t *testing.T) {
tmpDir := t.TempDir()
var sb strings.Builder
sb.WriteString("package main\n\n")
for i := 0; i < 5; i++ {
sb.WriteString(fmt.Sprintf("func Fn%c() {}\n\n", rune('A'+i)))
}
testFile := filepath.Join(tmpDir, "many.go")
if err := os.WriteFile(testFile, []byte(sb.String()), 0600); err != nil {
t.Fatalf("write: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, MaxFileSize: 1 << 20}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, _ := New(cfg, logger)
// Session wants 2, caller supplies 10 — all 5 should fit without cursor.
setSessionPrefs(srv, ParseSessionPrefs(map[string]any{"default_max_results": float64(2)}))
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
"language": "go",
"paths": []interface{}{tmpDir},
"max_results": 10, // explicit override
}
result, _ := srv.handleASTQuery(ctx, req)
if result == nil || len(result.Content) == 0 {
t.Fatal("empty result")
}
text := result.Content[0].(mcp.TextContent).Text
// With max_results=10 and only 5 funcs, no cursor line expected.
if strings.Contains(text, "remaining") {
t.Errorf("explicit max_results=10 should override session 2; 5 funcs fit; no cursor expected, got:\n%s", text)
}
}
+186
View File
@@ -0,0 +1,186 @@
package server
import (
"testing"
)
// TestParseSessionPrefsEmpty verifies zero-value result for nil/empty input.
func TestParseSessionPrefsEmpty(t *testing.T) {
p := ParseSessionPrefs(nil)
if p.ASTQueryFormat != "" {
t.Errorf("ASTQueryFormat: want \"\", got %q", p.ASTQueryFormat)
}
if p.DefaultMaxResults != 0 {
t.Errorf("DefaultMaxResults: want 0, got %d", p.DefaultMaxResults)
}
if p.DefaultCluster != nil {
t.Errorf("DefaultCluster: want nil, got %v", *p.DefaultCluster)
}
if p.CompactRefs != nil {
t.Errorf("CompactRefs: want nil, got %v", *p.CompactRefs)
}
if p.LineNumbers != "" {
t.Errorf("LineNumbers: want \"\", got %q", p.LineNumbers)
}
if p.ResourceLinkThreshold != 0 {
t.Errorf("ResourceLinkThreshold: want 0, got %d", p.ResourceLinkThreshold)
}
// Also test with an empty (non-nil) map.
p2 := ParseSessionPrefs(map[string]any{})
if p2.ASTQueryFormat != "" || p2.DefaultMaxResults != 0 {
t.Error("empty map should produce zero-value prefs")
}
}
// TestParseSessionPrefsAllFields verifies full round-trip with all supported keys.
func TestParseSessionPrefsAllFields(t *testing.T) {
raw := map[string]any{
"terse": true, // no-op; should not produce an error
"default_format": "compact",
"default_max_results": float64(50), // JSON numbers decode as float64
"default_cluster": true,
"compact_refs": true,
"line_numbers": "none",
"resource_link_threshold": float64(32768),
}
p := ParseSessionPrefs(raw)
if p.ASTQueryFormat != "compact" {
t.Errorf("ASTQueryFormat: want \"compact\", got %q", p.ASTQueryFormat)
}
if p.DefaultMaxResults != 50 {
t.Errorf("DefaultMaxResults: want 50, got %d", p.DefaultMaxResults)
}
if p.DefaultCluster == nil || !*p.DefaultCluster {
t.Errorf("DefaultCluster: want true, got %v", p.DefaultCluster)
}
if p.CompactRefs == nil || !*p.CompactRefs {
t.Errorf("CompactRefs: want true, got %v", p.CompactRefs)
}
if p.LineNumbers != "none" {
t.Errorf("LineNumbers: want \"none\", got %q", p.LineNumbers)
}
if p.ResourceLinkThreshold != 32768 {
t.Errorf("ResourceLinkThreshold: want 32768, got %d", p.ResourceLinkThreshold)
}
}
// TestParseSessionPrefsLineNumbersVariants tests all valid line_numbers values.
func TestParseSessionPrefsLineNumbersVariants(t *testing.T) {
for _, want := range []string{"none", "compact", "full"} {
p := ParseSessionPrefs(map[string]any{"line_numbers": want})
if p.LineNumbers != want {
t.Errorf("line_numbers=%q: got %q", want, p.LineNumbers)
}
}
// Invalid value → ignored (empty string).
p := ParseSessionPrefs(map[string]any{"line_numbers": "bogus"})
if p.LineNumbers != "" {
t.Errorf("invalid line_numbers should be ignored, got %q", p.LineNumbers)
}
}
// TestParseSessionPrefsFormatVariants tests all valid default_format values.
func TestParseSessionPrefsFormatVariants(t *testing.T) {
for _, want := range []string{"verbose", "compact", "location"} {
p := ParseSessionPrefs(map[string]any{"default_format": want})
if p.ASTQueryFormat != want {
t.Errorf("default_format=%q: got %q", want, p.ASTQueryFormat)
}
}
// Invalid value → ignored.
p := ParseSessionPrefs(map[string]any{"default_format": "yaml"})
if p.ASTQueryFormat != "" {
t.Errorf("invalid format should be ignored, got %q", p.ASTQueryFormat)
}
}
// TestParseSessionPrefsTypeMismatch verifies that wrong types are silently ignored.
func TestParseSessionPrefsTypeMismatch(t *testing.T) {
raw := map[string]any{
"default_format": 123, // wrong type (int instead of string)
"default_max_results": "fifty", // wrong type
"default_cluster": "yes", // wrong type (string instead of bool)
"compact_refs": 42, // wrong type
"line_numbers": true, // wrong type
"resource_link_threshold": "big", // wrong type
}
p := ParseSessionPrefs(raw)
if p.ASTQueryFormat != "" {
t.Errorf("type mismatch for format should produce empty, got %q", p.ASTQueryFormat)
}
if p.DefaultMaxResults != 0 {
t.Errorf("type mismatch for max_results should produce 0, got %d", p.DefaultMaxResults)
}
if p.DefaultCluster != nil {
t.Errorf("type mismatch for cluster should produce nil")
}
if p.CompactRefs != nil {
t.Errorf("type mismatch for compact_refs should produce nil")
}
if p.LineNumbers != "" {
t.Errorf("type mismatch for line_numbers should produce empty, got %q", p.LineNumbers)
}
if p.ResourceLinkThreshold != 0 {
t.Errorf("type mismatch for threshold should produce 0, got %d", p.ResourceLinkThreshold)
}
}
// TestParseSessionPrefsNegativeValues verifies negative numbers are rejected.
func TestParseSessionPrefsNegativeValues(t *testing.T) {
p := ParseSessionPrefs(map[string]any{
"default_max_results": float64(-5),
"resource_link_threshold": float64(-1),
})
if p.DefaultMaxResults != 0 {
t.Errorf("negative max_results should be rejected, got %d", p.DefaultMaxResults)
}
if p.ResourceLinkThreshold != 0 {
t.Errorf("negative threshold should be rejected, got %d", p.ResourceLinkThreshold)
}
}
// TestParseSessionPrefsIntCoercion verifies int and int64 inputs also work.
func TestParseSessionPrefsIntCoercion(t *testing.T) {
p := ParseSessionPrefs(map[string]any{
"default_max_results": int(25),
"resource_link_threshold": int64(16384),
})
if p.DefaultMaxResults != 25 {
t.Errorf("int max_results: want 25, got %d", p.DefaultMaxResults)
}
if p.ResourceLinkThreshold != 16384 {
t.Errorf("int64 threshold: want 16384, got %d", p.ResourceLinkThreshold)
}
}
// TestParseSessionPrefsClusterFalse ensures default_cluster=false stores a non-nil false.
func TestParseSessionPrefsClusterFalse(t *testing.T) {
p := ParseSessionPrefs(map[string]any{"default_cluster": false})
if p.DefaultCluster == nil {
t.Error("default_cluster=false should store non-nil pointer")
}
if *p.DefaultCluster != false {
t.Error("default_cluster=false: want false pointer")
}
}
// TestSessionPrefsAtomicStore verifies sessionPrefsPtr is readable after Store.
func TestSessionPrefsAtomicStore(t *testing.T) {
var sp sessionPrefsPtr
if sp.Load() != nil {
t.Error("uninitialised Load() should return nil")
}
prefs := ParseSessionPrefs(map[string]any{"default_format": "compact"})
sp.Store(&prefs)
loaded := sp.Load()
if loaded == nil {
t.Fatal("Load() returned nil after Store")
}
if loaded.ASTQueryFormat != "compact" {
t.Errorf("loaded format: want compact, got %q", loaded.ASTQueryFormat)
}
}