feat(docs, ci, config): add comprehensive documentation and tooling

- [x] Add API reference documentation with tool descriptions and examples
- [x] Add ERROR_CODES reference with error descriptions and remediation steps
- [x] Add PERFORMANCE tuning guide with caching and optimization details
- [x] Add GitHub Actions workflows for linting and security scanning
- [x] Add golangci-lint configuration with comprehensive linter settings
- [x] Add pre-commit hooks configuration for local development
- [x] Add API documentation generator tool (cmd/docgen)
- [x] Update Go version from 1.24 to 1.25 across workflows
- [x] Add static build configuration to goreleaser
- [x] Add metrics package with Prometheus-style metric types
- [x] Add parser benchmarks for performance testing
- [x] Add LSP manager integration tests
- [x] Add server integration tests with MCP protocol flow testing
- [x] Extract regex cache to shared utility package
- [x] Add context cancellation handling in AST queries
- [x] Add graceful shutdown with timeout to server
- [x] Add configurable max parse size (MaxParseSize)
- [x] Add Config.Validate() method with comprehensive checks
- [x] Add parser cache statistics tracking
- [x] Add file permission preservation in edit operations
- [x] Improve line splitting for large files with bufio.Scanner
- [x] Add comprehensive config tests for edge cases
- [x] Update Makefile with new targets and documentation
This commit is contained in:
2026-01-28 20:43:20 +00:00
parent 143a166249
commit 9205b2bc26
27 changed files with 6332 additions and 1634 deletions
+42
View File
@@ -2,6 +2,7 @@
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
@@ -17,6 +18,7 @@ type Config struct {
LSPTimeout time.Duration `json:"lsp_timeout"`
SearchTimeout time.Duration `json:"search_timeout"`
MaxFileSize int64 `json:"max_file_size"`
MaxParseSize int64 `json:"max_parse_size"`
MaxSearchResults int `json:"max_search_results"`
MaxEditSize int64 `json:"max_edit_size"`
EnableLSP bool `json:"enable_lsp"`
@@ -29,6 +31,7 @@ const (
DefaultLSPTimeout = 5 * time.Minute
DefaultSearchTimeout = 30 * time.Second
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
DefaultMaxParseSize = 10 * 1024 * 1024 // 10 MB
DefaultMaxSearchResults = 1000
DefaultMaxEditSize = 100 * 1024 // 100 KB
)
@@ -40,6 +43,7 @@ func Default() *Config {
LSPTimeout: DefaultLSPTimeout,
SearchTimeout: DefaultSearchTimeout,
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: DefaultMaxParseSize,
MaxSearchResults: DefaultMaxSearchResults,
MaxEditSize: DefaultMaxEditSize,
EnableLSP: true,
@@ -172,3 +176,41 @@ func (c *Config) IsPathAllowed(path string) bool {
// Also reject empty relative path (which means it's the workspace root itself)
return rel != "." && !strings.HasPrefix(rel, "..")
}
// Validate validates the configuration and returns an error if invalid.
// Checks include:
// - MaxFileSize and MaxParseSize must be positive
// - LSPTimeout must be positive
// - WorkspaceRoot must exist (when not empty)
func (c *Config) Validate() error {
// Validate MaxFileSize
if c.MaxFileSize <= 0 {
return fmt.Errorf("max_file_size must be positive, got %d", c.MaxFileSize)
}
// Validate MaxParseSize
if c.MaxParseSize <= 0 {
return fmt.Errorf("max_parse_size must be positive, got %d", c.MaxParseSize)
}
// Validate LSPTimeout
if c.LSPTimeout <= 0 {
return fmt.Errorf("lsp_timeout must be positive, got %v", c.LSPTimeout)
}
// Validate WorkspaceRoot exists
if c.WorkspaceRoot != "" {
info, err := os.Stat(c.WorkspaceRoot)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("workspace_root does not exist: %s", c.WorkspaceRoot)
}
return fmt.Errorf("cannot access workspace_root: %w", err)
}
if !info.IsDir() {
return fmt.Errorf("workspace_root is not a directory: %s", c.WorkspaceRoot)
}
}
return nil
}
+382 -4
View File
@@ -45,7 +45,7 @@ func TestLoad(t *testing.T) {
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
cfg, err := Load(tmpDir)
if err != nil {
@@ -108,7 +108,7 @@ func TestIsPathAllowed(t *testing.T) {
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
cfg := Default()
cfg.WorkspaceRoot = tmpDir
@@ -156,7 +156,7 @@ func TestLoadWithConfigFile(t *testing.T) {
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
// Write config file
configPath := filepath.Join(tmpDir, ".mcp-filepuff.json")
@@ -164,7 +164,7 @@ func TestLoadWithConfigFile(t *testing.T) {
"enable_lsp": false,
"follow_symlinks": false
}`
err = os.WriteFile(configPath, []byte(configContent), 0600)
err = os.WriteFile(configPath, []byte(configContent), 0o600)
if err != nil {
t.Fatalf("failed to write config file: %v", err)
}
@@ -182,3 +182,381 @@ func TestLoadWithConfigFile(t *testing.T) {
t.Error("expected FollowSymlinks to be false from config file")
}
}
// TestValidate tests the Validate method with various inputs.
func TestValidate(t *testing.T) {
tests := []struct {
name string
cfg *Config
expectErr bool
errMsg string
}{
{
name: "valid_config",
cfg: Default(),
expectErr: false,
},
{
name: "invalid_max_file_size",
cfg: &Config{
WorkspaceRoot: ".",
MaxFileSize: -1,
MaxParseSize: DefaultMaxParseSize,
LSPTimeout: DefaultLSPTimeout,
},
expectErr: true,
errMsg: "max_file_size must be positive",
},
{
name: "zero_max_file_size",
cfg: &Config{
WorkspaceRoot: ".",
MaxFileSize: 0,
MaxParseSize: DefaultMaxParseSize,
LSPTimeout: DefaultLSPTimeout,
},
expectErr: true,
errMsg: "max_file_size must be positive",
},
{
name: "invalid_max_parse_size",
cfg: &Config{
WorkspaceRoot: ".",
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: -1,
LSPTimeout: DefaultLSPTimeout,
},
expectErr: true,
errMsg: "max_parse_size must be positive",
},
{
name: "zero_max_parse_size",
cfg: &Config{
WorkspaceRoot: ".",
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: 0,
LSPTimeout: DefaultLSPTimeout,
},
expectErr: true,
errMsg: "max_parse_size must be positive",
},
{
name: "invalid_lsp_timeout",
cfg: &Config{
WorkspaceRoot: ".",
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: DefaultMaxParseSize,
LSPTimeout: -1 * time.Second,
},
expectErr: true,
errMsg: "lsp_timeout must be positive",
},
{
name: "zero_lsp_timeout",
cfg: &Config{
WorkspaceRoot: ".",
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: DefaultMaxParseSize,
LSPTimeout: 0,
},
expectErr: true,
errMsg: "lsp_timeout must be positive",
},
{
name: "nonexistent_workspace",
cfg: &Config{
WorkspaceRoot: "/nonexistent/path/that/does/not/exist",
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: DefaultMaxParseSize,
LSPTimeout: DefaultLSPTimeout,
},
expectErr: true,
errMsg: "workspace_root does not exist",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.cfg.Validate()
if tt.expectErr {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.errMsg)
} else if !contains(err.Error(), tt.errMsg) {
t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
// TestValidateWithFile tests validation with an actual file as workspace root.
func TestValidateWithFile(t *testing.T) {
// Create a temporary file
tmpFile, err := os.CreateTemp("", "test-file-*.txt")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
_ = tmpFile.Close()
t.Cleanup(func() { _ = os.Remove(tmpFile.Name()) })
cfg := &Config{
WorkspaceRoot: tmpFile.Name(),
MaxFileSize: DefaultMaxFileSize,
MaxParseSize: DefaultMaxParseSize,
LSPTimeout: DefaultLSPTimeout,
}
err = cfg.Validate()
if err == nil {
t.Error("expected error when workspace_root is a file, got nil")
} else if !contains(err.Error(), "is not a directory") {
t.Errorf("expected error about not being a directory, got: %v", err)
}
}
// TestLoadEnvironmentPrecedence tests environment variable precedence.
func TestLoadEnvironmentPrecedence(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
// Write a config file with specific values
configPath := filepath.Join(tmpDir, ".mcp-filepuff.json")
configContent := `{
"enable_lsp": false,
"follow_symlinks": false,
"lsp_timeout": 60000000000
}`
if err := os.WriteFile(configPath, []byte(configContent), 0o600); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
// Save and restore environment variables
origEnableLSP := os.Getenv("MCP_ENABLE_LSP")
origLSPTimeout := os.Getenv("MCP_LSP_TIMEOUT")
t.Cleanup(func() {
_ = os.Setenv("MCP_ENABLE_LSP", origEnableLSP)
_ = os.Setenv("MCP_LSP_TIMEOUT", origLSPTimeout)
})
// Set environment variables that should override config file
_ = os.Setenv("MCP_ENABLE_LSP", "false")
_ = os.Setenv("MCP_LSP_TIMEOUT", "2m")
cfg, err := Load(tmpDir)
if err != nil {
t.Fatalf("Load failed: %v", err)
}
// Environment variable should override config file
if cfg.LSPTimeout != 2*time.Minute {
t.Errorf("expected LSP timeout 2m from env, got %v", cfg.LSPTimeout)
}
if cfg.EnableLSP {
t.Error("expected EnableLSP to be false from env")
}
// Value from config file (not overridden by env)
if cfg.FollowSymlinks {
t.Error("expected FollowSymlinks to be false from config file")
}
}
// TestIsPathAllowedEdgeCases tests edge cases in path validation.
func TestIsPathAllowedEdgeCases(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
cfg := Default()
cfg.WorkspaceRoot = tmpDir
tests := []struct {
name string
path string
allowed bool
desc string
}{
{
name: "workspace_root_itself",
path: tmpDir,
allowed: false,
desc: "workspace root itself should not be allowed",
},
{
name: "dot_relative",
path: ".",
allowed: false,
desc: "current directory should not be allowed",
},
{
name: "empty_path",
path: "",
allowed: false,
desc: "empty path should not be allowed",
},
{
name: "path_with_double_dots",
path: filepath.Join(tmpDir, "..", filepath.Base(tmpDir), "file.txt"),
allowed: true,
desc: "path with .. that resolves back inside workspace should be allowed",
},
{
name: "deeply_nested_valid",
path: filepath.Join(tmpDir, "a", "b", "c", "file.txt"),
allowed: true,
desc: "deeply nested path should be allowed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := cfg.IsPathAllowed(tt.path)
if result != tt.allowed {
t.Errorf("%s: IsPathAllowed(%q) = %v, want %v", tt.desc, tt.path, result, tt.allowed)
}
})
}
}
// TestIsPathAllowedWithSymlinks tests path validation with symbolic links.
func TestIsPathAllowedWithSymlinks(t *testing.T) {
// Create temporary directories
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
realDir := filepath.Join(tmpDir, "real")
if err := os.MkdirAll(realDir, 0o755); err != nil {
t.Fatalf("failed to create real dir: %v", err)
}
// Create a symlink inside workspace
symlinkPath := filepath.Join(tmpDir, "link")
if err := os.Symlink(realDir, symlinkPath); err != nil {
t.Skip("symlink creation not supported on this system")
}
cfg := Default()
cfg.WorkspaceRoot = tmpDir
// File accessed through symlink should be allowed
fileViaSymlink := filepath.Join(symlinkPath, "test.txt")
if !cfg.IsPathAllowed(fileViaSymlink) {
t.Error("file accessed through symlink inside workspace should be allowed")
}
// Direct access should also work
fileDirect := filepath.Join(realDir, "test.txt")
if !cfg.IsPathAllowed(fileDirect) {
t.Error("file accessed directly should be allowed")
}
}
// TestLoadDefaultValues tests that default values are properly set.
func TestLoadDefaultValues(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
// Clear any environment variables that might affect defaults
origVars := []struct{ key, val string }{
{"MCP_ENABLE_LSP", os.Getenv("MCP_ENABLE_LSP")},
{"MCP_FOLLOW_SYMLINKS", os.Getenv("MCP_FOLLOW_SYMLINKS")},
{"MCP_RESPECT_GITIGNORE", os.Getenv("MCP_RESPECT_GITIGNORE")},
}
t.Cleanup(func() {
for _, v := range origVars {
_ = os.Setenv(v.key, v.val)
}
})
for _, v := range origVars {
_ = os.Unsetenv(v.key)
}
cfg, err := Load(tmpDir)
if err != nil {
t.Fatalf("Load failed: %v", err)
}
// Verify all default values
if cfg.LSPTimeout != DefaultLSPTimeout {
t.Errorf("expected LSPTimeout %v, got %v", DefaultLSPTimeout, cfg.LSPTimeout)
}
if cfg.SearchTimeout != DefaultSearchTimeout {
t.Errorf("expected SearchTimeout %v, got %v", DefaultSearchTimeout, cfg.SearchTimeout)
}
if cfg.MaxFileSize != DefaultMaxFileSize {
t.Errorf("expected MaxFileSize %d, got %d", DefaultMaxFileSize, cfg.MaxFileSize)
}
if cfg.MaxParseSize != DefaultMaxParseSize {
t.Errorf("expected MaxParseSize %d, got %d", DefaultMaxParseSize, cfg.MaxParseSize)
}
if cfg.MaxSearchResults != DefaultMaxSearchResults {
t.Errorf("expected MaxSearchResults %d, got %d", DefaultMaxSearchResults, cfg.MaxSearchResults)
}
if cfg.MaxEditSize != DefaultMaxEditSize {
t.Errorf("expected MaxEditSize %d, got %d", DefaultMaxEditSize, cfg.MaxEditSize)
}
if !cfg.EnableLSP {
t.Error("expected EnableLSP to be true by default")
}
if !cfg.FollowSymlinks {
t.Error("expected FollowSymlinks to be true by default")
}
if !cfg.RespectGitignore {
t.Error("expected RespectGitignore to be true by default")
}
if cfg.Formatters == nil {
t.Error("expected Formatters map to be initialized")
}
}
// TestConfigFileLoadingErrors tests error handling during config file loading.
func TestConfigFileLoadingErrors(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
// Write invalid JSON
configPath := filepath.Join(tmpDir, ".mcp-filepuff.json")
invalidJSON := `{"enable_lsp": invalid_value}`
if err := os.WriteFile(configPath, []byte(invalidJSON), 0o600); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
_, err = Load(tmpDir)
if err == nil {
t.Error("expected error when loading invalid JSON config file")
}
}
// Helper function to check if a string contains a substring.
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && containsHelper(s, substr)))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
+19 -25
View File
@@ -6,37 +6,17 @@ import (
"context"
"fmt"
"os"
"regexp"
"strings"
"sync"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/util"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/sergi/go-diff/diffmatchpatch"
sitter "github.com/smacker/go-tree-sitter"
)
// Global regex cache for compiled patterns (thread-safe)
var regexCache sync.Map // string -> *regexp.Regexp
// compileRegex compiles a regex pattern with caching for performance.
func compileRegex(pattern string) (*regexp.Regexp, error) {
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
}
// Compile and cache
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
regexCache.Store(pattern, re)
return re, nil
}
// EditOperation defines the type of edit operation.
type EditOperation string
@@ -198,7 +178,14 @@ func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool)
// Apply changes if requested
if apply {
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
// Preserve original file permissions
fileInfo, err := os.Stat(edit.File)
perm := os.FileMode(0o600) // default fallback
if err == nil {
perm = fileInfo.Mode().Perm()
}
if err := os.WriteFile(edit.File, newContent, perm); err != nil {
structuredErr := errors.NewFileNotWritableError(edit.File, err)
return &EditResult{
Success: false,
@@ -250,7 +237,14 @@ func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) (
// Apply changes if requested
if apply {
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
// Preserve original file permissions
fileInfo, err := os.Stat(edit.File)
perm := os.FileMode(0o600) // default fallback
if err == nil {
perm = fileInfo.Mode().Perm()
}
if err := os.WriteFile(edit.File, newContent, perm); err != nil {
structuredErr := errors.NewFileNotWritableError(edit.File, err)
return &EditResult{
Success: false,
@@ -319,7 +313,7 @@ func (e *Engine) validateTextEdit(edit *ASTEdit) error {
// Validate regex pattern if provided (uses cached compilation)
if edit.Selector.TextPattern != "" {
if _, err := compileRegex(edit.Selector.TextPattern); err != nil {
if _, err := util.CompileRegex(edit.Selector.TextPattern); err != nil {
return errors.Wrap(errors.ErrInvalidEdit, "invalid text_pattern regex", err)
}
}
@@ -672,7 +666,7 @@ func (e *Engine) findExactText(content []byte, text string, index int) (start, e
// findRegexPattern finds a regex pattern match in content.
func (e *Engine) findRegexPattern(content []byte, pattern string, index int) (start, end int, err error) {
re, err := compileRegex(pattern)
re, err := util.CompileRegex(pattern)
if err != nil {
return 0, 0, errors.Wrap(errors.ErrInvalidEdit, "invalid regex pattern", err)
}
+15 -5
View File
@@ -153,13 +153,20 @@ func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*Manag
openDocs: make(map[string]int),
}
// Setup cleanup on failure - ensures resources are freed if initialization fails
var initialized bool
defer func() {
if !initialized {
_ = client.Close()
// Ensure process is killed on initialization failure
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
}
}()
// Initialize server
if err := m.initializeServer(ctx, newSrv); err != nil {
_ = client.Close()
// Ensure process is killed on initialization failure
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
newSrv.initErr = err
return nil, errors.Wrap(errors.ErrLSPInitFailed, "LSP server initialization failed", err).
WithContext("language", string(lang)).
@@ -167,6 +174,9 @@ func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*Manag
WithRemediation("Check LSP server logs for initialization errors")
}
// Mark as successfully initialized to prevent cleanup
initialized = true
newSrv.ready = true
m.servers[lang] = newSrv
m.logger.Info("started LSP server", "language", lang, "command", config.Command[0])
+231
View File
@@ -1,7 +1,11 @@
package lsp
import (
"context"
"log/slog"
"os"
"testing"
"time"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
@@ -110,3 +114,230 @@ func TestDefaultServerConfigs(t *testing.T) {
}
}
}
// TestManagerTimeout tests timeout handling in LSP operations.
func TestManagerTimeout(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
t.Cleanup(func() { _ = manager.Close() })
// Verify timeout is set
if manager.timeout == 0 {
t.Error("manager timeout should not be zero")
}
// Verify default timeout is reasonable
if manager.timeout != 10*time.Second {
t.Errorf("expected default timeout of 10s, got %v", manager.timeout)
}
// Test that manager can handle short timeouts
manager.timeout = 1 * time.Millisecond
// Create a context that will timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
// Try to get a server with very short timeout - this should fail quickly
// Use a language that doesn't have an LSP server installed
_, err := manager.GetServer(ctx, "invalid_language")
if err == nil {
t.Log("GetServer with invalid language succeeded (LSP server may be installed)")
}
}
// TestManagerConnectionFailure tests handling of LSP connection failures.
func TestManagerConnectionFailure(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
t.Cleanup(func() { _ = manager.Close() })
ctx := context.Background()
// Test 1: Invalid language
_, err := manager.GetServer(ctx, "nonexistent_language")
if err == nil {
t.Error("expected error for nonexistent language")
}
// Test 2: Try to use LSP features without a valid server
// This should fail gracefully
_, err = manager.Hover(ctx, "/tmp/test.fake", 1, 1)
if err == nil {
t.Error("expected error for hover on unsupported language")
}
}
// TestManagerGracefulShutdown tests graceful shutdown of LSP servers.
func TestManagerGracefulShutdown(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
// Close should not panic even with no servers started
err := manager.Close()
if err != nil {
t.Errorf("Close() returned error: %v", err)
}
// Verify manager is stopped
if !manager.stopped {
t.Error("manager should be marked as stopped after Close()")
}
// Note: We don't test multiple Close() calls because the implementation
// closes the stopReaper channel which can't be closed twice.
// In production, Close() should only be called once during shutdown.
}
// TestManagerIdleReaper tests the idle server cleanup mechanism.
func TestManagerIdleReaper(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
// Set a very short idle timeout for testing
manager.idleTimeout = 100 * time.Millisecond
// Verify idle timeout is set correctly
if manager.idleTimeout != 100*time.Millisecond {
t.Errorf("expected idle timeout of 100ms, got %v", manager.idleTimeout)
}
// The reaper goroutine should be running
// We can't easily test it without actually starting LSP servers,
// but we can verify it doesn't panic on close
time.Sleep(150 * time.Millisecond)
err := manager.Close()
if err != nil {
t.Errorf("Close() with active reaper returned error: %v", err)
}
}
// TestManagerDocumentManagement tests document open/close operations.
func TestManagerDocumentManagement(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
t.Cleanup(func() { _ = manager.Close() })
ctx := context.Background()
// Test closing a document for a non-existent server
err := manager.CloseDocument(ctx, protocol.LangGo, "/tmp/test.go")
if err != nil {
t.Errorf("CloseDocument on non-existent server should not error: %v", err)
}
// Test closing a document that was never opened
err = manager.CloseDocument(ctx, protocol.LangGo, "/tmp/test.go")
if err != nil {
t.Errorf("CloseDocument on unopened document should not error: %v", err)
}
}
// TestManagerConcurrentAccess tests concurrent access to the manager.
func TestManagerConcurrentAccess(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
t.Cleanup(func() { _ = manager.Close() })
// Test concurrent IsAvailable calls
const numGoroutines = 10
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer func() {
if r := recover(); r != nil {
t.Errorf("panic in concurrent IsAvailable: %v", r)
}
done <- true
}()
// Call IsAvailable multiple times
for j := 0; j < 5; j++ {
_ = manager.IsAvailable(protocol.LangGo)
_ = manager.IsAvailable(protocol.LangPython)
_ = manager.IsAvailable(protocol.LangTypeScript)
}
}()
}
// Wait for all goroutines
for i := 0; i < numGoroutines; i++ {
<-done
}
}
// TestManagerErrorHandling tests various error conditions.
func TestManagerErrorHandling(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
t.Cleanup(func() { _ = manager.Close() })
ctx := context.Background()
tests := []struct {
name string
testFunc func() error
}{
{
name: "hover_on_nonexistent_file",
testFunc: func() error {
_, err := manager.Hover(ctx, "/nonexistent/file.go", 1, 1)
return err
},
},
{
name: "definition_on_nonexistent_file",
testFunc: func() error {
_, err := manager.Definition(ctx, "/nonexistent/file.go", 1, 1)
return err
},
},
{
name: "references_on_nonexistent_file",
testFunc: func() error {
_, err := manager.References(ctx, "/nonexistent/file.go", 1, 1, true)
return err
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.testFunc()
if err == nil {
t.Log("operation succeeded (LSP server may be handling gracefully)")
}
// We don't require an error because behavior depends on whether
// the LSP server is installed and how it handles missing files
})
}
}
// TestManagerWorkspaceRoot tests workspace root handling.
func TestManagerWorkspaceRoot(t *testing.T) {
tmpDir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
manager := NewManager(tmpDir, logger)
t.Cleanup(func() { _ = manager.Close() })
if manager.workspaceRoot != tmpDir {
t.Errorf("expected workspace root %s, got %s", tmpDir, manager.workspaceRoot)
}
}
+464
View File
@@ -0,0 +1,464 @@
// Package metrics provides Prometheus-style metrics collection for the MCP server.
// It offers low-overhead, thread-safe metric types suitable for observability.
package metrics
import (
"fmt"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
)
// Counter is a monotonically increasing metric.
type Counter struct {
name string
help string
labels map[string]string
value atomic.Int64
}
// NewCounter creates a new counter metric.
func NewCounter(name, help string, labels map[string]string) *Counter {
return &Counter{
name: name,
help: help,
labels: labels,
}
}
// Inc increments the counter by 1.
func (c *Counter) Inc() {
c.value.Add(1)
}
// Add adds the given value to the counter.
func (c *Counter) Add(delta int64) {
c.value.Add(delta)
}
// Value returns the current counter value.
func (c *Counter) Value() int64 {
return c.value.Load()
}
// Reset resets the counter to 0.
func (c *Counter) Reset() {
c.value.Store(0)
}
// Gauge is a metric that can go up or down.
type Gauge struct {
name string
help string
labels map[string]string
value atomic.Int64
}
// NewGauge creates a new gauge metric.
func NewGauge(name, help string, labels map[string]string) *Gauge {
return &Gauge{
name: name,
help: help,
labels: labels,
}
}
// Set sets the gauge to the given value.
func (g *Gauge) Set(val int64) {
g.value.Store(val)
}
// Inc increments the gauge by 1.
func (g *Gauge) Inc() {
g.value.Add(1)
}
// Dec decrements the gauge by 1.
func (g *Gauge) Dec() {
g.value.Add(-1)
}
// Add adds the given value to the gauge.
func (g *Gauge) Add(delta int64) {
g.value.Add(delta)
}
// Value returns the current gauge value.
func (g *Gauge) Value() int64 {
return g.value.Load()
}
// Histogram tracks the distribution of values in predefined buckets.
type Histogram struct {
name string
help string
labels map[string]string
buckets []float64
counts []atomic.Int64
sum atomic.Int64 // sum of all observed values (in nanoseconds for durations)
count atomic.Int64 // total count of observations
}
// DefaultDurationBuckets are default buckets for request durations (in seconds).
var DefaultDurationBuckets = []float64{
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
}
// NewHistogram creates a new histogram with the given buckets.
func NewHistogram(name, help string, labels map[string]string, buckets []float64) *Histogram {
if buckets == nil {
buckets = DefaultDurationBuckets
}
// Ensure buckets are sorted
sorted := make([]float64, len(buckets))
copy(sorted, buckets)
sort.Float64s(sorted)
h := &Histogram{
name: name,
help: help,
labels: labels,
buckets: sorted,
counts: make([]atomic.Int64, len(sorted)+1), // +1 for +Inf bucket
}
return h
}
// Observe records a value in the histogram.
func (h *Histogram) Observe(val float64) {
h.count.Add(1)
// Store sum in nanoseconds for precision with durations
h.sum.Add(int64(val * 1e9))
// Find bucket and increment
for i, bound := range h.buckets {
if val <= bound {
h.counts[i].Add(1)
return
}
}
// Value exceeds all buckets, add to +Inf
h.counts[len(h.buckets)].Add(1)
}
// ObserveDuration records a duration in seconds.
func (h *Histogram) ObserveDuration(d time.Duration) {
h.Observe(d.Seconds())
}
// Count returns the total number of observations.
func (h *Histogram) Count() int64 {
return h.count.Load()
}
// Sum returns the sum of all observations.
func (h *Histogram) Sum() float64 {
return float64(h.sum.Load()) / 1e9
}
// Registry holds all registered metrics.
type Registry struct {
mu sync.RWMutex
counters map[string]*Counter
gauges map[string]*Gauge
histograms map[string]*Histogram
}
// NewRegistry creates a new metrics registry.
func NewRegistry() *Registry {
return &Registry{
counters: make(map[string]*Counter),
gauges: make(map[string]*Gauge),
histograms: make(map[string]*Histogram),
}
}
// Counter returns or creates a counter with the given name.
func (r *Registry) Counter(name, help string, labels map[string]string) *Counter {
key := metricKey(name, labels)
r.mu.RLock()
if c, ok := r.counters[key]; ok {
r.mu.RUnlock()
return c
}
r.mu.RUnlock()
r.mu.Lock()
defer r.mu.Unlock()
// Double-check after acquiring write lock
if c, ok := r.counters[key]; ok {
return c
}
c := NewCounter(name, help, labels)
r.counters[key] = c
return c
}
// Gauge returns or creates a gauge with the given name.
func (r *Registry) Gauge(name, help string, labels map[string]string) *Gauge {
key := metricKey(name, labels)
r.mu.RLock()
if g, ok := r.gauges[key]; ok {
r.mu.RUnlock()
return g
}
r.mu.RUnlock()
r.mu.Lock()
defer r.mu.Unlock()
if g, ok := r.gauges[key]; ok {
return g
}
g := NewGauge(name, help, labels)
r.gauges[key] = g
return g
}
// Histogram returns or creates a histogram with the given name.
func (r *Registry) Histogram(name, help string, labels map[string]string, buckets []float64) *Histogram {
key := metricKey(name, labels)
r.mu.RLock()
if h, ok := r.histograms[key]; ok {
r.mu.RUnlock()
return h
}
r.mu.RUnlock()
r.mu.Lock()
defer r.mu.Unlock()
if h, ok := r.histograms[key]; ok {
return h
}
h := NewHistogram(name, help, labels, buckets)
r.histograms[key] = h
return h
}
// metricKey creates a unique key for a metric based on name and labels.
func metricKey(name string, labels map[string]string) string {
if len(labels) == 0 {
return name
}
var parts []string
for k, v := range labels {
parts = append(parts, fmt.Sprintf("%s=%q", k, v))
}
sort.Strings(parts)
return name + "{" + strings.Join(parts, ",") + "}"
}
// formatLabels formats labels for Prometheus output.
func formatLabels(labels map[string]string) string {
if len(labels) == 0 {
return ""
}
var parts []string
for k, v := range labels {
parts = append(parts, fmt.Sprintf("%s=%q", k, v))
}
sort.Strings(parts)
return "{" + strings.Join(parts, ",") + "}"
}
// Expose returns all metrics in Prometheus text format.
func (r *Registry) Expose() string {
r.mu.RLock()
defer r.mu.RUnlock()
var sb strings.Builder
// Export counters
for _, c := range r.counters {
if c.help != "" {
sb.WriteString(fmt.Sprintf("# HELP %s %s\n", c.name, c.help))
}
sb.WriteString(fmt.Sprintf("# TYPE %s counter\n", c.name))
sb.WriteString(fmt.Sprintf("%s%s %d\n", c.name, formatLabels(c.labels), c.value.Load()))
}
// Export gauges
for _, g := range r.gauges {
if g.help != "" {
sb.WriteString(fmt.Sprintf("# HELP %s %s\n", g.name, g.help))
}
sb.WriteString(fmt.Sprintf("# TYPE %s gauge\n", g.name))
sb.WriteString(fmt.Sprintf("%s%s %d\n", g.name, formatLabels(g.labels), g.value.Load()))
}
// Export histograms
for _, h := range r.histograms {
if h.help != "" {
sb.WriteString(fmt.Sprintf("# HELP %s %s\n", h.name, h.help))
}
sb.WriteString(fmt.Sprintf("# TYPE %s histogram\n", h.name))
// Cumulative bucket counts
var cumulative int64
for i, bound := range h.buckets {
cumulative += h.counts[i].Load()
labelStr := formatLabels(h.labels)
if labelStr == "" {
sb.WriteString(fmt.Sprintf("%s_bucket{le=\"%g\"} %d\n", h.name, bound, cumulative))
} else {
// Insert le label into existing labels
sb.WriteString(fmt.Sprintf("%s_bucket%s %d\n", h.name,
strings.Replace(labelStr, "}", fmt.Sprintf(",le=\"%g\"}", bound), 1), cumulative))
}
}
// +Inf bucket
cumulative += h.counts[len(h.buckets)].Load()
labelStr := formatLabels(h.labels)
if labelStr == "" {
sb.WriteString(fmt.Sprintf("%s_bucket{le=\"+Inf\"} %d\n", h.name, cumulative))
} else {
sb.WriteString(fmt.Sprintf("%s_bucket%s %d\n", h.name,
strings.Replace(labelStr, "}", ",le=\"+Inf\"}", 1), cumulative))
}
// Sum and count
sb.WriteString(fmt.Sprintf("%s_sum%s %g\n", h.name, formatLabels(h.labels), h.Sum()))
sb.WriteString(fmt.Sprintf("%s_count%s %d\n", h.name, formatLabels(h.labels), h.count.Load()))
}
return sb.String()
}
// Reset resets all metrics to zero.
func (r *Registry) Reset() {
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.counters {
c.Reset()
}
for _, g := range r.gauges {
g.Set(0)
}
// Note: Histograms don't have a simple reset due to atomic bucket counts
}
// ServerMetrics provides pre-defined metrics for the MCP server.
type ServerMetrics struct {
registry *Registry
// Request metrics
RequestsTotal *Counter
RequestErrors *Counter
RequestDuration *Histogram
// Cache metrics
CacheHits *Counter
CacheMisses *Counter
// LSP metrics
ActiveLSPServers *Gauge
// Parse metrics
ParseDuration *Histogram
ParseErrors *Counter
}
// NewServerMetrics creates a new set of server metrics.
func NewServerMetrics() *ServerMetrics {
r := NewRegistry()
return &ServerMetrics{
registry: r,
RequestsTotal: r.Counter(
"mcp_requests_total",
"Total number of MCP requests processed",
nil,
),
RequestErrors: r.Counter(
"mcp_request_errors_total",
"Total number of MCP request errors",
nil,
),
RequestDuration: r.Histogram(
"mcp_request_duration_seconds",
"Request duration in seconds",
nil,
DefaultDurationBuckets,
),
CacheHits: r.Counter(
"mcp_cache_hits_total",
"Total number of cache hits",
nil,
),
CacheMisses: r.Counter(
"mcp_cache_misses_total",
"Total number of cache misses",
nil,
),
ActiveLSPServers: r.Gauge(
"mcp_lsp_servers_active",
"Number of active LSP server connections",
nil,
),
ParseDuration: r.Histogram(
"mcp_parse_duration_seconds",
"Parse duration in seconds",
nil,
[]float64{0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0},
),
ParseErrors: r.Counter(
"mcp_parse_errors_total",
"Total number of parse errors",
nil,
),
}
}
// Expose returns all metrics in Prometheus text format.
func (m *ServerMetrics) Expose() string {
return m.registry.Expose()
}
// Registry returns the underlying metrics registry.
func (m *ServerMetrics) Registry() *Registry {
return m.registry
}
// RecordRequest records a request with its duration and error status.
func (m *ServerMetrics) RecordRequest(duration time.Duration, err error) {
m.RequestsTotal.Inc()
m.RequestDuration.ObserveDuration(duration)
if err != nil {
m.RequestErrors.Inc()
}
}
// RecordParse records a parse operation with its duration and error status.
func (m *ServerMetrics) RecordParse(duration time.Duration, err error) {
m.ParseDuration.ObserveDuration(duration)
if err != nil {
m.ParseErrors.Inc()
}
}
// RecordCacheHit records a cache hit.
func (m *ServerMetrics) RecordCacheHit() {
m.CacheHits.Inc()
}
// RecordCacheMiss records a cache miss.
func (m *ServerMetrics) RecordCacheMiss() {
m.CacheMisses.Inc()
}
// SetActiveLSPServers sets the number of active LSP servers.
func (m *ServerMetrics) SetActiveLSPServers(count int64) {
m.ActiveLSPServers.Set(count)
}
+101 -10
View File
@@ -5,6 +5,8 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/cespare/xxhash/v2"
lru "github.com/hashicorp/golang-lru/v2"
@@ -22,14 +24,25 @@ import (
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// MaxFileSize is the maximum file size we'll parse (10MB).
// MaxFileSize is the default maximum file size we'll parse (10MB).
// Deprecated: Use Registry.maxParseSize instead.
const MaxFileSize = 10 * 1024 * 1024
// Registry manages Tree-sitter parsers for different languages.
type Registry struct {
parsers map[protocol.Language]*sitter.Parser
cache *lru.Cache[string, *CachedTree]
mu sync.RWMutex
parsers map[protocol.Language]*sitter.Parser
cache *lru.Cache[string, *CachedTree]
maxParseSize int64
mu sync.RWMutex
// Cache metrics (atomic for thread-safety)
cacheHits atomic.Int64
cacheMisses atomic.Int64
// Parse duration tracking
totalParseTime atomic.Int64 // nanoseconds
parseCount atomic.Int64
lastParseDuration atomic.Int64 // nanoseconds
}
// CachedTree stores a parsed tree with its metadata.
@@ -54,8 +67,27 @@ type SyntaxError struct {
Location protocol.Location
}
// NewRegistry creates a new parser registry.
// CacheStatsResult contains cache statistics.
type CacheStatsResult struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
HitRate float64 `json:"hit_rate"`
Size int `json:"size"`
TotalParseTime int64 `json:"total_parse_time_ns"`
ParseCount int64 `json:"parse_count"`
AvgParseTime int64 `json:"avg_parse_time_ns"`
LastParseTime int64 `json:"last_parse_time_ns"`
}
// NewRegistry creates a new parser registry with the default max parse size.
// For custom max parse size, use NewRegistryWithSize.
func NewRegistry() *Registry {
return NewRegistryWithSize(0)
}
// NewRegistryWithSize creates a new parser registry with the specified max parse size.
// If maxParseSize is 0 or negative, uses the default MaxFileSize constant.
func NewRegistryWithSize(maxParseSize int64) *Registry {
// Create LRU cache with capacity of 100 trees
cache, err := lru.New[string, *CachedTree](100)
if err != nil {
@@ -63,9 +95,14 @@ func NewRegistry() *Registry {
panic(fmt.Sprintf("failed to create LRU cache: %v", err))
}
if maxParseSize <= 0 {
maxParseSize = MaxFileSize
}
return &Registry{
parsers: make(map[protocol.Language]*sitter.Parser),
cache: cache,
parsers: make(map[protocol.Language]*sitter.Parser),
cache: cache,
maxParseSize: maxParseSize,
}
}
@@ -130,9 +167,9 @@ func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
// Parse parses the given content for the specified language.
func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
// Check file size against configured limit
if int64(len(content)) > r.maxParseSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), r.maxParseSize)
}
// Detect binary files
@@ -161,6 +198,7 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
// Check cache (LRU cache is thread-safe)
hash := contentHash(content)
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
r.cacheHits.Add(1)
errors := extractErrors(cached.Tree.RootNode(), content)
return &ParseResult{
Tree: cached.Tree,
@@ -169,6 +207,7 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
Content: content,
}, nil
}
r.cacheMisses.Add(1)
// Get parser
parser, err := r.GetParser(lang)
@@ -178,9 +217,17 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
// Parse content - tree-sitter parsers are not thread-safe,
// so we need to hold the lock during parsing
// Track parse duration
start := time.Now()
r.mu.Lock()
tree, err := parser.ParseCtx(ctx, nil, content)
r.mu.Unlock()
duration := time.Since(start)
// Update duration metrics
r.totalParseTime.Add(duration.Nanoseconds())
r.parseCount.Add(1)
r.lastParseDuration.Store(duration.Nanoseconds())
if err != nil {
return nil, errors.NewParseError(string(lang), filename, err)
@@ -203,6 +250,50 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (
}, nil
}
// CacheStats returns cache hit/miss statistics.
func (r *Registry) CacheStats() (hits, misses int64) {
return r.cacheHits.Load(), r.cacheMisses.Load()
}
// CacheStatsDetailed returns detailed cache and parse statistics.
func (r *Registry) CacheStatsDetailed() CacheStatsResult {
hits := r.cacheHits.Load()
misses := r.cacheMisses.Load()
totalParseTime := r.totalParseTime.Load()
parseCount := r.parseCount.Load()
var hitRate float64
total := hits + misses
if total > 0 {
hitRate = float64(hits) / float64(total)
}
var avgParseTime int64
if parseCount > 0 {
avgParseTime = totalParseTime / parseCount
}
return CacheStatsResult{
Hits: hits,
Misses: misses,
HitRate: hitRate,
Size: r.cache.Len(),
TotalParseTime: totalParseTime,
ParseCount: parseCount,
AvgParseTime: avgParseTime,
LastParseTime: r.lastParseDuration.Load(),
}
}
// ResetStats resets all cache and parse statistics.
func (r *Registry) ResetStats() {
r.cacheHits.Store(0)
r.cacheMisses.Store(0)
r.totalParseTime.Store(0)
r.parseCount.Store(0)
r.lastParseDuration.Store(0)
}
// extractErrors finds all error nodes in the tree.
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
var errors []SyntaxError
+459
View File
@@ -0,0 +1,459 @@
package parser
import (
"context"
"strings"
"testing"
)
// BenchmarkParse benchmarks parsing files of various sizes.
func BenchmarkParse(b *testing.B) {
registry := NewRegistry()
defer registry.Close()
ctx := context.Background()
benchmarks := []struct {
name string
content string
}{
{
name: "small_file_100_lines",
content: generateGoCode(100),
},
{
name: "medium_file_1000_lines",
content: generateGoCode(1000),
},
{
name: "large_file_5000_lines",
content: generateGoCode(5000),
},
{
name: "very_large_file_10000_lines",
content: generateGoCode(10000),
},
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
content := []byte(bm.content)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := registry.Parse(ctx, "test.go", content)
if err != nil {
b.Fatalf("Parse failed: %v", err)
}
}
})
}
}
// BenchmarkParseCacheHit benchmarks cache hit performance.
func BenchmarkParseCacheHit(b *testing.B) {
registry := NewRegistry()
defer registry.Close()
ctx := context.Background()
content := []byte(generateGoCode(1000))
// Warm up the cache
_, err := registry.Parse(ctx, "test.go", content)
if err != nil {
b.Fatalf("initial parse failed: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := registry.Parse(ctx, "test.go", content)
if err != nil {
b.Fatalf("Parse failed: %v", err)
}
}
}
// BenchmarkParseCacheMiss benchmarks cache miss performance.
func BenchmarkParseCacheMiss(b *testing.B) {
registry := NewRegistry()
defer registry.Close()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Use different content each time to force cache miss
content := []byte(generateGoCodeWithSuffix(1000, i))
_, err := registry.Parse(ctx, "test.go", content)
if err != nil {
b.Fatalf("Parse failed: %v", err)
}
}
}
// BenchmarkParseLanguages benchmarks parsing different language files.
func BenchmarkParseLanguages(b *testing.B) {
registry := NewRegistry()
defer registry.Close()
ctx := context.Background()
languages := []struct {
name string
filename string
content string
}{
{
name: "go",
filename: "test.go",
content: generateGoCode(500),
},
{
name: "typescript",
filename: "test.ts",
content: generateTypeScriptCode(500),
},
{
name: "python",
filename: "test.py",
content: generatePythonCode(500),
},
{
name: "javascript",
filename: "test.js",
content: generateJavaScriptCode(500),
},
}
for _, lang := range languages {
b.Run(lang.name, func(b *testing.B) {
content := []byte(lang.content)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := registry.Parse(ctx, lang.filename, content)
if err != nil {
b.Fatalf("Parse failed: %v", err)
}
}
})
}
}
// BenchmarkParseComplexity benchmarks parsing files with varying complexity.
func BenchmarkParseComplexity(b *testing.B) {
registry := NewRegistry()
defer registry.Close()
ctx := context.Background()
benchmarks := []struct {
name string
content string
}{
{
name: "simple_functions",
content: generateSimpleFunctions(100),
},
{
name: "nested_structures",
content: generateNestedStructures(50),
},
{
name: "complex_types",
content: generateComplexTypes(50),
},
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
content := []byte(bm.content)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := registry.Parse(ctx, "test.go", content)
if err != nil {
b.Fatalf("Parse failed: %v", err)
}
}
})
}
}
// BenchmarkContentHash benchmarks the content hashing function.
func BenchmarkContentHash(b *testing.B) {
sizes := []int{100, 1000, 10000, 100000}
for _, size := range sizes {
b.Run(formatSize(size), func(b *testing.B) {
content := []byte(strings.Repeat("a", size))
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = contentHash(content)
}
})
}
}
// BenchmarkIsBinary benchmarks the binary detection function.
func BenchmarkIsBinary(b *testing.B) {
sizes := []int{100, 1000, 8000, 10000}
for _, size := range sizes {
b.Run(formatSize(size)+"_text", func(b *testing.B) {
content := []byte(strings.Repeat("Hello, World!\n", size/14))
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = isBinary(content)
}
})
b.Run(formatSize(size)+"_binary", func(b *testing.B) {
content := make([]byte, size)
for j := 0; j < size; j++ {
content[j] = byte(j % 256)
}
content[size/2] = 0 // Add null byte
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = isBinary(content)
}
})
}
}
// BenchmarkParseWithMaxSize benchmarks parsing with different max size limits.
func BenchmarkParseWithMaxSize(b *testing.B) {
ctx := context.Background()
limits := []int64{
10 * 1024, // 10KB
100 * 1024, // 100KB
1024 * 1024, // 1MB
10 * 1024 * 1024, // 10MB
}
content := []byte(generateGoCode(500))
for _, limit := range limits {
b.Run(formatBytes(limit), func(b *testing.B) {
// Skip if content is larger than limit
if int64(len(content)) > limit {
b.Skipf("content size %d exceeds limit %d", len(content), limit)
}
registry := NewRegistryWithSize(limit)
defer registry.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := registry.Parse(ctx, "test.go", content)
if err != nil {
b.Fatalf("Parse failed: %v", err)
}
}
})
}
}
// BenchmarkConcurrentParse benchmarks concurrent parsing operations.
func BenchmarkConcurrentParse(b *testing.B) {
registry := NewRegistry()
defer registry.Close()
ctx := context.Background()
content := []byte(generateGoCode(500))
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := registry.Parse(ctx, "test.go", content)
if err != nil {
b.Fatalf("Parse failed: %v", err)
}
}
})
}
// Helper functions to generate test code
func generateGoCode(lines int) string {
var sb strings.Builder
sb.WriteString("package main\n\n")
for i := 0; i < lines/10; i++ {
sb.WriteString("func Function")
sb.WriteString(itoa(i))
sb.WriteString("(a, b int) int {\n")
sb.WriteString("\tif a > b {\n")
sb.WriteString("\t\treturn a + b\n")
sb.WriteString("\t}\n")
sb.WriteString("\treturn a - b\n")
sb.WriteString("}\n\n")
}
return sb.String()
}
func generateGoCodeWithSuffix(lines int, suffix int) string {
code := generateGoCode(lines)
return code + "// Suffix: " + itoa(suffix) + "\n"
}
func generateTypeScriptCode(lines int) string {
var sb strings.Builder
for i := 0; i < lines/8; i++ {
sb.WriteString("function function")
sb.WriteString(itoa(i))
sb.WriteString("(a: number, b: number): number {\n")
sb.WriteString(" if (a > b) {\n")
sb.WriteString(" return a + b;\n")
sb.WriteString(" }\n")
sb.WriteString(" return a - b;\n")
sb.WriteString("}\n\n")
}
return sb.String()
}
func generatePythonCode(lines int) string {
var sb strings.Builder
for i := 0; i < lines/6; i++ {
sb.WriteString("def function")
sb.WriteString(itoa(i))
sb.WriteString("(a, b):\n")
sb.WriteString(" if a > b:\n")
sb.WriteString(" return a + b\n")
sb.WriteString(" return a - b\n\n")
}
return sb.String()
}
func generateJavaScriptCode(lines int) string {
var sb strings.Builder
for i := 0; i < lines/8; i++ {
sb.WriteString("function function")
sb.WriteString(itoa(i))
sb.WriteString("(a, b) {\n")
sb.WriteString(" if (a > b) {\n")
sb.WriteString(" return a + b;\n")
sb.WriteString(" }\n")
sb.WriteString(" return a - b;\n")
sb.WriteString("}\n\n")
}
return sb.String()
}
func generateSimpleFunctions(count int) string {
var sb strings.Builder
sb.WriteString("package main\n\n")
for i := 0; i < count; i++ {
sb.WriteString("func Func")
sb.WriteString(itoa(i))
sb.WriteString("() { }\n\n")
}
return sb.String()
}
func generateNestedStructures(depth int) string {
var sb strings.Builder
sb.WriteString("package main\n\n")
for i := 0; i < depth; i++ {
sb.WriteString("type Struct")
sb.WriteString(itoa(i))
sb.WriteString(" struct {\n")
sb.WriteString("\tField1 int\n")
sb.WriteString("\tField2 string\n")
if i > 0 {
sb.WriteString("\tNested Struct")
sb.WriteString(itoa(i - 1))
sb.WriteString("\n")
}
sb.WriteString("}\n\n")
}
return sb.String()
}
func generateComplexTypes(count int) string {
var sb strings.Builder
sb.WriteString("package main\n\n")
for i := 0; i < count; i++ {
sb.WriteString("type Type")
sb.WriteString(itoa(i))
sb.WriteString(" interface {\n")
sb.WriteString("\tMethod1() error\n")
sb.WriteString("\tMethod2(a int, b string) (int, error)\n")
sb.WriteString("\tMethod3() chan interface{}\n")
sb.WriteString("}\n\n")
}
return sb.String()
}
func formatSize(size int) string {
if size < 1000 {
return itoa(size) + "B"
}
return itoa(size/1000) + "KB"
}
func formatBytes(bytes int64) string {
if bytes < 1024 {
return itoa(int(bytes)) + "B"
}
if bytes < 1024*1024 {
return itoa(int(bytes/1024)) + "KB"
}
return itoa(int(bytes/(1024*1024))) + "MB"
}
// Simple integer to string conversion without importing strconv
func itoa(n int) string {
if n == 0 {
return "0"
}
negative := n < 0
if negative {
n = -n
}
var buf [20]byte
i := len(buf) - 1
for n > 0 {
buf[i] = byte('0' + n%10)
n /= 10
i--
}
if negative {
buf[i] = '-'
i--
}
return string(buf[i+1:])
}
+2 -26
View File
@@ -6,37 +6,13 @@ import (
"fmt"
"regexp"
"strings"
"sync"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/util"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// Global regex cache for compiled patterns (thread-safe)
var regexCache sync.Map // string -> *regexp.Regexp
// compileRegex compiles a regex pattern with caching for performance.
// Cached patterns avoid repeated compilation overhead (10-50x speedup).
// Thread-safe: uses LoadOrStore to prevent race conditions.
func compileRegex(pattern string) (*regexp.Regexp, error) {
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
}
// Compile regex
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
// Try to store - if another goroutine already stored it, use theirs
// This prevents race conditions where multiple goroutines compile the same pattern
actual, _ := regexCache.LoadOrStore(pattern, re)
return actual.(*regexp.Regexp), nil
}
// ASTQuery defines a query for matching AST patterns.
type ASTQuery struct {
Pattern string `json:"pattern"` // code pattern with $VAR placeholders
@@ -438,7 +414,7 @@ func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool
return false
}
name := parser.GetNodeText(nameNode, content)
re, err := compileRegex(filters.NameMatches)
re, err := util.CompileRegex(filters.NameMatches)
if err != nil {
return false
}
+36 -41
View File
@@ -4,23 +4,25 @@ import (
"regexp"
"sync"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/util"
)
// TestCompileRegexCaching tests that regex compilation is cached.
func TestCompileRegexCaching(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
util.ClearRegexCache()
pattern := `^test_\w+$`
// First compilation
re1, err := compileRegex(pattern)
re1, err := util.CompileRegex(pattern)
if err != nil {
t.Fatalf("First compile failed: %v", err)
}
// Second compilation should return cached version
re2, err := compileRegex(pattern)
re2, err := util.CompileRegex(pattern)
if err != nil {
t.Fatalf("Second compile failed: %v", err)
}
@@ -29,22 +31,12 @@ func TestCompileRegexCaching(t *testing.T) {
if re1 != re2 {
t.Error("Expected cached regex to be reused, got different objects")
}
// Verify it's in the cache
cached, ok := regexCache.Load(pattern)
if !ok {
t.Error("Pattern not found in cache")
}
if cached.(*regexp.Regexp) != re1 {
t.Error("Cached regex doesn't match returned regex")
}
}
// TestCompileRegexConcurrent tests concurrent regex compilation.
func TestCompileRegexConcurrent(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
util.ClearRegexCache()
pattern := `[a-z]+_\d+`
const numGoroutines = 100
@@ -60,7 +52,7 @@ func TestCompileRegexConcurrent(t *testing.T) {
go func() {
defer wg.Done()
re, err := compileRegex(pattern)
re, err := util.CompileRegex(pattern)
if err != nil {
errors <- err
return
@@ -89,26 +81,30 @@ func TestCompileRegexConcurrent(t *testing.T) {
// TestCompileRegexInvalidPattern tests error handling for invalid patterns.
func TestCompileRegexInvalidPattern(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
util.ClearRegexCache()
invalidPattern := `[invalid(`
_, err := compileRegex(invalidPattern)
_, err := util.CompileRegex(invalidPattern)
if err == nil {
t.Error("Expected error for invalid pattern, got nil")
}
// Invalid patterns should not be cached
_, ok := regexCache.Load(invalidPattern)
if ok {
t.Error("Invalid pattern should not be cached")
// Verify that a valid pattern still works after an invalid one
validPattern := `^valid$`
re, err := util.CompileRegex(validPattern)
if err != nil {
t.Errorf("Expected valid pattern to compile, got error: %v", err)
}
if re == nil {
t.Error("Expected non-nil regex for valid pattern")
}
}
// TestCompileRegexMultiplePatterns tests that different patterns are cached separately.
func TestCompileRegexMultiplePatterns(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
util.ClearRegexCache()
patterns := []string{
`^test_\w+$`,
@@ -121,26 +117,14 @@ func TestCompileRegexMultiplePatterns(t *testing.T) {
// Compile all patterns
for i, pattern := range patterns {
re, err := compileRegex(pattern)
re, err := util.CompileRegex(pattern)
if err != nil {
t.Fatalf("Compile failed for pattern %s: %v", pattern, err)
}
compiled[i] = re
}
// Verify all are cached
for i, pattern := range patterns {
cached, ok := regexCache.Load(pattern)
if !ok {
t.Errorf("Pattern %s not in cache", pattern)
}
if cached.(*regexp.Regexp) != compiled[i] {
t.Errorf("Cached regex for %s doesn't match compiled version", pattern)
}
}
// All should be different objects
// All should be different objects (different patterns)
for i := 0; i < len(compiled); i++ {
for j := i + 1; j < len(compiled); j++ {
if compiled[i] == compiled[j] {
@@ -148,6 +132,17 @@ func TestCompileRegexMultiplePatterns(t *testing.T) {
}
}
}
// Re-compile should return cached versions
for i, pattern := range patterns {
re, err := util.CompileRegex(pattern)
if err != nil {
t.Fatalf("Re-compile failed for pattern %s: %v", pattern, err)
}
if re != compiled[i] {
t.Errorf("Pattern %s was not cached properly", pattern)
}
}
}
// BenchmarkCompileRegex_Uncached benchmarks regex compilation without caching.
@@ -163,23 +158,23 @@ func BenchmarkCompileRegex_Uncached(b *testing.B) {
// BenchmarkCompileRegex_Cached benchmarks regex compilation with caching.
func BenchmarkCompileRegex_Cached(b *testing.B) {
// Clear cache
regexCache = sync.Map{}
util.ClearRegexCache()
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
// Pre-populate cache
_, _ = compileRegex(pattern)
_, _ = util.CompileRegex(pattern)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = compileRegex(pattern)
_, _ = util.CompileRegex(pattern)
}
}
// BenchmarkCompileRegex_MixedPatterns benchmarks realistic workload with multiple patterns.
func BenchmarkCompileRegex_MixedPatterns(b *testing.B) {
// Clear cache
regexCache = sync.Map{}
util.ClearRegexCache()
patterns := []string{
`^test_\w+$`,
@@ -193,6 +188,6 @@ func BenchmarkCompileRegex_MixedPatterns(b *testing.B) {
for i := 0; i < b.N; i++ {
// Simulate realistic access pattern
pattern := patterns[i%len(patterns)]
_, _ = compileRegex(pattern)
_, _ = util.CompileRegex(pattern)
}
}
+2 -2
View File
@@ -259,7 +259,7 @@ func TestSearchIntegration(t *testing.T) {
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
// Create test files
testFile := filepath.Join(tmpDir, "test.go")
@@ -269,7 +269,7 @@ func main() {
println("Hello, World!")
}
`
err = os.WriteFile(testFile, []byte(content), 0600)
err = os.WriteFile(testFile, []byte(content), 0o600)
if err != nil {
t.Fatalf("failed to write test file: %v", err)
}
+510
View File
@@ -0,0 +1,510 @@
package server
import (
"context"
"log/slog"
"os"
"path/filepath"
"testing"
"time"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/mark3labs/mcp-go/mcp"
)
// TestMCPProtocolEndToEnd tests the complete MCP protocol communication flow.
func TestMCPProtocolEndToEnd(t *testing.T) {
tmpDir := t.TempDir()
// Create test files
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() string {
return "hello"
}
`
if err := os.WriteFile(testFile, []byte(content), 0o600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: config.DefaultMaxFileSize,
MaxParseSize: config.DefaultMaxParseSize,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
// Test 1: Ping tool (health check)
t.Run("ping", func(t *testing.T) {
req := mcp.CallToolRequest{}
result, err := srv.handlePing(ctx, req)
if err != nil {
t.Errorf("handlePing() error = %v", err)
}
if result == nil {
t.Fatal("handlePing() returned nil")
}
if len(result.Content) == 0 {
t.Fatal("handlePing() returned empty content")
}
})
// Test 2: File read
t.Run("file_read", func(t *testing.T) {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"path": testFile,
}
result, err := srv.handleFileRead(ctx, req)
if err != nil {
t.Errorf("handleFileRead() error = %v", err)
}
if result == nil {
t.Fatal("handleFileRead() returned nil")
}
if len(result.Content) == 0 {
t.Fatal("handleFileRead() returned empty content")
}
})
// Test 3: AST query
t.Run("ast_query", func(t *testing.T) {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() string",
"language": "go",
"paths": []interface{}{tmpDir},
}
result, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Errorf("handleASTQuery() error = %v", err)
}
if result == nil {
t.Fatal("handleASTQuery() returned nil")
}
})
// Test 4: Edit preview and apply
t.Run("edit_workflow", func(t *testing.T) {
// Preview edit
previewReq := mcp.CallToolRequest{}
previewReq.Params.Arguments = map[string]interface{}{
"file": testFile,
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Hello",
"new_content": "func Hello() string {\n\treturn \"goodbye\"\n}",
}
previewResult, err := srv.handleEditPreview(ctx, previewReq)
if err != nil {
t.Errorf("handleEditPreview() error = %v", err)
}
if previewResult == nil {
t.Fatal("handleEditPreview() returned nil")
}
// Verify file unchanged after preview
originalContent, _ := os.ReadFile(testFile)
if string(originalContent) != content {
t.Error("preview should not modify file")
}
// Apply edit
applyReq := mcp.CallToolRequest{}
applyReq.Params.Arguments = previewReq.Params.Arguments
applyResult, err := srv.handleEditApply(ctx, applyReq)
if err != nil {
t.Errorf("handleEditApply() error = %v", err)
}
if applyResult == nil {
t.Fatal("handleEditApply() returned nil")
}
// Verify file changed after apply
modifiedContent, _ := os.ReadFile(testFile)
if string(modifiedContent) == content {
t.Error("apply should modify file")
}
})
}
// TestMCPToolDiscovery tests that all expected tools are registered.
func TestMCPToolDiscovery(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
// Note: The MCP server doesn't expose a method to list tools directly,
// but we can verify the server was created successfully
if srv.mcp == nil {
t.Fatal("MCP server not initialized")
}
// Verify each expected tool works
ctx := context.Background()
// Test ping tool
pingReq := mcp.CallToolRequest{}
if _, err := srv.handlePing(ctx, pingReq); err != nil {
t.Errorf("ping tool failed: %v", err)
}
}
// TestMCPErrorResponses tests error handling following MCP spec.
func TestMCPErrorResponses(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: 1024, // Small size to trigger errors
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
tests := []struct {
name string
handler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error)
setupReq func() mcp.CallToolRequest
expectError bool
}{
{
name: "file_read_missing_path",
handler: srv.handleFileRead,
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{}
return req
},
expectError: true,
},
{
name: "file_read_nonexistent",
handler: srv.handleFileRead,
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"path": filepath.Join(tmpDir, "nonexistent.txt"),
}
return req
},
expectError: true,
},
{
name: "file_read_outside_workspace",
handler: srv.handleFileRead,
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"path": "/etc/passwd",
}
return req
},
expectError: true,
},
{
name: "ast_query_missing_pattern",
handler: srv.handleASTQuery,
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"language": "go",
}
return req
},
expectError: true,
},
{
name: "ast_query_missing_language",
handler: srv.handleASTQuery,
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
}
return req
},
expectError: true,
},
{
name: "ast_query_unsupported_language",
handler: srv.handleASTQuery,
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
"language": "cobol",
}
return req
},
expectError: true,
},
{
name: "edit_missing_file",
handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return srv.handleEdit(ctx, req, false)
},
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"operation": "replace",
}
return req
},
expectError: true,
},
{
name: "edit_missing_operation",
handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return srv.handleEdit(ctx, req, false)
},
setupReq: func() mcp.CallToolRequest {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"file": filepath.Join(tmpDir, "test.go"),
}
return req
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
request := tt.setupReq()
result, err := tt.handler(ctx, request)
// Check for error - MCP tools return errors as nil error with error in result content
hasError := err != nil || (result != nil && len(result.Content) > 0)
if tt.expectError && !hasError {
t.Errorf("expected error but got none")
}
// Note: We don't check for unexpected success because some operations
// might legitimately return empty results
})
}
}
// TestMCPRequestResponseFlow tests the complete request/response flow.
func TestMCPRequestResponseFlow(t *testing.T) {
tmpDir := t.TempDir()
// Create test file
testFile := filepath.Join(tmpDir, "flow.go")
content := `package main
func Add(a, b int) int {
return a + b
}
`
if err := os.WriteFile(testFile, []byte(content), 0o600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: config.DefaultMaxFileSize,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Test sequential operations
t.Run("sequential_operations", func(t *testing.T) {
// 1. Read file
readReq := mcp.CallToolRequest{}
readReq.Params.Arguments = map[string]interface{}{
"path": testFile,
}
readResult, err := srv.handleFileRead(ctx, readReq)
if err != nil {
t.Fatalf("handleFileRead() error = %v", err)
}
if readResult == nil {
t.Fatal("handleFileRead() returned nil")
}
// 2. Query AST
queryReq := mcp.CallToolRequest{}
queryReq.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME($$$ARGS) int",
"language": "go",
"paths": []interface{}{tmpDir},
}
queryResult, err := srv.handleASTQuery(ctx, queryReq)
if err != nil {
t.Fatalf("handleASTQuery() error = %v", err)
}
if queryResult == nil {
t.Fatal("handleASTQuery() returned nil")
}
// 3. Preview edit
editReq := mcp.CallToolRequest{}
editReq.Params.Arguments = map[string]interface{}{
"file": testFile,
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Add",
"new_content": "func Add(a, b int) int {\n\treturn a + b + 1\n}",
}
editResult, err := srv.handleEditPreview(ctx, editReq)
if err != nil {
t.Fatalf("handleEditPreview() error = %v", err)
}
if editResult == nil {
t.Fatal("handleEditPreview() returned nil")
}
})
}
// TestMCPConcurrentRequests tests handling of concurrent requests.
func TestMCPConcurrentRequests(t *testing.T) {
tmpDir := t.TempDir()
// Create multiple test files
for i := 0; i < 5; i++ {
testFile := filepath.Join(tmpDir, "test"+string(rune(i+48))+".go")
content := `package main
func Test() {
println("test")
}
`
if err := os.WriteFile(testFile, []byte(content), 0o600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
}
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: config.DefaultMaxFileSize,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
// Run multiple concurrent requests
const numRequests = 10
done := make(chan bool, numRequests)
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func(index int) {
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
"language": "go",
"paths": []interface{}{tmpDir},
}
_, err := srv.handleASTQuery(ctx, req)
if err != nil {
errors <- err
}
done <- true
}(i)
}
// Wait for all requests to complete
for i := 0; i < numRequests; i++ {
<-done
}
// Check for errors
close(errors)
for err := range errors {
t.Errorf("concurrent request failed: %v", err)
}
}
// TestMCPContextCancellation tests handling of context cancellation.
func TestMCPContextCancellation(t *testing.T) {
tmpDir := t.TempDir()
// Create a large directory structure to ensure operation takes time
for i := 0; i < 10; i++ {
subdir := filepath.Join(tmpDir, "subdir"+string(rune(i+48)))
if err := os.MkdirAll(subdir, 0o755); err != nil {
t.Fatalf("failed to create subdir: %v", err)
}
for j := 0; j < 10; j++ {
testFile := filepath.Join(subdir, "test"+string(rune(j+48))+".go")
content := `package main
func Test() {
println("test")
}
`
if err := os.WriteFile(testFile, []byte(content), 0o600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
}
}
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false,
MaxFileSize: config.DefaultMaxFileSize,
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
// Create a context with a very short timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME()",
"language": "go",
"paths": []interface{}{tmpDir},
}
// This should either complete quickly or handle cancellation gracefully
_, err = srv.handleASTQuery(ctx, req)
// We don't check for specific error as it might complete before timeout
// The important thing is it doesn't panic or hang
if err != nil {
t.Logf("handleASTQuery with cancelled context: %v", err)
}
}
+71 -12
View File
@@ -2,6 +2,7 @@
package server
import (
"bufio"
"context"
"fmt"
"log/slog"
@@ -37,7 +38,7 @@ type Server struct {
// New creates a new MCP server instance.
func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
parserRegistry := parser.NewRegistry()
parserRegistry := parser.NewRegistryWithSize(cfg.MaxParseSize)
s := &Server{
cfg: cfg,
logger: logger,
@@ -545,7 +546,25 @@ func symbolKindIcon(kind protocol.SymbolKind) string {
}
func splitLines(s string) []string {
// Use optimized stdlib implementation (2-3x faster than manual loop)
// For large files (> 1MB), use bufio.Scanner which is more memory efficient
// For smaller files, use simple string split which is faster
const largeSizeThreshold = 1024 * 1024 // 1MB
if len(s) > largeSizeThreshold {
// Use scanner for large files
scanner := bufio.NewScanner(strings.NewReader(s))
var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
// Handle potential error and add empty line if string ended with newline
if len(s) > 0 && s[len(s)-1] == '\n' {
lines = append(lines, "")
}
return lines
}
// Use optimized stdlib implementation for smaller files (2-3x faster than manual loop)
return strings.Split(s, "\n")
}
@@ -596,6 +615,13 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest
}
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
// Check for context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err != nil {
return nil // Skip files with errors
}
@@ -956,25 +982,58 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, ap
// Run starts the MCP server and blocks until shutdown.
func (s *Server) Run(ctx context.Context) error {
// Set up signal handling for graceful shutdown
_, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(sigChan)
// Channel to communicate server errors
errChan := make(chan error, 1)
// Start server in goroutine
go func() {
sig := <-sigChan
s.logger.Info("received shutdown signal", "signal", sig)
cancel()
s.logger.Info("starting MCP server",
"workspace", s.cfg.WorkspaceRoot,
"lsp_enabled", s.cfg.EnableLSP,
)
errChan <- server.ServeStdio(s.mcp)
}()
s.logger.Info("starting MCP server",
"workspace", s.cfg.WorkspaceRoot,
"lsp_enabled", s.cfg.EnableLSP,
)
// Wait for either signal or server error
select {
case sig := <-sigChan:
s.logger.Info("received shutdown signal", "signal", sig)
// Start the MCP server with stdio transport
return server.ServeStdio(s.mcp)
// Create timeout context for shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
// Call graceful shutdown
if err := s.Shutdown(shutdownCtx); err != nil {
s.logger.Error("error during shutdown", "error", err)
return err
}
s.logger.Info("server shutdown complete")
return nil
case err := <-errChan:
// Server stopped on its own
return err
case <-ctx.Done():
// Context cancelled externally
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
if err := s.Shutdown(shutdownCtx); err != nil {
s.logger.Error("error during shutdown", "error", err)
}
return ctx.Err()
}
}
// Shutdown gracefully shuts down the server.
+41
View File
@@ -0,0 +1,41 @@
// Package util provides shared utility functions and caches.
package util
import (
"regexp"
"sync"
)
// regexCache is a global thread-safe cache for compiled regular expressions.
// Caching regex compilation provides 10-50x speedup for repeated patterns.
var regexCache sync.Map // string -> *regexp.Regexp
// CompileRegex compiles a regex pattern with caching for performance.
// Thread-safe: uses LoadOrStore to prevent race conditions.
// Returns the compiled regex or an error if the pattern is invalid.
func CompileRegex(pattern string) (*regexp.Regexp, error) {
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
}
// Compile regex
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
// Try to store - if another goroutine already stored it, use theirs
// This prevents race conditions where multiple goroutines compile the same pattern
actual, _ := regexCache.LoadOrStore(pattern, re)
return actual.(*regexp.Regexp), nil
}
// ClearRegexCache clears all cached compiled regular expressions.
// Useful for testing or when memory usage needs to be reduced.
func ClearRegexCache() {
regexCache.Range(func(key, value interface{}) bool {
regexCache.Delete(key)
return true
})
}