fixup! Update, bugfixes on diff and edit handling

This commit is contained in:
2026-02-22 14:03:54 +00:00
parent 6980d3b294
commit 982c2c8b44
23 changed files with 655 additions and 194 deletions
+16 -3
View File
@@ -69,13 +69,26 @@ func Load(workspaceRoot string) (*Config, error) {
cfg.WorkspaceRoot = cwd
}
// Try to load from config file in workspace root
// Try to load from config file in workspace root.
// Save WorkspaceRoot before loading config file so it cannot be overridden.
savedRoot := cfg.WorkspaceRoot
configPath := filepath.Join(cfg.WorkspaceRoot, ".mcp-filepuff.json")
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, cfg); err != nil {
return nil, err
}
}
// Restore WorkspaceRoot — config file must not override path guards.
cfg.WorkspaceRoot = savedRoot
// Clamp size limits to prevent config file from requesting excessive memory.
const maxAllowedSize int64 = 100 * 1024 * 1024 // 100 MB
if cfg.MaxFileSize > maxAllowedSize {
cfg.MaxFileSize = maxAllowedSize
}
if cfg.MaxParseSize > maxAllowedSize {
cfg.MaxParseSize = maxAllowedSize
}
// Override from environment variables
cfg.loadFromEnv()
@@ -173,8 +186,8 @@ func (c *Config) IsPathAllowed(path string) bool {
// Check if the path is within workspace (doesn't start with ..)
// This prevents both "../" attacks and symlink bypasses
// Also reject empty relative path (which means it's the workspace root itself)
return rel != "." && !strings.HasPrefix(rel, "..")
// The workspace root itself (rel == ".") is a valid, allowed path
return !strings.HasPrefix(rel, "..")
}
// Validate validates the configuration and returns an error if invalid.
+55 -2
View File
@@ -389,8 +389,8 @@ func TestIsPathAllowedEdgeCases(t *testing.T) {
{
name: "workspace_root_itself",
path: tmpDir,
allowed: false,
desc: "workspace root itself should not be allowed",
allowed: true,
desc: "workspace root itself should be allowed",
},
{
name: "dot_relative",
@@ -546,6 +546,59 @@ func TestConfigFileLoadingErrors(t *testing.T) {
}
}
// TestIsPathAllowed_SymlinkOutsideWorkspace verifies that symlinks pointing
// outside the workspace are rejected (T-01).
func TestIsPathAllowed_SymlinkOutsideWorkspace(t *testing.T) {
// Create two separate temp dirs: one as workspace, one as outside target
workspace, err := os.MkdirTemp("", "mcp-workspace-*")
if err != nil {
t.Fatalf("failed to create workspace dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(workspace) })
outside, err := os.MkdirTemp("", "mcp-outside-*")
if err != nil {
t.Fatalf("failed to create outside dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(outside) })
// Create a file outside the workspace
outsideFile := filepath.Join(outside, "secret.txt")
if err := os.WriteFile(outsideFile, []byte("secret"), 0o600); err != nil {
t.Fatalf("failed to write outside file: %v", err)
}
// Create a symlink inside the workspace pointing outside
symlinkPath := filepath.Join(workspace, "escape-link")
if err := os.Symlink(outsideFile, symlinkPath); err != nil {
t.Skip("symlink creation not supported on this system")
}
cfg := Default()
cfg.WorkspaceRoot = workspace
// The symlink resolves to a file outside workspace — must be rejected
if cfg.IsPathAllowed(symlinkPath) {
t.Error("symlink pointing outside workspace should NOT be allowed")
}
// Direct access to the outside file should also be rejected
if cfg.IsPathAllowed(outsideFile) {
t.Error("file outside workspace should NOT be allowed")
}
// File inside workspace should still be allowed
insideFile := filepath.Join(workspace, "safe.txt")
if !cfg.IsPathAllowed(insideFile) {
t.Error("file inside workspace should be allowed")
}
// Workspace root itself should be allowed (C-08 fix)
if !cfg.IsPathAllowed(workspace) {
t.Error("workspace root itself should be allowed")
}
}
// Helper function to check if a string contains a substring.
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
+1 -1
View File
@@ -126,7 +126,7 @@ func TestIsPathAllowed_BasicCases(t *testing.T) {
{
name: "workspace root itself",
path: tmpDir,
expected: false, // Empty relative path
expected: true, // Workspace root is a valid, allowed path (needed for ast_query)
},
}
+19 -11
View File
@@ -6,6 +6,7 @@ import (
"context"
"fmt"
"os"
"slices"
"strings"
"sync"
@@ -60,6 +61,7 @@ type EditResult struct {
// Engine performs AST-aware edits.
type Engine struct {
registry *parser.Registry
dmp *diffmatchpatch.DiffMatchPatch
fileLocks sync.Map // map[string]*sync.Mutex for per-file locking
}
@@ -67,6 +69,7 @@ type Engine struct {
func NewEngine(registry *parser.Registry) *Engine {
return &Engine{
registry: registry,
dmp: diffmatchpatch.New(),
fileLocks: sync.Map{},
}
}
@@ -166,7 +169,7 @@ func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool)
}
// Generate diff
diff := generateDiff(string(content), string(newContent), edit.File)
diff := e.generateDiff(string(content), string(newContent), edit.File)
result := &EditResult{
Success: true,
@@ -225,7 +228,7 @@ func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) (
}
// Generate diff
diff := generateDiff(string(content), string(newContent), edit.File)
diff := e.generateDiff(string(content), string(newContent), edit.File)
result := &EditResult{
Success: true,
@@ -369,17 +372,18 @@ func sortBySpecificity(nodes []*sitter.Node) []*sitter.Node {
return nodes
}
// Sort by specificity: named nodes first, then by size (smallest first)
result := make([]*sitter.Node, len(nodes))
copy(result, nodes)
for i := 0; i < len(result)-1; i++ {
for j := i + 1; j < len(result); j++ {
if shouldPrefer(result[j], result[i]) {
result[i], result[j] = result[j], result[i]
}
slices.SortFunc(result, func(a, b *sitter.Node) int {
if shouldPrefer(a, b) {
return -1
}
}
if shouldPrefer(b, a) {
return 1
}
return 0
})
return result
}
@@ -566,8 +570,8 @@ func indentContent(content string, indent string) string {
// generateDiff creates a unified diff between original and modified content.
// Uses line-level Myers diff algorithm for accurate and readable diffs.
func generateDiff(original, modified, filename string) string {
dmp := diffmatchpatch.New()
func (e *Engine) generateDiff(original, modified, filename string) string {
dmp := e.dmp
// Use line-level diffing: encode each line as a single character,
// diff the encoded strings, then decode back to real lines.
@@ -692,6 +696,10 @@ func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, e
}
lines := bytes.Split(content, []byte("\n"))
// Trim phantom empty element from trailing newline
if len(lines) > 0 && len(lines[len(lines)-1]) == 0 {
lines = lines[:len(lines)-1]
}
totalLines := len(lines)
// Convert to 0-indexed
+4 -3
View File
@@ -8,6 +8,7 @@ import (
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/sergi/go-diff/diffmatchpatch"
sitter "github.com/smacker/go-tree-sitter"
)
@@ -394,7 +395,7 @@ func TestGenerateDiff(t *testing.T) {
modified := "line1\nmodified\nline3"
filename := "test.txt"
diff := generateDiff(original, modified, filename)
diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, filename)
if !strings.Contains(diff, "---") {
t.Error("diff should contain --- header")
@@ -420,7 +421,7 @@ func TestGenerateDiffLineLevelAccuracy(t *testing.T) {
original := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello\")\n}\n"
modified := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello world\")\n}\n"
diff := generateDiff(original, modified, "test.go")
diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, "test.go")
// The diff must show whole-line removals and additions
if !strings.Contains(diff, "-\tfmt.Println(\"hello\")\n") {
@@ -453,7 +454,7 @@ func TestGenerateDiffNoPhantomChanges(t *testing.T) {
original := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\n"
modified := "line1\nREPLACED\nline3\nline4\nline5\nline6\nline7\nline8\n"
diff := generateDiff(original, modified, "test.txt")
diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, "test.txt")
// Count changed lines (excluding headers)
addCount := 0
+23
View File
@@ -268,7 +268,11 @@ func (c *Client) send(msg interface{}) error {
}
// readLoop reads and dispatches messages from the server.
// On exit (for any reason), it drains all pending Call waiters with a
// synthetic error so that goroutines blocked in Call are unblocked.
func (c *Client) readLoop() {
defer c.drainPending()
reader := bufio.NewReader(c.stdout)
for {
@@ -329,6 +333,25 @@ func (c *Client) readLoop() {
}
}
// drainPending sends a synthetic error response to every pending Call waiter
// so that goroutines blocked in Call are unblocked when readLoop exits.
func (c *Client) drainPending() {
c.mu.Lock()
defer c.mu.Unlock()
for id, ch := range c.pending {
ch <- &Response{
JSONRPC: "2.0",
ID: id,
Error: &ResponseError{
Code: -32603, // InternalError
Message: "LSP client readLoop terminated",
},
}
delete(c.pending, id)
}
}
// IsRunning returns whether the client is running.
func (c *Client) IsRunning() bool {
c.runningMu.RLock()
+43 -26
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"net/url"
"os"
"os/exec"
"path/filepath"
@@ -28,6 +29,9 @@ const (
)
// Manager manages LSP servers for different languages.
//
// Lock ordering: m.mu must always be acquired before srv.mu.
// Never acquire m.mu while holding srv.mu.
type Manager struct {
servers map[protocol.Language]*ManagedServer
logger *slog.Logger
@@ -36,6 +40,7 @@ type Manager struct {
timeout time.Duration
idleTimeout time.Duration
mu sync.RWMutex
closeOnce sync.Once
stopped bool
}
@@ -163,9 +168,10 @@ func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*Manag
WithRemediation("Only whitelisted LSP server binaries are allowed for security reasons")
}
// Create command
// Create command — use exec.Command (not CommandContext) so the LSP
// subprocess is not killed when the request-scoped context expires.
args := append(config.Command[1:], config.Args...)
cmd := exec.CommandContext(ctx, cmdPath, args...)
cmd := exec.Command(cmdPath, args...)
cmd.Env = os.Environ()
cmd.Dir = m.workspaceRoot
// Create client
@@ -412,9 +418,13 @@ func (m *Manager) References(ctx context.Context, file string, line, col int, in
}
// ensureDocumentOpen opens a document if not already open.
func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, file string) error {
// It reads the file content outside the lock (to avoid holding the lock during I/O),
// then holds srv.mu for the entire check-and-send sequence to prevent duplicate didOpen
// notifications from concurrent goroutines.
func (m *Manager) ensureDocumentOpen(_ context.Context, srv *ManagedServer, file string) error {
uri := fileToURI(file)
// Quick check under lock — common fast path.
srv.mu.Lock()
if _, ok := srv.openDocs[uri]; ok {
srv.mu.Unlock()
@@ -422,15 +432,23 @@ func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, fi
}
srv.mu.Unlock()
// Read file content
// Read file content outside the lock to avoid holding it during I/O.
content, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("failed to read file: %w", err)
}
// Get language ID
langID := languageToLSPID(srv.language)
// Re-acquire lock and re-check to prevent TOCTOU race: two goroutines could
// both pass the fast-path check above and both try to send didOpen.
srv.mu.Lock()
defer srv.mu.Unlock()
if _, ok := srv.openDocs[uri]; ok {
return nil
}
params := DidOpenTextDocumentParams{
TextDocument: TextDocumentItem{
URI: uri,
@@ -444,10 +462,7 @@ func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, fi
return fmt.Errorf("didOpen failed: %w", err)
}
srv.mu.Lock()
srv.openDocs[uri] = 1
srv.mu.Unlock()
return nil
}
@@ -519,26 +534,28 @@ func (m *Manager) reapIdleServers() {
}
}
// Close shuts down all LSP servers.
// Close shuts down all LSP servers. It is safe to call multiple times.
func (m *Manager) Close() error {
close(m.stopReaper)
m.closeOnce.Do(func() {
close(m.stopReaper)
m.mu.Lock()
defer m.mu.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
m.stopped = true
m.stopped = true
for lang, srv := range m.servers {
m.logger.Info("shutting down LSP server", "language", lang)
// Try graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
_, _ = srv.client.Call(ctx, "shutdown", nil)
cancel()
_ = srv.client.Notify("exit", nil)
_ = srv.client.Close()
}
for lang, srv := range m.servers {
m.logger.Info("shutting down LSP server", "language", lang)
// Try graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
_, _ = srv.client.Call(ctx, "shutdown", nil)
cancel()
_ = srv.client.Notify("exit", nil)
_ = srv.client.Close()
}
m.servers = make(map[protocol.Language]*ManagedServer)
m.servers = make(map[protocol.Language]*ManagedServer)
})
return nil
}
@@ -553,13 +570,13 @@ func (m *Manager) IsAvailable(lang protocol.Language) bool {
return err == nil
}
// fileToURI converts a file path to a file URI.
// fileToURI converts a file path to a properly percent-encoded file URI.
func fileToURI(file string) string {
absPath, err := filepath.Abs(file)
if err != nil {
return "file://" + file
absPath = file
}
return "file://" + absPath
return (&url.URL{Scheme: "file", Path: absPath}).String()
}
// URIToFile converts a file URI to a file path.
+23 -3
View File
@@ -189,10 +189,30 @@ func TestManagerGracefulShutdown(t *testing.T) {
if !manager.stopped {
t.Error("manager should be marked as stopped after Close()")
}
}
// Note: We don't test multiple Close() calls because the implementation
// closes the stopReaper channel which can't be closed twice.
// In production, Close() should only be called once during shutdown.
// TestManagerDoubleClose verifies that calling Close() twice does not panic (T-05, C-02).
func TestManagerDoubleClose(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
// First close should succeed
err := manager.Close()
if err != nil {
t.Errorf("first Close() returned error: %v", err)
}
// Second close must not panic (C-02 fix wraps close in sync.Once)
err = manager.Close()
if err != nil {
t.Errorf("second Close() returned error: %v", err)
}
if !manager.stopped {
t.Error("manager should be marked as stopped after double Close()")
}
}
// TestManagerIdleReaper tests the idle server cleanup mechanism.
+8 -6
View File
@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"testing"
"github.com/cespare/xxhash/v2"
)
// TestLRUCacheEviction tests that the LRU cache properly evicts old entries.
@@ -82,8 +84,8 @@ func TestContentHashCollisionResistance(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hash1 := contentHash(tc.content1)
hash2 := contentHash(tc.content2)
hash1 := fmt.Sprintf("%016x", xxhash.Sum64(tc.content1))
hash2 := fmt.Sprintf("%016x", xxhash.Sum64(tc.content2))
if hash1 == hash2 {
t.Errorf("Hash collision: %s == %s for different content", hash1, hash2)
@@ -96,9 +98,9 @@ func TestContentHashCollisionResistance(t *testing.T) {
func TestContentHashConsistency(t *testing.T) {
content := []byte("package main\n\nfunc test() {}\n")
hash1 := contentHash(content)
hash2 := contentHash(content)
hash3 := contentHash(content)
hash1 := fmt.Sprintf("%016x", xxhash.Sum64(content))
hash2 := fmt.Sprintf("%016x", xxhash.Sum64(content))
hash3 := fmt.Sprintf("%016x", xxhash.Sum64(content))
if hash1 != hash2 || hash2 != hash3 {
t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3)
@@ -115,7 +117,7 @@ func BenchmarkContentHash_xxHash(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = contentHash(content)
_ = fmt.Sprintf("%016x", xxhash.Sum64(content))
}
}
+3 -67
View File
@@ -24,9 +24,8 @@ import (
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// MaxFileSize is the default maximum file size we'll parse (10MB).
// Deprecated: Use Registry.maxParseSize instead.
const MaxFileSize = 10 * 1024 * 1024
// maxFileSize is the default maximum file size we'll parse (10MB).
const maxFileSize = 10 * 1024 * 1024
// Registry manages Tree-sitter parsers for different languages.
type Registry struct {
@@ -69,18 +68,6 @@ type SyntaxError struct {
Location protocol.Location
}
// CacheStatsResult contains cache statistics.
type CacheStatsResult struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
HitRate float64 `json:"hit_rate"`
Size int `json:"size"`
TotalParseTime int64 `json:"total_parse_time_ns"`
ParseCount int64 `json:"parse_count"`
AvgParseTime int64 `json:"avg_parse_time_ns"`
LastParseTime int64 `json:"last_parse_time_ns"`
}
// NewRegistry creates a new parser registry with the default max parse size.
// For custom max parse size, use NewRegistryWithSize.
func NewRegistry() *Registry {
@@ -98,7 +85,7 @@ func NewRegistryWithSize(maxParseSize int64) *Registry {
}
if maxParseSize <= 0 {
maxParseSize = MaxFileSize
maxParseSize = maxFileSize
}
return &Registry{
@@ -266,50 +253,6 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
}, nil
}
// CacheStats returns cache hit/miss statistics.
func (r *Registry) CacheStats() (hits, misses int64) {
return r.cacheHits.Load(), r.cacheMisses.Load()
}
// CacheStatsDetailed returns detailed cache and parse statistics.
func (r *Registry) CacheStatsDetailed() CacheStatsResult {
hits := r.cacheHits.Load()
misses := r.cacheMisses.Load()
totalParseTime := r.totalParseTime.Load()
parseCount := r.parseCount.Load()
var hitRate float64
total := hits + misses
if total > 0 {
hitRate = float64(hits) / float64(total)
}
var avgParseTime int64
if parseCount > 0 {
avgParseTime = totalParseTime / parseCount
}
return CacheStatsResult{
Hits: hits,
Misses: misses,
HitRate: hitRate,
Size: r.cache.Len(),
TotalParseTime: totalParseTime,
ParseCount: parseCount,
AvgParseTime: avgParseTime,
LastParseTime: r.lastParseDuration.Load(),
}
}
// ResetStats resets all cache and parse statistics.
func (r *Registry) ResetStats() {
r.cacheHits.Store(0)
r.cacheMisses.Store(0)
r.totalParseTime.Store(0)
r.parseCount.Store(0)
r.lastParseDuration.Store(0)
}
// extractErrors finds all error nodes in the tree.
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
var errors []SyntaxError
@@ -346,13 +289,6 @@ func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
return errors
}
// contentHash returns a fast hash of the content for caching.
// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes.
func contentHash(content []byte) string {
h := xxhash.Sum64(content)
return fmt.Sprintf("%016x", h)
}
// isBinary checks if content appears to be binary.
func isBinary(content []byte) bool {
// Check first 8000 bytes for null bytes
+4 -1
View File
@@ -2,8 +2,11 @@ package parser
import (
"context"
"fmt"
"strings"
"testing"
"github.com/cespare/xxhash/v2"
)
// BenchmarkParse benchmarks parsing files of various sizes.
@@ -194,7 +197,7 @@ func BenchmarkContentHash(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = contentHash(content)
_ = fmt.Sprintf("%016x", xxhash.Sum64(content))
}
})
}
+4 -4
View File
@@ -31,8 +31,8 @@ type JSONNode struct {
// ParseYAML parses YAML content and returns a tree-sitter-compatible result
func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
if len(content) > maxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), maxFileSize)
}
// Parse YAML
@@ -57,8 +57,8 @@ func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byt
// ParseJSON parses JSON content and returns a tree-sitter-compatible result
func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
if len(content) > maxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), maxFileSize)
}
// Parse JSON to validate syntax
+6 -8
View File
@@ -22,13 +22,11 @@ type ASTQuery struct {
// QueryFilters provide additional filtering criteria.
type QueryFilters struct {
HasChild *ASTQuery `json:"has_child,omitempty"`
HasParent *ASTQuery `json:"has_parent,omitempty"`
NameMatches string `json:"name_matches,omitempty"`
NameExact string `json:"name_exact,omitempty"`
InFile string `json:"in_file,omitempty"`
NotInFile string `json:"not_in_file,omitempty"`
KindIn []string `json:"kind_in,omitempty"`
NameMatches string `json:"name_matches,omitempty"`
NameExact string `json:"name_exact,omitempty"`
InFile string `json:"in_file,omitempty"`
NotInFile string `json:"not_in_file,omitempty"`
KindIn []string `json:"kind_in,omitempty"`
}
// MatchResult represents a single match from a query.
@@ -259,7 +257,7 @@ func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content []
}
// Match struct patterns (Go, C, C++)
if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") {
if (strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ")) && strings.Contains(patternLower, "struct") {
if nodeType != "type_declaration" && nodeType != "struct_specifier" {
return false
}
+149
View File
@@ -499,6 +499,155 @@ func TestFormatResults(t *testing.T) {
}
}
// TestMatchStructOperatorPrecedence verifies the C-07 operator precedence fix.
// Before the fix, patterns like "struct Foo" would match because
// strings.Contains(p, "struct ") short-circuited the entire condition.
// After the fix, both "struct" must be present for the struct branch to match.
func TestMatchStructOperatorPrecedence(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
type Server struct {
Port int
}
func main() {}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
name string
pattern string
wantMatches int
}{
{
name: "type struct pattern should match",
pattern: "type $NAME struct { $$$FIELDS }",
wantMatches: 1, // Server
},
{
name: "struct keyword alone should match",
pattern: "struct $NAME { $$$FIELDS }",
wantMatches: 1, // Server
},
{
name: "func pattern should not match struct branch",
pattern: "func $NAME() {}",
wantMatches: 1, // main (matches function branch, not struct)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := &ASTQuery{
Pattern: tt.pattern,
Language: "go",
}
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
for i, r := range results {
t.Logf("match %d: type=%s, text=%q", i, r.Node.Type(), truncateForLog(r.Text, 80))
}
}
})
}
}
func truncateForLog(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max] + "..."
}
// TestPassesFilters_AllBranches tests passesFilters for each filter type.
func TestPassesFilters_AllBranches(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
content := `package main
func Alpha() {}
func Beta() {}
func Gamma() {}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
matcher := NewMatcher(reg)
tests := []struct {
name string
filters QueryFilters
wantMatches int
}{
{
name: "no filters matches all",
filters: QueryFilters{},
wantMatches: 3,
},
{
name: "name_exact filter",
filters: QueryFilters{NameExact: "Alpha"},
wantMatches: 1,
},
{
name: "name_matches regex filter",
filters: QueryFilters{NameMatches: "^[AB]"},
wantMatches: 2,
},
{
name: "kind_in filter",
filters: QueryFilters{KindIn: []string{"function_declaration"}},
wantMatches: 3,
},
{
name: "kind_in filter excludes non-matching kinds",
filters: QueryFilters{KindIn: []string{"class_declaration"}},
wantMatches: 0,
},
{
name: "combined name_exact and kind_in",
filters: QueryFilters{NameExact: "Beta", KindIn: []string{"function_declaration"}},
wantMatches: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: tt.filters,
}
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
}
})
}
}
func TestQueryValidation(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
+21
View File
@@ -131,6 +131,11 @@ func (s *Searcher) Search(ctx context.Context, req *Request) (*SearchResults, er
WithRemediation("Provide a non-empty search pattern")
}
// Validate that at least one provided path is allowed
if err := s.validatePaths(req.Paths); err != nil {
return nil, err
}
// Build ripgrep command
args := s.buildArgs(req)
@@ -238,6 +243,22 @@ func (s *Searcher) buildArgs(req *Request) []string {
return args
}
// validatePaths checks that at least one caller-provided path is allowed.
// Returns an error if paths were provided but none passed IsPathAllowed.
func (s *Searcher) validatePaths(paths []string) error {
if len(paths) == 0 {
return nil // no explicit paths — will default to workspace root
}
for _, p := range paths {
if s.cfg.IsPathAllowed(p) {
return nil
}
}
return errors.New(errors.ErrPathNotAllowed, "all provided search paths are outside the workspace root").
WithContext("paths", fmt.Sprintf("%v", paths)).
WithRemediation("Provide paths within the workspace root")
}
// parseOutput parses ripgrep JSON output.
func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchResults, error) {
results := &SearchResults{
+24 -8
View File
@@ -43,9 +43,10 @@ func TestBuildArgs(t *testing.T) {
}
tests := []struct {
name string
req *Request
expected []string
name string
req *Request
expected []string
notExpected []string // T-06: verify absence of unexpected flags
}{
{
name: "basic search",
@@ -54,7 +55,8 @@ func TestBuildArgs(t *testing.T) {
ContextLines: 2,
Regex: true,
},
expected: []string{"--json", "--context=2", "--", "test", "."},
expected: []string{"--json", "--context=2", "--", "test", "."},
notExpected: []string{"--ignore-case", "--fixed-strings", "--max-total-count=0"},
},
{
name: "ignore case",
@@ -63,7 +65,8 @@ func TestBuildArgs(t *testing.T) {
IgnoreCase: true,
Regex: true,
},
expected: []string{"--json", "--ignore-case", "--", "test", "."},
expected: []string{"--json", "--ignore-case", "--", "test", "."},
notExpected: []string{"--fixed-strings"},
},
{
name: "fixed strings",
@@ -71,7 +74,8 @@ func TestBuildArgs(t *testing.T) {
Pattern: "test",
Regex: false,
},
expected: []string{"--json", "--fixed-strings", "--", "test", "."},
expected: []string{"--json", "--fixed-strings", "--", "test", "."},
notExpected: []string{"--ignore-case"},
},
{
name: "with file types",
@@ -80,7 +84,8 @@ func TestBuildArgs(t *testing.T) {
FileTypes: []string{"go", "ts"},
Regex: true,
},
expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."},
expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."},
notExpected: []string{"--ignore-case", "--fixed-strings"},
},
{
name: "with max results",
@@ -89,7 +94,8 @@ func TestBuildArgs(t *testing.T) {
MaxResults: 10,
Regex: true,
},
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
notExpected: []string{"--ignore-case", "--fixed-strings"},
},
}
@@ -110,6 +116,16 @@ func TestBuildArgs(t *testing.T) {
t.Errorf("expected arg %q not found in %v", exp, args)
}
}
// T-06: Check that unexpected args are absent
for _, notExp := range tt.notExpected {
for _, arg := range args {
if arg == notExp {
t.Errorf("unexpected arg %q found in %v", notExp, args)
break
}
}
}
})
}
}
+25 -14
View File
@@ -54,9 +54,9 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest
}
// Find files to search based on language
ext := languageToExtension(language)
if ext == "" {
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
exts := languageToExtensions(language)
if len(exts) == 0 {
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir)", language)), nil
}
var allResults []query.MatchResult
@@ -89,7 +89,14 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest
}
// Check file extension matches language
if !strings.HasSuffix(path, ext) {
matched := false
for _, ext := range exts {
if strings.HasSuffix(path, ext) {
matched = true
break
}
}
if !matched {
return nil
}
@@ -203,24 +210,28 @@ func symbolKindIcon(kind protocol.SymbolKind) string {
}
}
// languageToExtension maps language names to file extensions.
func languageToExtension(language string) string {
// languageToExtensions maps language names to file extensions.
func languageToExtensions(language string) []string {
switch strings.ToLower(language) {
case "go":
return ".go"
return []string{".go"}
case "typescript":
return ".ts"
return []string{".ts"}
case "javascript":
return ".js"
return []string{".js"}
case "python":
return ".py"
return []string{".py"}
case "c":
return ".c"
return []string{".c"}
case "cpp", "c++":
return ".cpp"
return []string{".cpp"}
case "html":
return []string{".html", ".htm"}
case "vue":
return []string{".vue"}
case "elixir":
return ".ex"
return []string{".ex", ".exs"}
default:
return ""
return nil
}
}
+10
View File
@@ -33,6 +33,16 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, ap
return mcp.NewToolResultError("operation is required"), nil
}
// Validate operation against known values
switch edit.EditOperation(operation) {
case edit.EditReplace, edit.EditInsertBefore, edit.EditInsertAfter, edit.EditDelete:
// valid
default:
return mcp.NewToolResultError(fmt.Sprintf(
"invalid operation %q: must be one of: replace, insert_before, insert_after, delete", operation,
)), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
+22 -9
View File
@@ -81,8 +81,8 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
return mcp.NewToolResultError("path is outside workspace root"), nil
}
// Read file
content, err := os.ReadFile(path)
// Check file size before reading to avoid loading huge files into memory
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
@@ -90,13 +90,21 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
if os.IsPermission(err) {
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
}
s.logger.Warn("file read error", "path", path, "error", err)
return mcp.NewToolResultError("error reading file"), nil
s.logger.Warn("file stat error", "path", path, "error", err)
return mcp.NewToolResultError("error accessing file"), nil
}
if info.Size() > s.cfg.MaxFileSize {
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)), nil
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
// Read file
content, err := os.ReadFile(path)
if err != nil {
if os.IsPermission(err) {
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
}
s.logger.Warn("file read error", "path", path, "error", err)
return mcp.NewToolResultError("error reading file"), nil
}
// Handle line range
@@ -167,13 +175,18 @@ func splitLines(s string) []string {
const largeSizeThreshold = 1024 * 1024 // 1MB
if len(s) > largeSizeThreshold {
// Use scanner for large files
// Use scanner for large files with increased buffer for long lines
scanner := bufio.NewScanner(strings.NewReader(s))
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), 1024*1024) // up to 1MB per line
var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
// Handle potential error and add empty line if string ended with newline
if err := scanner.Err(); err != nil {
// If scanning fails (e.g. line exceeds buffer), fall back to strings.Split
return strings.Split(s, "\n")
}
// Add empty line if string ended with newline
if len(s) > 0 && s[len(s)-1] == '\n' {
lines = append(lines, "")
}
+8 -2
View File
@@ -117,7 +117,7 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
// Try to read a preview snippet
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
preview := s.readFilePreview(filePath, loc.Range.Start.Line+1, 3)
if preview != "" {
output.WriteString("```\n")
output.WriteString(preview)
@@ -184,7 +184,13 @@ func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolR
}
// readFilePreview reads a few lines from a file around the given line.
func readFilePreview(file string, line, contextLines int) string {
// It validates that the file path is within the allowed workspace before reading.
func (s *Server) readFilePreview(file string, line, contextLines int) string {
if !s.cfg.IsPathAllowed(file) {
s.logger.Warn("readFilePreview: path not allowed", "path", file)
return ""
}
content, err := os.ReadFile(file)
if err != nil {
return ""
+1 -1
View File
@@ -169,7 +169,7 @@ func (s *Server) registerTools() {
),
mcp.WithString("language",
mcp.Required(),
mcp.Description("Target language: go, typescript, javascript, python, c, cpp"),
mcp.Description("Target language: go, typescript, javascript, python, c, cpp, html, vue, elixir"),
),
mcp.WithArray("paths",
mcp.Description("Paths to search in (defaults to workspace root)"),
+153
View File
@@ -5,6 +5,7 @@ import (
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
@@ -381,3 +382,155 @@ func Hello() {
t.Error("handleEdit(apply) should modify the file")
}
}
// TestHandleFileReadMaxFileSize verifies that handleFileRead rejects files
// exceeding MaxFileSize via os.Stat before loading them into memory (T-03, S-01).
func TestHandleFileReadMaxFileSize(t *testing.T) {
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "big.txt")
content := make([]byte, 1024) // 1KB file
for i := range content {
content[i] = 'A'
}
if err := os.WriteFile(testFile, content, 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
// Set MaxFileSize smaller than the file
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: 512, // 512 bytes — smaller than our 1KB file
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"path": testFile,
}
result, err := srv.handleFileRead(ctx, req)
if err != nil {
t.Fatalf("handleFileRead() returned Go error: %v", err)
}
// The result should indicate an error (file too large)
if result == nil {
t.Fatal("handleFileRead() returned nil result")
}
if !result.IsError {
t.Error("expected IsError=true for file exceeding MaxFileSize")
}
contents := result.Content
if len(contents) == 0 {
t.Fatal("expected non-empty content with error message")
}
textContent, ok := contents[0].(mcp.TextContent)
if !ok {
t.Fatal("expected text content")
}
if !strings.Contains(textContent.Text, "too large") {
t.Errorf("expected 'too large' error message, got: %s", textContent.Text)
}
}
// TestSplitLinesLargeFile tests the splitLines function with a large file (>1MB)
// to exercise the bufio.Scanner path including the scanner.Err() check (T-07, C-05).
func TestSplitLinesLargeFile(t *testing.T) {
// Build a string >1MB with known line count
lineCount := 20000
var sb strings.Builder
for i := 0; i < lineCount; i++ {
sb.WriteString(strings.Repeat("x", 60))
sb.WriteByte('\n')
}
largeContent := sb.String()
// Verify it's large enough to trigger the scanner path
if len(largeContent) <= 1024*1024 {
t.Fatalf("test content too small: %d bytes, need >1MB", len(largeContent))
}
lines := splitLines(largeContent)
// String ends with \n, so splitLines adds an empty trailing element
// (matching the behavior of strings.Split for the small-file path)
expectedLines := lineCount + 1 // lineCount lines + 1 trailing empty
if len(lines) != expectedLines {
t.Errorf("splitLines returned %d lines, expected %d", len(lines), expectedLines)
}
// Check first and last actual lines
if lines[0] != strings.Repeat("x", 60) {
t.Errorf("first line mismatch: got %q", lines[0][:20])
}
if lines[lineCount-1] != strings.Repeat("x", 60) {
t.Errorf("last content line mismatch: got %q", lines[lineCount-1][:20])
}
if lines[lineCount] != "" {
t.Errorf("expected empty trailing line, got %q", lines[lineCount])
}
}
// TestSplitLinesLargeFileNoTrailingNewline verifies splitLines for large files
// without a trailing newline.
func TestSplitLinesLargeFileNoTrailingNewline(t *testing.T) {
lineCount := 20000
var sb strings.Builder
for i := 0; i < lineCount; i++ {
if i > 0 {
sb.WriteByte('\n')
}
sb.WriteString(strings.Repeat("y", 60))
}
largeContent := sb.String()
if len(largeContent) <= 1024*1024 {
t.Fatalf("test content too small: %d bytes", len(largeContent))
}
lines := splitLines(largeContent)
if len(lines) != lineCount {
t.Errorf("splitLines returned %d lines, expected %d", len(lines), lineCount)
}
}
// TestSplitLinesLongLine verifies the scanner gracefully handles very long lines
// (up to the 1MB buffer limit set in C-05 fix).
func TestSplitLinesLongLine(t *testing.T) {
// Create content with one very long line (500KB) embedded in a >1MB file
shortLines := strings.Repeat("short line content\n", 60000) // ~60KB * ~1 = ~1.08MB
longLine := strings.Repeat("L", 500*1024) // 500KB line
largeContent := shortLines + longLine + "\n"
if len(largeContent) <= 1024*1024 {
t.Fatalf("test content too small: %d bytes", len(largeContent))
}
lines := splitLines(largeContent)
// Should not crash and should return some lines
if len(lines) == 0 {
t.Fatal("splitLines returned no lines for valid content")
}
// The long line should be present somewhere in the output
foundLong := false
for _, line := range lines {
if len(line) >= 500*1024 {
foundLong = true
break
}
}
if !foundLong {
t.Error("the 500KB long line was not found in splitLines output")
}
}
+33 -25
View File
@@ -5,7 +5,6 @@ import (
"fmt"
"regexp"
"sync"
"sync/atomic"
)
const (
@@ -19,10 +18,11 @@ const (
)
// regexCache is a global thread-safe cache for compiled regular expressions.
// Caching regex compilation provides 10-50x speedup for repeated patterns.
// Uses sync.RWMutex with a regular map so that ClearRegexCache can atomically
// clear the map and reset the count in a single lock acquisition.
var (
regexCache sync.Map // string -> *regexp.Regexp
cacheSize atomic.Int64
cacheMu sync.RWMutex
regexCache = make(map[string]*regexp.Regexp)
)
// RegexError represents an error during regex compilation or validation.
@@ -62,7 +62,7 @@ func ValidatePattern(pattern string) error {
}
// CompileRegex compiles a regex pattern with caching and validation for security.
// Thread-safe: uses LoadOrStore to prevent race conditions.
// Thread-safe: uses RWMutex to prevent race conditions.
// Returns the compiled regex or an error if the pattern is invalid or unsafe.
func CompileRegex(pattern string) (*regexp.Regexp, error) {
// Validate pattern first
@@ -70,12 +70,15 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) {
return nil, err
}
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
// Check cache first (read lock)
cacheMu.RLock()
if cached, ok := regexCache[pattern]; ok {
cacheMu.RUnlock()
return cached, nil
}
cacheMu.RUnlock()
// Compile regex
// Compile regex outside the lock to avoid holding it during compilation
re, err := regexp.Compile(pattern)
if err != nil {
return nil, &RegexError{
@@ -85,18 +88,22 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) {
}
}
// Check cache size and clear if too large
if cacheSize.Load() >= MaxCacheSize {
ClearRegexCache()
// Write lock to store in cache
cacheMu.Lock()
// Re-check in case another goroutine stored it while we were compiling
if cached, ok := regexCache[pattern]; ok {
cacheMu.Unlock()
return cached, nil
}
// Try to store - if another goroutine already stored it, use theirs
// This prevents race conditions where multiple goroutines compile the same pattern
actual, loaded := regexCache.LoadOrStore(pattern, re)
if !loaded {
cacheSize.Add(1)
// Check cache size and clear if too large
if len(regexCache) >= MaxCacheSize {
regexCache = make(map[string]*regexp.Regexp)
}
return actual.(*regexp.Regexp), nil
regexCache[pattern] = re
cacheMu.Unlock()
return re, nil
}
// CompileRegexUncached compiles a regex pattern without caching.
@@ -118,18 +125,19 @@ func CompileRegexUncached(pattern string) (*regexp.Regexp, error) {
}
// ClearRegexCache clears all cached compiled regular expressions.
// Useful for testing or when memory usage needs to be reduced.
// Atomically replaces the map under a single write lock.
func ClearRegexCache() {
regexCache.Range(func(key, _ interface{}) bool {
regexCache.Delete(key)
return true
})
cacheSize.Store(0)
cacheMu.Lock()
regexCache = make(map[string]*regexp.Regexp)
cacheMu.Unlock()
}
// CacheStats returns the current number of cached patterns.
func CacheStats() int64 {
return cacheSize.Load()
cacheMu.RLock()
n := int64(len(regexCache))
cacheMu.RUnlock()
return n
}
// truncatePattern truncates a pattern for display in error messages.