mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-09 22:53:44 +00:00
fixup! Update, bugfixes on diff and edit handling
This commit is contained in:
@@ -69,13 +69,26 @@ func Load(workspaceRoot string) (*Config, error) {
|
||||
cfg.WorkspaceRoot = cwd
|
||||
}
|
||||
|
||||
// Try to load from config file in workspace root
|
||||
// Try to load from config file in workspace root.
|
||||
// Save WorkspaceRoot before loading config file so it cannot be overridden.
|
||||
savedRoot := cfg.WorkspaceRoot
|
||||
configPath := filepath.Join(cfg.WorkspaceRoot, ".mcp-filepuff.json")
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Restore WorkspaceRoot — config file must not override path guards.
|
||||
cfg.WorkspaceRoot = savedRoot
|
||||
|
||||
// Clamp size limits to prevent config file from requesting excessive memory.
|
||||
const maxAllowedSize int64 = 100 * 1024 * 1024 // 100 MB
|
||||
if cfg.MaxFileSize > maxAllowedSize {
|
||||
cfg.MaxFileSize = maxAllowedSize
|
||||
}
|
||||
if cfg.MaxParseSize > maxAllowedSize {
|
||||
cfg.MaxParseSize = maxAllowedSize
|
||||
}
|
||||
|
||||
// Override from environment variables
|
||||
cfg.loadFromEnv()
|
||||
@@ -173,8 +186,8 @@ func (c *Config) IsPathAllowed(path string) bool {
|
||||
|
||||
// Check if the path is within workspace (doesn't start with ..)
|
||||
// This prevents both "../" attacks and symlink bypasses
|
||||
// Also reject empty relative path (which means it's the workspace root itself)
|
||||
return rel != "." && !strings.HasPrefix(rel, "..")
|
||||
// The workspace root itself (rel == ".") is a valid, allowed path
|
||||
return !strings.HasPrefix(rel, "..")
|
||||
}
|
||||
|
||||
// Validate validates the configuration and returns an error if invalid.
|
||||
|
||||
@@ -389,8 +389,8 @@ func TestIsPathAllowedEdgeCases(t *testing.T) {
|
||||
{
|
||||
name: "workspace_root_itself",
|
||||
path: tmpDir,
|
||||
allowed: false,
|
||||
desc: "workspace root itself should not be allowed",
|
||||
allowed: true,
|
||||
desc: "workspace root itself should be allowed",
|
||||
},
|
||||
{
|
||||
name: "dot_relative",
|
||||
@@ -546,6 +546,59 @@ func TestConfigFileLoadingErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPathAllowed_SymlinkOutsideWorkspace verifies that symlinks pointing
|
||||
// outside the workspace are rejected (T-01).
|
||||
func TestIsPathAllowed_SymlinkOutsideWorkspace(t *testing.T) {
|
||||
// Create two separate temp dirs: one as workspace, one as outside target
|
||||
workspace, err := os.MkdirTemp("", "mcp-workspace-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create workspace dir: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.RemoveAll(workspace) })
|
||||
|
||||
outside, err := os.MkdirTemp("", "mcp-outside-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create outside dir: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.RemoveAll(outside) })
|
||||
|
||||
// Create a file outside the workspace
|
||||
outsideFile := filepath.Join(outside, "secret.txt")
|
||||
if err := os.WriteFile(outsideFile, []byte("secret"), 0o600); err != nil {
|
||||
t.Fatalf("failed to write outside file: %v", err)
|
||||
}
|
||||
|
||||
// Create a symlink inside the workspace pointing outside
|
||||
symlinkPath := filepath.Join(workspace, "escape-link")
|
||||
if err := os.Symlink(outsideFile, symlinkPath); err != nil {
|
||||
t.Skip("symlink creation not supported on this system")
|
||||
}
|
||||
|
||||
cfg := Default()
|
||||
cfg.WorkspaceRoot = workspace
|
||||
|
||||
// The symlink resolves to a file outside workspace — must be rejected
|
||||
if cfg.IsPathAllowed(symlinkPath) {
|
||||
t.Error("symlink pointing outside workspace should NOT be allowed")
|
||||
}
|
||||
|
||||
// Direct access to the outside file should also be rejected
|
||||
if cfg.IsPathAllowed(outsideFile) {
|
||||
t.Error("file outside workspace should NOT be allowed")
|
||||
}
|
||||
|
||||
// File inside workspace should still be allowed
|
||||
insideFile := filepath.Join(workspace, "safe.txt")
|
||||
if !cfg.IsPathAllowed(insideFile) {
|
||||
t.Error("file inside workspace should be allowed")
|
||||
}
|
||||
|
||||
// Workspace root itself should be allowed (C-08 fix)
|
||||
if !cfg.IsPathAllowed(workspace) {
|
||||
t.Error("workspace root itself should be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring.
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||
|
||||
@@ -126,7 +126,7 @@ func TestIsPathAllowed_BasicCases(t *testing.T) {
|
||||
{
|
||||
name: "workspace root itself",
|
||||
path: tmpDir,
|
||||
expected: false, // Empty relative path
|
||||
expected: true, // Workspace root is a valid, allowed path (needed for ast_query)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+19
-11
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -60,6 +61,7 @@ type EditResult struct {
|
||||
// Engine performs AST-aware edits.
|
||||
type Engine struct {
|
||||
registry *parser.Registry
|
||||
dmp *diffmatchpatch.DiffMatchPatch
|
||||
fileLocks sync.Map // map[string]*sync.Mutex for per-file locking
|
||||
}
|
||||
|
||||
@@ -67,6 +69,7 @@ type Engine struct {
|
||||
func NewEngine(registry *parser.Registry) *Engine {
|
||||
return &Engine{
|
||||
registry: registry,
|
||||
dmp: diffmatchpatch.New(),
|
||||
fileLocks: sync.Map{},
|
||||
}
|
||||
}
|
||||
@@ -166,7 +169,7 @@ func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool)
|
||||
}
|
||||
|
||||
// Generate diff
|
||||
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||
diff := e.generateDiff(string(content), string(newContent), edit.File)
|
||||
|
||||
result := &EditResult{
|
||||
Success: true,
|
||||
@@ -225,7 +228,7 @@ func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) (
|
||||
}
|
||||
|
||||
// Generate diff
|
||||
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||
diff := e.generateDiff(string(content), string(newContent), edit.File)
|
||||
|
||||
result := &EditResult{
|
||||
Success: true,
|
||||
@@ -369,17 +372,18 @@ func sortBySpecificity(nodes []*sitter.Node) []*sitter.Node {
|
||||
return nodes
|
||||
}
|
||||
|
||||
// Sort by specificity: named nodes first, then by size (smallest first)
|
||||
result := make([]*sitter.Node, len(nodes))
|
||||
copy(result, nodes)
|
||||
|
||||
for i := 0; i < len(result)-1; i++ {
|
||||
for j := i + 1; j < len(result); j++ {
|
||||
if shouldPrefer(result[j], result[i]) {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
slices.SortFunc(result, func(a, b *sitter.Node) int {
|
||||
if shouldPrefer(a, b) {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
if shouldPrefer(b, a) {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -566,8 +570,8 @@ func indentContent(content string, indent string) string {
|
||||
|
||||
// generateDiff creates a unified diff between original and modified content.
|
||||
// Uses line-level Myers diff algorithm for accurate and readable diffs.
|
||||
func generateDiff(original, modified, filename string) string {
|
||||
dmp := diffmatchpatch.New()
|
||||
func (e *Engine) generateDiff(original, modified, filename string) string {
|
||||
dmp := e.dmp
|
||||
|
||||
// Use line-level diffing: encode each line as a single character,
|
||||
// diff the encoded strings, then decode back to real lines.
|
||||
@@ -692,6 +696,10 @@ func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, e
|
||||
}
|
||||
|
||||
lines := bytes.Split(content, []byte("\n"))
|
||||
// Trim phantom empty element from trailing newline
|
||||
if len(lines) > 0 && len(lines[len(lines)-1]) == 0 {
|
||||
lines = lines[:len(lines)-1]
|
||||
}
|
||||
totalLines := len(lines)
|
||||
|
||||
// Convert to 0-indexed
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
@@ -394,7 +395,7 @@ func TestGenerateDiff(t *testing.T) {
|
||||
modified := "line1\nmodified\nline3"
|
||||
filename := "test.txt"
|
||||
|
||||
diff := generateDiff(original, modified, filename)
|
||||
diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, filename)
|
||||
|
||||
if !strings.Contains(diff, "---") {
|
||||
t.Error("diff should contain --- header")
|
||||
@@ -420,7 +421,7 @@ func TestGenerateDiffLineLevelAccuracy(t *testing.T) {
|
||||
original := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello\")\n}\n"
|
||||
modified := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello world\")\n}\n"
|
||||
|
||||
diff := generateDiff(original, modified, "test.go")
|
||||
diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, "test.go")
|
||||
|
||||
// The diff must show whole-line removals and additions
|
||||
if !strings.Contains(diff, "-\tfmt.Println(\"hello\")\n") {
|
||||
@@ -453,7 +454,7 @@ func TestGenerateDiffNoPhantomChanges(t *testing.T) {
|
||||
original := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\n"
|
||||
modified := "line1\nREPLACED\nline3\nline4\nline5\nline6\nline7\nline8\n"
|
||||
|
||||
diff := generateDiff(original, modified, "test.txt")
|
||||
diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, "test.txt")
|
||||
|
||||
// Count changed lines (excluding headers)
|
||||
addCount := 0
|
||||
|
||||
@@ -268,7 +268,11 @@ func (c *Client) send(msg interface{}) error {
|
||||
}
|
||||
|
||||
// readLoop reads and dispatches messages from the server.
|
||||
// On exit (for any reason), it drains all pending Call waiters with a
|
||||
// synthetic error so that goroutines blocked in Call are unblocked.
|
||||
func (c *Client) readLoop() {
|
||||
defer c.drainPending()
|
||||
|
||||
reader := bufio.NewReader(c.stdout)
|
||||
|
||||
for {
|
||||
@@ -329,6 +333,25 @@ func (c *Client) readLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// drainPending sends a synthetic error response to every pending Call waiter
|
||||
// so that goroutines blocked in Call are unblocked when readLoop exits.
|
||||
func (c *Client) drainPending() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for id, ch := range c.pending {
|
||||
ch <- &Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Error: &ResponseError{
|
||||
Code: -32603, // InternalError
|
||||
Message: "LSP client readLoop terminated",
|
||||
},
|
||||
}
|
||||
delete(c.pending, id)
|
||||
}
|
||||
}
|
||||
|
||||
// IsRunning returns whether the client is running.
|
||||
func (c *Client) IsRunning() bool {
|
||||
c.runningMu.RLock()
|
||||
|
||||
+43
-26
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -28,6 +29,9 @@ const (
|
||||
)
|
||||
|
||||
// Manager manages LSP servers for different languages.
|
||||
//
|
||||
// Lock ordering: m.mu must always be acquired before srv.mu.
|
||||
// Never acquire m.mu while holding srv.mu.
|
||||
type Manager struct {
|
||||
servers map[protocol.Language]*ManagedServer
|
||||
logger *slog.Logger
|
||||
@@ -36,6 +40,7 @@ type Manager struct {
|
||||
timeout time.Duration
|
||||
idleTimeout time.Duration
|
||||
mu sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
stopped bool
|
||||
}
|
||||
|
||||
@@ -163,9 +168,10 @@ func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*Manag
|
||||
WithRemediation("Only whitelisted LSP server binaries are allowed for security reasons")
|
||||
}
|
||||
|
||||
// Create command
|
||||
// Create command — use exec.Command (not CommandContext) so the LSP
|
||||
// subprocess is not killed when the request-scoped context expires.
|
||||
args := append(config.Command[1:], config.Args...)
|
||||
cmd := exec.CommandContext(ctx, cmdPath, args...)
|
||||
cmd := exec.Command(cmdPath, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Dir = m.workspaceRoot
|
||||
// Create client
|
||||
@@ -412,9 +418,13 @@ func (m *Manager) References(ctx context.Context, file string, line, col int, in
|
||||
}
|
||||
|
||||
// ensureDocumentOpen opens a document if not already open.
|
||||
func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, file string) error {
|
||||
// It reads the file content outside the lock (to avoid holding the lock during I/O),
|
||||
// then holds srv.mu for the entire check-and-send sequence to prevent duplicate didOpen
|
||||
// notifications from concurrent goroutines.
|
||||
func (m *Manager) ensureDocumentOpen(_ context.Context, srv *ManagedServer, file string) error {
|
||||
uri := fileToURI(file)
|
||||
|
||||
// Quick check under lock — common fast path.
|
||||
srv.mu.Lock()
|
||||
if _, ok := srv.openDocs[uri]; ok {
|
||||
srv.mu.Unlock()
|
||||
@@ -422,15 +432,23 @@ func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, fi
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
|
||||
// Read file content
|
||||
// Read file content outside the lock to avoid holding it during I/O.
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
// Get language ID
|
||||
langID := languageToLSPID(srv.language)
|
||||
|
||||
// Re-acquire lock and re-check to prevent TOCTOU race: two goroutines could
|
||||
// both pass the fast-path check above and both try to send didOpen.
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
if _, ok := srv.openDocs[uri]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := DidOpenTextDocumentParams{
|
||||
TextDocument: TextDocumentItem{
|
||||
URI: uri,
|
||||
@@ -444,10 +462,7 @@ func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, fi
|
||||
return fmt.Errorf("didOpen failed: %w", err)
|
||||
}
|
||||
|
||||
srv.mu.Lock()
|
||||
srv.openDocs[uri] = 1
|
||||
srv.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -519,26 +534,28 @@ func (m *Manager) reapIdleServers() {
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down all LSP servers.
|
||||
// Close shuts down all LSP servers. It is safe to call multiple times.
|
||||
func (m *Manager) Close() error {
|
||||
close(m.stopReaper)
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.stopReaper)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.stopped = true
|
||||
m.stopped = true
|
||||
|
||||
for lang, srv := range m.servers {
|
||||
m.logger.Info("shutting down LSP server", "language", lang)
|
||||
// Try graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
|
||||
_, _ = srv.client.Call(ctx, "shutdown", nil)
|
||||
cancel()
|
||||
_ = srv.client.Notify("exit", nil)
|
||||
_ = srv.client.Close()
|
||||
}
|
||||
for lang, srv := range m.servers {
|
||||
m.logger.Info("shutting down LSP server", "language", lang)
|
||||
// Try graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
|
||||
_, _ = srv.client.Call(ctx, "shutdown", nil)
|
||||
cancel()
|
||||
_ = srv.client.Notify("exit", nil)
|
||||
_ = srv.client.Close()
|
||||
}
|
||||
|
||||
m.servers = make(map[protocol.Language]*ManagedServer)
|
||||
m.servers = make(map[protocol.Language]*ManagedServer)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -553,13 +570,13 @@ func (m *Manager) IsAvailable(lang protocol.Language) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// fileToURI converts a file path to a file URI.
|
||||
// fileToURI converts a file path to a properly percent-encoded file URI.
|
||||
func fileToURI(file string) string {
|
||||
absPath, err := filepath.Abs(file)
|
||||
if err != nil {
|
||||
return "file://" + file
|
||||
absPath = file
|
||||
}
|
||||
return "file://" + absPath
|
||||
return (&url.URL{Scheme: "file", Path: absPath}).String()
|
||||
}
|
||||
|
||||
// URIToFile converts a file URI to a file path.
|
||||
|
||||
@@ -189,10 +189,30 @@ func TestManagerGracefulShutdown(t *testing.T) {
|
||||
if !manager.stopped {
|
||||
t.Error("manager should be marked as stopped after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
// Note: We don't test multiple Close() calls because the implementation
|
||||
// closes the stopReaper channel which can't be closed twice.
|
||||
// In production, Close() should only be called once during shutdown.
|
||||
// TestManagerDoubleClose verifies that calling Close() twice does not panic (T-05, C-02).
|
||||
func TestManagerDoubleClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
manager := NewManager(tmpDir, logger)
|
||||
|
||||
// First close should succeed
|
||||
err := manager.Close()
|
||||
if err != nil {
|
||||
t.Errorf("first Close() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Second close must not panic (C-02 fix wraps close in sync.Once)
|
||||
err = manager.Close()
|
||||
if err != nil {
|
||||
t.Errorf("second Close() returned error: %v", err)
|
||||
}
|
||||
|
||||
if !manager.stopped {
|
||||
t.Error("manager should be marked as stopped after double Close()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerIdleReaper tests the idle server cleanup mechanism.
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
)
|
||||
|
||||
// TestLRUCacheEviction tests that the LRU cache properly evicts old entries.
|
||||
@@ -82,8 +84,8 @@ func TestContentHashCollisionResistance(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hash1 := contentHash(tc.content1)
|
||||
hash2 := contentHash(tc.content2)
|
||||
hash1 := fmt.Sprintf("%016x", xxhash.Sum64(tc.content1))
|
||||
hash2 := fmt.Sprintf("%016x", xxhash.Sum64(tc.content2))
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("Hash collision: %s == %s for different content", hash1, hash2)
|
||||
@@ -96,9 +98,9 @@ func TestContentHashCollisionResistance(t *testing.T) {
|
||||
func TestContentHashConsistency(t *testing.T) {
|
||||
content := []byte("package main\n\nfunc test() {}\n")
|
||||
|
||||
hash1 := contentHash(content)
|
||||
hash2 := contentHash(content)
|
||||
hash3 := contentHash(content)
|
||||
hash1 := fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
hash2 := fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
hash3 := fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
|
||||
if hash1 != hash2 || hash2 != hash3 {
|
||||
t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3)
|
||||
@@ -115,7 +117,7 @@ func BenchmarkContentHash_xxHash(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = contentHash(content)
|
||||
_ = fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,9 +24,8 @@ import (
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// MaxFileSize is the default maximum file size we'll parse (10MB).
|
||||
// Deprecated: Use Registry.maxParseSize instead.
|
||||
const MaxFileSize = 10 * 1024 * 1024
|
||||
// maxFileSize is the default maximum file size we'll parse (10MB).
|
||||
const maxFileSize = 10 * 1024 * 1024
|
||||
|
||||
// Registry manages Tree-sitter parsers for different languages.
|
||||
type Registry struct {
|
||||
@@ -69,18 +68,6 @@ type SyntaxError struct {
|
||||
Location protocol.Location
|
||||
}
|
||||
|
||||
// CacheStatsResult contains cache statistics.
|
||||
type CacheStatsResult struct {
|
||||
Hits int64 `json:"hits"`
|
||||
Misses int64 `json:"misses"`
|
||||
HitRate float64 `json:"hit_rate"`
|
||||
Size int `json:"size"`
|
||||
TotalParseTime int64 `json:"total_parse_time_ns"`
|
||||
ParseCount int64 `json:"parse_count"`
|
||||
AvgParseTime int64 `json:"avg_parse_time_ns"`
|
||||
LastParseTime int64 `json:"last_parse_time_ns"`
|
||||
}
|
||||
|
||||
// NewRegistry creates a new parser registry with the default max parse size.
|
||||
// For custom max parse size, use NewRegistryWithSize.
|
||||
func NewRegistry() *Registry {
|
||||
@@ -98,7 +85,7 @@ func NewRegistryWithSize(maxParseSize int64) *Registry {
|
||||
}
|
||||
|
||||
if maxParseSize <= 0 {
|
||||
maxParseSize = MaxFileSize
|
||||
maxParseSize = maxFileSize
|
||||
}
|
||||
|
||||
return &Registry{
|
||||
@@ -266,50 +253,6 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CacheStats returns cache hit/miss statistics.
|
||||
func (r *Registry) CacheStats() (hits, misses int64) {
|
||||
return r.cacheHits.Load(), r.cacheMisses.Load()
|
||||
}
|
||||
|
||||
// CacheStatsDetailed returns detailed cache and parse statistics.
|
||||
func (r *Registry) CacheStatsDetailed() CacheStatsResult {
|
||||
hits := r.cacheHits.Load()
|
||||
misses := r.cacheMisses.Load()
|
||||
totalParseTime := r.totalParseTime.Load()
|
||||
parseCount := r.parseCount.Load()
|
||||
|
||||
var hitRate float64
|
||||
total := hits + misses
|
||||
if total > 0 {
|
||||
hitRate = float64(hits) / float64(total)
|
||||
}
|
||||
|
||||
var avgParseTime int64
|
||||
if parseCount > 0 {
|
||||
avgParseTime = totalParseTime / parseCount
|
||||
}
|
||||
|
||||
return CacheStatsResult{
|
||||
Hits: hits,
|
||||
Misses: misses,
|
||||
HitRate: hitRate,
|
||||
Size: r.cache.Len(),
|
||||
TotalParseTime: totalParseTime,
|
||||
ParseCount: parseCount,
|
||||
AvgParseTime: avgParseTime,
|
||||
LastParseTime: r.lastParseDuration.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// ResetStats resets all cache and parse statistics.
|
||||
func (r *Registry) ResetStats() {
|
||||
r.cacheHits.Store(0)
|
||||
r.cacheMisses.Store(0)
|
||||
r.totalParseTime.Store(0)
|
||||
r.parseCount.Store(0)
|
||||
r.lastParseDuration.Store(0)
|
||||
}
|
||||
|
||||
// extractErrors finds all error nodes in the tree.
|
||||
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
|
||||
var errors []SyntaxError
|
||||
@@ -346,13 +289,6 @@ func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
|
||||
return errors
|
||||
}
|
||||
|
||||
// contentHash returns a fast hash of the content for caching.
|
||||
// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes.
|
||||
func contentHash(content []byte) string {
|
||||
h := xxhash.Sum64(content)
|
||||
return fmt.Sprintf("%016x", h)
|
||||
}
|
||||
|
||||
// isBinary checks if content appears to be binary.
|
||||
func isBinary(content []byte) bool {
|
||||
// Check first 8000 bytes for null bytes
|
||||
|
||||
@@ -2,8 +2,11 @@ package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
)
|
||||
|
||||
// BenchmarkParse benchmarks parsing files of various sizes.
|
||||
@@ -194,7 +197,7 @@ func BenchmarkContentHash(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = contentHash(content)
|
||||
_ = fmt.Sprintf("%016x", xxhash.Sum64(content))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -31,8 +31,8 @@ type JSONNode struct {
|
||||
// ParseYAML parses YAML content and returns a tree-sitter-compatible result
|
||||
func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
if len(content) > maxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), maxFileSize)
|
||||
}
|
||||
|
||||
// Parse YAML
|
||||
@@ -57,8 +57,8 @@ func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byt
|
||||
// ParseJSON parses JSON content and returns a tree-sitter-compatible result
|
||||
func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
if len(content) > maxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), maxFileSize)
|
||||
}
|
||||
|
||||
// Parse JSON to validate syntax
|
||||
|
||||
@@ -22,13 +22,11 @@ type ASTQuery struct {
|
||||
|
||||
// QueryFilters provide additional filtering criteria.
|
||||
type QueryFilters struct {
|
||||
HasChild *ASTQuery `json:"has_child,omitempty"`
|
||||
HasParent *ASTQuery `json:"has_parent,omitempty"`
|
||||
NameMatches string `json:"name_matches,omitempty"`
|
||||
NameExact string `json:"name_exact,omitempty"`
|
||||
InFile string `json:"in_file,omitempty"`
|
||||
NotInFile string `json:"not_in_file,omitempty"`
|
||||
KindIn []string `json:"kind_in,omitempty"`
|
||||
NameMatches string `json:"name_matches,omitempty"`
|
||||
NameExact string `json:"name_exact,omitempty"`
|
||||
InFile string `json:"in_file,omitempty"`
|
||||
NotInFile string `json:"not_in_file,omitempty"`
|
||||
KindIn []string `json:"kind_in,omitempty"`
|
||||
}
|
||||
|
||||
// MatchResult represents a single match from a query.
|
||||
@@ -259,7 +257,7 @@ func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content []
|
||||
}
|
||||
|
||||
// Match struct patterns (Go, C, C++)
|
||||
if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") {
|
||||
if (strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ")) && strings.Contains(patternLower, "struct") {
|
||||
if nodeType != "type_declaration" && nodeType != "struct_specifier" {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -499,6 +499,155 @@ func TestFormatResults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatchStructOperatorPrecedence verifies the C-07 operator precedence fix.
|
||||
// Before the fix, patterns like "struct Foo" would match because
|
||||
// strings.Contains(p, "struct ") short-circuited the entire condition.
|
||||
// After the fix, both "struct" must be present for the struct branch to match.
|
||||
func TestMatchStructOperatorPrecedence(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `package main
|
||||
|
||||
type Server struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
func main() {}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "type struct pattern should match",
|
||||
pattern: "type $NAME struct { $$$FIELDS }",
|
||||
wantMatches: 1, // Server
|
||||
},
|
||||
{
|
||||
name: "struct keyword alone should match",
|
||||
pattern: "struct $NAME { $$$FIELDS }",
|
||||
wantMatches: 1, // Server
|
||||
},
|
||||
{
|
||||
name: "func pattern should not match struct branch",
|
||||
pattern: "func $NAME() {}",
|
||||
wantMatches: 1, // main (matches function branch, not struct)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query := &ASTQuery{
|
||||
Pattern: tt.pattern,
|
||||
Language: "go",
|
||||
}
|
||||
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
for i, r := range results {
|
||||
t.Logf("match %d: type=%s, text=%q", i, r.Node.Type(), truncateForLog(r.Text, 80))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func truncateForLog(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "..."
|
||||
}
|
||||
|
||||
// TestPassesFilters_AllBranches tests passesFilters for each filter type.
|
||||
func TestPassesFilters_AllBranches(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
content := `package main
|
||||
|
||||
func Alpha() {}
|
||||
func Beta() {}
|
||||
func Gamma() {}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filters QueryFilters
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "no filters matches all",
|
||||
filters: QueryFilters{},
|
||||
wantMatches: 3,
|
||||
},
|
||||
{
|
||||
name: "name_exact filter",
|
||||
filters: QueryFilters{NameExact: "Alpha"},
|
||||
wantMatches: 1,
|
||||
},
|
||||
{
|
||||
name: "name_matches regex filter",
|
||||
filters: QueryFilters{NameMatches: "^[AB]"},
|
||||
wantMatches: 2,
|
||||
},
|
||||
{
|
||||
name: "kind_in filter",
|
||||
filters: QueryFilters{KindIn: []string{"function_declaration"}},
|
||||
wantMatches: 3,
|
||||
},
|
||||
{
|
||||
name: "kind_in filter excludes non-matching kinds",
|
||||
filters: QueryFilters{KindIn: []string{"class_declaration"}},
|
||||
wantMatches: 0,
|
||||
},
|
||||
{
|
||||
name: "combined name_exact and kind_in",
|
||||
filters: QueryFilters{NameExact: "Beta", KindIn: []string{"function_declaration"}},
|
||||
wantMatches: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query := &ASTQuery{
|
||||
Pattern: "func $NAME() {}",
|
||||
Language: "go",
|
||||
Filters: tt.filters,
|
||||
}
|
||||
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryValidation(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
@@ -131,6 +131,11 @@ func (s *Searcher) Search(ctx context.Context, req *Request) (*SearchResults, er
|
||||
WithRemediation("Provide a non-empty search pattern")
|
||||
}
|
||||
|
||||
// Validate that at least one provided path is allowed
|
||||
if err := s.validatePaths(req.Paths); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build ripgrep command
|
||||
args := s.buildArgs(req)
|
||||
|
||||
@@ -238,6 +243,22 @@ func (s *Searcher) buildArgs(req *Request) []string {
|
||||
return args
|
||||
}
|
||||
|
||||
// validatePaths checks that at least one caller-provided path is allowed.
|
||||
// Returns an error if paths were provided but none passed IsPathAllowed.
|
||||
func (s *Searcher) validatePaths(paths []string) error {
|
||||
if len(paths) == 0 {
|
||||
return nil // no explicit paths — will default to workspace root
|
||||
}
|
||||
for _, p := range paths {
|
||||
if s.cfg.IsPathAllowed(p) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New(errors.ErrPathNotAllowed, "all provided search paths are outside the workspace root").
|
||||
WithContext("paths", fmt.Sprintf("%v", paths)).
|
||||
WithRemediation("Provide paths within the workspace root")
|
||||
}
|
||||
|
||||
// parseOutput parses ripgrep JSON output.
|
||||
func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchResults, error) {
|
||||
results := &SearchResults{
|
||||
|
||||
@@ -43,9 +43,10 @@ func TestBuildArgs(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *Request
|
||||
expected []string
|
||||
name string
|
||||
req *Request
|
||||
expected []string
|
||||
notExpected []string // T-06: verify absence of unexpected flags
|
||||
}{
|
||||
{
|
||||
name: "basic search",
|
||||
@@ -54,7 +55,8 @@ func TestBuildArgs(t *testing.T) {
|
||||
ContextLines: 2,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--context=2", "--", "test", "."},
|
||||
expected: []string{"--json", "--context=2", "--", "test", "."},
|
||||
notExpected: []string{"--ignore-case", "--fixed-strings", "--max-total-count=0"},
|
||||
},
|
||||
{
|
||||
name: "ignore case",
|
||||
@@ -63,7 +65,8 @@ func TestBuildArgs(t *testing.T) {
|
||||
IgnoreCase: true,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--ignore-case", "--", "test", "."},
|
||||
expected: []string{"--json", "--ignore-case", "--", "test", "."},
|
||||
notExpected: []string{"--fixed-strings"},
|
||||
},
|
||||
{
|
||||
name: "fixed strings",
|
||||
@@ -71,7 +74,8 @@ func TestBuildArgs(t *testing.T) {
|
||||
Pattern: "test",
|
||||
Regex: false,
|
||||
},
|
||||
expected: []string{"--json", "--fixed-strings", "--", "test", "."},
|
||||
expected: []string{"--json", "--fixed-strings", "--", "test", "."},
|
||||
notExpected: []string{"--ignore-case"},
|
||||
},
|
||||
{
|
||||
name: "with file types",
|
||||
@@ -80,7 +84,8 @@ func TestBuildArgs(t *testing.T) {
|
||||
FileTypes: []string{"go", "ts"},
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."},
|
||||
expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."},
|
||||
notExpected: []string{"--ignore-case", "--fixed-strings"},
|
||||
},
|
||||
{
|
||||
name: "with max results",
|
||||
@@ -89,7 +94,8 @@ func TestBuildArgs(t *testing.T) {
|
||||
MaxResults: 10,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
|
||||
expected: []string{"--json", "--max-total-count=10", "--", "test", "."},
|
||||
notExpected: []string{"--ignore-case", "--fixed-strings"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -110,6 +116,16 @@ func TestBuildArgs(t *testing.T) {
|
||||
t.Errorf("expected arg %q not found in %v", exp, args)
|
||||
}
|
||||
}
|
||||
|
||||
// T-06: Check that unexpected args are absent
|
||||
for _, notExp := range tt.notExpected {
|
||||
for _, arg := range args {
|
||||
if arg == notExp {
|
||||
t.Errorf("unexpected arg %q found in %v", notExp, args)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,9 +54,9 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest
|
||||
}
|
||||
|
||||
// Find files to search based on language
|
||||
ext := languageToExtension(language)
|
||||
if ext == "" {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
|
||||
exts := languageToExtensions(language)
|
||||
if len(exts) == 0 {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir)", language)), nil
|
||||
}
|
||||
|
||||
var allResults []query.MatchResult
|
||||
@@ -89,7 +89,14 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest
|
||||
}
|
||||
|
||||
// Check file extension matches language
|
||||
if !strings.HasSuffix(path, ext) {
|
||||
matched := false
|
||||
for _, ext := range exts {
|
||||
if strings.HasSuffix(path, ext) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -203,24 +210,28 @@ func symbolKindIcon(kind protocol.SymbolKind) string {
|
||||
}
|
||||
}
|
||||
|
||||
// languageToExtension maps language names to file extensions.
|
||||
func languageToExtension(language string) string {
|
||||
// languageToExtensions maps language names to file extensions.
|
||||
func languageToExtensions(language string) []string {
|
||||
switch strings.ToLower(language) {
|
||||
case "go":
|
||||
return ".go"
|
||||
return []string{".go"}
|
||||
case "typescript":
|
||||
return ".ts"
|
||||
return []string{".ts"}
|
||||
case "javascript":
|
||||
return ".js"
|
||||
return []string{".js"}
|
||||
case "python":
|
||||
return ".py"
|
||||
return []string{".py"}
|
||||
case "c":
|
||||
return ".c"
|
||||
return []string{".c"}
|
||||
case "cpp", "c++":
|
||||
return ".cpp"
|
||||
return []string{".cpp"}
|
||||
case "html":
|
||||
return []string{".html", ".htm"}
|
||||
case "vue":
|
||||
return []string{".vue"}
|
||||
case "elixir":
|
||||
return ".ex"
|
||||
return []string{".ex", ".exs"}
|
||||
default:
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,16 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, ap
|
||||
return mcp.NewToolResultError("operation is required"), nil
|
||||
}
|
||||
|
||||
// Validate operation against known values
|
||||
switch edit.EditOperation(operation) {
|
||||
case edit.EditReplace, edit.EditInsertBefore, edit.EditInsertAfter, edit.EditDelete:
|
||||
// valid
|
||||
default:
|
||||
return mcp.NewToolResultError(fmt.Sprintf(
|
||||
"invalid operation %q: must be one of: replace, insert_before, insert_after, delete", operation,
|
||||
)), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
|
||||
@@ -81,8 +81,8 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
|
||||
return mcp.NewToolResultError("path is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(path)
|
||||
// Check file size before reading to avoid loading huge files into memory
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
|
||||
@@ -90,13 +90,21 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest
|
||||
if os.IsPermission(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
|
||||
}
|
||||
s.logger.Warn("file read error", "path", path, "error", err)
|
||||
return mcp.NewToolResultError("error reading file"), nil
|
||||
s.logger.Warn("file stat error", "path", path, "error", err)
|
||||
return mcp.NewToolResultError("error accessing file"), nil
|
||||
}
|
||||
if info.Size() > s.cfg.MaxFileSize {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)), nil
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
|
||||
// Read file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsPermission(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
|
||||
}
|
||||
s.logger.Warn("file read error", "path", path, "error", err)
|
||||
return mcp.NewToolResultError("error reading file"), nil
|
||||
}
|
||||
|
||||
// Handle line range
|
||||
@@ -167,13 +175,18 @@ func splitLines(s string) []string {
|
||||
const largeSizeThreshold = 1024 * 1024 // 1MB
|
||||
|
||||
if len(s) > largeSizeThreshold {
|
||||
// Use scanner for large files
|
||||
// Use scanner for large files with increased buffer for long lines
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), 1024*1024) // up to 1MB per line
|
||||
var lines []string
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
// Handle potential error and add empty line if string ended with newline
|
||||
if err := scanner.Err(); err != nil {
|
||||
// If scanning fails (e.g. line exceeds buffer), fall back to strings.Split
|
||||
return strings.Split(s, "\n")
|
||||
}
|
||||
// Add empty line if string ended with newline
|
||||
if len(s) > 0 && s[len(s)-1] == '\n' {
|
||||
lines = append(lines, "")
|
||||
}
|
||||
|
||||
@@ -117,7 +117,7 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR
|
||||
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
|
||||
// Try to read a preview snippet
|
||||
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
|
||||
preview := s.readFilePreview(filePath, loc.Range.Start.Line+1, 3)
|
||||
if preview != "" {
|
||||
output.WriteString("```\n")
|
||||
output.WriteString(preview)
|
||||
@@ -184,7 +184,13 @@ func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolR
|
||||
}
|
||||
|
||||
// readFilePreview reads a few lines from a file around the given line.
|
||||
func readFilePreview(file string, line, contextLines int) string {
|
||||
// It validates that the file path is within the allowed workspace before reading.
|
||||
func (s *Server) readFilePreview(file string, line, contextLines int) string {
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
s.logger.Warn("readFilePreview: path not allowed", "path", file)
|
||||
return ""
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return ""
|
||||
|
||||
@@ -169,7 +169,7 @@ func (s *Server) registerTools() {
|
||||
),
|
||||
mcp.WithString("language",
|
||||
mcp.Required(),
|
||||
mcp.Description("Target language: go, typescript, javascript, python, c, cpp"),
|
||||
mcp.Description("Target language: go, typescript, javascript, python, c, cpp, html, vue, elixir"),
|
||||
),
|
||||
mcp.WithArray("paths",
|
||||
mcp.Description("Paths to search in (defaults to workspace root)"),
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
@@ -381,3 +382,155 @@ func Hello() {
|
||||
t.Error("handleEdit(apply) should modify the file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleFileReadMaxFileSize verifies that handleFileRead rejects files
|
||||
// exceeding MaxFileSize via os.Stat before loading them into memory (T-03, S-01).
|
||||
func TestHandleFileReadMaxFileSize(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test file
|
||||
testFile := filepath.Join(tmpDir, "big.txt")
|
||||
content := make([]byte, 1024) // 1KB file
|
||||
for i := range content {
|
||||
content[i] = 'A'
|
||||
}
|
||||
if err := os.WriteFile(testFile, content, 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
// Set MaxFileSize smaller than the file
|
||||
cfg := &config.Config{
|
||||
WorkspaceRoot: tmpDir,
|
||||
EnableLSP: false,
|
||||
MaxFileSize: 512, // 512 bytes — smaller than our 1KB file
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"path": testFile,
|
||||
}
|
||||
|
||||
result, err := srv.handleFileRead(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("handleFileRead() returned Go error: %v", err)
|
||||
}
|
||||
|
||||
// The result should indicate an error (file too large)
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead() returned nil result")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("expected IsError=true for file exceeding MaxFileSize")
|
||||
}
|
||||
contents := result.Content
|
||||
if len(contents) == 0 {
|
||||
t.Fatal("expected non-empty content with error message")
|
||||
}
|
||||
textContent, ok := contents[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("expected text content")
|
||||
}
|
||||
if !strings.Contains(textContent.Text, "too large") {
|
||||
t.Errorf("expected 'too large' error message, got: %s", textContent.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSplitLinesLargeFile tests the splitLines function with a large file (>1MB)
|
||||
// to exercise the bufio.Scanner path including the scanner.Err() check (T-07, C-05).
|
||||
func TestSplitLinesLargeFile(t *testing.T) {
|
||||
// Build a string >1MB with known line count
|
||||
lineCount := 20000
|
||||
var sb strings.Builder
|
||||
for i := 0; i < lineCount; i++ {
|
||||
sb.WriteString(strings.Repeat("x", 60))
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
largeContent := sb.String()
|
||||
|
||||
// Verify it's large enough to trigger the scanner path
|
||||
if len(largeContent) <= 1024*1024 {
|
||||
t.Fatalf("test content too small: %d bytes, need >1MB", len(largeContent))
|
||||
}
|
||||
|
||||
lines := splitLines(largeContent)
|
||||
|
||||
// String ends with \n, so splitLines adds an empty trailing element
|
||||
// (matching the behavior of strings.Split for the small-file path)
|
||||
expectedLines := lineCount + 1 // lineCount lines + 1 trailing empty
|
||||
if len(lines) != expectedLines {
|
||||
t.Errorf("splitLines returned %d lines, expected %d", len(lines), expectedLines)
|
||||
}
|
||||
|
||||
// Check first and last actual lines
|
||||
if lines[0] != strings.Repeat("x", 60) {
|
||||
t.Errorf("first line mismatch: got %q", lines[0][:20])
|
||||
}
|
||||
if lines[lineCount-1] != strings.Repeat("x", 60) {
|
||||
t.Errorf("last content line mismatch: got %q", lines[lineCount-1][:20])
|
||||
}
|
||||
if lines[lineCount] != "" {
|
||||
t.Errorf("expected empty trailing line, got %q", lines[lineCount])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSplitLinesLargeFileNoTrailingNewline verifies splitLines for large files
|
||||
// without a trailing newline.
|
||||
func TestSplitLinesLargeFileNoTrailingNewline(t *testing.T) {
|
||||
lineCount := 20000
|
||||
var sb strings.Builder
|
||||
for i := 0; i < lineCount; i++ {
|
||||
if i > 0 {
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
sb.WriteString(strings.Repeat("y", 60))
|
||||
}
|
||||
largeContent := sb.String()
|
||||
|
||||
if len(largeContent) <= 1024*1024 {
|
||||
t.Fatalf("test content too small: %d bytes", len(largeContent))
|
||||
}
|
||||
|
||||
lines := splitLines(largeContent)
|
||||
if len(lines) != lineCount {
|
||||
t.Errorf("splitLines returned %d lines, expected %d", len(lines), lineCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSplitLinesLongLine verifies the scanner gracefully handles very long lines
|
||||
// (up to the 1MB buffer limit set in C-05 fix).
|
||||
func TestSplitLinesLongLine(t *testing.T) {
|
||||
// Create content with one very long line (500KB) embedded in a >1MB file
|
||||
shortLines := strings.Repeat("short line content\n", 60000) // ~60KB * ~1 = ~1.08MB
|
||||
longLine := strings.Repeat("L", 500*1024) // 500KB line
|
||||
largeContent := shortLines + longLine + "\n"
|
||||
|
||||
if len(largeContent) <= 1024*1024 {
|
||||
t.Fatalf("test content too small: %d bytes", len(largeContent))
|
||||
}
|
||||
|
||||
lines := splitLines(largeContent)
|
||||
|
||||
// Should not crash and should return some lines
|
||||
if len(lines) == 0 {
|
||||
t.Fatal("splitLines returned no lines for valid content")
|
||||
}
|
||||
|
||||
// The long line should be present somewhere in the output
|
||||
foundLong := false
|
||||
for _, line := range lines {
|
||||
if len(line) >= 500*1024 {
|
||||
foundLong = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundLong {
|
||||
t.Error("the 500KB long line was not found in splitLines output")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,10 +18,11 @@ const (
|
||||
)
|
||||
|
||||
// regexCache is a global thread-safe cache for compiled regular expressions.
|
||||
// Caching regex compilation provides 10-50x speedup for repeated patterns.
|
||||
// Uses sync.RWMutex with a regular map so that ClearRegexCache can atomically
|
||||
// clear the map and reset the count in a single lock acquisition.
|
||||
var (
|
||||
regexCache sync.Map // string -> *regexp.Regexp
|
||||
cacheSize atomic.Int64
|
||||
cacheMu sync.RWMutex
|
||||
regexCache = make(map[string]*regexp.Regexp)
|
||||
)
|
||||
|
||||
// RegexError represents an error during regex compilation or validation.
|
||||
@@ -62,7 +62,7 @@ func ValidatePattern(pattern string) error {
|
||||
}
|
||||
|
||||
// CompileRegex compiles a regex pattern with caching and validation for security.
|
||||
// Thread-safe: uses LoadOrStore to prevent race conditions.
|
||||
// Thread-safe: uses RWMutex to prevent race conditions.
|
||||
// Returns the compiled regex or an error if the pattern is invalid or unsafe.
|
||||
func CompileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
// Validate pattern first
|
||||
@@ -70,12 +70,15 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if cached, ok := regexCache.Load(pattern); ok {
|
||||
return cached.(*regexp.Regexp), nil
|
||||
// Check cache first (read lock)
|
||||
cacheMu.RLock()
|
||||
if cached, ok := regexCache[pattern]; ok {
|
||||
cacheMu.RUnlock()
|
||||
return cached, nil
|
||||
}
|
||||
cacheMu.RUnlock()
|
||||
|
||||
// Compile regex
|
||||
// Compile regex outside the lock to avoid holding it during compilation
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, &RegexError{
|
||||
@@ -85,18 +88,22 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache size and clear if too large
|
||||
if cacheSize.Load() >= MaxCacheSize {
|
||||
ClearRegexCache()
|
||||
// Write lock to store in cache
|
||||
cacheMu.Lock()
|
||||
// Re-check in case another goroutine stored it while we were compiling
|
||||
if cached, ok := regexCache[pattern]; ok {
|
||||
cacheMu.Unlock()
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// Try to store - if another goroutine already stored it, use theirs
|
||||
// This prevents race conditions where multiple goroutines compile the same pattern
|
||||
actual, loaded := regexCache.LoadOrStore(pattern, re)
|
||||
if !loaded {
|
||||
cacheSize.Add(1)
|
||||
// Check cache size and clear if too large
|
||||
if len(regexCache) >= MaxCacheSize {
|
||||
regexCache = make(map[string]*regexp.Regexp)
|
||||
}
|
||||
return actual.(*regexp.Regexp), nil
|
||||
|
||||
regexCache[pattern] = re
|
||||
cacheMu.Unlock()
|
||||
return re, nil
|
||||
}
|
||||
|
||||
// CompileRegexUncached compiles a regex pattern without caching.
|
||||
@@ -118,18 +125,19 @@ func CompileRegexUncached(pattern string) (*regexp.Regexp, error) {
|
||||
}
|
||||
|
||||
// ClearRegexCache clears all cached compiled regular expressions.
|
||||
// Useful for testing or when memory usage needs to be reduced.
|
||||
// Atomically replaces the map under a single write lock.
|
||||
func ClearRegexCache() {
|
||||
regexCache.Range(func(key, _ interface{}) bool {
|
||||
regexCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
cacheSize.Store(0)
|
||||
cacheMu.Lock()
|
||||
regexCache = make(map[string]*regexp.Regexp)
|
||||
cacheMu.Unlock()
|
||||
}
|
||||
|
||||
// CacheStats returns the current number of cached patterns.
|
||||
func CacheStats() int64 {
|
||||
return cacheSize.Load()
|
||||
cacheMu.RLock()
|
||||
n := int64(len(regexCache))
|
||||
cacheMu.RUnlock()
|
||||
return n
|
||||
}
|
||||
|
||||
// truncatePattern truncates a pattern for display in error messages.
|
||||
|
||||
Reference in New Issue
Block a user