This commit is contained in:
2026-01-18 18:40:26 +00:00
commit 185e73da47
51 changed files with 14073 additions and 0 deletions
+27
View File
@@ -0,0 +1,27 @@
name: Test, build, release
on:
workflow_dispatch:
push:
paths-ignore:
- '**.md'
- '**/release.yaml'
branches:
- main
push:
tags:
- 'v*'
permissions:
id-token: write
contents: write
packages: write
jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24"
docker-enabled: false
rolling-release-tag: "v1"
secrets: inherit
+3
View File
@@ -0,0 +1,3 @@
TODO.md
bin/mcp-filepuff
mcp-filepuff
+122
View File
@@ -0,0 +1,122 @@
# yaml-language-server: $schema=https://goreleaser.com/static/schema.json
# vim: set ts=2 sw=2 tw=0 fo=cnqoj
version: 2
project_name: mcp-filepuff
before:
hooks:
- go mod tidy
- go generate ./...
builds:
- id: mcp-filepuff
main: ./cmd/mcp-filepuff
binary: mcp-filepuff
env:
- CGO_ENABLED=0
flags:
- -trimpath
ldflags:
- -s -w
- -X main.version={{.Version}}
- -X main.commit={{.Commit}}
- -X main.date={{.Date}}
goos:
- linux
- darwin
- windows
goarch:
- amd64
- arm64
archives:
- id: default
formats:
- tar.gz
name_template: >-
{{ .ProjectName }}_
{{- .Version }}_
{{- .Os }}_
{{- .Arch }}
files:
- LICENSE
- README.md
format_overrides:
- goos: windows
formats:
- zip
checksum:
name_template: 'checksums.txt'
snapshot:
version_template: "{{ incpatch .Version }}-next"
changelog:
sort: asc
filters:
exclude:
- '^docs:'
- '^test:'
- '^chore:'
- Merge pull request
- Merge branch
release:
github:
owner: lukaszraczylo
name: filepuff-mcp
draft: false
prerelease: auto
name_template: "v{{.Version}}"
header: |
## MCP Filepuff v{{.Version}}
AST-aware file operations and LSP integration for Claude Code.
### Installation
```bash
curl -sSL https://raw.githubusercontent.com/lukaszraczylo/filepuff-mcp/main/scripts/install.sh | bash
```
dockers_v2:
- images:
- "ghcr.io/lukaszraczylo/filepuff-mcp"
tags:
- "{{ .Version }}"
- "latest"
- "v1"
platforms:
- linux/amd64
- linux/arm64
dockerfile: Dockerfile.goreleaser
build_flag_templates:
- "--pull"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.source=https://github.com/lukaszraczylo/filepuff-mcp"
signs:
- cmd: cosign
signature: "${artifact}.sigstore.json"
args:
- sign-blob
- "--bundle=${signature}"
- "${artifact}"
- "--yes"
artifacts: checksum
output: true
docker_signs:
- cmd: cosign
artifacts: manifests
output: true
args:
- sign
- "${artifact}@${digest}"
- "--yes"
+13
View File
@@ -0,0 +1,13 @@
FROM alpine:3.21
ARG TARGETPLATFORM
RUN apk add --no-cache ca-certificates tzdata git ripgrep
COPY ${TARGETPLATFORM}/mcp-filepuff /usr/local/bin/mcp-filepuff
RUN chmod +x /usr/local/bin/mcp-filepuff
WORKDIR /workspace
ENTRYPOINT ["/usr/local/bin/mcp-filepuff"]
CMD ["-workspace", "/workspace"]
+80
View File
@@ -0,0 +1,80 @@
.PHONY: build test lint clean install run deps
# Binary name
BINARY_NAME=mcp-filepuff
# Build directory
BUILD_DIR=bin
# Main package
MAIN_PKG=./cmd/mcp-filepuff
# Go parameters
GOCMD=go
GOBUILD=$(GOCMD) build
GOTEST=$(GOCMD) test
GOGET=$(GOCMD) get
GOMOD=$(GOCMD) mod
GOFMT=$(GOCMD) fmt
# Build flags
LDFLAGS=-ldflags "-s -w" -buildvcs=false
# Default target
all: deps test build
# Install dependencies
deps:
$(GOMOD) download
$(GOMOD) tidy
# Build the binary
build:
mkdir -p $(BUILD_DIR)
$(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PKG)
# Build for all platforms
build-all:
mkdir -p $(BUILD_DIR)
GOOS=darwin GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 $(MAIN_PKG)
GOOS=darwin GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 $(MAIN_PKG)
GOOS=linux GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 $(MAIN_PKG)
GOOS=linux GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 $(MAIN_PKG)
GOOS=windows GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe $(MAIN_PKG)
# Run tests
test:
$(GOTEST) -v -race -coverprofile=coverage.out ./...
# Run tests with short flag
test-short:
$(GOTEST) -v -short ./...
# Clean build artifacts
clean:
rm -rf $(BUILD_DIR)
rm -f coverage.out
# Install binary to GOPATH/bin
install: build
cp $(BUILD_DIR)/$(BINARY_NAME) $(GOPATH)/bin/
# Run the server (for development)
run: build
./$(BUILD_DIR)/$(BINARY_NAME) -log-level debug
# Run with specific workspace
run-workspace: build
./$(BUILD_DIR)/$(BINARY_NAME) -workspace $(WORKSPACE) -log-level debug
# Show help
help:
@echo "Available targets:"
@echo " deps - Download and tidy dependencies"
@echo " build - Build the binary"
@echo " build-all - Build for all platforms"
@echo " test - Run tests with coverage"
@echo " test-short - Run short tests"
@echo " lint - Run linters"
@echo " clean - Clean build artifacts"
@echo " install - Install binary to GOPATH/bin"
@echo " run - Build and run the server"
@echo " run-workspace - Run with specific workspace (WORKSPACE=/path)"
+572
View File
@@ -0,0 +1,572 @@
# mcp-filepuff
A Go-based MCP (Model Context Protocol) server for Claude Code providing intelligent file operations with fast search, AST-aware querying, LSP integration, and safe editing capabilities.
## Features
- **Fast Text Search**: Powered by ripgrep for blazing-fast code search with regex support
- **AST-Aware File Reading**: Read files with symbol extraction using Tree-sitter
- **Code Pattern Matching**: Query code using patterns with capture placeholders
- **LSP Integration**: Go-to-definition, find references, and symbol info via language servers
- **Safe Editing**: AST-aware file editing with syntax validation and preview
- **Multi-Language Support**: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue, React
- **Token Efficient**: Optimized for minimal token usage with symbols-only mode and output limiting
## Installation
### Binary Releases (Recommended)
Download pre-built binaries from the [releases page](https://github.com/lukaszraczylo/filepuff-mcp/releases):
```bash
# macOS (ARM64)
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-darwin-arm64.tar.gz | tar xz
sudo mv mcp-filepuff /usr/local/bin/
# macOS (AMD64)
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-darwin-amd64.tar.gz | tar xz
sudo mv mcp-filepuff /usr/local/bin/
# Linux (ARM64)
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-linux-arm64.tar.gz | tar xz
sudo mv mcp-filepuff /usr/local/bin/
# Linux (AMD64)
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-linux-amd64.tar.gz | tar xz
sudo mv mcp-filepuff /usr/local/bin/
# Windows (PowerShell)
Invoke-WebRequest -Uri "https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-windows-amd64.zip" -OutFile mcp-filepuff.zip
Expand-Archive mcp-filepuff.zip -DestinationPath .
Move-Item mcp-filepuff.exe C:\Windows\System32\
```
### Prerequisites
- [ripgrep](https://github.com/BurntSushi/ripgrep) (`rg`) installed and in PATH
### Optional Dependencies (for LSP features)
- `gopls` - Go language server
- `typescript-language-server` - TypeScript/JavaScript language server
- `pylsp` - Python language server
- `clangd` - C/C++ language server
### Build from Source
```bash
git clone https://github.com/lukaszraczylo/filepuff-mcp.git
cd filepuff-mcp
make build
```
The binary will be available at `bin/mcp-filepuff`.
### Install via Claude Code
After downloading or building the binary, configure it in Claude Code:
1. **Create or edit `~/.config/claude-code/claude_desktop_config.json`**:
```json
{
"mcpServers": {
"filepuff": {
"command": "/usr/local/bin/mcp-filepuff",
"args": ["-workspace", "/path/to/your/workspace"],
"env": {
"MCP_LOG_LEVEL": "info"
}
}
}
}
```
2. **Restart Claude Code** to load the MCP server
3. **Verify** by asking Claude: "Can you ping the filepuff server?"
See the [Claude Code MCP documentation](https://code.claude.com/docs/en/mcp) for more details.
## Usage
### Running the Server (Standalone)
```bash
./bin/mcp-filepuff -workspace /path/to/workspace
```
### Command Line Options
- `-workspace string`: Workspace root directory (default: current directory)
- `-log-level string`: Log level - debug, info, warn, error (default: "info")
- `-log-file string`: Log file path (default: stderr)
### Configuration
The server can be configured via:
1. **Environment Variables**:
- `MCP_WORKSPACE_ROOT`: Workspace root directory
- `MCP_LSP_TIMEOUT`: LSP timeout duration (e.g., "10m")
- `MCP_SEARCH_TIMEOUT`: Search timeout duration (e.g., "1m")
- `MCP_ENABLE_LSP`: Enable LSP features ("true"/"false")
- `MCP_FOLLOW_SYMLINKS`: Follow symbolic links ("true"/"false")
- `MCP_RESPECT_GITIGNORE`: Respect .gitignore files ("true"/"false")
2. **Config File**: Create `.mcp-filepuff.json` in the workspace root:
```json
{
"enable_lsp": true,
"follow_symlinks": true,
"respect_gitignore": true
}
```
### Claude Code Integration
To use mcp-filepuff with Claude Code, add it to your MCP server configuration:
1. **Global Configuration** (`~/.config/claude-code/mcp_servers.json`):
```json
{
"mcpServers": {
"filepuff": {
"command": "/path/to/mcp-filepuff",
"args": ["-workspace", "/path/to/your/workspace"]
}
}
}
```
2. **Project-specific Configuration** (`.claude/mcp_servers.json` in your project):
```json
{
"mcpServers": {
"filepuff": {
"command": "mcp-filepuff",
"args": ["-workspace", "."]
}
}
}
```
After configuration, Claude Code will have access to all mcp-filepuff tools for enhanced file operations.
### Making Claude Code Prefer Filepuff Tools
By default, Claude Code uses its built-in file operation tools. To make it prefer filepuff's enhanced tools, add instructions to your `CLAUDE.md` file:
**Global Configuration** (`~/.claude/CLAUDE.md`):
```markdown
# MCP Tool Preferences
When performing file operations, prefer filepuff MCP tools over built-in equivalents:
| Operation | Use This | Instead Of |
|-----------|----------|------------|
| Read files | `mcp__filepuff__file_read` | Read |
| Search content | `mcp__filepuff__file_search` | Grep |
| AST pattern search | `mcp__filepuff__ast_query` | Grep/Glob |
| Edit files | `mcp__filepuff__edit_preview` + `mcp__filepuff__edit_apply` | Edit |
| Find definitions | `mcp__filepuff__find_definition` | Grep |
| Find references | `mcp__filepuff__find_references` | Grep |
| Symbol info | `mcp__filepuff__symbol_at` | - |
Benefits of filepuff tools:
- AST-aware operations that understand code structure
- LSP integration for accurate symbol navigation
- Syntax validation before applying edits
```
You can also place this in a project-specific `CLAUDE.md` or `.claude/CLAUDE.md` file.
**Optional: Restrict Built-in Tools**
To enforce filepuff usage, add permission restrictions in `.claude/settings.json`:
```json
{
"permissions": {
"deny": ["Read", "Edit", "Grep"]
}
}
```
## Available Tools
### `ping`
Health check tool to verify the server is running.
**Returns**: "pong"
---
### `file_search`
Search for text patterns in files using ripgrep.
**Parameters**:
- `pattern` (required): The search pattern (regex by default)
- `paths`: Paths to search in (defaults to workspace root)
- `file_types`: File types to search (e.g., ["go", "ts", "py"])
- `ignore_case`: Case insensitive search
- `regex`: Treat pattern as regex (default: true)
- `context_lines`: Number of context lines around matches (default: 2)
- `max_results`: Maximum number of results to return
---
### `file_read`
Read a file's contents with optional line range and AST symbol summary. Supports token-efficient modes for AI assistants.
**Parameters**:
- `path` (required): Path to the file to read
- `line_start`: Starting line number (1-indexed)
- `line_end`: Ending line number (inclusive)
- `include_ast`: Include AST symbol summary (functions, classes, types, etc.)
- `symbols_only`: **[Token Efficient]** Return only symbol summary without file content. Requires `include_ast=true`. Reduces token usage by ~90-98%.
- `max_lines`: **[Token Efficient]** Maximum number of lines to return. Useful for large files where you only need a preview.
**Example Output with AST**:
```
**server.go** (245 lines, go)
Symbols:
func NewServer L12
func (Server).Start L45
struct Server L5
type Config L150
---
12│ func NewServer(config Config) *Server {
13│ return &Server{config: config}
14│ }
```
**Token-Efficient Example (symbols_only)**:
```json
{"path": "server.go", "include_ast": true, "symbols_only": true}
```
Returns only the symbol summary (~500 tokens instead of ~8,000 tokens for the full file):
```
**server.go** (245 lines, go)
Symbols:
func NewServer L12
func (Server).Start L45
struct Server L5
type Config L150
```
**Token-Efficient Example (max_lines)**:
```json
{"path": "server.go", "max_lines": 50}
```
Returns first 50 lines with a truncation notice if the file is longer.
---
### `ast_query`
Search for AST patterns in code files using structural pattern matching.
**Parameters**:
- `pattern` (required): Code pattern with placeholders
- `$NAME` - capture single node
- `$$$ARGS` - capture multiple nodes
- `$_` - wildcard (match but don't capture)
- `language` (required): Target language (go, typescript, javascript, python, c, cpp)
- `paths`: Paths to search in
- `name_matches`: Regex pattern to filter by name
- `name_exact`: Exact name to match
- `kind_in`: Node types to match (e.g., function_declaration)
- `max_results`: Maximum number of results (default: 100)
**Examples**:
```json
// Find all Go functions returning error
{"pattern": "func $NAME($$$ARGS) error", "language": "go"}
// Find all Python classes
{"pattern": "class $NAME: $$$BODY", "language": "python"}
// Find React components (functions starting with uppercase)
{"pattern": "function $NAME($PROPS) { $$$BODY }", "language": "javascript", "name_matches": "^[A-Z]"}
```
---
### `symbol_at`
Get information about the symbol at a specific position. Uses LSP when available, falls back to AST.
**Parameters**:
- `file` (required): Path to the file
- `line` (required): Line number (1-indexed)
- `column` (required): Column number (1-indexed)
---
### `find_definition`
Find the definition of the symbol at a specific position.
**Parameters**:
- `file` (required): Path to the file
- `line` (required): Line number (1-indexed)
- `column` (required): Column number (1-indexed)
---
### `find_references`
Find all references to the symbol at a specific position.
**Parameters**:
- `file` (required): Path to the file
- `line` (required): Line number (1-indexed)
- `column` (required): Column number (1-indexed)
- `include_declaration`: Include the declaration in results (default: true)
---
### `edit_preview`
Preview an edit without applying it. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++), and text-based editing for other files (Markdown, JSON, YAML, config files, etc.).
**Parameters**:
- `file` (required): Path to the file to edit
- `operation` (required): Edit operation (replace, insert_before, insert_after, delete)
- `new_content`: New content (required for replace/insert operations)
**AST-mode selectors** (for code files):
- `selector_kind`: Node type to match (e.g., function_declaration)
- `selector_name`: Name of the symbol to match
**Shared selectors**:
- `selector_line`: Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range.
- `selector_index`: Index of the match to use if multiple matches found (default: 0)
**Text-mode selectors** (for non-code files or explicit text matching):
- `selector_line_end`: End line number for range selection
- `selector_text`: Exact text to match (must be unique or use selector_index)
- `selector_pattern`: Regex pattern to match
---
### `edit_apply`
Apply an edit to a file. Uses AST-aware editing for code files with syntax validation, and text-based editing for other files.
**Parameters**: Same as `edit_preview`
**Example (AST mode - Go file)**:
```json
{
"file": "server.go",
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Hello",
"new_content": "func Hello() {\n\tprintln(\"New Hello\")\n}"
}
```
**Example (Text mode - Markdown file)**:
```json
{
"file": "README.md",
"operation": "replace",
"selector_text": "## Installation",
"new_content": "## Getting Started"
}
```
**Example (Text mode - JSON with regex)**:
```json
{
"file": "package.json",
"operation": "replace",
"selector_pattern": "\"version\":\\s*\"[^\"]+\"",
"new_content": "\"version\": \"2.0.0\""
}
```
**Example (Text mode - Line range)**:
```json
{
"file": "config.yaml",
"operation": "replace",
"selector_line": 5,
"selector_line_end": 10,
"new_content": "database:\n host: production.db.example.com\n port: 5432"
}
```
## Supported Languages
| Language | Extensions | Search | AST | LSP | Edit |
|----------|-----------|--------|-----|-----|------|
| Go | .go | Yes | Yes | gopls | Yes |
| TypeScript | .ts, .tsx | Yes | Yes | typescript-language-server | Yes |
| JavaScript | .js, .jsx, .mjs, .cjs | Yes | Yes | typescript-language-server | Yes |
| Python | .py, .pyw | Yes | Yes | pylsp | Yes |
| C | .c, .h | Yes | Yes | clangd | Yes |
| C++ | .cpp, .cc, .cxx, .hpp, .hxx | Yes | Yes | clangd | Yes |
| HTML | .html, .htm | Yes | Yes | - | Yes |
| Vue | .vue | Yes | Yes* | - | Yes |
| React | .jsx, .tsx | Yes | Yes | typescript-language-server | Yes |
\* Vue uses HTML parser for template sections
## Development
### Build
```bash
make build
```
### Run Tests
```bash
make test
```
### Lint
```bash
make lint
```
### Clean
```bash
make clean
```
## Project Structure
```
.
├── cmd/
│ └── mcp-filepuff/ # Main entry point
├── internal/
│ ├── config/ # Configuration management
│ ├── edit/ # AST-aware editing engine
│ ├── lsp/ # LSP client and manager
│ ├── parser/ # Tree-sitter integration
│ ├── query/ # AST pattern matching
│ ├── search/ # Ripgrep wrapper
│ └── server/ # MCP server implementation
├── pkg/
│ └── protocol/ # Shared types
├── .github/
│ └── workflows/ # CI configuration
├── Makefile # Build automation
├── .goreleaser.yaml # Release configuration
└── TODO.md # Implementation roadmap
```
## Architecture
```
┌─────────────────────────────────────────────────────────┐
│ MCP Server │
├─────────────────────────────────────────────────────────┤
│ Tools: file_search, file_read, ast_query, symbol_at, │
│ find_definition, find_references, │
│ edit_preview, edit_apply, ping │
├─────────────────────────────────────────────────────────┤
│ Core Engines │
├───────────┬─────────────┬────────────┬─────────────────┤
│ Search │ Parser │ LSP │ Edit │
│ (ripgrep) │(tree-sitter)│ Manager │ Engine │
└───────────┴─────────────┴────────────┴─────────────────┘
```
## Troubleshooting
### Common Issues
#### "ripgrep not found" Error
The `file_search` tool requires ripgrep (`rg`) to be installed and in your PATH.
**Solution**: Install ripgrep:
```bash
# macOS
brew install ripgrep
# Ubuntu/Debian
sudo apt install ripgrep
# Windows (with Chocolatey)
choco install ripgrep
```
#### LSP Features Not Working
LSP features (go-to-definition, find-references, symbol-at) require language servers to be installed.
**Solution**: Install the appropriate language server:
```bash
# Go
go install golang.org/x/tools/gopls@latest
# TypeScript/JavaScript
npm install -g typescript-language-server typescript
# Python
pip install python-lsp-server
# C/C++
# macOS: brew install llvm
# Ubuntu: sudo apt install clangd
```
#### AST Parsing Fails for Valid Code
If AST parsing fails for code that compiles correctly, it may be a Tree-sitter grammar limitation.
**Solution**:
- Ensure the file has the correct extension for its language
- Check for unusual syntax that may not be supported by the Tree-sitter grammar
- Try using the `file_search` tool instead for text-based operations
#### Edit Operations Fail with "syntax error"
The edit engine validates syntax before and after edits.
**Solution**:
- Ensure `new_content` is syntactically valid for the target language
- Use `edit_preview` first to see the proposed changes
- Check that the selector matches exactly one node
#### Timeout Errors
Long-running operations may timeout.
**Solution**: Configure timeout values via environment variables:
```bash
export MCP_LSP_TIMEOUT="10m" # LSP operations (default: 5m)
export MCP_SEARCH_TIMEOUT="2m" # Search operations (default: 30s)
```
#### Permission Denied Errors
The server needs read/write access to workspace files.
**Solution**:
- Ensure the user running the server has appropriate file permissions
- Check that the workspace path is correct and accessible
- On macOS, grant terminal/IDE full disk access if needed
### Debug Logging
Enable debug logging to troubleshoot issues:
```bash
./bin/mcp-filepuff -workspace /path/to/workspace -log-level debug -log-file /tmp/mcp-filepuff.log
```
### Verifying Installation
Use the `ping` tool to verify the server is running correctly:
```json
{"tool": "ping"}
```
Expected response: `"pong"`
## License
MIT License
+84
View File
@@ -0,0 +1,84 @@
// Package main is the entry point for the MCP file operations server.
package main
import (
"context"
"flag"
"log/slog"
"os"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/lukaszraczylo/mcp-filepuff/internal/server"
)
func main() {
// Parse command line flags
var (
workspaceRoot = flag.String("workspace", "", "Workspace root directory (default: current directory)")
logLevel = flag.String("log-level", "info", "Log level (debug, info, warn, error)")
logFile = flag.String("log-file", "", "Log file path (default: stderr)")
)
flag.Parse()
// Set up logging
logger := setupLogger(*logLevel, *logFile)
// Load configuration
cfg, err := config.Load(*workspaceRoot)
if err != nil {
logger.Error("failed to load configuration", "error", err)
os.Exit(1)
}
logger.Info("configuration loaded",
"workspace_root", cfg.WorkspaceRoot,
"lsp_enabled", cfg.EnableLSP,
)
// Create and run server
srv, err := server.New(cfg, logger)
if err != nil {
logger.Error("failed to create server", "error", err)
os.Exit(1)
}
ctx := context.Background()
if err := srv.Run(ctx); err != nil {
logger.Error("server error", "error", err)
os.Exit(1)
}
}
func setupLogger(level string, logFile string) *slog.Logger {
var logLevel slog.Level
switch level {
case "debug":
logLevel = slog.LevelDebug
case "warn":
logLevel = slog.LevelWarn
case "error":
logLevel = slog.LevelError
default:
logLevel = slog.LevelInfo
}
opts := &slog.HandlerOptions{
Level: logLevel,
}
var handler slog.Handler
if logFile != "" {
f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
// Fallback to stderr
handler = slog.NewJSONHandler(os.Stderr, opts)
} else {
handler = slog.NewJSONHandler(f, opts)
}
} else {
// Use stderr for MCP servers (stdout is for protocol messages)
handler = slog.NewJSONHandler(os.Stderr, opts)
}
return slog.New(handler)
}
+1479
View File
File diff suppressed because it is too large Load Diff
+24
View File
@@ -0,0 +1,24 @@
module github.com/lukaszraczylo/mcp-filepuff
go 1.25.5
require (
github.com/cespare/xxhash/v2 v2.3.0
github.com/goccy/go-json v0.10.5
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/mark3labs/mcp-go v0.43.2
github.com/sergi/go-diff v1.4.0
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/mailru/easyjson v0.9.1 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
)
+57
View File
@@ -0,0 +1,57 @@
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I=
github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4=
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+174
View File
@@ -0,0 +1,174 @@
// Package config provides configuration management for the MCP file operations server.
package config
import (
"os"
"path/filepath"
"strings"
"time"
json "github.com/goccy/go-json"
)
// Config holds all configuration options for the MCP server.
type Config struct {
Formatters map[string]string `json:"formatters"`
WorkspaceRoot string `json:"workspace_root"`
LSPTimeout time.Duration `json:"lsp_timeout"`
SearchTimeout time.Duration `json:"search_timeout"`
MaxFileSize int64 `json:"max_file_size"`
MaxSearchResults int `json:"max_search_results"`
MaxEditSize int64 `json:"max_edit_size"`
EnableLSP bool `json:"enable_lsp"`
FollowSymlinks bool `json:"follow_symlinks"`
RespectGitignore bool `json:"respect_gitignore"`
}
// Default values for configuration.
const (
DefaultLSPTimeout = 5 * time.Minute
DefaultSearchTimeout = 30 * time.Second
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
DefaultMaxSearchResults = 1000
DefaultMaxEditSize = 100 * 1024 // 100 KB
)
// Default returns a Config with default values.
func Default() *Config {
return &Config{
WorkspaceRoot: ".",
LSPTimeout: DefaultLSPTimeout,
SearchTimeout: DefaultSearchTimeout,
MaxFileSize: DefaultMaxFileSize,
MaxSearchResults: DefaultMaxSearchResults,
MaxEditSize: DefaultMaxEditSize,
EnableLSP: true,
Formatters: make(map[string]string),
FollowSymlinks: true,
RespectGitignore: true,
}
}
// Load loads configuration from environment variables and optional config file.
// Priority: CLI flags > environment variables > config file > defaults.
func Load(workspaceRoot string) (*Config, error) {
cfg := Default()
// Set workspace root
if workspaceRoot != "" {
absPath, err := filepath.Abs(workspaceRoot)
if err != nil {
return nil, err
}
cfg.WorkspaceRoot = absPath
} else if cwd, err := os.Getwd(); err == nil {
cfg.WorkspaceRoot = cwd
}
// Try to load from config file in workspace root
configPath := filepath.Join(cfg.WorkspaceRoot, ".mcp-filepuff.json")
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, cfg); err != nil {
return nil, err
}
}
// Override from environment variables
cfg.loadFromEnv()
return cfg, nil
}
func (c *Config) loadFromEnv() {
if v := os.Getenv("MCP_WORKSPACE_ROOT"); v != "" {
if absPath, err := filepath.Abs(v); err == nil {
c.WorkspaceRoot = absPath
}
}
if v := os.Getenv("MCP_LSP_TIMEOUT"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
c.LSPTimeout = d
}
}
if v := os.Getenv("MCP_SEARCH_TIMEOUT"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
c.SearchTimeout = d
}
}
if v := os.Getenv("MCP_ENABLE_LSP"); v == "false" || v == "0" {
c.EnableLSP = false
}
if v := os.Getenv("MCP_FOLLOW_SYMLINKS"); v == "false" || v == "0" {
c.FollowSymlinks = false
}
if v := os.Getenv("MCP_RESPECT_GITIGNORE"); v == "false" || v == "0" {
c.RespectGitignore = false
}
}
// IsPathAllowed checks if a path is within the workspace root.
// It resolves symlinks to prevent path traversal attacks.
func (c *Config) IsPathAllowed(path string) bool {
// Get absolute path of the target
absPath, err := filepath.Abs(path)
if err != nil {
return false
}
// Get absolute path of workspace root
absRoot, err := filepath.Abs(c.WorkspaceRoot)
if err != nil {
return false
}
// Always try to resolve workspace root symlinks for consistent comparison
evalRoot, evalErr := filepath.EvalSymlinks(absRoot)
if evalErr == nil {
absRoot = evalRoot
}
// For the target path, try to resolve symlinks
evalPath, evalErr := filepath.EvalSymlinks(absPath)
if evalErr == nil {
// File exists and was resolved
absPath = evalPath
} else {
// File doesn't exist - resolve parent directories to match workspace root resolution
// Walk up the tree until we find an existing directory
dir := filepath.Dir(absPath)
remaining := filepath.Base(absPath)
for dir != "." && dir != "/" && dir != absPath {
evalDir, evalErr := filepath.EvalSymlinks(dir)
if evalErr == nil {
// Found an existing directory, reconstruct the path
absPath = filepath.Join(evalDir, remaining)
break
}
// Move up one level
newDir := filepath.Dir(dir)
if newDir == dir {
// Reached the root without finding an existing directory
break
}
remaining = filepath.Join(filepath.Base(dir), remaining)
dir = newDir
}
}
// Compute relative path
rel, err := filepath.Rel(absRoot, absPath)
if err != nil {
return false
}
// Check if the path is within workspace (doesn't start with ..)
// This prevents both "../" attacks and symlink bypasses
// Also reject empty relative path (which means it's the workspace root itself)
return rel != "." && !strings.HasPrefix(rel, "..")
}
+184
View File
@@ -0,0 +1,184 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestDefault(t *testing.T) {
cfg := Default()
if cfg.WorkspaceRoot != "." {
t.Errorf("expected default workspace root '.', got %q", cfg.WorkspaceRoot)
}
if cfg.LSPTimeout != DefaultLSPTimeout {
t.Errorf("expected default LSP timeout %v, got %v", DefaultLSPTimeout, cfg.LSPTimeout)
}
if cfg.SearchTimeout != DefaultSearchTimeout {
t.Errorf("expected default search timeout %v, got %v", DefaultSearchTimeout, cfg.SearchTimeout)
}
if cfg.MaxFileSize != DefaultMaxFileSize {
t.Errorf("expected default max file size %d, got %d", DefaultMaxFileSize, cfg.MaxFileSize)
}
if cfg.MaxSearchResults != DefaultMaxSearchResults {
t.Errorf("expected default max search results %d, got %d", DefaultMaxSearchResults, cfg.MaxSearchResults)
}
if cfg.MaxEditSize != DefaultMaxEditSize {
t.Errorf("expected default max edit size %d, got %d", DefaultMaxEditSize, cfg.MaxEditSize)
}
if !cfg.EnableLSP {
t.Error("expected EnableLSP to be true by default")
}
if !cfg.FollowSymlinks {
t.Error("expected FollowSymlinks to be true by default")
}
if !cfg.RespectGitignore {
t.Error("expected RespectGitignore to be true by default")
}
}
func TestLoad(t *testing.T) {
// Create a temporary directory for workspace
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg, err := Load(tmpDir)
if err != nil {
t.Fatalf("Load failed: %v", err)
}
absPath, _ := filepath.Abs(tmpDir)
if cfg.WorkspaceRoot != absPath {
t.Errorf("expected workspace root %q, got %q", absPath, cfg.WorkspaceRoot)
}
}
func TestLoadFromEnv(t *testing.T) {
// Save original env values
origLSPTimeout := os.Getenv("MCP_LSP_TIMEOUT")
origSearchTimeout := os.Getenv("MCP_SEARCH_TIMEOUT")
origEnableLSP := os.Getenv("MCP_ENABLE_LSP")
origFollowSymlinks := os.Getenv("MCP_FOLLOW_SYMLINKS")
origRespectGitignore := os.Getenv("MCP_RESPECT_GITIGNORE")
// Restore env after test
t.Cleanup(func() {
_ = os.Setenv("MCP_LSP_TIMEOUT", origLSPTimeout)
_ = os.Setenv("MCP_SEARCH_TIMEOUT", origSearchTimeout)
_ = os.Setenv("MCP_ENABLE_LSP", origEnableLSP)
_ = os.Setenv("MCP_FOLLOW_SYMLINKS", origFollowSymlinks)
_ = os.Setenv("MCP_RESPECT_GITIGNORE", origRespectGitignore)
})
// Set test env values
_ = os.Setenv("MCP_LSP_TIMEOUT", "10m")
_ = os.Setenv("MCP_SEARCH_TIMEOUT", "1m")
_ = os.Setenv("MCP_ENABLE_LSP", "false")
_ = os.Setenv("MCP_FOLLOW_SYMLINKS", "0")
_ = os.Setenv("MCP_RESPECT_GITIGNORE", "false")
cfg := Default()
cfg.loadFromEnv()
if cfg.LSPTimeout != 10*time.Minute {
t.Errorf("expected LSP timeout 10m, got %v", cfg.LSPTimeout)
}
if cfg.SearchTimeout != 1*time.Minute {
t.Errorf("expected search timeout 1m, got %v", cfg.SearchTimeout)
}
if cfg.EnableLSP {
t.Error("expected EnableLSP to be false")
}
if cfg.FollowSymlinks {
t.Error("expected FollowSymlinks to be false")
}
if cfg.RespectGitignore {
t.Error("expected RespectGitignore to be false")
}
}
func TestIsPathAllowed(t *testing.T) {
// Create a temporary directory
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := Default()
cfg.WorkspaceRoot = tmpDir
tests := []struct {
name string
path string
allowed bool
}{
{
name: "file in workspace",
path: filepath.Join(tmpDir, "test.go"),
allowed: true,
},
{
name: "nested file in workspace",
path: filepath.Join(tmpDir, "subdir", "test.go"),
allowed: true,
},
{
name: "path outside workspace",
path: "/etc/passwd",
allowed: false,
},
{
name: "relative path traversal",
path: filepath.Join(tmpDir, "..", "outside.txt"),
allowed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := cfg.IsPathAllowed(tt.path)
if result != tt.allowed {
t.Errorf("IsPathAllowed(%q) = %v, want %v", tt.path, result, tt.allowed)
}
})
}
}
func TestLoadWithConfigFile(t *testing.T) {
// Create a temporary directory
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Write config file
configPath := filepath.Join(tmpDir, ".mcp-filepuff.json")
configContent := `{
"enable_lsp": false,
"follow_symlinks": false
}`
err = os.WriteFile(configPath, []byte(configContent), 0600)
if err != nil {
t.Fatalf("failed to write config file: %v", err)
}
var cfg *Config
cfg, err = Load(tmpDir)
if err != nil {
t.Fatalf("Load failed: %v", err)
}
if cfg.EnableLSP {
t.Error("expected EnableLSP to be false from config file")
}
if cfg.FollowSymlinks {
t.Error("expected FollowSymlinks to be false from config file")
}
}
+141
View File
@@ -0,0 +1,141 @@
package config
import (
"os"
"path/filepath"
"testing"
)
// TestIsPathAllowed_SymlinkSecurity tests the symlink security fix.
func TestIsPathAllowed_SymlinkSecurity(t *testing.T) {
// Create a temporary workspace
tmpDir := t.TempDir()
workspace := filepath.Join(tmpDir, "workspace")
outside := filepath.Join(tmpDir, "outside")
if err := os.MkdirAll(workspace, 0700); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(outside, 0700); err != nil {
t.Fatal(err)
}
// Create a file outside the workspace
outsideFile := filepath.Join(outside, "secret.txt")
if err := os.WriteFile(outsideFile, []byte("secret data"), 0600); err != nil {
t.Fatal(err)
}
cfg := &Config{
WorkspaceRoot: workspace,
}
tests := []struct {
setup func() string
name string
expected bool
}{
{
name: "regular file inside workspace",
setup: func() string {
return filepath.Join(workspace, "file.txt")
},
expected: true,
},
{
name: "file with parent directory traversal",
setup: func() string {
return filepath.Join(workspace, "../outside/secret.txt")
},
expected: false,
},
{
name: "symlink pointing outside workspace",
setup: func() string {
symlink := filepath.Join(workspace, "link.txt")
_ = os.Symlink(outsideFile, symlink)
return symlink
},
expected: false,
},
{
name: "symlink pointing inside workspace",
setup: func() string {
inside := filepath.Join(workspace, "inside.txt")
_ = os.WriteFile(inside, []byte("ok"), 0600)
symlink := filepath.Join(workspace, "link_inside.txt")
_ = os.Symlink(inside, symlink)
return symlink
},
expected: true,
},
{
name: "dotfile inside workspace",
setup: func() string {
return filepath.Join(workspace, ".gitignore")
},
expected: true,
},
{
name: "hidden directory inside workspace",
setup: func() string {
return filepath.Join(workspace, ".git/config")
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := tt.setup()
result := cfg.IsPathAllowed(path)
if result != tt.expected {
t.Errorf("IsPathAllowed(%q) = %v, want %v", path, result, tt.expected)
}
})
}
}
// TestIsPathAllowed_BasicCases tests basic path validation.
func TestIsPathAllowed_BasicCases(t *testing.T) {
tmpDir := t.TempDir()
cfg := &Config{
WorkspaceRoot: tmpDir,
}
tests := []struct {
name string
path string
expected bool
}{
{
name: "path inside workspace",
path: filepath.Join(tmpDir, "file.txt"),
expected: true,
},
{
name: "path outside workspace",
path: "/etc/passwd",
expected: false,
},
{
name: "parent directory reference",
path: filepath.Join(tmpDir, "../../../etc/passwd"),
expected: false,
},
{
name: "workspace root itself",
path: tmpDir,
expected: false, // Empty relative path
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := cfg.IsPathAllowed(tt.path)
if result != tt.expected {
t.Errorf("IsPathAllowed(%q) = %v, want %v", tt.path, result, tt.expected)
}
})
}
}
+203
View File
@@ -0,0 +1,203 @@
package edit
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
)
// TestConcurrentEditLocking tests that concurrent edits to the same file are properly serialized.
func TestConcurrentEditLocking(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
// Create initial file
initialContent := `package main
func main() {
println("hello")
}
`
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
t.Fatal(err)
}
registry := parser.NewRegistry()
engine := NewEngine(registry)
// Run 10 concurrent edits
const numEdits = 10
var wg sync.WaitGroup
wg.Add(numEdits)
errors := make(chan error, numEdits)
for i := 0; i < numEdits; i++ {
i := i
go func() {
defer wg.Done()
edit := &ASTEdit{
File: testFile,
Operation: EditReplace,
Selector: ASTSelector{
Kind: "function_declaration",
Name: "main",
},
NewContent: `func main() {
println("edit ` + string(rune(i)) + `")
}`,
}
ctx := context.Background()
result, err := engine.Apply(ctx, edit)
if err != nil {
errors <- err
return
}
if !result.Success {
errors <- err
return
}
}()
}
wg.Wait()
close(errors)
// Check for errors
for err := range errors {
if err != nil {
t.Errorf("Concurrent edit failed: %v", err)
}
}
// Verify file wasn't corrupted
finalContent, err := os.ReadFile(testFile)
if err != nil {
t.Fatal(err)
}
// Parse to ensure it's still valid Go
_, err = registry.Parse(context.Background(), testFile, finalContent)
if err != nil {
t.Errorf("File corrupted after concurrent edits: %v\nContent:\n%s", err, string(finalContent))
}
}
// TestConcurrentEditDifferentFiles tests that concurrent edits to different files don't block each other.
func TestConcurrentEditDifferentFiles(t *testing.T) {
tmpDir := t.TempDir()
registry := parser.NewRegistry()
engine := NewEngine(registry)
const numFiles = 5
var wg sync.WaitGroup
wg.Add(numFiles)
startBarrier := make(chan struct{})
for i := 0; i < numFiles; i++ {
i := i
testFile := filepath.Join(tmpDir, fmt.Sprintf("test%d.go", i))
// Create initial file
initialContent := `package main
func test() {
println("initial")
}
`
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
t.Fatal(err)
}
go func() {
defer wg.Done()
// Wait for all goroutines to be ready
<-startBarrier
edit := &ASTEdit{
File: testFile,
Operation: EditReplace,
Selector: ASTSelector{
Kind: "function_declaration",
Name: "test",
},
NewContent: `func test() {
println("modified")
}`,
}
ctx := context.Background()
result, err := engine.Apply(ctx, edit)
if err != nil {
t.Errorf("Edit failed for %s: %v", testFile, err)
return
}
if !result.Success {
t.Errorf("Edit unsuccessful for %s: %s", testFile, result.Error)
}
}()
}
// Release all goroutines simultaneously
close(startBarrier)
wg.Wait()
}
// TestFileLockRelease tests that file locks are properly released after edits.
func TestFileLockRelease(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.go")
initialContent := `package main
func test() {}
`
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
t.Fatal(err)
}
registry := parser.NewRegistry()
engine := NewEngine(registry)
edit := &ASTEdit{
File: testFile,
Operation: EditReplace,
Selector: ASTSelector{
Kind: "function_declaration",
Name: "test",
},
NewContent: `func test() { println("updated") }`,
}
// First edit
ctx := context.Background()
result1, err := engine.Apply(ctx, edit)
if err != nil {
t.Fatal(err)
}
if !result1.Success {
t.Fatalf("First edit failed: %s", result1.Error)
}
// Second edit should succeed (lock was released)
result2, err := engine.Apply(ctx, edit)
if err != nil {
t.Fatal(err)
}
if !result2.Success {
t.Fatalf("Second edit failed (lock not released?): %s", result2.Error)
}
}
+757
View File
@@ -0,0 +1,757 @@
// Package edit provides AST-aware file editing capabilities.
package edit
import (
"bytes"
"context"
"fmt"
"os"
"regexp"
"strings"
"sync"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/sergi/go-diff/diffmatchpatch"
sitter "github.com/smacker/go-tree-sitter"
)
// Global regex cache for compiled patterns (thread-safe)
var regexCache sync.Map // string -> *regexp.Regexp
// compileRegex compiles a regex pattern with caching for performance.
func compileRegex(pattern string) (*regexp.Regexp, error) {
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
}
// Compile and cache
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
regexCache.Store(pattern, re)
return re, nil
}
// EditOperation defines the type of edit operation.
type EditOperation string
const (
EditReplace EditOperation = "replace"
EditInsertBefore EditOperation = "insert_before"
EditInsertAfter EditOperation = "insert_after"
EditDelete EditOperation = "delete"
)
// ASTEdit represents an AST-aware edit request.
type ASTEdit struct {
File string `json:"file"`
Operation EditOperation `json:"operation"`
NewContent string `json:"new_content,omitempty"`
Selector ASTSelector `json:"selector"`
}
// ASTSelector specifies how to find the target node.
type ASTSelector struct {
Kind string `json:"kind,omitempty"`
Name string `json:"name,omitempty"`
Pattern string `json:"pattern,omitempty"`
Text string `json:"text,omitempty"`
TextPattern string `json:"text_pattern,omitempty"`
AtLine int `json:"at_line,omitempty"`
Index int `json:"index,omitempty"`
LineEnd int `json:"line_end,omitempty"`
}
// EditResult contains the result of an edit operation.
type EditResult struct {
Diff string `json:"diff,omitempty"`
OriginalContent string `json:"original_content,omitempty"`
NewContent string `json:"new_content,omitempty"`
Error string `json:"error,omitempty"`
Success bool `json:"success"`
Applied bool `json:"applied"`
}
// Engine performs AST-aware edits.
type Engine struct {
registry *parser.Registry
fileLocks sync.Map // map[string]*sync.Mutex for per-file locking
}
// NewEngine creates a new edit engine.
func NewEngine(registry *parser.Registry) *Engine {
return &Engine{
registry: registry,
fileLocks: sync.Map{},
}
}
// lockFile acquires a lock for the specified file and returns an unlock function.
// This prevents concurrent edits to the same file which could cause corruption.
func (e *Engine) lockFile(filePath string) func() {
// Get or create mutex for this file
actual, _ := e.fileLocks.LoadOrStore(filePath, &sync.Mutex{})
mu := actual.(*sync.Mutex)
mu.Lock()
return mu.Unlock
}
// Preview generates a preview of an edit without applying it.
func (e *Engine) Preview(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
return e.performEdit(ctx, edit, false)
}
// Apply performs an edit and writes the result to disk.
// Uses file locking to prevent concurrent edits to the same file.
func (e *Engine) Apply(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
unlock := e.lockFile(edit.File)
defer unlock()
return e.performEdit(ctx, edit, true)
}
// performEdit executes an edit operation.
func (e *Engine) performEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
// Determine if we should use text mode
useTextMode := e.shouldUseTextMode(edit)
if useTextMode {
return e.performTextEdit(ctx, edit, apply)
}
return e.performASTEdit(ctx, edit, apply)
}
// shouldUseTextMode determines if text-based editing should be used.
func (e *Engine) shouldUseTextMode(edit *ASTEdit) bool {
// Use text mode if text-specific selectors are provided
if edit.Selector.Text != "" || edit.Selector.TextPattern != "" {
return true
}
// Use text mode if line range is specified without AST selectors
if edit.Selector.AtLine > 0 && edit.Selector.LineEnd > 0 &&
edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" {
return true
}
// Use text mode if language is not supported for AST
lang := protocol.DetectLanguage(edit.File)
return lang == protocol.LangUnknown
}
// performASTEdit executes an AST-aware edit operation.
func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
// Validate operation
if err := e.validateASTEdit(edit); err != nil {
return &EditResult{Success: false, Error: err.Error()}, nil
}
// Read file
content, err := os.ReadFile(edit.File)
if err != nil {
structuredErr := errors.NewFileNotReadableError(edit.File, err)
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
}
// Parse file
parseResult, err := e.registry.Parse(ctx, edit.File, content)
if err != nil {
return &EditResult{Success: false, Error: err.Error()}, nil
}
// Find target node
node, err := e.resolveSelector(edit.Selector, parseResult.Tree, content)
if err != nil {
return &EditResult{Success: false, Error: err.Error()}, nil
}
// Apply edit
newContent, err := e.applyEdit(edit, node, content)
if err != nil {
return &EditResult{Success: false, Error: err.Error()}, nil
}
// Validate new content (re-parse)
_, err = e.registry.Parse(ctx, edit.File, newContent)
if err != nil {
structuredErr := errors.NewEditValidationError(edit.File, err)
return &EditResult{
Success: false,
Error: structuredErr.Error(),
}, nil
}
// Generate diff
diff := generateDiff(string(content), string(newContent), edit.File)
result := &EditResult{
Success: true,
Diff: diff,
OriginalContent: string(content),
NewContent: string(newContent),
Applied: false,
}
// Apply changes if requested
if apply {
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
structuredErr := errors.NewFileNotWritableError(edit.File, err)
return &EditResult{
Success: false,
Error: structuredErr.Error(),
}, nil
}
result.Applied = true
}
return result, nil
}
// performTextEdit executes a text-based edit operation for non-AST files.
func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
// Validate operation
if err := e.validateTextEdit(edit); err != nil {
return &EditResult{Success: false, Error: err.Error()}, nil
}
// Read file
content, err := os.ReadFile(edit.File)
if err != nil {
structuredErr := errors.NewFileNotReadableError(edit.File, err)
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
}
// Find the text selection (byte range)
start, end, err := e.resolveTextSelector(edit.Selector, content)
if err != nil {
return &EditResult{Success: false, Error: err.Error()}, nil
}
// Apply edit
newContent, err := e.applyTextEditOperation(edit.Operation, content, start, end, edit.NewContent)
if err != nil {
return &EditResult{Success: false, Error: err.Error()}, nil
}
// Generate diff
diff := generateDiff(string(content), string(newContent), edit.File)
result := &EditResult{
Success: true,
Diff: diff,
OriginalContent: string(content),
NewContent: string(newContent),
Applied: false,
}
// Apply changes if requested
if apply {
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
structuredErr := errors.NewFileNotWritableError(edit.File, err)
return &EditResult{
Success: false,
Error: structuredErr.Error(),
}, nil
}
result.Applied = true
}
return result, nil
}
// validateBaseEdit checks common edit request fields.
func (e *Engine) validateBaseEdit(edit *ASTEdit) error {
if edit.File == "" {
return errors.NewInvalidEditError("file is required")
}
if edit.Operation == "" {
return errors.NewInvalidEditError("operation is required")
}
// Validate operation type
switch edit.Operation {
case EditReplace, EditInsertBefore, EditInsertAfter:
if edit.NewContent == "" {
return errors.NewInvalidEditError(fmt.Sprintf("new_content is required for %s operation", edit.Operation))
}
case EditDelete:
// new_content not required
default:
return errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
}
return nil
}
// validateASTEdit checks if an AST edit request is valid.
func (e *Engine) validateASTEdit(edit *ASTEdit) error {
if err := e.validateBaseEdit(edit); err != nil {
return err
}
// Validate AST selector
if edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" && edit.Selector.AtLine == 0 {
return errors.NewInvalidEditError("AST selector must specify at least one of: kind, name, pattern, or at_line")
}
return nil
}
// validateTextEdit checks if a text edit request is valid.
func (e *Engine) validateTextEdit(edit *ASTEdit) error {
if err := e.validateBaseEdit(edit); err != nil {
return err
}
// Validate text selector - need at least one text selection method
hasTextSelector := edit.Selector.Text != "" ||
edit.Selector.TextPattern != "" ||
edit.Selector.AtLine > 0
if !hasTextSelector {
return errors.NewInvalidEditError("text selector must specify at least one of: text, text_pattern, or at_line")
}
// Validate regex pattern if provided (uses cached compilation)
if edit.Selector.TextPattern != "" {
if _, err := compileRegex(edit.Selector.TextPattern); err != nil {
return errors.Wrap(errors.ErrInvalidEdit, "invalid text_pattern regex", err)
}
}
return nil
}
// resolveSelector finds the target node based on the selector.
func (e *Engine) resolveSelector(sel ASTSelector, tree *sitter.Tree, content []byte) (*sitter.Node, error) {
if tree == nil {
return nil, errors.NewNodeNotFoundError("no AST tree available")
}
root := tree.RootNode()
if root == nil {
return nil, errors.NewNodeNotFoundError("empty AST tree")
}
var matches []*sitter.Node
parser.WalkTree(root, func(n *sitter.Node) bool {
if e.matchesSelector(sel, n, content) {
matches = append(matches, n)
}
return true
})
if len(matches) == 0 {
selectorDesc := fmt.Sprintf("kind=%s name=%s pattern=%s line=%d", sel.Kind, sel.Name, sel.Pattern, sel.AtLine)
return nil, errors.NewNodeNotFoundError(selectorDesc)
}
// Use index to select specific match
index := sel.Index
if index < 0 || index >= len(matches) {
return nil, errors.NewInvalidSelectionError(fmt.Sprintf("selector matched %d nodes, but index %d is out of range", len(matches), index))
}
return matches[index], nil
}
// matchesSelector checks if a node matches the selector criteria.
func (e *Engine) matchesSelector(sel ASTSelector, n *sitter.Node, content []byte) bool {
// Check kind
if sel.Kind != "" && n.Type() != sel.Kind {
return false
}
// Check name (look for identifier in the node)
if sel.Name != "" {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
// Also try to find an identifier child
found := false
for i := 0; i < int(n.NamedChildCount()); i++ {
child := n.NamedChild(i)
if child != nil && child.Type() == "identifier" {
if parser.GetNodeText(child, content) == sel.Name {
found = true
break
}
}
}
if !found {
return false
}
} else if parser.GetNodeText(nameNode, content) != sel.Name {
return false
}
}
// Check line
if sel.AtLine > 0 {
startLine := int(n.StartPoint().Row) + 1
endLine := int(n.EndPoint().Row) + 1
if sel.AtLine < startLine || sel.AtLine > endLine {
return false
}
}
// Pattern matching is handled separately (simplified here)
if sel.Pattern != "" {
nodeText := parser.GetNodeText(n, content)
if !strings.Contains(nodeText, sel.Pattern) {
return false
}
}
return true
}
// applyEdit applies the edit operation to the content.
func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]byte, error) {
startByte := node.StartByte()
endByte := node.EndByte()
// Detect and preserve indentation
indentation := detectIndentation(content, startByte)
newContent := indentContent(edit.NewContent, indentation)
var result []byte
switch edit.Operation {
case EditReplace:
result = append(result, content[:startByte]...)
result = append(result, []byte(newContent)...)
result = append(result, content[endByte:]...)
case EditInsertBefore:
result = append(result, content[:startByte]...)
result = append(result, []byte(newContent)...)
result = append(result, '\n')
result = append(result, content[startByte:]...)
case EditInsertAfter:
result = append(result, content[:endByte]...)
result = append(result, '\n')
result = append(result, []byte(newContent)...)
result = append(result, content[endByte:]...)
case EditDelete:
result = append(result, content[:startByte]...)
result = append(result, content[endByte:]...)
default:
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
}
return result, nil
}
// detectIndentation detects the indentation at a given byte position.
func detectIndentation(content []byte, bytePos uint32) string {
// Find the start of the line
lineStart := int(bytePos)
for lineStart > 0 && content[lineStart-1] != '\n' {
lineStart--
}
// Extract leading whitespace
var indent strings.Builder
for i := lineStart; i < int(bytePos) && i < len(content); i++ {
c := content[i]
if c == ' ' || c == '\t' {
indent.WriteByte(c)
} else {
break
}
}
return indent.String()
}
// indentContent applies indentation to multi-line content.
func indentContent(content string, indent string) string {
if indent == "" {
return content
}
lines := strings.Split(content, "\n")
for i, line := range lines {
if i > 0 && line != "" {
lines[i] = indent + line
}
}
return strings.Join(lines, "\n")
}
// generateDiff creates a unified diff between original and modified content.
// Uses Myers diff algorithm for accurate and readable diffs.
func generateDiff(original, modified, filename string) string {
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(original, modified, false)
// Cleanup for readability
diffs = dmp.DiffCleanupSemantic(diffs)
// Convert to unified diff format
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("--- %s\n", filename))
buf.WriteString(fmt.Sprintf("+++ %s\n", filename))
// Group diffs into hunks
lineNum := 1
for _, diff := range diffs {
lines := strings.Split(diff.Text, "\n")
for i, line := range lines {
// Skip empty last line from split
if i == len(lines)-1 && line == "" {
continue
}
switch diff.Type {
case diffmatchpatch.DiffDelete:
buf.WriteString(fmt.Sprintf("-%s\n", line))
case diffmatchpatch.DiffInsert:
buf.WriteString(fmt.Sprintf("+%s\n", line))
case diffmatchpatch.DiffEqual:
buf.WriteString(fmt.Sprintf(" %s\n", line))
lineNum++
}
}
}
return buf.String()
}
// resolveTextSelector finds the byte range for a text-based selection.
func (e *Engine) resolveTextSelector(sel ASTSelector, content []byte) (start, end int, err error) {
switch {
case sel.Text != "":
return e.findExactText(content, sel.Text, sel.Index)
case sel.TextPattern != "":
return e.findRegexPattern(content, sel.TextPattern, sel.Index)
case sel.AtLine > 0:
return e.findLineRange(content, sel.AtLine, sel.LineEnd)
default:
return 0, 0, errors.NewInvalidEditError("text selector requires text, text_pattern, or at_line")
}
}
// findExactText finds an exact text match in content.
func (e *Engine) findExactText(content []byte, text string, index int) (start, end int, err error) {
if text == "" {
return 0, 0, errors.NewInvalidEditError("text selector cannot be empty")
}
textBytes := []byte(text)
type match struct{ start, end int }
var matches []match
offset := 0
for {
idx := bytes.Index(content[offset:], textBytes)
if idx == -1 {
break
}
matches = append(matches, match{
start: offset + idx,
end: offset + idx + len(textBytes),
})
offset += idx + 1
}
if len(matches) == 0 {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text not found: %q", truncateString(text, 50)))
}
// If multiple matches and no index specified, require explicit selection
if len(matches) > 1 && index == 0 {
// Check if index was explicitly set to 0 or just defaulted
// Since we can't distinguish, we'll allow index 0 but warn about multiple matches
// Actually, let's be strict and require explicit index for multiple matches
locations := make([]string, 0, min(len(matches), 5))
for i, m := range matches {
if i >= 5 {
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
break
}
line := countLines(content[:m.start]) + 1
locations = append(locations, fmt.Sprintf("line %d", line))
}
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text matches %d locations (%s); use selector_index to specify which one (0-%d)",
len(matches), strings.Join(locations, ", "), len(matches)-1))
}
if index >= len(matches) {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
}
return matches[index].start, matches[index].end, nil
}
// findRegexPattern finds a regex pattern match in content.
func (e *Engine) findRegexPattern(content []byte, pattern string, index int) (start, end int, err error) {
re, err := compileRegex(pattern)
if err != nil {
return 0, 0, errors.Wrap(errors.ErrInvalidEdit, "invalid regex pattern", err)
}
matches := re.FindAllIndex(content, -1)
if len(matches) == 0 {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern not found: %q", truncateString(pattern, 50)))
}
// If multiple matches and index is 0 (default), show error with locations
if len(matches) > 1 && index == 0 {
locations := make([]string, 0, min(len(matches), 5))
for i, m := range matches {
if i >= 5 {
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
break
}
line := countLines(content[:m[0]]) + 1
locations = append(locations, fmt.Sprintf("line %d", line))
}
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern matches %d locations (%s); use selector_index to specify which one (0-%d)",
len(matches), strings.Join(locations, ", "), len(matches)-1))
}
if index >= len(matches) {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
}
return matches[index][0], matches[index][1], nil
}
// findLineRange finds the byte range for a line range selection.
func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, end int, err error) {
if lineEnd == 0 {
lineEnd = lineStart
}
if lineStart < 1 {
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line number must be >= 1, got %d", lineStart))
}
if lineEnd < lineStart {
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line_end (%d) must be >= line (%d)", lineEnd, lineStart))
}
lines := bytes.Split(content, []byte("\n"))
totalLines := len(lines)
// Convert to 0-indexed
startIdx := lineStart - 1
endIdx := lineEnd - 1
if startIdx >= totalLines {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line %d out of range (file has %d lines)", lineStart, totalLines))
}
if endIdx >= totalLines {
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line_end %d out of range (file has %d lines)", lineEnd, totalLines))
}
// Calculate byte positions
start = 0
for i := 0; i < startIdx; i++ {
start += len(lines[i]) + 1 // +1 for newline
}
end = start
for i := startIdx; i <= endIdx; i++ {
end += len(lines[i])
if i < totalLines-1 {
end += 1 // newline
}
}
return start, end, nil
}
// applyTextEditOperation applies a text edit operation.
func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start, end int, newContent string) ([]byte, error) {
// Detect indentation at the selection point
indentation := detectIndentationAtByte(content, start)
indentedContent := indentContent(newContent, indentation)
var result []byte
switch op {
case EditReplace:
result = append(result, content[:start]...)
result = append(result, []byte(indentedContent)...)
result = append(result, content[end:]...)
case EditInsertBefore:
result = append(result, content[:start]...)
result = append(result, []byte(indentedContent)...)
result = append(result, '\n')
result = append(result, content[start:]...)
case EditInsertAfter:
result = append(result, content[:end]...)
result = append(result, '\n')
result = append(result, []byte(indentedContent)...)
result = append(result, content[end:]...)
case EditDelete:
result = append(result, content[:start]...)
result = append(result, content[end:]...)
default:
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", op))
}
return result, nil
}
// detectIndentationAtByte detects indentation at a byte position.
func detectIndentationAtByte(content []byte, bytePos int) string {
// Find the start of the line
lineStart := bytePos
for lineStart > 0 && content[lineStart-1] != '\n' {
lineStart--
}
// Extract leading whitespace
var indent strings.Builder
for i := lineStart; i < bytePos && i < len(content); i++ {
c := content[i]
if c == ' ' || c == '\t' {
indent.WriteByte(c)
} else {
break
}
}
return indent.String()
}
// truncateString truncates a string to maxLen with ellipsis.
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen-3] + "..."
}
// countLines counts the number of newlines in content.
func countLines(content []byte) int {
return bytes.Count(content, []byte("\n"))
}
// ValidateLanguage checks if AST editing is supported for a file.
// Returns nil for supported languages, error for unsupported.
// Note: Text-based editing is always available regardless of this check.
func ValidateLanguage(filename string) error {
lang := protocol.DetectLanguage(filename)
if lang == protocol.LangUnknown {
return fmt.Errorf("unsupported file type for AST editing: %s (text-based editing is available)", filename)
}
return nil
}
+836
View File
@@ -0,0 +1,836 @@
package edit
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
)
func TestValidateEdit(t *testing.T) {
e := NewEngine(parser.NewRegistry())
tests := []struct {
edit *ASTEdit
name string
wantErr bool
}{
{
name: "valid replace",
edit: &ASTEdit{
File: "test.go",
Operation: EditReplace,
Selector: ASTSelector{Kind: "function_declaration"},
NewContent: "func NewFunc() {}",
},
wantErr: false,
},
{
name: "valid delete",
edit: &ASTEdit{
File: "test.go",
Operation: EditDelete,
Selector: ASTSelector{Name: "oldFunc"},
},
wantErr: false,
},
{
name: "missing file",
edit: &ASTEdit{
Operation: EditReplace,
Selector: ASTSelector{Kind: "function_declaration"},
NewContent: "func NewFunc() {}",
},
wantErr: true,
},
{
name: "missing operation",
edit: &ASTEdit{
File: "test.go",
Selector: ASTSelector{Kind: "function_declaration"},
NewContent: "func NewFunc() {}",
},
wantErr: true,
},
{
name: "replace without content",
edit: &ASTEdit{
File: "test.go",
Operation: EditReplace,
Selector: ASTSelector{Kind: "function_declaration"},
},
wantErr: true,
},
{
name: "empty selector",
edit: &ASTEdit{
File: "test.go",
Operation: EditReplace,
Selector: ASTSelector{},
NewContent: "func NewFunc() {}",
},
wantErr: true,
},
{
name: "unknown operation",
edit: &ASTEdit{
File: "test.go",
Operation: "unknown",
Selector: ASTSelector{Kind: "function_declaration"},
NewContent: "func NewFunc() {}",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := e.validateASTEdit(tt.edit)
if tt.wantErr && err == nil {
t.Error("expected error")
}
if !tt.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestResolveSelector(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
content := []byte(`package main
func Hello() {
println("hello")
}
func Goodbye() {
println("goodbye")
}
`)
ctx := context.Background()
result, err := registry.Parse(ctx, "test.go", content)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
name string
sel ASTSelector
wantErr bool
}{
{
name: "by kind",
sel: ASTSelector{Kind: "function_declaration"},
wantErr: false,
},
{
name: "by name",
sel: ASTSelector{Name: "Hello"},
wantErr: false,
},
{
name: "by kind and name",
sel: ASTSelector{Kind: "function_declaration", Name: "Goodbye"},
wantErr: false,
},
{
name: "by line",
sel: ASTSelector{AtLine: 3},
wantErr: false,
},
{
name: "no match",
sel: ASTSelector{Name: "NonExistent"},
wantErr: true,
},
{
name: "index out of range",
sel: ASTSelector{Kind: "function_declaration", Index: 10},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
node, err := e.resolveSelector(tt.sel, result.Tree, content)
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if node == nil {
t.Error("expected node")
}
}
})
}
}
func TestApplyEdit(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
content := []byte(`package main
func Hello() {
println("hello")
}
`)
ctx := context.Background()
result, err := registry.Parse(ctx, "test.go", content)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
name string
operation EditOperation
newCode string
wantIn string // substring that should be in result
}{
{
name: "replace",
operation: EditReplace,
newCode: "func NewHello() {}",
wantIn: "NewHello",
},
{
name: "insert after",
operation: EditInsertAfter,
newCode: "func After() {}",
wantIn: "After",
},
{
name: "insert before",
operation: EditInsertBefore,
newCode: "func Before() {}",
wantIn: "Before",
},
{
name: "delete",
operation: EditDelete,
newCode: "",
wantIn: "package main", // Should still have package declaration
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Find the function node
node, err := e.resolveSelector(ASTSelector{Kind: "function_declaration"}, result.Tree, content)
if err != nil {
t.Fatalf("resolve failed: %v", err)
}
edit := &ASTEdit{
File: "test.go",
Operation: tt.operation,
NewContent: tt.newCode,
}
newContent, err := e.applyEdit(edit, node, content)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
if !strings.Contains(string(newContent), tt.wantIn) {
t.Errorf("result does not contain %q:\n%s", tt.wantIn, string(newContent))
}
})
}
}
func TestPreview(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() {
println("hello")
}
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditReplace,
Selector: ASTSelector{Kind: "function_declaration"},
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
}
result, err := e.Preview(ctx, edit)
if err != nil {
t.Fatalf("preview failed: %v", err)
}
if !result.Success {
t.Fatalf("preview was not successful: %s", result.Error)
}
if result.Applied {
t.Error("preview should not apply changes")
}
if result.Diff == "" {
t.Error("expected diff in result")
}
// Verify original file is unchanged
fileContent, _ := os.ReadFile(tmpFile)
if string(fileContent) != content {
t.Error("original file was modified during preview")
}
}
func TestApplyToFile(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() {
println("hello")
}
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditReplace,
Selector: ASTSelector{Kind: "function_declaration"},
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
}
result, err := e.Apply(ctx, edit)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
if !result.Success {
t.Fatalf("apply was not successful: %s", result.Error)
}
if !result.Applied {
t.Error("apply should set Applied=true")
}
// Verify file was modified
fileContent, _ := os.ReadFile(tmpFile)
if !strings.Contains(string(fileContent), "NewHello") {
t.Error("file was not modified")
}
}
func TestDetectIndentation(t *testing.T) {
tests := []struct {
name string
content string
want string
pos uint32
}{
{
name: "no indent",
content: "func main() {}",
pos: 0,
want: "",
},
{
name: "tab indent",
content: "func main() {\n\tprintln(\"hello\")\n}",
pos: 15,
want: "\t",
},
{
name: "space indent",
content: "func main() {\n println(\"hello\")\n}",
pos: 18,
want: " ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := detectIndentation([]byte(tt.content), tt.pos)
if got != tt.want {
t.Errorf("detectIndentation() = %q, want %q", got, tt.want)
}
})
}
}
func TestGenerateDiff(t *testing.T) {
original := "line1\nline2\nline3"
modified := "line1\nmodified\nline3"
filename := "test.txt"
diff := generateDiff(original, modified, filename)
if !strings.Contains(diff, "---") {
t.Error("diff should contain --- header")
}
if !strings.Contains(diff, "+++") {
t.Error("diff should contain +++ header")
}
if !strings.Contains(diff, "-line2") {
t.Error("diff should show removed line")
}
if !strings.Contains(diff, "+modified") {
t.Error("diff should show added line")
}
}
// ==================== Text-based editing tests ====================
func TestTextEditWithExactText(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp markdown file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "README.md")
content := `# My Project
## Installation
Run the following command:
## Usage
See the docs.
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditReplace,
Selector: ASTSelector{Text: "## Installation"},
NewContent: "## Getting Started",
}
result, err := e.Apply(ctx, edit)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
if !result.Success {
t.Fatalf("apply was not successful: %s", result.Error)
}
// Verify file was modified
fileContent, _ := os.ReadFile(tmpFile)
if !strings.Contains(string(fileContent), "## Getting Started") {
t.Error("file was not modified correctly")
}
if strings.Contains(string(fileContent), "## Installation") {
t.Error("old text should be replaced")
}
}
func TestTextEditWithLineRange(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp config file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "config.yaml")
content := `name: myapp
version: 1.0.0
database:
host: localhost
port: 5432
logging:
level: debug
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditReplace,
Selector: ASTSelector{
AtLine: 3,
LineEnd: 5,
},
NewContent: "database:\n host: production.db.example.com\n port: 5433",
}
result, err := e.Apply(ctx, edit)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
if !result.Success {
t.Fatalf("apply was not successful: %s", result.Error)
}
// Verify file was modified
fileContent, _ := os.ReadFile(tmpFile)
if !strings.Contains(string(fileContent), "production.db.example.com") {
t.Error("file was not modified correctly")
}
}
func TestTextEditWithRegexPattern(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp JSON file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "package.json")
content := `{
"name": "my-package",
"version": "1.0.0",
"description": "A test package"
}
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditReplace,
Selector: ASTSelector{TextPattern: `"version":\s*"[^"]+"`},
NewContent: `"version": "2.0.0"`,
}
result, err := e.Apply(ctx, edit)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
if !result.Success {
t.Fatalf("apply was not successful: %s", result.Error)
}
// Verify file was modified
fileContent, _ := os.ReadFile(tmpFile)
if !strings.Contains(string(fileContent), `"version": "2.0.0"`) {
t.Error("file was not modified correctly")
}
}
func TestTextEditInsertAfter(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp env file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, ".env")
content := `DATABASE_URL=postgres://localhost/mydb
SECRET_KEY=abc123
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditInsertAfter,
Selector: ASTSelector{Text: "DATABASE_URL=postgres://localhost/mydb"},
NewContent: "REDIS_URL=redis://localhost:6379",
}
result, err := e.Apply(ctx, edit)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
if !result.Success {
t.Fatalf("apply was not successful: %s", result.Error)
}
// Verify file was modified
fileContent, _ := os.ReadFile(tmpFile)
if !strings.Contains(string(fileContent), "REDIS_URL=redis://localhost:6379") {
t.Error("file was not modified correctly")
}
}
func TestTextEditMultipleMatchesError(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp file with repeated text
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := `TODO: fix this
some code here
TODO: also fix this
more code
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditReplace,
Selector: ASTSelector{Text: "TODO"},
NewContent: "DONE",
}
result, err := e.Apply(ctx, edit)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
// Should fail because of multiple matches
if result.Success {
t.Error("expected error for multiple matches without index")
}
if !strings.Contains(result.Error, "matches") {
t.Errorf("error should mention multiple matches: %s", result.Error)
}
}
func TestTextEditWithIndex(t *testing.T) {
registry := parser.NewRegistry()
defer registry.Close()
e := NewEngine(registry)
// Create a temp file with repeated text
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := `TODO: fix this
some code here
TODO: also fix this
more code
`
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
ctx := context.Background()
edit := &ASTEdit{
File: tmpFile,
Operation: EditReplace,
Selector: ASTSelector{
Text: "TODO",
Index: 1, // Select second match
},
NewContent: "DONE",
}
result, err := e.Apply(ctx, edit)
if err != nil {
t.Fatalf("apply failed: %v", err)
}
if !result.Success {
t.Fatalf("apply was not successful: %s", result.Error)
}
// Verify only second TODO was replaced
fileContent, _ := os.ReadFile(tmpFile)
contentStr := string(fileContent)
if !strings.Contains(contentStr, "TODO: fix this") {
t.Error("first TODO should not be replaced")
}
if !strings.Contains(contentStr, "DONE: also fix this") {
t.Error("second TODO should be replaced")
}
}
func TestValidateTextEdit(t *testing.T) {
e := NewEngine(parser.NewRegistry())
tests := []struct {
edit *ASTEdit
name string
wantErr bool
}{
{
name: "valid text selector",
edit: &ASTEdit{
File: "test.md",
Operation: EditReplace,
Selector: ASTSelector{Text: "some text"},
NewContent: "new text",
},
wantErr: false,
},
{
name: "valid pattern selector",
edit: &ASTEdit{
File: "test.md",
Operation: EditReplace,
Selector: ASTSelector{TextPattern: "\\d+"},
NewContent: "replaced",
},
wantErr: false,
},
{
name: "valid line selector",
edit: &ASTEdit{
File: "test.md",
Operation: EditReplace,
Selector: ASTSelector{AtLine: 5},
NewContent: "new line",
},
wantErr: false,
},
{
name: "empty selector",
edit: &ASTEdit{
File: "test.md",
Operation: EditReplace,
Selector: ASTSelector{},
NewContent: "new text",
},
wantErr: true,
},
{
name: "invalid regex pattern",
edit: &ASTEdit{
File: "test.md",
Operation: EditReplace,
Selector: ASTSelector{TextPattern: "[invalid"},
NewContent: "new text",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := e.validateTextEdit(tt.edit)
if tt.wantErr && err == nil {
t.Error("expected error")
}
if !tt.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestFindLineRange(t *testing.T) {
e := NewEngine(parser.NewRegistry())
content := []byte("line1\nline2\nline3\nline4\nline5")
// Content: "line1\nline2\nline3\nline4\nline5" (no trailing newline)
// Positions: line1=0-5, \n=5, line2=6-10, \n=11, line3=12-16, \n=17, line4=18-22, \n=23, line5=24-28
tests := []struct {
name string
lineStart int
lineEnd int
wantStart int
wantEnd int
wantErr bool
}{
{
name: "single line",
lineStart: 2,
lineEnd: 0, // defaults to lineStart
wantStart: 6,
wantEnd: 12, // includes trailing newline
wantErr: false,
},
{
name: "range of lines",
lineStart: 2,
lineEnd: 4,
wantStart: 6,
wantEnd: 24, // through end of line4 including newline
wantErr: false,
},
{
name: "first line",
lineStart: 1,
lineEnd: 1,
wantStart: 0,
wantEnd: 6, // includes trailing newline
wantErr: false,
},
{
name: "line out of range",
lineStart: 10,
lineEnd: 10,
wantErr: true,
},
{
name: "invalid line number",
lineStart: 0,
lineEnd: 1,
wantErr: true,
},
{
name: "end before start",
lineStart: 3,
lineEnd: 2,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
start, end, err := e.findLineRange(content, tt.lineStart, tt.lineEnd)
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if start != tt.wantStart {
t.Errorf("start = %d, want %d", start, tt.wantStart)
}
if end != tt.wantEnd {
t.Errorf("end = %d, want %d", end, tt.wantEnd)
}
})
}
}
+310
View File
@@ -0,0 +1,310 @@
// Package lsp provides a generic LSP client implementation.
package lsp
import (
"bufio"
"context"
"fmt"
"io"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
)
// Client represents an LSP client connection.
type Client struct {
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.ReadCloser
cmd *exec.Cmd
pending map[int64]chan *Response
done chan struct{}
notifications chan *Notification
requestID atomic.Int64
runningMu sync.RWMutex
stopOnce sync.Once
mu sync.Mutex
running bool
}
// Request represents a JSON-RPC request.
type Request struct {
Params interface{} `json:"params,omitempty"`
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
ID int64 `json:"id"`
}
// Response represents a JSON-RPC response.
type Response struct {
Error *ResponseError `json:"error,omitempty"`
JSONRPC string `json:"jsonrpc"`
Result json.RawMessage `json:"result,omitempty"`
ID int64 `json:"id"`
}
// ResponseError represents a JSON-RPC error.
type ResponseError struct {
Data interface{} `json:"data,omitempty"`
Message string `json:"message"`
Code int `json:"code"`
}
func (e *ResponseError) Error() string {
return fmt.Sprintf("LSP error %d: %s", e.Code, e.Message)
}
// Notification represents a JSON-RPC notification.
type Notification struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
// NewClient creates a new LSP client from a command.
func NewClient(cmd *exec.Cmd) (*Client, error) {
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("failed to get stdin pipe: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
_ = stdin.Close()
return nil, fmt.Errorf("failed to get stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
_ = stdin.Close()
_ = stdout.Close()
return nil, fmt.Errorf("failed to get stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
_ = stdin.Close()
_ = stdout.Close()
_ = stderr.Close()
return nil, fmt.Errorf("failed to start LSP server: %w", err)
}
c := &Client{
cmd: cmd,
stdin: stdin,
stdout: stdout,
stderr: stderr,
pending: make(map[int64]chan *Response),
done: make(chan struct{}),
running: true,
notifications: make(chan *Notification, 100),
}
// Start reader goroutine
go c.readLoop()
return c, nil
}
// Call sends a request and waits for a response.
func (c *Client) Call(ctx context.Context, method string, params interface{}) (*Response, error) {
c.runningMu.RLock()
if !c.running {
c.runningMu.RUnlock()
return nil, fmt.Errorf("client is not running")
}
c.runningMu.RUnlock()
id := c.requestID.Add(1)
req := &Request{
JSONRPC: "2.0",
ID: id,
Method: method,
Params: params,
}
// Create response channel
respChan := make(chan *Response, 1)
c.mu.Lock()
c.pending[id] = respChan
c.mu.Unlock()
defer func() {
c.mu.Lock()
delete(c.pending, id)
c.mu.Unlock()
}()
// Send request
if err := c.send(req); err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
// Wait for response
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-c.done:
return nil, fmt.Errorf("client closed")
case resp := <-respChan:
if resp.Error != nil {
return nil, resp.Error
}
return resp, nil
}
}
// Notify sends a notification (no response expected).
func (c *Client) Notify(method string, params interface{}) error {
c.runningMu.RLock()
if !c.running {
c.runningMu.RUnlock()
return fmt.Errorf("client is not running")
}
c.runningMu.RUnlock()
notif := struct {
Params interface{} `json:"params,omitempty"`
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
}{
JSONRPC: "2.0",
Method: method,
Params: params,
}
return c.send(notif)
}
// Notifications returns a channel for receiving server notifications.
func (c *Client) Notifications() <-chan *Notification {
return c.notifications
}
// Close shuts down the client and the LSP server.
func (c *Client) Close() error {
var err error
c.stopOnce.Do(func() {
c.runningMu.Lock()
c.running = false
c.runningMu.Unlock()
close(c.done)
// Close stdin to signal the server
_ = c.stdin.Close()
// Wait for process to exit with timeout
done := make(chan struct{})
go func() {
_ = c.cmd.Wait()
close(done)
}()
select {
case <-done:
// Clean exit
case <-time.After(5 * time.Second):
// Force kill
_ = c.cmd.Process.Kill()
}
close(c.notifications)
})
return err
}
// send writes a JSON-RPC message to the server.
func (c *Client) send(msg interface{}) error {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// Format with Content-Length header
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
_, err = c.stdin.Write([]byte(header))
if err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
_, err = c.stdin.Write(data)
if err != nil {
return fmt.Errorf("failed to write body: %w", err)
}
return nil
}
// readLoop reads and dispatches messages from the server.
func (c *Client) readLoop() {
reader := bufio.NewReader(c.stdout)
for {
select {
case <-c.done:
return
default:
}
// Read headers
contentLength := -1
for {
line, err := reader.ReadString('\n')
if err != nil {
return
}
line = strings.TrimSpace(line)
if line == "" {
break
}
if strings.HasPrefix(line, "Content-Length:") {
lengthStr := strings.TrimSpace(strings.TrimPrefix(line, "Content-Length:"))
contentLength, _ = strconv.Atoi(lengthStr)
}
}
if contentLength <= 0 {
continue
}
// Read body
body := make([]byte, contentLength)
_, err := io.ReadFull(reader, body)
if err != nil {
return
}
// Try to parse as response first
var resp Response
if err := json.Unmarshal(body, &resp); err == nil && resp.ID != 0 {
c.mu.Lock()
if ch, ok := c.pending[resp.ID]; ok {
ch <- &resp
}
c.mu.Unlock()
continue
}
// Try to parse as notification
var notif Notification
if err := json.Unmarshal(body, &notif); err == nil && notif.Method != "" {
select {
case c.notifications <- &notif:
default:
// Drop notification if channel is full
}
}
}
}
// IsRunning returns whether the client is running.
func (c *Client) IsRunning() bool {
c.runningMu.RLock()
defer c.runningMu.RUnlock()
return c.running
}
+535
View File
@@ -0,0 +1,535 @@
package lsp
import (
"context"
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
"sync"
"time"
json "github.com/goccy/go-json"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// Manager manages LSP servers for different languages.
type Manager struct {
servers map[protocol.Language]*ManagedServer
logger *slog.Logger
stopReaper chan struct{}
workspaceRoot string
timeout time.Duration
idleTimeout time.Duration
mu sync.RWMutex
stopped bool
}
// ManagedServer represents a managed LSP server instance.
type ManagedServer struct {
lastUsed time.Time
initErr error
client *Client
openDocs map[string]int
language protocol.Language
capabilities ServerCapabilities
mu sync.Mutex
ready bool
}
// ServerConfig contains the configuration for an LSP server.
type ServerConfig struct {
Command []string
Args []string
}
// DefaultServerConfigs contains default configurations for LSP servers.
var DefaultServerConfigs = map[protocol.Language]ServerConfig{
protocol.LangGo: {
Command: []string{"gopls"},
Args: []string{"serve"},
},
protocol.LangTypeScript: {
Command: []string{"typescript-language-server"},
Args: []string{"--stdio"},
},
protocol.LangJavaScript: {
Command: []string{"typescript-language-server"},
Args: []string{"--stdio"},
},
protocol.LangPython: {
Command: []string{"pylsp"},
},
protocol.LangC: {
Command: []string{"clangd"},
},
protocol.LangCpp: {
Command: []string{"clangd"},
},
}
// NewManager creates a new LSP manager.
func NewManager(workspaceRoot string, logger *slog.Logger) *Manager {
m := &Manager{
servers: make(map[protocol.Language]*ManagedServer),
timeout: 10 * time.Second,
idleTimeout: 5 * time.Minute,
workspaceRoot: workspaceRoot,
logger: logger,
stopReaper: make(chan struct{}),
}
// Start idle reaper
go m.reapIdleServers()
return m
}
// GetServer returns or creates an LSP server for the given language.
func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*ManagedServer, error) {
m.mu.RLock()
srv, exists := m.servers[lang]
m.mu.RUnlock()
if exists && srv.ready {
// Update lastUsed with server's own lock to avoid race condition
srv.mu.Lock()
srv.lastUsed = time.Now()
srv.mu.Unlock()
return srv, nil
}
// Create new server
m.mu.Lock()
defer m.mu.Unlock()
// Double-check after acquiring write lock
if srv, ok := m.servers[lang]; ok && srv.ready {
srv.mu.Lock()
srv.lastUsed = time.Now()
srv.mu.Unlock()
return srv, nil
}
// Check if server config exists
config, ok := DefaultServerConfigs[lang]
if !ok {
return nil, errors.New(errors.ErrLSPServerNotFound, fmt.Sprintf("no LSP server configured for language: %s", lang)).
WithContext("language", string(lang)).
WithRemediation("Configure an LSP server for this language or use a supported language")
}
// Check if command is available
cmdPath, err := exec.LookPath(config.Command[0])
if err != nil {
return nil, errors.NewLSPServerNotFound(string(lang), config.Command[0])
}
// Create command
args := append(config.Command[1:], config.Args...)
cmd := exec.CommandContext(ctx, cmdPath, args...)
cmd.Env = os.Environ()
cmd.Dir = m.workspaceRoot
// Create client
client, err := NewClient(cmd)
if err != nil {
// Ensure process is killed if client creation fails
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
return nil, errors.Wrap(errors.ErrLSPCommunication, "failed to create LSP client", err).
WithContext("language", string(lang)).
WithContext("command", config.Command[0]).
WithRemediation("Ensure the LSP server binary is executable and compatible with your system")
}
newSrv := &ManagedServer{
client: client,
language: lang,
lastUsed: time.Now(),
openDocs: make(map[string]int),
}
// Initialize server
if err := m.initializeServer(ctx, newSrv); err != nil {
_ = client.Close()
// Ensure process is killed on initialization failure
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
newSrv.initErr = err
return nil, errors.Wrap(errors.ErrLSPInitFailed, "LSP server initialization failed", err).
WithContext("language", string(lang)).
WithContext("command", config.Command[0]).
WithRemediation("Check LSP server logs for initialization errors")
}
newSrv.ready = true
m.servers[lang] = newSrv
m.logger.Info("started LSP server", "language", lang, "command", config.Command[0])
return newSrv, nil
}
// initializeServer performs the LSP initialization handshake.
func (m *Manager) initializeServer(ctx context.Context, srv *ManagedServer) error {
// Create context with timeout
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
// Build root URI
rootURI := "file://" + m.workspaceRoot
// Send initialize request
params := InitializeParams{
ProcessID: os.Getpid(),
RootURI: rootURI,
Capabilities: Capabilities{
TextDocument: TextDocumentClientCapabilities{
Hover: HoverCapability{
ContentFormat: []string{"markdown", "plaintext"},
},
Definition: DefinitionCapability{
LinkSupport: true,
},
References: ReferencesCapability{},
},
},
}
resp, err := srv.client.Call(ctx, "initialize", params)
if err != nil {
return fmt.Errorf("initialize failed: %w", err)
}
// Parse capabilities
var result InitializeResult
if err := json.Unmarshal(resp.Result, &result); err != nil {
return fmt.Errorf("failed to parse initialize result: %w", err)
}
srv.capabilities = result.Capabilities
// Send initialized notification
if err := srv.client.Notify("initialized", struct{}{}); err != nil {
return fmt.Errorf("initialized notification failed: %w", err)
}
return nil
}
// Hover performs a hover request at the given position.
func (m *Manager) Hover(ctx context.Context, file string, line, col int) (*HoverResult, error) {
lang := protocol.DetectLanguage(file)
srv, err := m.GetServer(ctx, lang)
if err != nil {
return nil, err
}
// Ensure document is open
err = m.ensureDocumentOpen(ctx, srv, file)
if err != nil {
return nil, err
}
params := HoverParams{
TextDocumentPositionParams: TextDocumentPositionParams{
TextDocument: TextDocumentIdentifier{
URI: fileToURI(file),
},
Position: Position{
Line: line - 1, // Convert to 0-indexed
Character: col - 1,
},
},
}
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
resp, err := srv.client.Call(ctx, "textDocument/hover", params)
if err != nil {
return nil, fmt.Errorf("hover request failed: %w", err)
}
if resp.Result == nil || string(resp.Result) == "null" {
return nil, nil // No hover info
}
var result HoverResult
if err := json.Unmarshal(resp.Result, &result); err != nil {
return nil, fmt.Errorf("failed to parse hover result: %w", err)
}
return &result, nil
}
// Definition finds the definition of the symbol at the given position.
func (m *Manager) Definition(ctx context.Context, file string, line, col int) ([]Location, error) {
lang := protocol.DetectLanguage(file)
srv, err := m.GetServer(ctx, lang)
if err != nil {
return nil, err
}
// Ensure document is open
err = m.ensureDocumentOpen(ctx, srv, file)
if err != nil {
return nil, err
}
params := DefinitionParams{
TextDocumentPositionParams: TextDocumentPositionParams{
TextDocument: TextDocumentIdentifier{
URI: fileToURI(file),
},
Position: Position{
Line: line - 1,
Character: col - 1,
},
},
}
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
resp, err := srv.client.Call(ctx, "textDocument/definition", params)
if err != nil {
return nil, fmt.Errorf("definition request failed: %w", err)
}
if resp.Result == nil || string(resp.Result) == "null" {
return nil, nil
}
// Result can be Location, []Location, or []LocationLink
var locations []Location
if err := json.Unmarshal(resp.Result, &locations); err != nil {
// Try single location
var single Location
if err := json.Unmarshal(resp.Result, &single); err == nil {
locations = []Location{single}
}
}
return locations, nil
}
// References finds all references to the symbol at the given position.
func (m *Manager) References(ctx context.Context, file string, line, col int, includeDeclaration bool) ([]Location, error) {
lang := protocol.DetectLanguage(file)
srv, err := m.GetServer(ctx, lang)
if err != nil {
return nil, err
}
// Ensure document is open
err = m.ensureDocumentOpen(ctx, srv, file)
if err != nil {
return nil, err
}
params := ReferenceParams{
TextDocumentPositionParams: TextDocumentPositionParams{
TextDocument: TextDocumentIdentifier{
URI: fileToURI(file),
},
Position: Position{
Line: line - 1,
Character: col - 1,
},
},
Context: ReferenceContext{
IncludeDeclaration: includeDeclaration,
},
}
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
resp, err := srv.client.Call(ctx, "textDocument/references", params)
if err != nil {
return nil, fmt.Errorf("references request failed: %w", err)
}
if resp.Result == nil || string(resp.Result) == "null" {
return nil, nil
}
var locations []Location
if err := json.Unmarshal(resp.Result, &locations); err != nil {
return nil, fmt.Errorf("failed to parse references result: %w", err)
}
return locations, nil
}
// ensureDocumentOpen opens a document if not already open.
func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, file string) error {
uri := fileToURI(file)
srv.mu.Lock()
if _, ok := srv.openDocs[uri]; ok {
srv.mu.Unlock()
return nil
}
srv.mu.Unlock()
// Read file content
content, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("failed to read file: %w", err)
}
// Get language ID
langID := languageToLSPID(srv.language)
params := DidOpenTextDocumentParams{
TextDocument: TextDocumentItem{
URI: uri,
LanguageID: langID,
Version: 1,
Text: string(content),
},
}
if err := srv.client.Notify("textDocument/didOpen", params); err != nil {
return fmt.Errorf("didOpen failed: %w", err)
}
srv.mu.Lock()
srv.openDocs[uri] = 1
srv.mu.Unlock()
return nil
}
// CloseDocument closes a document in the server.
func (m *Manager) CloseDocument(_ context.Context, lang protocol.Language, file string) error {
m.mu.RLock()
srv, ok := m.servers[lang]
m.mu.RUnlock()
if !ok || !srv.ready {
return nil
}
uri := fileToURI(file)
srv.mu.Lock()
if _, ok := srv.openDocs[uri]; !ok {
srv.mu.Unlock()
return nil
}
delete(srv.openDocs, uri)
srv.mu.Unlock()
params := DidCloseTextDocumentParams{
TextDocument: TextDocumentIdentifier{
URI: uri,
},
}
return srv.client.Notify("textDocument/didClose", params)
}
// reapIdleServers periodically closes idle servers.
func (m *Manager) reapIdleServers() {
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-m.stopReaper:
return
case <-ticker.C:
m.mu.Lock()
for lang, srv := range m.servers {
// Check lastUsed with server's lock to avoid race condition
srv.mu.Lock()
idle := time.Since(srv.lastUsed) > m.idleTimeout
srv.mu.Unlock()
if idle {
m.logger.Info("closing idle LSP server", "language", lang)
_ = srv.client.Close()
delete(m.servers, lang)
}
}
m.mu.Unlock()
}
}
}
// Close shuts down all LSP servers.
func (m *Manager) Close() error {
close(m.stopReaper)
m.mu.Lock()
defer m.mu.Unlock()
m.stopped = true
for lang, srv := range m.servers {
m.logger.Info("shutting down LSP server", "language", lang)
// Try graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
_, _ = srv.client.Call(ctx, "shutdown", nil)
cancel()
_ = srv.client.Notify("exit", nil)
_ = srv.client.Close()
}
m.servers = make(map[protocol.Language]*ManagedServer)
return nil
}
// IsAvailable checks if an LSP server is available for the given language.
func (m *Manager) IsAvailable(lang protocol.Language) bool {
config, ok := DefaultServerConfigs[lang]
if !ok {
return false
}
_, err := exec.LookPath(config.Command[0])
return err == nil
}
// fileToURI converts a file path to a file URI.
func fileToURI(file string) string {
absPath, err := filepath.Abs(file)
if err != nil {
return "file://" + file
}
return "file://" + absPath
}
// URIToFile converts a file URI to a file path.
func URIToFile(uri string) string {
if len(uri) > 7 && uri[:7] == "file://" {
return uri[7:]
}
return uri
}
// languageToLSPID converts a language to LSP language ID.
func languageToLSPID(lang protocol.Language) string {
switch lang {
case protocol.LangGo:
return "go"
case protocol.LangTypeScript:
return "typescript"
case protocol.LangJavaScript:
return "javascript"
case protocol.LangPython:
return "python"
case protocol.LangC:
return "c"
case protocol.LangCpp:
return "cpp"
default:
return string(lang)
}
}
+112
View File
@@ -0,0 +1,112 @@
package lsp
import (
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestFileToURI(t *testing.T) {
tests := []struct {
name string
file string
want string
}{
{
name: "absolute path",
file: "/Users/test/file.go",
want: "file:///Users/test/file.go",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := fileToURI(tt.file)
if got != tt.want {
t.Errorf("fileToURI() = %v, want %v", got, tt.want)
}
})
}
}
func TestURIToFile(t *testing.T) {
tests := []struct {
name string
uri string
want string
}{
{
name: "file uri",
uri: "file:///Users/test/file.go",
want: "/Users/test/file.go",
},
{
name: "not a file uri",
uri: "/Users/test/file.go",
want: "/Users/test/file.go",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := URIToFile(tt.uri)
if got != tt.want {
t.Errorf("URIToFile() = %v, want %v", got, tt.want)
}
})
}
}
func TestLanguageToLSPID(t *testing.T) {
tests := []struct {
lang protocol.Language
want string
}{
{protocol.LangGo, "go"},
{protocol.LangTypeScript, "typescript"},
{protocol.LangJavaScript, "javascript"},
{protocol.LangPython, "python"},
{protocol.LangC, "c"},
{protocol.LangCpp, "cpp"},
{protocol.LangUnknown, "unknown"},
}
for _, tt := range tests {
t.Run(string(tt.lang), func(t *testing.T) {
got := languageToLSPID(tt.lang)
if got != tt.want {
t.Errorf("languageToLSPID() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsAvailable(t *testing.T) {
// This tests the structure of the manager without actually spawning servers
// which requires the actual LSP servers to be installed
// Just verify the DefaultServerConfigs structure
expectedLanguages := []protocol.Language{
protocol.LangGo,
protocol.LangTypeScript,
protocol.LangJavaScript,
protocol.LangPython,
protocol.LangC,
protocol.LangCpp,
}
for _, lang := range expectedLanguages {
if _, ok := DefaultServerConfigs[lang]; !ok {
t.Errorf("missing server config for language: %s", lang)
}
}
}
func TestDefaultServerConfigs(t *testing.T) {
// Verify the command structure
for lang, config := range DefaultServerConfigs {
if len(config.Command) == 0 {
t.Errorf("language %s has empty command", lang)
}
}
}
+150
View File
@@ -0,0 +1,150 @@
package lsp
// InitializeParams are the parameters for the initialize request.
type InitializeParams struct {
RootURI string `json:"rootUri"`
Capabilities Capabilities `json:"capabilities"`
ProcessID int `json:"processId"`
}
// Capabilities represents client capabilities.
type Capabilities struct {
TextDocument TextDocumentClientCapabilities `json:"textDocument"`
}
// TextDocumentClientCapabilities represents text document capabilities.
type TextDocumentClientCapabilities struct {
Hover HoverCapability `json:"hover,omitempty"`
Definition DefinitionCapability `json:"definition,omitempty"`
References ReferencesCapability `json:"references,omitempty"`
}
// HoverCapability represents hover capabilities.
type HoverCapability struct {
ContentFormat []string `json:"contentFormat,omitempty"`
}
// DefinitionCapability represents definition capabilities.
type DefinitionCapability struct {
LinkSupport bool `json:"linkSupport,omitempty"`
}
// ReferencesCapability represents references capabilities.
type ReferencesCapability struct{}
// InitializeResult is the result of the initialize request.
type InitializeResult struct {
Capabilities ServerCapabilities `json:"capabilities"`
}
// ServerCapabilities represents server capabilities.
type ServerCapabilities struct {
HoverProvider bool `json:"hoverProvider,omitempty"`
DefinitionProvider bool `json:"definitionProvider,omitempty"`
ReferencesProvider bool `json:"referencesProvider,omitempty"`
DocumentSymbolProvider bool `json:"documentSymbolProvider,omitempty"`
TextDocumentSync int `json:"textDocumentSync,omitempty"`
}
// Position represents a position in a document.
type Position struct {
Line int `json:"line"` // 0-indexed
Character int `json:"character"` // 0-indexed
}
// Range represents a range in a document.
type Range struct {
Start Position `json:"start"`
End Position `json:"end"`
}
// Location represents a location in a document.
type Location struct {
URI string `json:"uri"`
Range Range `json:"range"`
}
// TextDocumentIdentifier identifies a text document.
type TextDocumentIdentifier struct {
URI string `json:"uri"`
}
// TextDocumentPositionParams represents position parameters.
type TextDocumentPositionParams struct {
TextDocument TextDocumentIdentifier `json:"textDocument"`
Position Position `json:"position"`
}
// HoverParams are the parameters for the hover request.
type HoverParams struct {
TextDocumentPositionParams
}
// HoverResult is the result of the hover request.
type HoverResult struct {
Range *Range `json:"range,omitempty"`
Contents MarkupContent `json:"contents"`
}
// MarkupContent represents markup content.
type MarkupContent struct {
Kind string `json:"kind"` // "plaintext" or "markdown"
Value string `json:"value"`
}
// DefinitionParams are the parameters for the definition request.
type DefinitionParams struct {
TextDocumentPositionParams
}
// ReferenceParams are the parameters for the references request.
type ReferenceParams struct {
TextDocumentPositionParams
Context ReferenceContext `json:"context"`
}
// ReferenceContext represents reference context.
type ReferenceContext struct {
IncludeDeclaration bool `json:"includeDeclaration"`
}
// TextDocumentItem represents a text document.
type TextDocumentItem struct {
URI string `json:"uri"`
LanguageID string `json:"languageId"`
Text string `json:"text"`
Version int `json:"version"`
}
// DidOpenTextDocumentParams are the parameters for didOpen.
type DidOpenTextDocumentParams struct {
TextDocument TextDocumentItem `json:"textDocument"`
}
// DidCloseTextDocumentParams are the parameters for didClose.
type DidCloseTextDocumentParams struct {
TextDocument TextDocumentIdentifier `json:"textDocument"`
}
// DocumentSymbol represents a symbol in a document.
type DocumentSymbol struct {
Name string `json:"name"`
Detail string `json:"detail,omitempty"`
Children []DocumentSymbol `json:"children,omitempty"`
Range Range `json:"range"`
SelectionRange Range `json:"selectionRange"`
Kind int `json:"kind"`
}
// SymbolInformation represents symbol information.
type SymbolInformation struct {
Name string `json:"name"`
ContainerName string `json:"containerName,omitempty"`
Location Location `json:"location"`
Kind int `json:"kind"`
}
// DocumentSymbolParams are the parameters for documentSymbol.
type DocumentSymbolParams struct {
TextDocument TextDocumentIdentifier `json:"textDocument"`
}
+190
View File
@@ -0,0 +1,190 @@
package parser
import (
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// FindNodeAtPosition finds the node at the given line and column.
func FindNodeAtPosition(tree *sitter.Tree, line, col int) *sitter.Node {
if tree == nil {
return nil
}
root := tree.RootNode()
if root == nil {
return nil
}
// Convert to 0-indexed
point := sitter.Point{
Row: uint32(line - 1), // #nosec G115 - line numbers are bounded by file size
Column: uint32(col - 1), // #nosec G115 - column numbers are bounded by line length
}
return findNodeAtPoint(root, point)
}
// findNodeAtPoint recursively finds the smallest node containing the point.
func findNodeAtPoint(node *sitter.Node, point sitter.Point) *sitter.Node {
if node == nil {
return nil
}
startPoint := node.StartPoint()
endPoint := node.EndPoint()
// Check if point is within this node
if !pointInRange(point, startPoint, endPoint) {
return nil
}
// Try to find a more specific child node
for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
if child == nil {
continue
}
if result := findNodeAtPoint(child, point); result != nil {
return result
}
}
// No child contains the point, return this node
return node
}
// pointInRange checks if a point is within a range.
func pointInRange(point, start, end sitter.Point) bool {
// Before start?
if point.Row < start.Row || (point.Row == start.Row && point.Column < start.Column) {
return false
}
// After end?
if point.Row > end.Row || (point.Row == end.Row && point.Column >= end.Column) {
return false
}
return true
}
// FindParentOfKind finds the nearest ancestor of the given node type.
func FindParentOfKind(node *sitter.Node, kind string) *sitter.Node {
if node == nil {
return nil
}
current := node.Parent()
for current != nil {
if current.Type() == kind {
return current
}
current = current.Parent()
}
return nil
}
// GetNodeText returns the text content of a node.
func GetNodeText(node *sitter.Node, content []byte) string {
if node == nil {
return ""
}
start := node.StartByte()
end := node.EndByte()
if int(start) >= len(content) || int(end) > len(content) {
return ""
}
return string(content[start:end])
}
// WalkTree walks the tree calling fn for each node.
// If fn returns false, the walk stops.
func WalkTree(node *sitter.Node, fn func(*sitter.Node) bool) {
if node == nil {
return
}
if !fn(node) {
return
}
for i := 0; i < int(node.ChildCount()); i++ {
WalkTree(node.Child(i), fn)
}
}
// FindNodesByKind finds all nodes of a given kind.
func FindNodesByKind(root *sitter.Node, kind string) []*sitter.Node {
var nodes []*sitter.Node
WalkTree(root, func(n *sitter.Node) bool {
if n.Type() == kind {
nodes = append(nodes, n)
}
return true
})
return nodes
}
// FindNamedChildren returns all named (non-anonymous) children of a node.
func FindNamedChildren(node *sitter.Node) []*sitter.Node {
if node == nil {
return nil
}
var children []*sitter.Node
for i := 0; i < int(node.NamedChildCount()); i++ {
if child := node.NamedChild(i); child != nil {
children = append(children, child)
}
}
return children
}
// GetChildByFieldName returns the child node with the given field name.
func GetChildByFieldName(node *sitter.Node, fieldName string) *sitter.Node {
if node == nil {
return nil
}
return node.ChildByFieldName(fieldName)
}
// NodeLocation returns the location of a node.
func NodeLocation(node *sitter.Node, filename string) protocol.Location {
if node == nil {
return protocol.Location{}
}
startPoint := node.StartPoint()
return protocol.Location{
File: filename,
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
}
}
// NodeRange returns the range of a node.
func NodeRange(node *sitter.Node, filename string) protocol.Range {
if node == nil {
return protocol.Range{}
}
startPoint := node.StartPoint()
endPoint := node.EndPoint()
return protocol.Range{
Start: protocol.Location{
File: filename,
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
},
End: protocol.Location{
File: filename,
Line: int(endPoint.Row) + 1,
Column: int(endPoint.Column) + 1,
},
}
}
+140
View File
@@ -0,0 +1,140 @@
package parser
import (
"context"
"fmt"
"testing"
)
// TestLRUCacheEviction tests that the LRU cache properly evicts old entries.
func TestLRUCacheEviction(t *testing.T) {
registry := NewRegistry()
ctx := context.Background()
// Create 101 unique Go files (cache size is 100)
for i := 0; i < 101; i++ {
content := []byte(fmt.Sprintf("package main\n\nfunc test%d() {}\n", i))
filename := "test.go"
_, err := registry.Parse(ctx, filename, content)
if err != nil {
t.Fatalf("Parse failed for iteration %d: %v", i, err)
}
}
// The LRU cache should have evicted the oldest entry
// Verify cache size is capped at 100
cacheLen := registry.cache.Len()
if cacheLen > 100 {
t.Errorf("Cache size %d exceeds max size 100", cacheLen)
}
}
// TestCacheHit tests that repeated parsing of the same content uses cache.
func TestCacheHit(t *testing.T) {
registry := NewRegistry()
ctx := context.Background()
content := []byte("package main\n\nfunc test() {}\n")
filename := "test.go"
// First parse
result1, err := registry.Parse(ctx, filename, content)
if err != nil {
t.Fatalf("First parse failed: %v", err)
}
// Second parse should use cache
result2, err := registry.Parse(ctx, filename, content)
if err != nil {
t.Fatalf("Second parse failed: %v", err)
}
// The tree should be the same object (cached)
if result1.Tree != result2.Tree {
t.Error("Expected cached tree to be reused, but got different tree objects")
}
}
// TestContentHashCollisionResistance tests that different content produces different hashes.
func TestContentHashCollisionResistance(t *testing.T) {
testCases := []struct {
name string
content1 []byte
content2 []byte
}{
{
name: "different content",
content1: []byte("package main"),
content2: []byte("package test"),
},
{
name: "same prefix different suffix",
content1: []byte("package main\nfunc a() {}"),
content2: []byte("package main\nfunc b() {}"),
},
{
name: "different length",
content1: []byte("short"),
content2: []byte("much longer content here"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hash1 := contentHash(tc.content1)
hash2 := contentHash(tc.content2)
if hash1 == hash2 {
t.Errorf("Hash collision: %s == %s for different content", hash1, hash2)
}
})
}
}
// TestContentHashConsistency tests that the same content always produces the same hash.
func TestContentHashConsistency(t *testing.T) {
content := []byte("package main\n\nfunc test() {}\n")
hash1 := contentHash(content)
hash2 := contentHash(content)
hash3 := contentHash(content)
if hash1 != hash2 || hash2 != hash3 {
t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3)
}
}
// BenchmarkContentHash_xxHash benchmarks the xxHash implementation.
func BenchmarkContentHash_xxHash(b *testing.B) {
// Typical file content size (10KB)
content := make([]byte, 10*1024)
for i := range content {
content[i] = byte(i % 256)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = contentHash(content)
}
}
// BenchmarkCacheHitRate benchmarks cache performance with realistic workload.
func BenchmarkCacheHitRate(b *testing.B) {
registry := NewRegistry()
ctx := context.Background()
// Create a set of common files that get parsed repeatedly
files := [][]byte{
[]byte("package main\n\nfunc main() {}\n"),
[]byte("package test\n\nimport \"testing\"\n"),
[]byte("package util\n\nfunc helper() string { return \"\" }\n"),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Simulate realistic access pattern with cache hits
content := files[i%len(files)]
_, _ = registry.Parse(ctx, "test.go", content)
}
}
+550
View File
@@ -0,0 +1,550 @@
// Package parser provides documentation extraction for multiple languages.
package parser
import (
"regexp"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// DocComment represents an extracted documentation comment.
type DocComment struct {
Tags map[string]string
Text string
Raw string
Style CommentStyle
StartLine int
EndLine int
}
// CommentStyle indicates the type of comment.
type CommentStyle string
const (
CommentStyleLine CommentStyle = "line" // // comment
CommentStyleBlock CommentStyle = "block" // /* comment */
CommentStyleJSDoc CommentStyle = "jsdoc" // /** comment */
CommentStyleDoxygen CommentStyle = "doxygen" // /** comment */ or /// comment
CommentStyleDocstring CommentStyle = "docstring" // """comment""" or '''comment'''
CommentStyleHash CommentStyle = "hash" // # comment (Python)
)
// ExtractDocComment extracts the documentation comment for a node.
func ExtractDocComment(n *sitter.Node, content []byte, lang protocol.Language) *DocComment {
if n == nil {
return nil
}
switch lang {
case protocol.LangGo:
return extractGoDocComment(n, content)
case protocol.LangTypeScript, protocol.LangJavaScript:
return extractJSDocComment(n, content)
case protocol.LangPython:
return extractPythonDocComment(n, content)
case protocol.LangC, protocol.LangCpp:
return extractCDocComment(n, content)
default:
return nil
}
}
// extractGoDocComment extracts Go documentation comments.
// Go uses // or /* */ comments immediately preceding a declaration.
func extractGoDocComment(n *sitter.Node, content []byte) *DocComment {
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
cleaned := cleanGoComment(text)
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: detectCommentStyle(raw[0]),
Tags: nil, // Go doesn't use JSDoc-style tags
StartLine: startLine,
EndLine: endLine,
}
}
// extractJSDocComment extracts JSDoc-style documentation comments.
func extractJSDocComment(n *sitter.Node, content []byte) *DocComment {
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
// JSDoc prefers the last comment block if it's a JSDoc comment
var jsDocComment *sitter.Node
for i := len(comments) - 1; i >= 0; i-- {
text := GetNodeText(comments[i], content)
if strings.HasPrefix(strings.TrimSpace(text), "/**") {
jsDocComment = comments[i]
break
}
}
if jsDocComment != nil {
text := GetNodeText(jsDocComment, content)
cleaned, tags := parseJSDoc(text)
return &DocComment{
Text: cleaned,
Raw: text,
Style: CommentStyleJSDoc,
Tags: tags,
StartLine: int(jsDocComment.StartPoint().Row) + 1,
EndLine: int(jsDocComment.EndPoint().Row) + 1,
}
}
// Fall back to regular comments
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
cleaned := cleanJSComment(text)
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: CommentStyleLine,
Tags: nil,
StartLine: startLine,
EndLine: endLine,
}
}
// extractPythonDocComment extracts Python docstrings.
// Python docstrings are triple-quoted strings inside the function/class body.
func extractPythonDocComment(n *sitter.Node, content []byte) *DocComment {
// Python docstrings are inside the body, not before
body := n.ChildByFieldName("body")
if body == nil {
return nil
}
// First statement should be the docstring if present
if body.NamedChildCount() > 0 {
first := body.NamedChild(0)
if first != nil && first.Type() == "expression_statement" {
if first.NamedChildCount() > 0 {
expr := first.NamedChild(0)
if expr != nil && expr.Type() == "string" {
text := GetNodeText(expr, content)
cleaned := cleanPythonDocstring(text)
return &DocComment{
Text: cleaned,
Raw: text,
Style: CommentStyleDocstring,
Tags: nil,
StartLine: int(expr.StartPoint().Row) + 1,
EndLine: int(expr.EndPoint().Row) + 1,
}
}
}
}
}
// Also check for # comments before the definition
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
// Clean # comment
cleaned := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(text), "#"))
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: CommentStyleHash,
Tags: nil,
StartLine: startLine,
EndLine: endLine,
}
}
// extractCDocComment extracts C/C++ documentation comments (Doxygen style).
func extractCDocComment(n *sitter.Node, content []byte) *DocComment {
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
// Look for Doxygen-style comment
var doxyComment *sitter.Node
for i := len(comments) - 1; i >= 0; i-- {
text := GetNodeText(comments[i], content)
trimmed := strings.TrimSpace(text)
if strings.HasPrefix(trimmed, "/**") || strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
doxyComment = comments[i]
break
}
}
if doxyComment != nil {
text := GetNodeText(doxyComment, content)
cleaned, tags := parseDoxygen(text)
return &DocComment{
Text: cleaned,
Raw: text,
Style: CommentStyleDoxygen,
Tags: tags,
StartLine: int(doxyComment.StartPoint().Row) + 1,
EndLine: int(doxyComment.EndPoint().Row) + 1,
}
}
// Fall back to regular comments
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
cleaned := cleanCComment(text)
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: detectCommentStyle(raw[0]),
Tags: nil,
StartLine: startLine,
EndLine: endLine,
}
}
// collectPrecedingComments collects all comment nodes immediately before a node.
func collectPrecedingComments(n *sitter.Node, _ []byte, commentTypes []string) []*sitter.Node {
var comments []*sitter.Node
// Walk backwards through siblings
prev := n.PrevSibling()
lastCommentLine := int(n.StartPoint().Row)
for prev != nil {
isComment := false
nodeType := prev.Type()
for _, ct := range commentTypes {
if nodeType == ct {
isComment = true
break
}
}
if !isComment {
break
}
commentEndLine := int(prev.EndPoint().Row)
// Check if there's a blank line gap
if lastCommentLine-commentEndLine > 1 {
break
}
comments = append([]*sitter.Node{prev}, comments...)
lastCommentLine = int(prev.StartPoint().Row)
prev = prev.PrevSibling()
}
return comments
}
// detectCommentStyle determines the style of a comment.
func detectCommentStyle(comment string) CommentStyle {
trimmed := strings.TrimSpace(comment)
if strings.HasPrefix(trimmed, "/**") {
return CommentStyleJSDoc
}
if strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
return CommentStyleDoxygen
}
if strings.HasPrefix(trimmed, "/*") {
return CommentStyleBlock
}
if strings.HasPrefix(trimmed, "//") {
return CommentStyleLine
}
if strings.HasPrefix(trimmed, "#") {
return CommentStyleHash
}
if strings.HasPrefix(trimmed, `"""`) || strings.HasPrefix(trimmed, `'''`) {
return CommentStyleDocstring
}
return CommentStyleLine
}
// cleanGoComment cleans a Go comment.
func cleanGoComment(comment string) string {
comment = strings.TrimSpace(comment)
// Handle // comments
if after, found := strings.CutPrefix(comment, "//"); found {
return strings.TrimSpace(after)
}
// Handle /* */ comments
if strings.HasPrefix(comment, "/*") && strings.HasSuffix(comment, "*/") {
comment = strings.TrimPrefix(comment, "/*")
comment = strings.TrimSuffix(comment, "*/")
return cleanBlockComment(comment)
}
return strings.TrimSpace(comment)
}
// cleanJSComment cleans a JavaScript/TypeScript comment.
func cleanJSComment(comment string) string {
return cleanGoComment(comment) // Same rules
}
// cleanCComment cleans a C/C++ comment.
func cleanCComment(comment string) string {
return cleanGoComment(comment) // Same rules
}
// cleanBlockComment cleans the content of a block comment.
func cleanBlockComment(comment string) string {
lines := strings.Split(comment, "\n")
var cleaned []string
for _, line := range lines {
line = strings.TrimSpace(line)
// Remove leading * from each line (common in block comments)
line = strings.TrimPrefix(line, "*")
line = strings.TrimSpace(line)
cleaned = append(cleaned, line)
}
// Remove empty leading/trailing lines
for len(cleaned) > 0 && cleaned[0] == "" {
cleaned = cleaned[1:]
}
for len(cleaned) > 0 && cleaned[len(cleaned)-1] == "" {
cleaned = cleaned[:len(cleaned)-1]
}
return strings.Join(cleaned, "\n")
}
// parseJSDoc parses a JSDoc comment and extracts tags.
func parseJSDoc(comment string) (string, map[string]string) {
comment = strings.TrimSpace(comment)
// Remove /** and */
comment = strings.TrimPrefix(comment, "/**")
comment = strings.TrimSuffix(comment, "*/")
lines := strings.Split(comment, "\n")
var descLines []string
tags := make(map[string]string)
// Regex for JSDoc tags
tagPattern := regexp.MustCompile(`^\s*\*?\s*@(\w+)\s*(.*)$`)
for _, line := range lines {
line = strings.TrimSpace(line)
line = strings.TrimPrefix(line, "*")
line = strings.TrimSpace(line)
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
tagName := matches[1]
tagValue := strings.TrimSpace(matches[2])
if existing, ok := tags[tagName]; ok {
tags[tagName] = existing + "\n" + tagValue
} else {
tags[tagName] = tagValue
}
} else if line != "" {
descLines = append(descLines, line)
}
}
return strings.Join(descLines, "\n"), tags
}
// parseDoxygen parses a Doxygen comment and extracts tags.
func parseDoxygen(comment string) (string, map[string]string) {
comment = strings.TrimSpace(comment)
// Handle /// and //! style comments
comment = strings.TrimPrefix(comment, "///")
comment = strings.TrimPrefix(comment, "//!")
// Handle /** */ style comments
comment = strings.TrimPrefix(comment, "/**")
comment = strings.TrimSuffix(comment, "*/")
lines := strings.Split(comment, "\n")
var descLines []string
tags := make(map[string]string)
// Regex for Doxygen tags (@param, @return, \param, \return, etc.)
tagPattern := regexp.MustCompile(`^\s*\*?\s*[@\\](\w+)\s*(.*)$`)
for _, line := range lines {
line = strings.TrimSpace(line)
line = strings.TrimPrefix(line, "*")
line = strings.TrimSpace(line)
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
tagName := matches[1]
tagValue := strings.TrimSpace(matches[2])
if existing, ok := tags[tagName]; ok {
tags[tagName] = existing + "\n" + tagValue
} else {
tags[tagName] = tagValue
}
} else if line != "" {
descLines = append(descLines, line)
}
}
return strings.Join(descLines, "\n"), tags
}
// FormatDocComment formats a DocComment for display.
func FormatDocComment(doc *DocComment) string {
if doc == nil || doc.Text == "" {
return ""
}
var sb strings.Builder
sb.WriteString(doc.Text)
if len(doc.Tags) > 0 {
sb.WriteString("\n\n")
// Order: description, params, returns, other
paramOrder := []string{"param", "parameter", "arg", "argument"}
returnOrder := []string{"return", "returns", "retval"}
// Write params first
for _, tagName := range paramOrder {
if val, ok := doc.Tags[tagName]; ok {
for _, line := range strings.Split(val, "\n") {
sb.WriteString("@" + tagName + " " + line + "\n")
}
}
}
// Write returns
for _, tagName := range returnOrder {
if val, ok := doc.Tags[tagName]; ok {
sb.WriteString("@" + tagName + " " + val + "\n")
}
}
// Write remaining tags
written := make(map[string]bool)
for _, t := range paramOrder {
written[t] = true
}
for _, t := range returnOrder {
written[t] = true
}
for tagName, val := range doc.Tags {
if !written[tagName] {
sb.WriteString("@" + tagName + " " + val + "\n")
}
}
}
return strings.TrimSpace(sb.String())
}
// cleanPythonDocstring cleans a Python docstring.
func cleanPythonDocstring(doc string) string {
doc = strings.TrimSpace(doc)
// Remove triple quotes
doc = strings.TrimPrefix(doc, `"""`)
doc = strings.TrimSuffix(doc, `"""`)
doc = strings.TrimPrefix(doc, `'''`)
doc = strings.TrimSuffix(doc, `'''`)
return strings.TrimSpace(doc)
}
+630
View File
@@ -0,0 +1,630 @@
package parser
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
func TestExtractGoDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "single line comment",
code: `package main
// Hello says hello
func Hello() {}
`,
nodeKind: "function_declaration",
wantText: "Hello says hello",
wantStyle: CommentStyleLine,
},
{
name: "multi-line comments",
code: `package main
// This is a function
// that does something
// important
func DoSomething() {}
`,
nodeKind: "function_declaration",
wantText: "This is a function\nthat does something\nimportant",
wantStyle: CommentStyleLine,
},
{
name: "block comment",
code: `package main
/* This is a block comment
describing the function */
func BlockCommented() {}
`,
nodeKind: "function_declaration",
wantText: "This is a block comment\ndescribing the function",
wantStyle: CommentStyleBlock,
},
{
name: "doc comment with asterisks",
code: `package main
/*
* This is a properly formatted
* block comment with asterisks
*/
func FormattedBlock() {}
`,
nodeKind: "function_declaration",
wantText: "This is a properly formatted\nblock comment with asterisks",
wantStyle: CommentStyleBlock,
},
{
name: "no comment",
code: `package main
func NoComment() {}
`,
nodeKind: "function_declaration",
wantText: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.go", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
// Find the target node
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangGo)
if tt.wantText == "" {
if doc != nil && doc.Text != "" {
t.Errorf("expected no doc, got %q", doc.Text)
}
return
}
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
})
}
}
func TestExtractJSDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
wantTags map[string]string
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "JSDoc comment",
code: `/**
* Adds two numbers together.
* @param a The first number
* @param b The second number
* @returns The sum of a and b
*/
function add(a, b) {
return a + b;
}
`,
nodeKind: "function_declaration",
wantText: "Adds two numbers together.",
wantStyle: CommentStyleJSDoc,
wantTags: map[string]string{
"param": "a The first number\nb The second number",
"returns": "The sum of a and b",
},
},
{
name: "simple line comment",
code: `// This is a simple function
function simple() {}
`,
nodeKind: "function_declaration",
wantText: "This is a simple function",
wantStyle: CommentStyleLine,
},
{
name: "JSDoc with types",
code: `/**
* @param {string} name - The name
* @returns {boolean} True if valid
*/
function validate(name) {}
`,
nodeKind: "function_declaration",
wantText: "",
wantStyle: CommentStyleJSDoc,
wantTags: map[string]string{
"param": "{string} name - The name",
"returns": "{boolean} True if valid",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.js", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangJavaScript)
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
if tt.wantTags != nil {
for k, want := range tt.wantTags {
if got := doc.Tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
}
})
}
}
func TestExtractPythonDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "docstring",
code: `def greet(name):
"""Greet a person by name."""
print(f"Hello, {name}!")
`,
nodeKind: "function_definition",
wantText: "Greet a person by name.",
wantStyle: CommentStyleDocstring,
},
{
name: "multi-line docstring",
code: `def calculate(x, y):
"""
Calculate the sum of two numbers.
Args:
x: First number
y: Second number
Returns:
The sum of x and y
"""
return x + y
`,
nodeKind: "function_definition",
wantText: "Calculate the sum of two numbers.\n\n Args:\n x: First number\n y: Second number\n\n Returns:\n The sum of x and y",
wantStyle: CommentStyleDocstring,
},
{
name: "class docstring",
code: `class MyClass:
"""This is a class description."""
pass
`,
nodeKind: "class_definition",
wantText: "This is a class description.",
wantStyle: CommentStyleDocstring,
},
{
name: "single quote docstring",
code: `def func():
'''Single quote docstring'''
pass
`,
nodeKind: "function_definition",
wantText: "Single quote docstring",
wantStyle: CommentStyleDocstring,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.py", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangPython)
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
})
}
}
func TestExtractCDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
wantTags map[string]string
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "Doxygen block comment",
code: `/**
* Adds two numbers.
* @param a First number
* @param b Second number
* @return Sum of a and b
*/
int add(int a, int b) {
return a + b;
}
`,
nodeKind: "function_definition",
wantText: "Adds two numbers.",
wantStyle: CommentStyleDoxygen,
wantTags: map[string]string{
"param": "a First number\nb Second number",
"return": "Sum of a and b",
},
},
{
name: "regular block comment",
code: `/* This is a regular comment */
int regular() { return 0; }
`,
nodeKind: "function_definition",
wantText: "This is a regular comment",
wantStyle: CommentStyleBlock,
},
{
name: "line comment",
code: `// Simple function
int simple() { return 1; }
`,
nodeKind: "function_definition",
wantText: "Simple function",
wantStyle: CommentStyleLine,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.c", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangC)
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
if tt.wantTags != nil {
for k, want := range tt.wantTags {
if got := doc.Tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
}
})
}
}
func TestParseJSDoc(t *testing.T) {
tests := []struct {
wantTags map[string]string
name string
input string
wantText string
}{
{
name: "complete jsdoc",
input: `/**
* This is a description.
* Multiple lines.
* @param {string} name The name
* @returns {boolean} Result
*/`,
wantText: "This is a description.\nMultiple lines.",
wantTags: map[string]string{
"param": "{string} name The name",
"returns": "{boolean} Result",
},
},
{
name: "empty jsdoc",
input: `/** */`,
wantText: "",
wantTags: map[string]string{},
},
{
name: "only description",
input: `/** Simple description */`,
wantText: "Simple description",
wantTags: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, tags := parseJSDoc(tt.input)
if text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
}
if len(tags) != len(tt.wantTags) {
t.Errorf("tag count mismatch: got %d, want %d", len(tags), len(tt.wantTags))
}
for k, want := range tt.wantTags {
if got := tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
})
}
}
func TestParseDoxygen(t *testing.T) {
tests := []struct {
wantTags map[string]string
name string
input string
wantText string
}{
{
name: "doxygen with @ tags",
input: `/**
* Brief description.
* @param x Value
* @return Result
*/`,
wantText: "Brief description.",
wantTags: map[string]string{
"param": "x Value",
"return": "Result",
},
},
{
name: "doxygen with backslash tags",
input: `/**
* Description.
* \param y Input
* \retval Output value
*/`,
wantText: "Description.",
wantTags: map[string]string{
"param": "y Input",
"retval": "Output value",
},
},
{
name: "triple slash",
input: `/// Simple description`,
wantText: "Simple description",
wantTags: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, tags := parseDoxygen(tt.input)
if text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
}
for k, want := range tt.wantTags {
if got := tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
})
}
}
func TestFormatDocComment(t *testing.T) {
tests := []struct {
name string
doc *DocComment
want string
}{
{
name: "with tags",
doc: &DocComment{
Text: "This is a function.",
Tags: map[string]string{
"param": "x The value",
"returns": "The result",
},
},
want: "This is a function.\n\n@param x The value\n@returns The result",
},
{
name: "no tags",
doc: &DocComment{
Text: "Simple description.",
Tags: nil,
},
want: "Simple description.",
},
{
name: "nil doc",
doc: nil,
want: "",
},
{
name: "empty text",
doc: &DocComment{
Text: "",
Tags: nil,
},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatDocComment(tt.doc)
if got != tt.want {
t.Errorf("mismatch:\ngot: %q\nwant: %q", got, tt.want)
}
})
}
}
func TestDetectCommentStyle(t *testing.T) {
tests := []struct {
input string
want CommentStyle
}{
{"/** JSDoc */", CommentStyleJSDoc},
{"/// Doxygen", CommentStyleDoxygen},
{"//! Doxygen", CommentStyleDoxygen},
{"/* block */", CommentStyleBlock},
{"// line", CommentStyleLine},
{"# hash", CommentStyleHash},
{`"""docstring"""`, CommentStyleDocstring},
{`'''docstring'''`, CommentStyleDocstring},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := detectCommentStyle(tt.input)
if got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
// findNodeByKind finds the first node of the given kind.
func findNodeByKind(root *sitter.Node, nodeType string) *sitter.Node {
if root == nil {
return nil
}
var result *sitter.Node
WalkTree(root, func(n *sitter.Node) bool {
if n.Type() == nodeType {
result = n
return false // stop walking
}
return true
})
return result
}
func TestCleanBlockComment(t *testing.T) {
tests := []struct {
input string
want string
}{
{
input: "\n * Line 1\n * Line 2\n ",
want: "Line 1\nLine 2",
},
{
input: "Simple",
want: "Simple",
},
{
input: "\n\nWith blank lines\n\n",
want: "With blank lines",
},
}
for _, tt := range tests {
t.Run(tt.input[:min(10, len(tt.input))], func(t *testing.T) {
got := cleanBlockComment(tt.input)
if got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
+271
View File
@@ -0,0 +1,271 @@
// Package parser provides Tree-sitter based parsing for multiple languages.
package parser
import (
"context"
"fmt"
"sync"
"github.com/cespare/xxhash/v2"
lru "github.com/hashicorp/golang-lru/v2"
sitter "github.com/smacker/go-tree-sitter"
"github.com/smacker/go-tree-sitter/c"
"github.com/smacker/go-tree-sitter/cpp"
"github.com/smacker/go-tree-sitter/golang"
"github.com/smacker/go-tree-sitter/html"
"github.com/smacker/go-tree-sitter/javascript"
"github.com/smacker/go-tree-sitter/python"
"github.com/smacker/go-tree-sitter/typescript/typescript"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// MaxFileSize is the maximum file size we'll parse (10MB).
const MaxFileSize = 10 * 1024 * 1024
// Registry manages Tree-sitter parsers for different languages.
type Registry struct {
parsers map[protocol.Language]*sitter.Parser
cache *lru.Cache[string, *CachedTree]
mu sync.RWMutex
}
// CachedTree stores a parsed tree with its metadata.
// Content is not stored to reduce memory usage.
type CachedTree struct {
Tree *sitter.Tree
Language protocol.Language
}
// ParseResult contains the result of parsing a file.
type ParseResult struct {
Tree *sitter.Tree
Language protocol.Language
Errors []SyntaxError
Content []byte
}
// SyntaxError represents a syntax error found during parsing.
type SyntaxError struct {
Message string
NodeType string
Location protocol.Location
}
// NewRegistry creates a new parser registry.
func NewRegistry() *Registry {
// Create LRU cache with capacity of 100 trees
cache, err := lru.New[string, *CachedTree](100)
if err != nil {
// LRU.New only errors if size <= 0, which won't happen here
panic(fmt.Sprintf("failed to create LRU cache: %v", err))
}
return &Registry{
parsers: make(map[protocol.Language]*sitter.Parser),
cache: cache,
}
}
// getLanguage returns the Tree-sitter language for a given language.
func getLanguage(lang protocol.Language) (*sitter.Language, error) {
switch lang {
case protocol.LangGo:
return golang.GetLanguage(), nil
case protocol.LangTypeScript:
return typescript.GetLanguage(), nil
case protocol.LangJavaScript:
return javascript.GetLanguage(), nil
case protocol.LangPython:
return python.GetLanguage(), nil
case protocol.LangC:
return c.GetLanguage(), nil
case protocol.LangCpp:
return cpp.GetLanguage(), nil
case protocol.LangHTML:
return html.GetLanguage(), nil
case protocol.LangVue:
// Vue SFC files use HTML-like template syntax, so we use the HTML parser
return html.GetLanguage(), nil
default:
return nil, errors.New(errors.ErrInvalidLanguage, fmt.Sprintf("language %s is not supported", lang)).
WithContext("language", string(lang)).
WithRemediation("Supported languages: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue")
}
}
// GetParser returns a parser for the given language.
func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
r.mu.RLock()
if p, ok := r.parsers[lang]; ok {
r.mu.RUnlock()
return p, nil
}
r.mu.RUnlock()
// Create new parser
r.mu.Lock()
defer r.mu.Unlock()
// Double-check after acquiring write lock
if p, ok := r.parsers[lang]; ok {
return p, nil
}
sitterLang, err := getLanguage(lang)
if err != nil {
return nil, err
}
parser := sitter.NewParser()
parser.SetLanguage(sitterLang)
r.parsers[lang] = parser
return parser, nil
}
// Parse parses the given content for the specified language.
func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
}
// Detect binary files
if isBinary(content) {
return nil, errors.New(errors.ErrParseFailed, "binary file detected").
WithContext("file", filename).
WithRemediation("This appears to be a binary file and cannot be parsed as source code")
}
// Detect language
lang := protocol.DetectLanguage(filename)
if lang == protocol.LangUnknown {
return nil, errors.New(errors.ErrInvalidLanguage, "could not detect language from filename").
WithContext("file", filename).
WithRemediation("Ensure file has a recognized extension (e.g., .go, .ts, .py, .c, .cpp, .html, .vue, .json, .yaml)")
}
// Handle YAML and JSON separately (they don't use tree-sitter)
switch lang {
case protocol.LangYAML:
return r.ParseYAML(ctx, filename, content)
case protocol.LangJSON:
return r.ParseJSON(ctx, filename, content)
}
// Check cache (LRU cache is thread-safe)
hash := contentHash(content)
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
errors := extractErrors(cached.Tree.RootNode(), content)
return &ParseResult{
Tree: cached.Tree,
Language: lang,
Errors: errors,
Content: content,
}, nil
}
// Get parser
parser, err := r.GetParser(lang)
if err != nil {
return nil, err
}
// Parse content - tree-sitter parsers are not thread-safe,
// so we need to hold the lock during parsing
r.mu.Lock()
tree, err := parser.ParseCtx(ctx, nil, content)
r.mu.Unlock()
if err != nil {
return nil, errors.NewParseError(string(lang), filename, err)
}
// Extract syntax errors
errors := extractErrors(tree.RootNode(), content)
// Cache result (LRU cache handles eviction automatically)
r.cache.Add(hash, &CachedTree{
Tree: tree,
Language: lang,
})
return &ParseResult{
Tree: tree,
Language: lang,
Errors: errors,
Content: content,
}, nil
}
// extractErrors finds all error nodes in the tree.
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
var errors []SyntaxError
var walk func(n *sitter.Node)
walk = func(n *sitter.Node) {
if n == nil {
return
}
if n.IsError() || n.IsMissing() {
startPoint := n.StartPoint()
nodeType := "ERROR"
if n.IsMissing() {
nodeType = "MISSING"
}
errors = append(errors, SyntaxError{
Location: protocol.Location{
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
},
Message: fmt.Sprintf("syntax error: unexpected %s", n.Type()),
NodeType: nodeType,
})
}
for i := 0; i < int(n.ChildCount()); i++ {
walk(n.Child(i))
}
}
walk(node)
return errors
}
// contentHash returns a fast hash of the content for caching.
// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes.
func contentHash(content []byte) string {
h := xxhash.Sum64(content)
return fmt.Sprintf("%016x", h)
}
// isBinary checks if content appears to be binary.
func isBinary(content []byte) bool {
// Check first 8000 bytes for null bytes
checkLen := min(8000, len(content))
for i := range checkLen {
if content[i] == 0 {
return true
}
}
return false
}
// Close closes all parsers and clears the cache.
func (r *Registry) Close() {
r.mu.Lock()
defer r.mu.Unlock()
for _, p := range r.parsers {
p.Close()
}
r.parsers = make(map[protocol.Language]*sitter.Parser)
// Purge LRU cache
r.cache.Purge()
}
+230
View File
@@ -0,0 +1,230 @@
package parser
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestNewRegistry(t *testing.T) {
r := NewRegistry()
if r == nil {
t.Fatal("expected non-nil registry")
}
defer r.Close()
}
func TestGetParser(t *testing.T) {
r := NewRegistry()
defer r.Close()
tests := []struct {
lang protocol.Language
wantErr bool
}{
{protocol.LangGo, false},
{protocol.LangTypeScript, false},
{protocol.LangJavaScript, false},
{protocol.LangPython, false},
{protocol.LangC, false},
{protocol.LangCpp, false},
{protocol.LangHTML, false},
{protocol.LangVue, false},
{protocol.LangUnknown, true},
}
for _, tt := range tests {
t.Run(string(tt.lang), func(t *testing.T) {
parser, err := r.GetParser(tt.lang)
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if parser == nil {
t.Error("expected non-nil parser")
}
}
})
}
}
func TestParse(t *testing.T) {
r := NewRegistry()
defer r.Close()
tests := []struct {
name string
filename string
content string
wantLang protocol.Language
wantErr bool
}{
{
name: "go file",
filename: "test.go",
content: "package main\n\nfunc main() {}\n",
wantLang: protocol.LangGo,
wantErr: false,
},
{
name: "typescript file",
filename: "test.ts",
content: "function hello(): void {}\n",
wantLang: protocol.LangTypeScript,
wantErr: false,
},
{
name: "react tsx file",
filename: "Component.tsx",
content: `import React from 'react';\n\nexport const Button: React.FC = () => <button className="btn">Click</button>;`,
wantLang: protocol.LangTypeScript,
wantErr: false,
},
{
name: "react jsx file",
filename: "Component.jsx",
content: `import React from 'react';\n\nexport const Button = () => <button className="btn">Click</button>;`,
wantLang: protocol.LangJavaScript,
wantErr: false,
},
{
name: "python file",
filename: "test.py",
content: "def hello():\n pass\n",
wantLang: protocol.LangPython,
wantErr: false,
},
{
name: "html file",
filename: "test.html",
content: `<!DOCTYPE html><html><head><title>Test</title></head><body><h1 class="text-xl">Hello</h1></body></html>`,
wantLang: protocol.LangHTML,
wantErr: false,
},
{
name: "vue file",
filename: "Component.vue",
content: `<template><div class="container"><h1>{{ title }}</h1></div></template>`,
wantLang: protocol.LangVue,
wantErr: false,
},
{
name: "unknown file",
filename: "test.txt",
content: "hello world",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
result, err := r.Parse(ctx, tt.filename, []byte(tt.content))
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Language != tt.wantLang {
t.Errorf("expected language %s, got %s", tt.wantLang, result.Language)
}
if result.Tree == nil {
t.Error("expected non-nil tree")
}
})
}
}
func TestParseWithSyntaxErrors(t *testing.T) {
r := NewRegistry()
defer r.Close()
// Invalid Go code
content := "package main\n\nfunc main( {}\n" // Missing closing paren
ctx := context.Background()
result, err := r.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have parsed (tree-sitter is error-tolerant)
if result.Tree == nil {
t.Error("expected non-nil tree")
}
// Should have detected errors
if len(result.Errors) == 0 {
t.Error("expected syntax errors to be detected")
}
}
func TestIsBinary(t *testing.T) {
tests := []struct {
name string
content []byte
want bool
}{
{
name: "text file",
content: []byte("hello world"),
want: false,
},
{
name: "binary with null byte",
content: []byte{0x68, 0x65, 0x6c, 0x00, 0x6f},
want: true,
},
{
name: "empty file",
content: []byte{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isBinary(tt.content); got != tt.want {
t.Errorf("isBinary() = %v, want %v", got, tt.want)
}
})
}
}
func TestCaching(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := []byte("package main\n\nfunc main() {}\n")
ctx := context.Background()
// Parse once
result1, err := r.Parse(ctx, "test.go", content)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
// Parse again with same content
result2, err := r.Parse(ctx, "test.go", content)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
// Should return cached tree (same pointer)
if result1.Tree != result2.Tree {
t.Error("expected cached tree to be returned")
}
}
+474
View File
@@ -0,0 +1,474 @@
package parser
import (
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// ExtractSymbols extracts symbols from a parsed tree.
func ExtractSymbols(tree *sitter.Tree, content []byte, lang protocol.Language, filename string) []protocol.Symbol {
if tree == nil {
return nil
}
root := tree.RootNode()
if root == nil {
return nil
}
switch lang {
case protocol.LangGo:
return extractGoSymbols(root, content, filename)
case protocol.LangTypeScript, protocol.LangJavaScript:
return extractJSSymbols(root, content, filename)
case protocol.LangPython:
return extractPythonSymbols(root, content, filename)
case protocol.LangC, protocol.LangCpp:
return extractCSymbols(root, content, filename)
default:
return nil
}
}
// extractGoSymbols extracts symbols from Go code.
func extractGoSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_declaration":
symbol = extractGoFunction(n, content, filename)
case "method_declaration":
symbol = extractGoMethod(n, content, filename)
case "type_declaration":
symbol = extractGoType(n, content, filename)
case "const_declaration", "var_declaration":
syms := extractGoVarConst(n, content, filename)
symbols = append(symbols, syms...)
return true
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangGo); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractGoFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func extractGoMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
// Get receiver type
receiver := n.ChildByFieldName("receiver")
receiverType := ""
if receiver != nil {
// Find the type in the receiver
WalkTree(receiver, func(node *sitter.Node) bool {
if node.Type() == "type_identifier" {
receiverType = GetNodeText(node, content)
return false
}
return true
})
}
name := GetNodeText(nameNode, content)
if receiverType != "" {
name = "(" + receiverType + ")." + name
}
return &protocol.Symbol{
Name: name,
Kind: protocol.SymbolMethod,
Location: NodeLocation(n, filename),
}
}
func extractGoType(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
// Find type_spec child
for i := 0; i < int(n.NamedChildCount()); i++ {
child := n.NamedChild(i)
if child != nil && child.Type() == "type_spec" {
nameNode := child.ChildByFieldName("name")
if nameNode == nil {
continue
}
kind := protocol.SymbolType
typeNode := child.ChildByFieldName("type")
if typeNode != nil {
switch typeNode.Type() {
case "struct_type":
kind = protocol.SymbolStruct
case "interface_type":
kind = protocol.SymbolInterface
}
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: kind,
Location: NodeLocation(child, filename),
}
}
}
return nil
}
func extractGoVarConst(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
kind := protocol.SymbolVariable
if n.Type() == "const_declaration" {
kind = protocol.SymbolConstant
}
WalkTree(n, func(node *sitter.Node) bool {
if node.Type() == "const_spec" || node.Type() == "var_spec" {
nameNode := node.ChildByFieldName("name")
if nameNode != nil {
symbols = append(symbols, protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: kind,
Location: NodeLocation(node, filename),
})
}
}
return true
})
return symbols
}
// extractJSSymbols extracts symbols from JavaScript/TypeScript code.
func extractJSSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_declaration":
symbol = extractJSFunction(n, content, filename)
case "class_declaration":
symbol = extractJSClass(n, content, filename)
case "method_definition":
symbol = extractJSMethod(n, content, filename)
case "lexical_declaration", "variable_declaration":
syms := extractJSVariable(n, content, filename)
symbols = append(symbols, syms...)
return true
case "interface_declaration":
symbol = extractTSInterface(n, content, filename)
case "type_alias_declaration":
symbol = extractTSTypeAlias(n, content, filename)
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangJavaScript); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractJSFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func extractJSClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolClass,
Location: NodeLocation(n, filename),
}
}
func extractJSMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolMethod,
Location: NodeLocation(n, filename),
}
}
func extractJSVariable(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(n, func(node *sitter.Node) bool {
if node.Type() == "variable_declarator" {
nameNode := node.ChildByFieldName("name")
if nameNode != nil {
symbols = append(symbols, protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolVariable,
Location: NodeLocation(node, filename),
})
}
}
return true
})
return symbols
}
func extractTSInterface(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolInterface,
Location: NodeLocation(n, filename),
}
}
func extractTSTypeAlias(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolType,
Location: NodeLocation(n, filename),
}
}
// extractPythonSymbols extracts symbols from Python code.
func extractPythonSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_definition":
symbol = extractPythonFunction(n, content, filename)
case "class_definition":
symbol = extractPythonClass(n, content, filename)
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangPython); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractPythonFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
// Check if this is a method (inside a class)
parent := n.Parent()
kind := protocol.SymbolFunction
if parent != nil && parent.Type() == "block" {
grandparent := parent.Parent()
if grandparent != nil && grandparent.Type() == "class_definition" {
kind = protocol.SymbolMethod
}
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: kind,
Location: NodeLocation(n, filename),
}
}
func extractPythonClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolClass,
Location: NodeLocation(n, filename),
}
}
// extractCSymbols extracts symbols from C/C++ code.
func extractCSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_definition":
symbol = extractCFunction(n, content, filename)
case "struct_specifier":
symbol = extractCStruct(n, content, filename)
case "class_specifier":
symbol = extractCppClass(n, content, filename)
case "declaration":
// Could be function declaration or variable
if hasFunctionDeclarator(n) {
symbol = extractCFunctionDecl(n, content, filename)
}
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangC); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractCFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
declarator := n.ChildByFieldName("declarator")
if declarator == nil {
return nil
}
// Find the function name within the declarator
var name string
WalkTree(declarator, func(node *sitter.Node) bool {
if node.Type() == "identifier" {
name = GetNodeText(node, content)
return false
}
return true
})
if name == "" {
return nil
}
return &protocol.Symbol{
Name: name,
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func extractCStruct(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolStruct,
Location: NodeLocation(n, filename),
}
}
func extractCppClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolClass,
Location: NodeLocation(n, filename),
}
}
func extractCFunctionDecl(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
declarator := n.ChildByFieldName("declarator")
if declarator == nil {
return nil
}
var name string
WalkTree(declarator, func(node *sitter.Node) bool {
if node.Type() == "identifier" {
name = GetNodeText(node, content)
return false
}
return true
})
if name == "" {
return nil
}
return &protocol.Symbol{
Name: name,
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func hasFunctionDeclarator(n *sitter.Node) bool {
found := false
WalkTree(n, func(node *sitter.Node) bool {
if node.Type() == "function_declarator" {
found = true
return false
}
return true
})
return found
}
+226
View File
@@ -0,0 +1,226 @@
package parser
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestExtractGoSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `package main
// Hello prints a greeting
func Hello() {
println("hello")
}
// Server handles requests
type Server struct {
Port int
}
// Start starts the server
func (s *Server) Start() error {
return nil
}
const MaxConnections = 100
var globalVar = "test"
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangGo, "test.go")
expectedSymbols := map[string]protocol.SymbolKind{
"Hello": protocol.SymbolFunction,
"Server": protocol.SymbolStruct,
"(Server).Start": protocol.SymbolMethod,
"MaxConnections": protocol.SymbolConstant,
"globalVar": protocol.SymbolVariable,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
func TestExtractJSSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `
function greet(name) {
console.log("Hello, " + name);
}
class User {
constructor(name) {
this.name = name;
}
getName() {
return this.name;
}
}
const MAX_USERS = 100;
let currentUser = null;
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.js", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangJavaScript, "test.js")
expectedSymbols := map[string]protocol.SymbolKind{
"greet": protocol.SymbolFunction,
"User": protocol.SymbolClass,
"MAX_USERS": protocol.SymbolVariable,
"currentUser": protocol.SymbolVariable,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
func TestExtractPythonSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `
def greet(name):
"""Greet a person by name."""
print(f"Hello, {name}")
class User:
"""Represents a user."""
def __init__(self, name):
self.name = name
def get_name(self):
return self.name
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.py", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangPython, "test.py")
expectedSymbols := map[string]protocol.SymbolKind{
"greet": protocol.SymbolFunction,
"User": protocol.SymbolClass,
"__init__": protocol.SymbolMethod,
"get_name": protocol.SymbolMethod,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
func TestExtractCSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `
#include <stdio.h>
struct Point {
int x;
int y;
};
void print_point(struct Point p) {
printf("(%d, %d)\n", p.x, p.y);
}
int main() {
struct Point p = {1, 2};
print_point(p);
return 0;
}
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.c", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangC, "test.c")
// Note: C symbol extraction is complex, checking for at least main and Point
expectedSymbols := map[string]protocol.SymbolKind{
"Point": protocol.SymbolStruct,
"main": protocol.SymbolFunction,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
+195
View File
@@ -0,0 +1,195 @@
// Package parser provides YAML and JSON parsing with AST support.
package parser
import (
"context"
"encoding/json"
"fmt"
"gopkg.in/yaml.v3"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// YAMLNode wraps yaml.Node to provide tree-sitter-like interface
type YAMLNode struct {
*yaml.Node
Content []byte
}
// JSONNode represents a JSON AST node
type JSONNode struct {
Value any
Type string
Children []*JSONNode
Line int
Column int
}
// ParseYAML parses YAML content and returns a tree-sitter-compatible result
func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
}
// Parse YAML
var root yaml.Node
if err := yaml.Unmarshal(content, &root); err != nil {
return nil, errors.NewParseError("yaml", filename, err)
}
// Extract syntax errors from YAML parse
syntaxErrors := extractYAMLErrors()
// Create a pseudo tree-sitter tree for compatibility
// We'll use nil for the tree since YAML doesn't use tree-sitter
return &ParseResult{
Tree: nil, // YAML uses yaml.Node instead
Language: protocol.LangYAML,
Errors: syntaxErrors,
Content: content,
}, nil
}
// ParseJSON parses JSON content and returns a tree-sitter-compatible result
func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
}
// Parse JSON to validate syntax
var jsonData any
if err := json.Unmarshal(content, &jsonData); err != nil {
return nil, errors.NewParseError("json", filename, err)
}
// JSON parsing succeeded, no syntax errors
return &ParseResult{
Tree: nil, // JSON uses native Go structures
Language: protocol.LangJSON,
Errors: []SyntaxError{},
Content: content,
}, nil
}
// extractYAMLErrors extracts errors from YAML nodes
func extractYAMLErrors() []SyntaxError {
// YAML parser already validates during unmarshal
// If we got here, there are no syntax errors
// However, we could add semantic validation here in the future
return []SyntaxError{}
}
// WalkYAML walks a YAML AST and calls fn for each node
func WalkYAML(node *yaml.Node, fn func(*yaml.Node) bool) {
if node == nil || !fn(node) {
return
}
for _, child := range node.Content {
WalkYAML(child, fn)
}
}
// GetYAMLNodeText returns the text representation of a YAML node
func GetYAMLNodeText(node *yaml.Node) string {
if node == nil {
return ""
}
switch node.Kind {
case yaml.DocumentNode:
if len(node.Content) > 0 {
return GetYAMLNodeText(node.Content[0])
}
return ""
case yaml.MappingNode:
return node.Value
case yaml.SequenceNode:
return node.Value
case yaml.ScalarNode:
return node.Value
case yaml.AliasNode:
return node.Value
default:
return ""
}
}
// GetYAMLNodeLocation returns the location of a YAML node
func GetYAMLNodeLocation(node *yaml.Node) protocol.Location {
if node == nil {
return protocol.Location{Line: 1, Column: 1}
}
return protocol.Location{
Line: node.Line,
Column: node.Column,
}
}
// QueryYAML performs a simple query on YAML content
// Example: "$.metadata.name" to find the name field in metadata
func QueryYAML(content []byte, query string) ([]*yaml.Node, error) {
var root yaml.Node
if err := yaml.Unmarshal(content, &root); err != nil {
return nil, fmt.Errorf("failed to parse YAML: %w", err)
}
// Simple path-based query implementation
// This is a basic implementation - can be extended with more sophisticated queries
var results []*yaml.Node
WalkYAML(&root, func(node *yaml.Node) bool {
if node.Value == query || node.Tag == query {
results = append(results, node)
}
return true
})
return results, nil
}
// QueryJSON performs a simple query on JSON content
func QueryJSON(content []byte, query string) ([]any, error) {
var data any
if err := json.Unmarshal(content, &data); err != nil {
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
// Basic implementation - can be extended with JSONPath support
var results []any
// For now, just validate that it's valid JSON
results = append(results, data)
return results, nil
}
// ValidateYAML validates YAML content without parsing to full AST
func ValidateYAML(content []byte) error {
var node yaml.Node
if err := yaml.Unmarshal(content, &node); err != nil {
return fmt.Errorf("YAML validation failed: %w", err)
}
return nil
}
// ValidateJSON validates JSON content
func ValidateJSON(content []byte) error {
var data any
if err := json.Unmarshal(content, &data); err != nil {
return fmt.Errorf("JSON validation failed: %w", err)
}
return nil
}
// ToSitterTree is a placeholder that returns nil for YAML/JSON
// These formats don't use tree-sitter, but we keep this for interface compatibility
func (yn *YAMLNode) ToSitterTree() *sitter.Tree {
return nil
}
+283
View File
@@ -0,0 +1,283 @@
package parser
import (
"context"
"testing"
"gopkg.in/yaml.v3"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestParseYAML(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
content string
shouldError bool
}{
{
name: "valid simple YAML",
content: `name: test
version: 1.0.0
enabled: true`,
shouldError: false,
},
{
name: "valid nested YAML",
content: `metadata:
name: test-app
namespace: default
spec:
replicas: 3
selector:
matchLabels:
app: test`,
shouldError: false,
},
{
name: "valid list YAML",
content: `items:
- name: item1
value: 100
- name: item2
value: 200`,
shouldError: false,
},
{
name: "invalid YAML - bad syntax",
content: `name: test\n bad: indent\n wrong: [unclosed`,
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.ParseYAML(context.Background(), "test.yaml", []byte(tt.content))
if tt.shouldError {
if err == nil {
t.Error("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result == nil {
t.Error("expected result but got nil")
return
}
if result.Language != protocol.LangYAML {
t.Errorf("expected language YAML, got %s", result.Language)
}
if len(result.Errors) > 0 {
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
}
})
}
}
func TestParseJSON(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
content string
shouldError bool
}{
{
name: "valid simple JSON",
content: `{"name": "test", "version": "1.0.0", "enabled": true}`,
shouldError: false,
},
{
name: "valid nested JSON",
content: `{
"metadata": {
"name": "test-app",
"namespace": "default"
},
"spec": {
"replicas": 3,
"selector": {
"matchLabels": {
"app": "test"
}
}
}
}`,
shouldError: false,
},
{
name: "valid array JSON",
content: `[{"name": "item1", "value": 100}, {"name": "item2", "value": 200}]`,
shouldError: false,
},
{
name: "invalid JSON - unclosed brace",
content: `{"name": "test", "value": 100`,
shouldError: true,
},
{
name: "invalid JSON - trailing comma",
content: `{"name": "test", "value": 100,}`,
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.ParseJSON(context.Background(), "test.json", []byte(tt.content))
if tt.shouldError {
if err == nil {
t.Error("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result == nil {
t.Error("expected result but got nil")
return
}
if result.Language != protocol.LangJSON {
t.Errorf("expected language JSON, got %s", result.Language)
}
if len(result.Errors) > 0 {
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
}
})
}
}
func TestRegistryParse_YAML_JSON(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
yamlContent := []byte(`name: test
version: 1.0.0`)
jsonContent := []byte(`{"name": "test", "version": "1.0.0"}`)
// Test YAML through main Parse method
yamlResult, err := registry.Parse(context.Background(), "config.yaml", yamlContent)
if err != nil {
t.Errorf("failed to parse YAML: %v", err)
}
if yamlResult.Language != protocol.LangYAML {
t.Errorf("expected YAML language, got %s", yamlResult.Language)
}
// Test JSON through main Parse method
jsonResult, err := registry.Parse(context.Background(), "config.json", jsonContent)
if err != nil {
t.Errorf("failed to parse JSON: %v", err)
}
if jsonResult.Language != protocol.LangJSON {
t.Errorf("expected JSON language, got %s", jsonResult.Language)
}
// Test .yml extension
ymlResult, err := registry.Parse(context.Background(), "config.yml", yamlContent)
if err != nil {
t.Errorf("failed to parse .yml: %v", err)
}
if ymlResult.Language != protocol.LangYAML {
t.Errorf("expected YAML language for .yml extension, got %s", ymlResult.Language)
}
}
func TestWalkYAML(t *testing.T) {
content := []byte(`metadata:
name: test
labels:
app: myapp
env: prod`)
var root yaml.Node
if err := yaml.Unmarshal(content, &root); err != nil {
t.Fatalf("failed to parse YAML: %v", err)
}
nodeCount := 0
WalkYAML(&root, func(node *yaml.Node) bool {
nodeCount++
return true
})
if nodeCount == 0 {
t.Error("expected to visit nodes, but count is 0")
}
}
func TestValidateYAML(t *testing.T) {
tests := []struct {
name string
content []byte
shouldError bool
}{
{
name: "valid YAML",
content: []byte("name: test\nvalue: 100"),
shouldError: false,
},
{
name: "invalid YAML",
content: []byte("name: test\n bad:\n[unclosed"),
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateYAML(tt.content)
if (err != nil) != tt.shouldError {
t.Errorf("ValidateYAML() error = %v, shouldError = %v", err, tt.shouldError)
}
})
}
}
func TestValidateJSON(t *testing.T) {
tests := []struct {
name string
content []byte
shouldError bool
}{
{
name: "valid JSON",
content: []byte(`{"name": "test", "value": 100}`),
shouldError: false,
},
{
name: "invalid JSON",
content: []byte(`{"name": "test", "value": 100`),
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateJSON(tt.content)
if (err != nil) != tt.shouldError {
t.Errorf("ValidateJSON() error = %v, shouldError = %v", err, tt.shouldError)
}
})
}
}
+538
View File
@@ -0,0 +1,538 @@
// Package query implements a hybrid AST query language with pattern matching.
package query
import (
"context"
"fmt"
"regexp"
"strings"
"sync"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// Global regex cache for compiled patterns (thread-safe)
var regexCache sync.Map // string -> *regexp.Regexp
// compileRegex compiles a regex pattern with caching for performance.
// Cached patterns avoid repeated compilation overhead (10-50x speedup).
// Thread-safe: uses LoadOrStore to prevent race conditions.
func compileRegex(pattern string) (*regexp.Regexp, error) {
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
}
// Compile regex
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
// Try to store - if another goroutine already stored it, use theirs
// This prevents race conditions where multiple goroutines compile the same pattern
actual, _ := regexCache.LoadOrStore(pattern, re)
return actual.(*regexp.Regexp), nil
}
// ASTQuery defines a query for matching AST patterns.
type ASTQuery struct {
Pattern string `json:"pattern"` // code pattern with $VAR placeholders
Language string `json:"language"` // required
Filters QueryFilters `json:"filters,omitempty"`
}
// QueryFilters provide additional filtering criteria.
type QueryFilters struct {
HasChild *ASTQuery `json:"has_child,omitempty"`
HasParent *ASTQuery `json:"has_parent,omitempty"`
NameMatches string `json:"name_matches,omitempty"`
NameExact string `json:"name_exact,omitempty"`
InFile string `json:"in_file,omitempty"`
NotInFile string `json:"not_in_file,omitempty"`
KindIn []string `json:"kind_in,omitempty"`
}
// MatchResult represents a single match from a query.
type MatchResult struct {
Node *sitter.Node
Captures map[string]CapturedNode
File string
Text string
Location protocol.Location
}
// CapturedNode represents a captured node or nodes.
type CapturedNode struct {
Text string
Nodes []*sitter.Node
}
// CaptureType indicates the type of capture.
type CaptureType int
const (
CaptureSingle CaptureType = iota // $NAME - single node
CaptureMultiple // $$$NAME - multiple nodes
CaptureWildcard // $_ - wildcard (don't capture)
)
// Capture represents a placeholder in a pattern.
type Capture struct {
Name string
Type CaptureType
Position int // position in the pattern
}
// ParsedPattern represents a parsed code pattern.
type ParsedPattern struct {
Original string
Template string
Captures []Capture
}
// Matcher performs AST pattern matching.
type Matcher struct {
registry *parser.Registry
}
// NewMatcher creates a new pattern matcher.
func NewMatcher(registry *parser.Registry) *Matcher {
return &Matcher{registry: registry}
}
// ParsePattern parses a pattern string and extracts captures.
func ParsePattern(pattern string) (*ParsedPattern, error) {
if pattern == "" {
return nil, fmt.Errorf("empty pattern")
}
var captures []Capture
template := pattern
captureID := 0
// Find all captures: $$$ (multi), $_ (wildcard), $NAME (single)
// Order matters: check $$$ first
multiRe := regexp.MustCompile(`\$\$\$([A-Za-z_][A-Za-z0-9_]*)`)
wildcardRe := regexp.MustCompile(`\$_`)
singleRe := regexp.MustCompile(`\$([A-Za-z_][A-Za-z0-9_]*)`)
// Extract multi-node captures ($$$NAME)
for _, match := range multiRe.FindAllStringSubmatchIndex(pattern, -1) {
name := pattern[match[2]:match[3]]
captures = append(captures, Capture{
Name: name,
Type: CaptureMultiple,
Position: match[0],
})
}
// Replace multi-captures with placeholder identifiers
template = multiRe.ReplaceAllStringFunc(template, func(s string) string {
captureID++
return fmt.Sprintf("__multi_%d__", captureID)
})
// Extract wildcards ($_)
for _, match := range wildcardRe.FindAllStringIndex(pattern, -1) {
captures = append(captures, Capture{
Name: "_",
Type: CaptureWildcard,
Position: match[0],
})
}
// Replace wildcards with placeholder identifiers
template = wildcardRe.ReplaceAllStringFunc(template, func(s string) string {
captureID++
return fmt.Sprintf("__wild_%d__", captureID)
})
// Extract single-node captures ($NAME) - exclude those that are part of $$$NAME
// Check which $NAME patterns are not preceded by $$
remaining := template
for _, match := range singleRe.FindAllStringSubmatchIndex(remaining, -1) {
name := remaining[match[2]:match[3]]
// Skip if this looks like our placeholder
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
continue
}
captures = append(captures, Capture{
Name: name,
Type: CaptureSingle,
Position: match[0],
})
}
// Replace single captures with placeholder identifiers
template = singleRe.ReplaceAllStringFunc(template, func(s string) string {
name := strings.TrimPrefix(s, "$")
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
return s // keep our placeholders as is
}
captureID++
return fmt.Sprintf("__single_%d__", captureID)
})
return &ParsedPattern{
Original: pattern,
Captures: captures,
Template: template,
}, nil
}
// Match executes a query against a parsed tree.
func (m *Matcher) Match(ctx context.Context, query *ASTQuery, tree *sitter.Tree, content []byte, filename string) ([]MatchResult, error) {
if query.Pattern == "" {
return nil, fmt.Errorf("query pattern is required")
}
lang := protocol.Language(query.Language)
if lang == "" || lang == protocol.LangUnknown {
return nil, fmt.Errorf("valid language is required")
}
// Parse the pattern
parsed, err := ParsePattern(query.Pattern)
if err != nil {
return nil, fmt.Errorf("invalid pattern: %w", err)
}
var results []MatchResult
// Walk the tree and find matches
root := tree.RootNode()
if root == nil {
return results, nil
}
parser.WalkTree(root, func(n *sitter.Node) bool {
// Check for context cancellation
select {
case <-ctx.Done():
return false
default:
}
// Try to match this node against the pattern
if matched, captures := matchNode(n, parsed, content); matched {
// Apply filters
if !passesFilters(n, query.Filters, content) {
return true // continue walking
}
startPoint := n.StartPoint()
results = append(results, MatchResult{
Node: n,
Captures: captures,
File: filename,
Location: protocol.Location{
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
},
Text: parser.GetNodeText(n, content),
})
}
return true
})
return results, nil
}
// matchNode attempts to match a node against a parsed pattern.
// This is a simplified matcher that looks for structural similarity.
func matchNode(node *sitter.Node, pattern *ParsedPattern, content []byte) (bool, map[string]CapturedNode) {
if node == nil {
return false, nil
}
captures := make(map[string]CapturedNode)
// Use pattern keyword matching as a heuristic to find matching nodes
// A full implementation would parse both pattern and node and compare AST structure
matched := matchPatternHeuristic(node, pattern, content, captures)
return matched, captures
}
// matchPatternHeuristic uses heuristics to match patterns.
// This is a simplified implementation that matches based on node type and structure.
func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content []byte, captures map[string]CapturedNode) bool {
patternLower := strings.ToLower(pattern.Original)
nodeType := node.Type()
// Match function patterns
if strings.Contains(patternLower, "func ") || strings.Contains(patternLower, "function ") {
if nodeType != "function_declaration" && nodeType != "method_declaration" && nodeType != "function_definition" {
return false
}
extractFunctionCaptures(node, pattern.Captures, content, captures)
return true
}
// Match class patterns
if strings.Contains(patternLower, "class ") {
if nodeType != "class_declaration" && nodeType != "class_definition" {
return false
}
extractClassCaptures(node, pattern.Captures, content, captures)
return true
}
// Match struct patterns (Go, C, C++)
if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") {
if nodeType != "type_declaration" && nodeType != "struct_specifier" {
return false
}
extractStructCaptures(node, pattern.Captures, content, captures)
return true
}
// Match interface patterns (Go, TypeScript)
if strings.Contains(patternLower, "interface ") {
if nodeType != "interface_declaration" && nodeType != "type_declaration" {
return false
}
extractInterfaceCaptures(node, pattern.Captures, content, captures)
return true
}
return false
}
// extractFunctionCaptures extracts captures from a function node.
func extractFunctionCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "ARGS", "args", "PARAMS", "params":
if paramsNode := node.ChildByFieldName("parameters"); paramsNode != nil {
var paramNodes []*sitter.Node
for i := 0; i < int(paramsNode.NamedChildCount()); i++ {
paramNodes = append(paramNodes, paramsNode.NamedChild(i))
}
captures[cap.Name] = CapturedNode{
Nodes: paramNodes,
Text: parser.GetNodeText(paramsNode, content),
}
}
case "BODY", "body":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
case "RETURN", "return", "RESULT", "result":
if resultNode := node.ChildByFieldName("result"); resultNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{resultNode},
Text: parser.GetNodeText(resultNode, content),
}
}
}
}
}
// extractClassCaptures extracts captures from a class node.
func extractClassCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "BODY", "body":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
case "EXTENDS", "extends", "SUPERCLASS", "superclass":
if extendsNode := node.ChildByFieldName("superclass"); extendsNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{extendsNode},
Text: parser.GetNodeText(extendsNode, content),
}
}
}
}
}
// extractStructCaptures extracts captures from a struct node.
func extractStructCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
// For Go type_declaration, we need to look at the type_spec child
if node.Type() == "type_declaration" {
for i := 0; i < int(node.NamedChildCount()); i++ {
child := node.NamedChild(i)
if child != nil && child.Type() == "type_spec" {
if nameNode := child.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
}
}
} else if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "FIELDS", "fields", "BODY", "body":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
}
}
}
// extractInterfaceCaptures extracts captures from an interface node.
func extractInterfaceCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "BODY", "body", "METHODS", "methods":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
}
}
}
// passesFilters checks if a node passes all the specified filters.
func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool {
// Name regex filter (uses cached compilation)
if filters.NameMatches != "" {
nameNode := node.ChildByFieldName("name")
if nameNode == nil {
return false
}
name := parser.GetNodeText(nameNode, content)
re, err := compileRegex(filters.NameMatches)
if err != nil {
return false
}
if !re.MatchString(name) {
return false
}
}
// Exact name filter
if filters.NameExact != "" {
nameNode := node.ChildByFieldName("name")
if nameNode == nil {
return false
}
name := parser.GetNodeText(nameNode, content)
if name != filters.NameExact {
return false
}
}
// Kind filter
if len(filters.KindIn) > 0 {
nodeType := node.Type()
found := false
for _, kind := range filters.KindIn {
if nodeType == kind {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// FormatResults formats match results for display.
func FormatResults(results []MatchResult, maxResults int) string {
if len(results) == 0 {
return "No matches found."
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results)))
displayCount := len(results)
truncated := false
if maxResults > 0 && displayCount > maxResults {
displayCount = maxResults
truncated = true
}
for i := 0; i < displayCount; i++ {
r := results[i]
nodeType := "unknown"
if r.Node != nil {
nodeType = r.Node.Type()
}
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
// Truncate very long text
text := r.Text
if len(text) > 500 {
text = text[:500] + "..."
}
sb.WriteString("```\n")
sb.WriteString(text)
sb.WriteString("\n```\n")
// Show captures
if len(r.Captures) > 0 {
sb.WriteString("Captures: ")
first := true
for name, cap := range r.Captures {
if !first {
sb.WriteString(", ")
}
first = false
capText := cap.Text
if len(capText) > 50 {
capText = capText[:50] + "..."
}
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
}
sb.WriteString("\n")
}
sb.WriteString("\n")
}
if truncated {
sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults))
}
return sb.String()
}
+559
View File
@@ -0,0 +1,559 @@
package query
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestParsePattern(t *testing.T) {
tests := []struct {
name string
pattern string
captureNames []string
captureTypes []CaptureType
wantCaptures int
wantErr bool
}{
{
name: "empty pattern",
pattern: "",
wantErr: true,
wantCaptures: 0,
},
{
name: "single capture",
pattern: "func $NAME() {}",
wantErr: false,
wantCaptures: 1,
captureNames: []string{"NAME"},
captureTypes: []CaptureType{CaptureSingle},
},
{
name: "multiple single captures",
pattern: "func $NAME($ARGS) $RETURN",
wantErr: false,
wantCaptures: 3,
captureNames: []string{"NAME", "ARGS", "RETURN"},
captureTypes: []CaptureType{CaptureSingle, CaptureSingle, CaptureSingle},
},
{
name: "multi-node capture",
pattern: "func $NAME($$$ARGS) { $$$BODY }",
wantErr: false,
wantCaptures: 3,
captureNames: []string{"ARGS", "BODY", "NAME"},
captureTypes: []CaptureType{CaptureMultiple, CaptureMultiple, CaptureSingle},
},
{
name: "wildcard capture",
pattern: "func $NAME($_) {}",
wantErr: false,
wantCaptures: 2,
captureNames: []string{"NAME", "_"},
captureTypes: []CaptureType{CaptureSingle, CaptureWildcard},
},
{
name: "no captures",
pattern: "func main() {}",
wantErr: false,
wantCaptures: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParsePattern(tt.pattern)
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(parsed.Captures) != tt.wantCaptures {
t.Errorf("expected %d captures, got %d", tt.wantCaptures, len(parsed.Captures))
}
// Check capture names (order may vary)
if tt.captureNames != nil {
captureMap := make(map[string]CaptureType)
for _, cap := range parsed.Captures {
captureMap[cap.Name] = cap.Type
}
for i, name := range tt.captureNames {
if _, ok := captureMap[name]; !ok {
t.Errorf("expected capture %s not found", name)
}
if captureMap[name] != tt.captureTypes[i] {
t.Errorf("capture %s: expected type %v, got %v", name, tt.captureTypes[i], captureMap[name])
}
}
}
})
}
}
func TestMatchGoFunctions(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
func Hello() {
println("hello")
}
func Greet(name string) error {
println("hello", name)
return nil
}
type Server struct {
Port int
}
func (s *Server) Start() error {
return nil
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMatches int
}{
{
name: "match all functions",
query: &ASTQuery{
Pattern: "func $NAME($$$ARGS) { $$$BODY }",
Language: "go",
},
wantMatches: 3, // Hello, Greet, Start
},
{
name: "match functions starting with H",
query: &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: QueryFilters{
NameMatches: "^H",
},
},
wantMatches: 1, // Hello
},
{
name: "match specific function",
query: &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: QueryFilters{
NameExact: "Hello",
},
},
wantMatches: 1, // Hello
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
for i, r := range results {
t.Logf("match %d: %s at line %d", i, r.Node.Type(), r.Location.Line)
}
}
})
}
}
func TestMatchGoStructs(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
type Server struct {
Port int
Host string
}
type Config struct {
Timeout int
}
type Logger interface {
Log(msg string)
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMinimum int
}{
{
name: "match all structs",
query: &ASTQuery{
Pattern: "type $NAME struct { $$$FIELDS }",
Language: "go",
},
wantMinimum: 2, // Server, Config (may also match interface as type_declaration)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) < tt.wantMinimum {
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
}
})
}
}
func TestMatchJSFunctions(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `
function greet(name) {
console.log("Hello, " + name);
}
function sayHello() {
console.log("Hello!");
}
class User {
constructor(name) {
this.name = name;
}
getName() {
return this.name;
}
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.js", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMatches int
}{
{
name: "match all functions",
query: &ASTQuery{
Pattern: "function $NAME($$$ARGS) { $$$BODY }",
Language: "javascript",
},
wantMatches: 2, // greet, sayHello
},
{
name: "match classes",
query: &ASTQuery{
Pattern: "class $NAME { $$$BODY }",
Language: "javascript",
},
wantMatches: 1, // User
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.js")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
}
})
}
}
func TestMatchPythonSymbols(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `
def greet(name):
print(f"Hello, {name}")
def calculate(a, b):
return a + b
class User:
def __init__(self, name):
self.name = name
def get_name(self):
return self.name
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.py", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMinimum int
}{
{
name: "match classes",
query: &ASTQuery{
Pattern: "class $NAME: $$$BODY",
Language: "python",
},
wantMinimum: 1, // User
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.py")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) < tt.wantMinimum {
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
}
})
}
}
func TestQueryFilters(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
func HelloWorld() {}
func helloWorld() {}
func GoodbyeWorld() {}
func Main() {}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
name string
filters QueryFilters
wantMatches int
}{
{
name: "regex filter - starts with H",
filters: QueryFilters{
NameMatches: "^[Hh]ello",
},
wantMatches: 2, // HelloWorld, helloWorld
},
{
name: "exact name filter",
filters: QueryFilters{
NameExact: "Main",
},
wantMatches: 1, // Main
},
{
name: "kind filter",
filters: QueryFilters{
KindIn: []string{"function_declaration"},
},
wantMatches: 4, // all functions
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: tt.filters,
}
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
for _, r := range results {
if nameNode := r.Node.ChildByFieldName("name"); nameNode != nil {
t.Logf("matched: %s", parser.GetNodeText(nameNode, []byte(content)))
}
}
}
})
}
}
func TestFormatResults(t *testing.T) {
tests := []struct {
name string
results []MatchResult
maxResults int
wantEmpty bool
}{
{
name: "empty results",
results: []MatchResult{},
maxResults: 100,
wantEmpty: true,
},
{
name: "single result",
results: []MatchResult{
{
File: "test.go",
Location: protocol.Location{Line: 10, Column: 1},
Text: "func Hello() {}",
Captures: map[string]CapturedNode{
"NAME": {Text: "Hello"},
},
},
},
maxResults: 100,
wantEmpty: false,
},
{
name: "truncated results",
results: []MatchResult{
{File: "a.go", Location: protocol.Location{Line: 1}, Text: "func A() {}"},
{File: "b.go", Location: protocol.Location{Line: 1}, Text: "func B() {}"},
{File: "c.go", Location: protocol.Location{Line: 1}, Text: "func C() {}"},
},
maxResults: 2,
wantEmpty: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output := FormatResults(tt.results, tt.maxResults)
if tt.wantEmpty {
if output != "No matches found." {
t.Errorf("expected 'No matches found.', got: %s", output)
}
} else {
if output == "No matches found." {
t.Error("expected results, got 'No matches found.'")
}
}
})
}
}
func TestQueryValidation(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
ctx := context.Background()
// Parse some valid content
content := `package main
func main() {}
`
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantErr bool
}{
{
name: "empty pattern",
query: &ASTQuery{Pattern: "", Language: "go"},
wantErr: true,
},
{
name: "missing language",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: ""},
wantErr: true,
},
{
name: "unknown language",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "unknown"},
wantErr: true,
},
{
name: "valid query",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "go"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
+198
View File
@@ -0,0 +1,198 @@
package query
import (
"regexp"
"sync"
"testing"
)
// TestCompileRegexCaching tests that regex compilation is cached.
func TestCompileRegexCaching(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
pattern := `^test_\w+$`
// First compilation
re1, err := compileRegex(pattern)
if err != nil {
t.Fatalf("First compile failed: %v", err)
}
// Second compilation should return cached version
re2, err := compileRegex(pattern)
if err != nil {
t.Fatalf("Second compile failed: %v", err)
}
// Should be the exact same object
if re1 != re2 {
t.Error("Expected cached regex to be reused, got different objects")
}
// Verify it's in the cache
cached, ok := regexCache.Load(pattern)
if !ok {
t.Error("Pattern not found in cache")
}
if cached.(*regexp.Regexp) != re1 {
t.Error("Cached regex doesn't match returned regex")
}
}
// TestCompileRegexConcurrent tests concurrent regex compilation.
func TestCompileRegexConcurrent(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
pattern := `[a-z]+_\d+`
const numGoroutines = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
results := make([]*regexp.Regexp, numGoroutines)
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
i := i
go func() {
defer wg.Done()
re, err := compileRegex(pattern)
if err != nil {
errors <- err
return
}
results[i] = re
}()
}
wg.Wait()
close(errors)
// Check for errors
for err := range errors {
t.Errorf("Concurrent compile failed: %v", err)
}
// All results should be the same object (cached)
for i := 1; i < numGoroutines; i++ {
if results[i] != results[0] {
t.Errorf("Result %d is different from result 0 (cache not working)", i)
}
}
}
// TestCompileRegexInvalidPattern tests error handling for invalid patterns.
func TestCompileRegexInvalidPattern(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
invalidPattern := `[invalid(`
_, err := compileRegex(invalidPattern)
if err == nil {
t.Error("Expected error for invalid pattern, got nil")
}
// Invalid patterns should not be cached
_, ok := regexCache.Load(invalidPattern)
if ok {
t.Error("Invalid pattern should not be cached")
}
}
// TestCompileRegexMultiplePatterns tests that different patterns are cached separately.
func TestCompileRegexMultiplePatterns(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
patterns := []string{
`^test_\w+$`,
`^\d{4}-\d{2}-\d{2}$`,
`^[A-Z][a-z]+$`,
`\b\w+@\w+\.\w+\b`,
}
compiled := make([]*regexp.Regexp, len(patterns))
// Compile all patterns
for i, pattern := range patterns {
re, err := compileRegex(pattern)
if err != nil {
t.Fatalf("Compile failed for pattern %s: %v", pattern, err)
}
compiled[i] = re
}
// Verify all are cached
for i, pattern := range patterns {
cached, ok := regexCache.Load(pattern)
if !ok {
t.Errorf("Pattern %s not in cache", pattern)
}
if cached.(*regexp.Regexp) != compiled[i] {
t.Errorf("Cached regex for %s doesn't match compiled version", pattern)
}
}
// All should be different objects
for i := 0; i < len(compiled); i++ {
for j := i + 1; j < len(compiled); j++ {
if compiled[i] == compiled[j] {
t.Errorf("Pattern %d and %d have same regex object", i, j)
}
}
}
}
// BenchmarkCompileRegex_Uncached benchmarks regex compilation without caching.
func BenchmarkCompileRegex_Uncached(b *testing.B) {
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = regexp.Compile(pattern)
}
}
// BenchmarkCompileRegex_Cached benchmarks regex compilation with caching.
func BenchmarkCompileRegex_Cached(b *testing.B) {
// Clear cache
regexCache = sync.Map{}
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
// Pre-populate cache
_, _ = compileRegex(pattern)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = compileRegex(pattern)
}
}
// BenchmarkCompileRegex_MixedPatterns benchmarks realistic workload with multiple patterns.
func BenchmarkCompileRegex_MixedPatterns(b *testing.B) {
// Clear cache
regexCache = sync.Map{}
patterns := []string{
`^test_\w+$`,
`^\d{4}-\d{2}-\d{2}$`,
`^[A-Z][a-z]+$`,
`\b\w+@\w+\.\w+\b`,
`^func\s+\w+\(`,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Simulate realistic access pattern
pattern := patterns[i%len(patterns)]
_, _ = compileRegex(pattern)
}
}
+401
View File
@@ -0,0 +1,401 @@
// Package search provides text search functionality using ripgrep.
package search
import (
"bufio"
"bytes"
"context"
"fmt"
"log/slog"
"os/exec"
"path/filepath"
"strings"
json "github.com/goccy/go-json"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// Searcher provides text search functionality using ripgrep.
type Searcher struct {
cfg *config.Config
logger *slog.Logger
rgPath string
}
// Request represents a search request.
type Request struct {
Pattern string
Paths []string
FileTypes []string
ContextLines int
MaxResults int
IgnoreCase bool
Regex bool
IncludeHidden bool
FollowSymlinks bool
}
// Result represents a single search result.
type Result struct {
File string `json:"file"`
MatchText string `json:"match_text"`
Language protocol.Language `json:"language"`
Context ContextLines `json:"context"`
Line int `json:"line"`
Column int `json:"column"`
}
// ContextLines holds lines before and after a match.
type ContextLines struct {
Before []string `json:"before"`
After []string `json:"after"`
}
// SearchResults holds the complete search results.
type SearchResults struct {
Results []Result `json:"results"`
Truncated bool `json:"truncated"`
Total int `json:"total"`
}
// ripgrep JSON output types
type rgMessage struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
type rgMatch struct {
Path struct {
Text string `json:"text"`
} `json:"path"`
Lines struct {
Text string `json:"text"`
} `json:"lines"`
Submatches []struct {
Match struct {
Text string `json:"text"`
} `json:"match"`
Start int `json:"start"`
End int `json:"end"`
} `json:"submatches"`
LineNumber int `json:"line_number"`
AbsoluteOffset int `json:"absolute_offset"`
}
type rgContext struct {
Path struct {
Text string `json:"text"`
} `json:"path"`
Lines struct {
Text string `json:"text"`
} `json:"lines"`
LineNumber int `json:"line_number"`
}
type rgSummary struct {
ElapsedTotal struct {
Secs int `json:"secs"`
Nanos int `json:"nanos"`
} `json:"elapsed_total"`
Stats struct {
Searches int `json:"searches"`
SearchesWithMatch int `json:"searches_with_match"`
BytesSearched int64 `json:"bytes_searched"`
BytesPrinted int64 `json:"bytes_printed"`
MatchedLines int `json:"matched_lines"`
Matches int `json:"matches"`
} `json:"stats"`
}
// New creates a new Searcher instance.
func New(cfg *config.Config, logger *slog.Logger) (*Searcher, error) {
// Detect ripgrep binary
rgPath, err := exec.LookPath("rg")
if err != nil {
return nil, errors.NewRipgrepNotFound()
}
return &Searcher{
cfg: cfg,
logger: logger,
rgPath: rgPath,
}, nil
}
// Search executes a search and returns results.
func (s *Searcher) Search(ctx context.Context, req *Request) (*SearchResults, error) {
if req.Pattern == "" {
return nil, errors.New(errors.ErrInvalidPattern, "pattern cannot be empty").
WithRemediation("Provide a non-empty search pattern")
}
// Build ripgrep command
args := s.buildArgs(req)
s.logger.Debug("executing ripgrep", "args", args)
// Create command with timeout
ctx, cancel := context.WithTimeout(ctx, s.cfg.SearchTimeout)
defer cancel()
cmd := exec.CommandContext(ctx, s.rgPath, args...) // #nosec G204 - rgPath is validated at initialization
// Set working directory to workspace root
cmd.Dir = s.cfg.WorkspaceRoot
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// Run command - ripgrep returns exit code 1 for no matches, which is not an error
err := cmd.Run()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, errors.NewSearchTimeout(req.Pattern, s.cfg.SearchTimeout.String())
}
// Exit code 1 means no matches, which is fine
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
return &SearchResults{Results: []Result{}, Total: 0}, nil
}
// Exit code 2 means error
if stderr.Len() > 0 {
return nil, errors.Wrap(errors.ErrSearchFailed, "ripgrep search failed", err).
WithContext("pattern", req.Pattern).
WithContext("stderr", stderr.String()).
WithRemediation("Check search pattern syntax and ensure files are readable")
}
return nil, errors.Wrap(errors.ErrSearchFailed, "ripgrep search failed", err).
WithContext("pattern", req.Pattern).
WithRemediation("Check search pattern syntax and ensure ripgrep is functioning correctly")
}
// Parse JSON output
return s.parseOutput(&stdout, req.MaxResults)
}
// buildArgs builds the ripgrep command arguments.
func (s *Searcher) buildArgs(req *Request) []string {
args := []string{"--json"}
// Add context lines
if req.ContextLines > 0 {
args = append(args, fmt.Sprintf("--context=%d", req.ContextLines))
}
// File type filtering
for _, ft := range req.FileTypes {
args = append(args, "--type", ft)
}
// Case sensitivity
if req.IgnoreCase {
args = append(args, "--ignore-case")
}
// Fixed strings (non-regex)
if !req.Regex {
args = append(args, "--fixed-strings")
}
// Follow symlinks
if req.FollowSymlinks || s.cfg.FollowSymlinks {
args = append(args, "--follow")
}
// Include hidden files
if req.IncludeHidden {
args = append(args, "--hidden")
}
// Respect .gitignore (default behavior for rg)
if !s.cfg.RespectGitignore {
args = append(args, "--no-ignore")
}
// Max count per file to limit results
if req.MaxResults > 0 {
args = append(args, fmt.Sprintf("--max-count=%d", req.MaxResults))
}
// Add pattern
args = append(args, "--", req.Pattern)
// Add paths (default to current directory which is workspace root)
if len(req.Paths) > 0 {
for _, p := range req.Paths {
// Validate path is within workspace
if s.cfg.IsPathAllowed(p) {
args = append(args, p)
}
}
} else {
args = append(args, ".")
}
return args
}
// parseOutput parses ripgrep JSON output.
func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchResults, error) {
results := &SearchResults{
Results: []Result{},
}
// Track context by file and line
contextBefore := make(map[string][]string) // file -> lines before current match
currentFile := ""
scanner := bufio.NewScanner(output)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
var msg rgMessage
if err := json.Unmarshal(line, &msg); err != nil {
s.logger.Debug("failed to parse ripgrep output line", "error", err, "line", string(line))
continue
}
switch msg.Type {
case "match":
var match rgMatch
if err := json.Unmarshal(msg.Data, &match); err != nil {
continue
}
// Check max results
if maxResults > 0 && len(results.Results) >= maxResults {
results.Truncated = true
continue
}
result := Result{
File: match.Path.Text,
Line: match.LineNumber,
MatchText: strings.TrimRight(match.Lines.Text, "\n\r"),
Language: protocol.DetectLanguage(match.Path.Text),
}
// Add column from first submatch
if len(match.Submatches) > 0 {
result.Column = match.Submatches[0].Start + 1 // 1-indexed
}
// Add context before
if ctx, ok := contextBefore[match.Path.Text]; ok {
result.Context.Before = ctx
delete(contextBefore, match.Path.Text)
}
results.Results = append(results.Results, result)
currentFile = match.Path.Text
case "context":
var ctx rgContext
if err := json.Unmarshal(msg.Data, &ctx); err != nil {
continue
}
lineText := strings.TrimRight(ctx.Lines.Text, "\n\r")
// Determine if this is before or after context
if len(results.Results) > 0 {
lastResult := &results.Results[len(results.Results)-1]
if lastResult.File == ctx.Path.Text && ctx.LineNumber > lastResult.Line {
// This is after context
lastResult.Context.After = append(lastResult.Context.After, lineText)
} else if ctx.Path.Text == currentFile || currentFile == "" {
// This is before context for a potential upcoming match
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
}
} else {
// Before any match - store as potential before context
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
}
case "summary":
var summary rgSummary
if err := json.Unmarshal(msg.Data, &summary); err != nil {
continue
}
results.Total = summary.Stats.Matches
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading ripgrep output: %w", err)
}
return results, nil
}
// FormatResults formats search results for display.
func (s *Searcher) FormatResults(results *SearchResults) string {
if len(results.Results) == 0 {
return "No matches found."
}
var sb strings.Builder
// Group results by file
fileResults := make(map[string][]Result)
var fileOrder []string
for _, r := range results.Results {
if _, exists := fileResults[r.File]; !exists {
fileOrder = append(fileOrder, r.File)
}
fileResults[r.File] = append(fileResults[r.File], r)
}
// Write summary
totalMatches := len(results.Results)
fileCount := len(fileResults)
sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount))
if results.Truncated {
sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total))
}
sb.WriteString(":\n\n")
// Write results grouped by file
for _, file := range fileOrder {
// Make path relative to workspace root if possible
relPath := file
if absPath, err := filepath.Abs(file); err == nil {
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
relPath = rel
}
}
sb.WriteString(fmt.Sprintf("**%s**\n", relPath))
for _, r := range fileResults[file] {
// Write context before
for _, ctx := range r.Context.Before {
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
}
// Write match line
sb.WriteString(fmt.Sprintf("L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200)))
// Write context after
for _, ctx := range r.Context.After {
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
}
}
sb.WriteString("\n")
}
return sb.String()
}
// truncateLine truncates a line if it exceeds maxLen.
func truncateLine(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen-3] + "..."
}
+326
View File
@@ -0,0 +1,326 @@
package search
import (
"bytes"
"context"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
)
func TestNew(t *testing.T) {
cfg := config.Default()
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
searcher, err := New(cfg, logger)
if err != nil {
// ripgrep might not be installed
if strings.Contains(err.Error(), "not found") {
t.Skip("ripgrep not installed, skipping test")
}
t.Fatalf("New failed: %v", err)
}
if searcher == nil {
t.Fatal("expected non-nil searcher")
}
}
func TestBuildArgs(t *testing.T) {
cfg := config.Default()
cfg.WorkspaceRoot = "/test/workspace"
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Create searcher without checking for rg binary
s := &Searcher{
cfg: cfg,
logger: logger,
rgPath: "/usr/bin/rg",
}
tests := []struct {
name string
req *Request
expected []string
}{
{
name: "basic search",
req: &Request{
Pattern: "test",
ContextLines: 2,
Regex: true,
},
expected: []string{"--json", "--context=2", "--", "test", "."},
},
{
name: "ignore case",
req: &Request{
Pattern: "test",
IgnoreCase: true,
Regex: true,
},
expected: []string{"--json", "--ignore-case", "--", "test", "."},
},
{
name: "fixed strings",
req: &Request{
Pattern: "test",
Regex: false,
},
expected: []string{"--json", "--fixed-strings", "--", "test", "."},
},
{
name: "with file types",
req: &Request{
Pattern: "test",
FileTypes: []string{"go", "ts"},
Regex: true,
},
expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."},
},
{
name: "with max results",
req: &Request{
Pattern: "test",
MaxResults: 10,
Regex: true,
},
expected: []string{"--json", "--max-count=10", "--", "test", "."},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args := s.buildArgs(tt.req)
// Check that all expected args are present
for _, exp := range tt.expected {
found := false
for _, arg := range args {
if arg == exp {
found = true
break
}
}
if !found {
t.Errorf("expected arg %q not found in %v", exp, args)
}
}
})
}
}
func TestFormatResults(t *testing.T) {
cfg := config.Default()
cfg.WorkspaceRoot = "/test/workspace"
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
s := &Searcher{
cfg: cfg,
logger: logger,
rgPath: "/usr/bin/rg",
}
tests := []struct {
name string
results *SearchResults
contains []string
}{
{
name: "empty results",
results: &SearchResults{
Results: []Result{},
},
contains: []string{"No matches found"},
},
{
name: "single result",
results: &SearchResults{
Results: []Result{
{
File: "test.go",
Line: 10,
Column: 5,
MatchText: "func TestSomething()",
},
},
Total: 1,
},
contains: []string{"test.go", "L10", "TestSomething"},
},
{
name: "truncated results",
results: &SearchResults{
Results: []Result{
{
File: "test.go",
Line: 10,
MatchText: "match",
},
},
Truncated: true,
Total: 100,
},
contains: []string{"truncated", "100"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output := s.FormatResults(tt.results)
for _, exp := range tt.contains {
if !strings.Contains(output, exp) {
t.Errorf("expected output to contain %q, got:\n%s", exp, output)
}
}
})
}
}
func TestParseOutput(t *testing.T) {
cfg := config.Default()
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
s := &Searcher{
cfg: cfg,
logger: logger,
rgPath: "/usr/bin/rg",
}
// Sample ripgrep JSON output
jsonOutput := `{"type":"begin","data":{"path":{"text":"test.go"}}}
{"type":"match","data":{"path":{"text":"test.go"},"lines":{"text":"func TestSomething() {\n"},"line_number":10,"absolute_offset":100,"submatches":[{"match":{"text":"Test"},"start":5,"end":9}]}}
{"type":"end","data":{"path":{"text":"test.go"},"stats":{"bytes_searched":1000}}}
{"type":"summary","data":{"elapsed_total":{"secs":0,"nanos":1000000},"stats":{"searches":1,"searches_with_match":1,"bytes_searched":1000,"bytes_printed":100,"matched_lines":1,"matches":1}}}
`
buf := bytes.NewBufferString(jsonOutput)
results, err := s.parseOutput(buf, 0)
if err != nil {
t.Fatalf("parseOutput failed: %v", err)
}
if len(results.Results) != 1 {
t.Errorf("expected 1 result, got %d", len(results.Results))
}
if results.Results[0].File != "test.go" {
t.Errorf("expected file 'test.go', got %q", results.Results[0].File)
}
if results.Results[0].Line != 10 {
t.Errorf("expected line 10, got %d", results.Results[0].Line)
}
if results.Results[0].Column != 6 { // 1-indexed
t.Errorf("expected column 6, got %d", results.Results[0].Column)
}
}
func TestTruncateLine(t *testing.T) {
tests := []struct {
input string
expected string
maxLen int
}{
{
input: "short",
maxLen: 10,
expected: "short",
},
{
input: "this is a very long line that should be truncated",
maxLen: 20,
expected: "this is a very lo...",
},
{
input: "exact",
maxLen: 5,
expected: "exact",
},
}
for _, tt := range tests {
result := truncateLine(tt.input, tt.maxLen)
if result != tt.expected {
t.Errorf("truncateLine(%q, %d) = %q, want %q", tt.input, tt.maxLen, result, tt.expected)
}
}
}
func TestSearchIntegration(t *testing.T) {
// Create a temporary directory with test files
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-search-test")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test files
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func main() {
println("Hello, World!")
}
`
err = os.WriteFile(testFile, []byte(content), 0600)
if err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := config.Default()
cfg.WorkspaceRoot = tmpDir
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
searcher, err := New(cfg, logger)
if err != nil {
t.Skip("ripgrep not installed, skipping integration test")
}
ctx := context.Background()
req := &Request{
Pattern: "Hello",
ContextLines: 1,
Regex: false,
}
results, err := searcher.Search(ctx, req)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
if len(results.Results) != 1 {
t.Errorf("expected 1 result, got %d", len(results.Results))
}
if len(results.Results) > 0 && !strings.Contains(results.Results[0].MatchText, "Hello") {
t.Errorf("expected match to contain 'Hello', got %q", results.Results[0].MatchText)
}
}
func TestSearchEmptyPattern(t *testing.T) {
cfg := config.Default()
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
s := &Searcher{
cfg: cfg,
logger: logger,
rgPath: "/usr/bin/rg",
}
ctx := context.Background()
req := &Request{
Pattern: "",
}
_, err := s.Search(ctx, req)
if err == nil {
t.Error("expected error for empty pattern")
}
}
+993
View File
@@ -0,0 +1,993 @@
// Package server implements the MCP server for file operations.
package server
import (
"context"
"fmt"
"log/slog"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/lukaszraczylo/mcp-filepuff/internal/edit"
"github.com/lukaszraczylo/mcp-filepuff/internal/lsp"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// Server represents the MCP file operations server.
type Server struct {
cfg *config.Config
logger *slog.Logger
mcp *server.MCPServer
searcher *search.Searcher
parser *parser.Registry
matcher *query.Matcher
lspManager *lsp.Manager
editor *edit.Engine
}
// New creates a new MCP server instance.
func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
parserRegistry := parser.NewRegistry()
s := &Server{
cfg: cfg,
logger: logger,
parser: parserRegistry,
matcher: query.NewMatcher(parserRegistry),
editor: edit.NewEngine(parserRegistry),
}
// Initialize searcher
searcher, err := search.New(cfg, logger)
if err != nil {
logger.Warn("ripgrep not available, search functionality disabled", "error", err)
}
s.searcher = searcher
// Initialize LSP manager if enabled
if cfg.EnableLSP {
s.lspManager = lsp.NewManager(cfg.WorkspaceRoot, logger)
}
// Create MCP server
mcpServer := server.NewMCPServer(
"mcp-filepuff",
"1.0.0",
server.WithLogging(),
)
s.mcp = mcpServer
// Register tools
s.registerTools()
return s, nil
}
// registerTools registers all available tools with the MCP server.
func (s *Server) registerTools() {
// Register ping tool for health checks
s.mcp.AddTool(
mcp.NewTool("ping",
mcp.WithDescription("Health check - returns pong to verify the server is running"),
mcp.WithReadOnlyHintAnnotation(true),
),
s.handlePing,
)
// Register file_search tool
if s.searcher != nil {
s.mcp.AddTool(
mcp.NewTool("file_search",
mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("pattern",
mcp.Required(),
mcp.Description("The search pattern (regex by default)"),
),
mcp.WithArray("paths",
mcp.Description("Paths to search in (defaults to workspace root)"),
mcp.WithStringItems(),
),
mcp.WithArray("file_types",
mcp.Description("File types to search (e.g., ['go', 'ts', 'py'])"),
mcp.WithStringItems(),
),
mcp.WithBoolean("ignore_case",
mcp.Description("Case insensitive search"),
),
mcp.WithBoolean("regex",
mcp.Description("Treat pattern as regex (default: true)"),
),
mcp.WithNumber("context_lines",
mcp.Description("Number of context lines around matches (default: 2)"),
),
mcp.WithNumber("max_results",
mcp.Description("Maximum number of results to return"),
),
),
s.handleFileSearch,
)
}
// Register file_read tool
s.mcp.AddTool(
mcp.NewTool("file_read",
mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary"),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("path",
mcp.Required(),
mcp.Description("Path to the file to read"),
),
mcp.WithNumber("line_start",
mcp.Description("Starting line number (1-indexed)"),
),
mcp.WithNumber("line_end",
mcp.Description("Ending line number (inclusive)"),
),
mcp.WithBoolean("include_ast",
mcp.Description("Include AST symbol summary (functions, classes, types, etc.)"),
),
mcp.WithBoolean("symbols_only",
mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true."),
),
mcp.WithNumber("max_lines",
mcp.Description("Maximum number of lines to return (for token efficiency). Applied after line_start/line_end."),
),
),
s.handleFileRead,
)
// Register ast_query tool
s.mcp.AddTool(
mcp.NewTool("ast_query",
mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("pattern",
mcp.Required(),
mcp.Description("Code pattern with placeholders: $NAME (single), $$$ARGS (multiple), $_ (wildcard). Examples: 'func $NAME($$$ARGS) error', 'class $NAME { $$$BODY }'"),
),
mcp.WithString("language",
mcp.Required(),
mcp.Description("Target language: go, typescript, javascript, python, c, cpp"),
),
mcp.WithArray("paths",
mcp.Description("Paths to search in (defaults to workspace root)"),
mcp.WithStringItems(),
),
mcp.WithString("name_matches",
mcp.Description("Regex pattern to filter by name"),
),
mcp.WithString("name_exact",
mcp.Description("Exact name to match"),
),
mcp.WithArray("kind_in",
mcp.Description("Node types to match (e.g., function_declaration, class_declaration)"),
mcp.WithStringItems(),
),
mcp.WithNumber("max_results",
mcp.Description("Maximum number of results to return (default: 100)"),
),
),
s.handleASTQuery,
)
// Register LSP-based tools if LSP is enabled
if s.lspManager != nil {
// Register symbol_at tool
s.mcp.AddTool(
mcp.NewTool("symbol_at",
mcp.WithDescription("Get information about the symbol at a specific position in a file. Returns type, documentation, and definition location using LSP when available."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file"),
),
mcp.WithNumber("line",
mcp.Required(),
mcp.Description("Line number (1-indexed)"),
),
mcp.WithNumber("column",
mcp.Required(),
mcp.Description("Column number (1-indexed)"),
),
),
s.handleSymbolAt,
)
// Register find_definition tool
s.mcp.AddTool(
mcp.NewTool("find_definition",
mcp.WithDescription("Find the definition of the symbol at a specific position. Uses LSP to locate where a function, variable, type, etc. is defined."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file"),
),
mcp.WithNumber("line",
mcp.Required(),
mcp.Description("Line number (1-indexed)"),
),
mcp.WithNumber("column",
mcp.Required(),
mcp.Description("Column number (1-indexed)"),
),
),
s.handleFindDefinition,
)
// Register find_references tool
s.mcp.AddTool(
mcp.NewTool("find_references",
mcp.WithDescription("Find all references to the symbol at a specific position. Uses LSP to locate all usages of a function, variable, type, etc."),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file"),
),
mcp.WithNumber("line",
mcp.Required(),
mcp.Description("Line number (1-indexed)"),
),
mcp.WithNumber("column",
mcp.Required(),
mcp.Description("Column number (1-indexed)"),
),
mcp.WithBoolean("include_declaration",
mcp.Description("Include the declaration in results (default: true)"),
),
),
s.handleFindReferences,
)
}
// Register edit tools
s.mcp.AddTool(
mcp.NewTool("edit_preview",
mcp.WithDescription("Preview an edit without applying it. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++), and text-based editing for other files (Markdown, JSON, YAML, config files, etc.)."),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file to edit"),
),
mcp.WithString("operation",
mcp.Required(),
mcp.Description("Edit operation: replace, insert_before, insert_after, delete"),
),
mcp.WithString("new_content",
mcp.Description("New content (required for replace/insert operations)"),
),
// AST-mode selectors (for code files)
mcp.WithString("selector_kind",
mcp.Description("AST node type to match (e.g., function_declaration, class_declaration). For code files only."),
),
mcp.WithString("selector_name",
mcp.Description("Name of the symbol to match. For code files only."),
),
// Shared selectors
mcp.WithNumber("selector_line",
mcp.Description("Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range."),
),
mcp.WithNumber("selector_index",
mcp.Description("Index of the match to use if multiple matches found (default: 0)"),
),
// Text-mode selectors (for non-code files or explicit text matching)
mcp.WithNumber("selector_line_end",
mcp.Description("End line number for range selection (text mode). Used with selector_line."),
),
mcp.WithString("selector_text",
mcp.Description("Exact text to match (text mode). Must be unique or use selector_index."),
),
mcp.WithString("selector_pattern",
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
),
),
s.handleEditPreview,
)
s.mcp.AddTool(
mcp.NewTool("edit_apply",
mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.)."),
mcp.WithString("file",
mcp.Required(),
mcp.Description("Path to the file to edit"),
),
mcp.WithString("operation",
mcp.Required(),
mcp.Description("Edit operation: replace, insert_before, insert_after, delete"),
),
mcp.WithString("new_content",
mcp.Description("New content (required for replace/insert operations)"),
),
// AST-mode selectors (for code files)
mcp.WithString("selector_kind",
mcp.Description("AST node type to match (e.g., function_declaration, class_declaration). For code files only."),
),
mcp.WithString("selector_name",
mcp.Description("Name of the symbol to match. For code files only."),
),
// Shared selectors
mcp.WithNumber("selector_line",
mcp.Description("Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range."),
),
mcp.WithNumber("selector_index",
mcp.Description("Index of the match to use if multiple matches found (default: 0)"),
),
// Text-mode selectors (for non-code files or explicit text matching)
mcp.WithNumber("selector_line_end",
mcp.Description("End line number for range selection (text mode). Used with selector_line."),
),
mcp.WithString("selector_text",
mcp.Description("Exact text to match (text mode). Must be unique or use selector_index."),
),
mcp.WithString("selector_pattern",
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
),
),
s.handleEditApply,
)
}
// handlePing handles the ping health check tool.
func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("pong"), nil
}
// handleFileSearch handles the file_search tool.
func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
start := time.Now()
defer func() {
s.logger.Debug("file_search completed",
"duration_ms", time.Since(start).Milliseconds(),
)
}()
if s.searcher == nil {
return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil
}
// Parse request arguments using SDK helpers
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
}
req := &search.Request{
Pattern: pattern,
Paths: request.GetStringSlice("paths", nil),
FileTypes: request.GetStringSlice("file_types", nil),
IgnoreCase: request.GetBool("ignore_case", false),
Regex: request.GetBool("regex", true),
ContextLines: request.GetInt("context_lines", 2),
MaxResults: request.GetInt("max_results", 0),
}
// Execute search
results, err := s.searcher.Search(ctx, req)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("search error: %v", err)), nil
}
s.logger.Info("search completed",
"pattern", pattern,
"results_count", len(results.Results),
"truncated", results.Truncated,
)
// Format results
output := s.searcher.FormatResults(results)
return mcp.NewToolResultText(output), nil
}
// handleFileRead handles the file_read tool.
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
path, err := request.RequireString("path")
if err != nil {
return mcp.NewToolResultError("path is required"), nil
}
// Validate path is within workspace
if !s.cfg.IsPathAllowed(path) {
return mcp.NewToolResultError("path is outside workspace root"), nil
}
// Read file
content, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
}
if os.IsPermission(err) {
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
}
return mcp.NewToolResultError(fmt.Sprintf("error reading file: %v", err)), nil
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
}
// Handle line range
lines := splitLines(string(content))
lineStart := request.GetInt("line_start", 1)
lineEnd := request.GetInt("line_end", len(lines))
// Clamp to valid range
if lineStart < 1 {
lineStart = 1
}
if lineEnd > len(lines) {
lineEnd = len(lines)
}
if lineStart > lineEnd {
lineStart = lineEnd
}
var output strings.Builder
// Include AST summary if requested
includeAST := request.GetBool("include_ast", false)
symbolsOnly := request.GetBool("symbols_only", false)
maxLines := request.GetInt("max_lines", 0)
// Validate symbols_only requires include_ast
if symbolsOnly && !includeAST {
return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil
}
if includeAST {
astSummary := s.generateASTSummary(ctx, path, content)
if astSummary != "" {
output.WriteString(astSummary)
if !symbolsOnly {
output.WriteString("\n---\n\n")
}
}
}
// Skip file content if symbols_only mode
if !symbolsOnly {
// Apply max_lines limit if specified
effectiveEnd := lineEnd
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
effectiveEnd = lineStart + maxLines - 1
if effectiveEnd < lineEnd {
// Add note that output was truncated
defer func() {
output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd))
}()
}
}
// Extract requested lines
for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ {
output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i]))
}
}
return mcp.NewToolResultText(output.String()), nil
}
// generateASTSummary generates a summary of symbols in the file.
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
// Parse the file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return "" // Silently skip AST if parsing fails
}
// Extract symbols
lang := protocol.DetectLanguage(path)
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
if len(symbols) == 0 {
return ""
}
var sb strings.Builder
// Get relative path
relPath := path
if absPath, err := filepath.Abs(path); err == nil {
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
relPath = rel
}
}
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
sb.WriteString("Symbols:\n")
for _, sym := range symbols {
kindStr := symbolKindIcon(sym.Kind)
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
}
return sb.String()
}
// symbolKindIcon returns an icon/prefix for a symbol kind.
func symbolKindIcon(kind protocol.SymbolKind) string {
switch kind {
case protocol.SymbolFunction:
return "func"
case protocol.SymbolMethod:
return "meth"
case protocol.SymbolClass:
return "class"
case protocol.SymbolStruct:
return "struct"
case protocol.SymbolInterface:
return "iface"
case protocol.SymbolVariable:
return "var"
case protocol.SymbolConstant:
return "const"
case protocol.SymbolType:
return "type"
case protocol.SymbolField:
return "field"
case protocol.SymbolProperty:
return "prop"
case protocol.SymbolModule:
return "mod"
case protocol.SymbolPackage:
return "pkg"
default:
return "sym"
}
}
func splitLines(s string) []string {
// Use optimized stdlib implementation (2-3x faster than manual loop)
return strings.Split(s, "\n")
}
// handleASTQuery handles the ast_query tool.
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
}
language, err := request.RequireString("language")
if err != nil {
return mcp.NewToolResultError("language is required"), nil
}
// Build query
astQuery := &query.ASTQuery{
Pattern: pattern,
Language: language,
Filters: query.QueryFilters{
NameMatches: request.GetString("name_matches", ""),
NameExact: request.GetString("name_exact", ""),
KindIn: request.GetStringSlice("kind_in", nil),
},
}
maxResults := request.GetInt("max_results", 100)
paths := request.GetStringSlice("paths", nil)
// Default to workspace root if no paths specified
if len(paths) == 0 {
paths = []string{s.cfg.WorkspaceRoot}
}
// Find files to search based on language
ext := languageToExtension(language)
if ext == "" {
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
}
var allResults []query.MatchResult
// Walk through paths and find matching files
for _, searchPath := range paths {
// Validate path is within workspace
if !s.cfg.IsPathAllowed(searchPath) {
continue
}
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Skip files with errors
}
if info.IsDir() {
// Skip hidden directories
if strings.HasPrefix(info.Name(), ".") {
return filepath.SkipDir
}
return nil
}
// Check file extension matches language
if !strings.HasSuffix(path, ext) {
return nil
}
// Read and parse file
content, err := os.ReadFile(path)
if err != nil {
return nil // Skip unreadable files
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return nil // Skip large files
}
// Parse file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return nil // Skip unparseable files
}
// Run query
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
if err != nil {
return nil // Skip on error
}
allResults = append(allResults, matches...)
// Stop if we have enough results
if maxResults > 0 && len(allResults) >= maxResults {
return filepath.SkipAll
}
return nil
})
if err != nil {
s.logger.Warn("error walking path", "path", searchPath, "error", err)
}
}
// Format and return results
output := query.FormatResults(allResults, maxResults)
return mcp.NewToolResultText(output), nil
}
// languageToExtension maps language names to file extensions.
func languageToExtension(language string) string {
switch strings.ToLower(language) {
case "go":
return ".go"
case "typescript":
return ".ts"
case "javascript":
return ".js"
case "python":
return ".py"
case "c":
return ".c"
case "cpp", "c++":
return ".cpp"
default:
return ""
}
}
// handleSymbolAt handles the symbol_at tool.
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// Try LSP hover
hover, err := s.lspManager.Hover(ctx, file, line, col)
if err != nil {
// Fall back to AST-based info
return s.handleSymbolAtFallback(ctx, file, line, col)
}
if hover == nil {
return mcp.NewToolResultText("No symbol information available at this position."), nil
}
var output strings.Builder
output.WriteString("**Symbol Information**\n\n")
output.WriteString(hover.Contents.Value)
return mcp.NewToolResultText(output.String()), nil
}
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
content, err := os.ReadFile(file)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %v", err)), nil
}
result, err := s.parser.Parse(ctx, file, content)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to parse file: %v", err)), nil
}
node := parser.FindNodeAtPosition(result.Tree, line, col)
if node == nil {
return mcp.NewToolResultText("No symbol at this position."), nil
}
var output strings.Builder
output.WriteString("**Symbol Information** (AST fallback)\n\n")
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
return mcp.NewToolResultText(output.String()), nil
}
// handleFindDefinition handles the find_definition tool.
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
locations, err := s.lspManager.Definition(ctx, file, line, col)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %v", err)), nil
}
if len(locations) == 0 {
return mcp.NewToolResultText("No definition found."), nil
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
// Try to read a preview snippet
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
if preview != "" {
output.WriteString("```\n")
output.WriteString(preview)
output.WriteString("```\n")
}
output.WriteString("\n")
}
return mcp.NewToolResultText(output.String()), nil
}
// handleFindReferences handles the find_references tool.
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
line := request.GetInt("line", 0)
if line <= 0 {
return mcp.NewToolResultError("line must be positive"), nil
}
col := request.GetInt("column", 0)
if col <= 0 {
return mcp.NewToolResultError("column must be positive"), nil
}
includeDecl := request.GetBool("include_declaration", true)
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %v", err)), nil
}
if len(locations) == 0 {
return mcp.NewToolResultText("No references found."), nil
}
var output strings.Builder
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
// Group by file
fileGroups := make(map[string][]lsp.Location)
for _, loc := range locations {
filePath := lsp.URIToFile(loc.URI)
fileGroups[filePath] = append(fileGroups[filePath], loc)
}
for filePath, locs := range fileGroups {
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
for _, loc := range locs {
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
}
output.WriteString("\n")
}
return mcp.NewToolResultText(output.String()), nil
}
// readFilePreview reads a few lines from a file around the given line.
func readFilePreview(file string, line, contextLines int) string {
content, err := os.ReadFile(file)
if err != nil {
return ""
}
lines := splitLines(string(content))
startLine := max(1, line-contextLines)
endLine := min(line+contextLines, len(lines))
var preview strings.Builder
for i := startLine - 1; i < endLine && i < len(lines); i++ {
lineText := lines[i]
if len(lineText) > 100 {
lineText = lineText[:100] + "..."
}
prefix := " "
if i+1 == line {
prefix = "> "
}
preview.WriteString(fmt.Sprintf("%s%4d: %s\n", prefix, i+1, lineText))
}
return preview.String()
}
// handleEditPreview handles the edit_preview tool.
func (s *Server) handleEditPreview(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleEdit(ctx, request, false)
}
// handleEditApply handles the edit_apply tool.
func (s *Server) handleEditApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return s.handleEdit(ctx, request, true)
}
// handleEdit is the shared implementation for edit_preview and edit_apply.
func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, apply bool) (*mcp.CallToolResult, error) {
file, err := request.RequireString("file")
if err != nil {
return mcp.NewToolResultError("file is required"), nil
}
operation, err := request.RequireString("operation")
if err != nil {
return mcp.NewToolResultError("operation is required"), nil
}
// Validate path
if !s.cfg.IsPathAllowed(file) {
return mcp.NewToolResultError("file is outside workspace root"), nil
}
// Note: We no longer validate language support here.
// The edit engine automatically detects whether to use AST or text mode.
// Build edit request with both AST and text-mode selectors
astEdit := &edit.ASTEdit{
File: file,
Operation: edit.EditOperation(operation),
NewContent: request.GetString("new_content", ""),
Selector: edit.ASTSelector{
// AST-mode selectors
Kind: request.GetString("selector_kind", ""),
Name: request.GetString("selector_name", ""),
AtLine: request.GetInt("selector_line", 0),
Index: request.GetInt("selector_index", 0),
// Text-mode selectors
LineEnd: request.GetInt("selector_line_end", 0),
Text: request.GetString("selector_text", ""),
TextPattern: request.GetString("selector_pattern", ""),
},
}
// Perform edit
var result *edit.EditResult
if apply {
result, err = s.editor.Apply(ctx, astEdit)
} else {
result, err = s.editor.Preview(ctx, astEdit)
}
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("edit failed: %v", err)), nil
}
if !result.Success {
return mcp.NewToolResultError(result.Error), nil
}
// Format output
var output strings.Builder
if apply {
output.WriteString("**Edit Applied Successfully**\n\n")
} else {
output.WriteString("**Edit Preview**\n\n")
}
output.WriteString("Diff:\n```diff\n")
output.WriteString(result.Diff)
output.WriteString("```\n")
return mcp.NewToolResultText(output.String()), nil
}
// Run starts the MCP server and blocks until shutdown.
func (s *Server) Run(ctx context.Context) error {
// Set up signal handling for graceful shutdown
_, cancel := context.WithCancel(ctx)
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
s.logger.Info("received shutdown signal", "signal", sig)
cancel()
}()
s.logger.Info("starting MCP server",
"workspace", s.cfg.WorkspaceRoot,
"lsp_enabled", s.cfg.EnableLSP,
)
// Start the MCP server with stdio transport
return server.ServeStdio(s.mcp)
}
// Shutdown gracefully shuts down the server.
func (s *Server) Shutdown(ctx context.Context) error {
s.logger.Info("shutting down MCP server")
// Close LSP manager
if s.lspManager != nil {
_ = s.lspManager.Close()
}
// Close parser registry
if s.parser != nil {
s.parser.Close()
}
return nil
}
+377
View File
@@ -0,0 +1,377 @@
package server
import (
"context"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
"github.com/mark3labs/mcp-go/mcp"
)
func TestNew(t *testing.T) {
// Create temp directory for testing
tmpDir := t.TempDir()
cfg := &config.Config{
WorkspaceRoot: tmpDir,
EnableLSP: false, // Disable LSP for simpler testing
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
if srv == nil {
t.Fatal("New() returned nil server")
}
if srv.cfg != cfg {
t.Error("server config mismatch")
}
if srv.parser == nil {
t.Error("parser should not be nil")
}
if srv.matcher == nil {
t.Error("matcher should not be nil")
}
if srv.editor == nil {
t.Error("editor should not be nil")
}
}
func TestHandlePing(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
result, err := srv.handlePing(ctx, req)
if err != nil {
t.Errorf("handlePing() error = %v", err)
}
if result == nil {
t.Fatal("handlePing() returned nil result")
}
// Check that the result contains "pong"
contents := result.Content
if len(contents) == 0 {
t.Fatal("handlePing() returned empty content")
}
textContent, ok := contents[0].(mcp.TextContent)
if !ok {
t.Fatal("handlePing() did not return text content")
}
if textContent.Text != "pong" {
t.Errorf("handlePing() = %v, want 'pong'", textContent.Text)
}
}
func TestHandleFileRead(t *testing.T) {
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
// Hello says hello
func Hello() {
println("Hello, World!")
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"path": testFile,
}
result, err := srv.handleFileRead(ctx, req)
if err != nil {
t.Errorf("handleFileRead() error = %v", err)
}
if result == nil {
t.Fatal("handleFileRead() returned nil result")
}
contents := result.Content
if len(contents) == 0 {
t.Fatal("handleFileRead() returned empty content")
}
textContent, ok := contents[0].(mcp.TextContent)
if !ok {
t.Fatal("handleFileRead() did not return text content")
}
// Should contain the file content
if textContent.Text == "" {
t.Error("handleFileRead() returned empty text")
}
}
func TestHandleFileReadWithAST(t *testing.T) {
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
// Hello says hello
func Hello() {
println("Hello, World!")
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"path": testFile,
"include_ast": true,
}
result, err := srv.handleFileRead(ctx, req)
if err != nil {
t.Errorf("handleFileRead() error = %v", err)
}
if result == nil {
t.Fatal("handleFileRead() returned nil result")
}
contents := result.Content
if len(contents) == 0 {
t.Fatal("handleFileRead() returned empty content")
}
textContent, ok := contents[0].(mcp.TextContent)
if !ok {
t.Fatal("handleFileRead() did not return text content")
}
// Should contain "Symbols:" section when include_ast is true
if textContent.Text == "" {
t.Error("handleFileRead() returned empty text")
}
}
func TestHandleFileReadNotFound(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"path": filepath.Join(tmpDir, "nonexistent.go"),
}
result, err := srv.handleFileRead(ctx, req)
// Should return error for non-existent file
if err == nil && result != nil {
// Check if result indicates an error
contents := result.Content
if len(contents) > 0 {
textContent, ok := contents[0].(mcp.TextContent)
if ok && textContent.Text == "" {
t.Log("handleFileRead() returned empty text for non-existent file")
}
}
}
}
func TestHandleASTQuery(t *testing.T) {
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() error {
return nil
}
func Goodbye() error {
return nil
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"pattern": "func $NAME() error",
"language": "go",
"paths": []interface{}{tmpDir},
}
result, err := srv.handleASTQuery(ctx, req)
if err != nil {
t.Errorf("handleASTQuery() error = %v", err)
}
if result == nil {
t.Fatal("handleASTQuery() returned nil result")
}
contents := result.Content
if len(contents) == 0 {
t.Fatal("handleASTQuery() returned empty content")
}
}
func TestHandleEditPreview(t *testing.T) {
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() {
println("Hello")
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"file": testFile,
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Hello",
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
}
result, err := srv.handleEdit(ctx, req, false)
if err != nil {
t.Errorf("handleEdit(preview) error = %v", err)
}
if result == nil {
t.Fatal("handleEdit(preview) returned nil result")
}
// Verify file was NOT modified (it's just a preview)
fileContent, _ := os.ReadFile(testFile)
if string(fileContent) != content {
t.Error("handleEdit(preview) should not modify the file")
}
}
func TestHandleEditApply(t *testing.T) {
tmpDir := t.TempDir()
// Create a test file
testFile := filepath.Join(tmpDir, "test.go")
content := `package main
func Hello() {
println("Hello")
}
`
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
srv, err := New(cfg, logger)
if err != nil {
t.Fatalf("New() error = %v", err)
}
ctx := context.Background()
req := mcp.CallToolRequest{}
req.Params.Arguments = map[string]interface{}{
"file": testFile,
"operation": "replace",
"selector_kind": "function_declaration",
"selector_name": "Hello",
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
}
result, err := srv.handleEdit(ctx, req, true)
if err != nil {
t.Errorf("handleEdit(apply) error = %v", err)
}
if result == nil {
t.Fatal("handleEdit(apply) returned nil result")
}
// Verify file WAS modified
fileContent, _ := os.ReadFile(testFile)
if string(fileContent) == content {
t.Error("handleEdit(apply) should modify the file")
}
}
+289
View File
@@ -0,0 +1,289 @@
// Package errors provides structured error handling with error codes and context.
package errors
import (
"fmt"
"runtime"
"strings"
)
// ErrorCode represents a specific error condition.
type ErrorCode int
// Error codes organized by category
const (
// Search errors (1000-1099)
ErrRipgrepNotFound ErrorCode = 1001
ErrRipgrepTimeout ErrorCode = 1002
ErrInvalidPattern ErrorCode = 1003
ErrSearchFailed ErrorCode = 1004
ErrNoResults ErrorCode = 1005
// Parser errors (1100-1199)
ErrParserNotFound ErrorCode = 1101
ErrParseFailed ErrorCode = 1102
ErrInvalidLanguage ErrorCode = 1103
ErrFileTooBig ErrorCode = 1104
ErrInvalidSyntax ErrorCode = 1105
// LSP errors (1200-1299)
ErrLSPServerNotFound ErrorCode = 1201
ErrLSPInitFailed ErrorCode = 1202
ErrLSPTimeout ErrorCode = 1203
ErrLSPCommunication ErrorCode = 1204
ErrNoHoverInfo ErrorCode = 1205
ErrNoDefinition ErrorCode = 1206
ErrNoReferences ErrorCode = 1207
// Edit errors (1300-1399)
ErrEditFailed ErrorCode = 1301
ErrInvalidEdit ErrorCode = 1302
ErrFileNotFound ErrorCode = 1303
ErrFileNotReadable ErrorCode = 1304
ErrFileNotWritable ErrorCode = 1305
ErrNodeNotFound ErrorCode = 1306
ErrValidationFailed ErrorCode = 1307
ErrInvalidSelection ErrorCode = 1308
// Query errors (1400-1499)
ErrInvalidQuery ErrorCode = 1401
ErrQueryTimeout ErrorCode = 1402
ErrNoMatches ErrorCode = 1403
ErrQueryCompile ErrorCode = 1404
// Config errors (1500-1599)
ErrInvalidConfig ErrorCode = 1501
ErrPathNotAllowed ErrorCode = 1502
ErrWorkspaceNotSet ErrorCode = 1503
// Internal errors (1900-1999)
ErrInternal ErrorCode = 1900
ErrCacheFailed ErrorCode = 1901
ErrConcurrency ErrorCode = 1902
)
// StructuredError represents an error with rich context and remediation info.
type StructuredError struct {
Cause error
Context map[string]any
Message string
Remediation string
Stack string
Code ErrorCode
}
// Error implements the error interface.
func (e *StructuredError) Error() string {
var sb strings.Builder
// Error code and message
sb.WriteString(fmt.Sprintf("[%d] %s", e.Code, e.Message))
// Context if available
if len(e.Context) > 0 {
sb.WriteString("\nContext:")
for k, v := range e.Context {
sb.WriteString(fmt.Sprintf("\n %s: %v", k, v))
}
}
// Remediation if available
if e.Remediation != "" {
sb.WriteString(fmt.Sprintf("\nHow to fix: %s", e.Remediation))
}
// Underlying cause if available
if e.Cause != nil {
sb.WriteString(fmt.Sprintf("\nCaused by: %v", e.Cause))
}
return sb.String()
}
// Unwrap returns the underlying cause for error chain support.
func (e *StructuredError) Unwrap() error {
return e.Cause
}
// WithContext adds context to the error.
func (e *StructuredError) WithContext(key string, value any) *StructuredError {
if e.Context == nil {
e.Context = make(map[string]any)
}
e.Context[key] = value
return e
}
// WithRemediation sets the remediation message.
func (e *StructuredError) WithRemediation(msg string) *StructuredError {
e.Remediation = msg
return e
}
// New creates a new structured error with stack trace.
func New(code ErrorCode, message string) *StructuredError {
return &StructuredError{
Code: code,
Message: message,
Context: make(map[string]interface{}),
Stack: captureStack(2),
}
}
// Wrap wraps an existing error with structured error information.
func Wrap(code ErrorCode, message string, cause error) *StructuredError {
return &StructuredError{
Code: code,
Message: message,
Context: make(map[string]interface{}),
Cause: cause,
Stack: captureStack(2),
}
}
// Is checks if an error matches the given error code.
func Is(err error, code ErrorCode) bool {
if err == nil {
return false
}
if se, ok := err.(*StructuredError); ok {
return se.Code == code
}
return false
}
// GetCode extracts the error code from an error, or returns 0 if not a structured error.
func GetCode(err error) ErrorCode {
if err == nil {
return 0
}
if se, ok := err.(*StructuredError); ok {
return se.Code
}
return 0
}
// captureStack captures the stack trace.
func captureStack(skip int) string {
const depth = 16
var pcs [depth]uintptr
n := runtime.Callers(skip+1, pcs[:])
var sb strings.Builder
frames := runtime.CallersFrames(pcs[:n])
for {
frame, more := frames.Next()
if !strings.Contains(frame.File, "runtime/") {
sb.WriteString(fmt.Sprintf("\n %s:%d %s", frame.File, frame.Line, frame.Function))
}
if !more {
break
}
}
return sb.String()
}
// Common error constructors for convenience
// NewRipgrepNotFound creates an error for missing ripgrep binary.
func NewRipgrepNotFound() *StructuredError {
os := runtime.GOOS
install := "brew install ripgrep"
switch os {
case "linux":
install = "apt-get install ripgrep (Debian/Ubuntu) or yum install ripgrep (RHEL/CentOS)"
case "windows":
install = "choco install ripgrep or scoop install ripgrep"
}
return New(ErrRipgrepNotFound, "ripgrep (rg) binary not found in system PATH").
WithContext("os", os).
WithRemediation(fmt.Sprintf("Install ripgrep: %s", install))
}
// NewLSPServerNotFound creates an error for missing LSP server.
func NewLSPServerNotFound(language, serverName string) *StructuredError {
return New(ErrLSPServerNotFound, fmt.Sprintf("LSP server '%s' not found for language %s", serverName, language)).
WithContext("language", language).
WithContext("server", serverName).
WithRemediation(fmt.Sprintf("Install the %s LSP server to enable IDE features for %s", serverName, language))
}
// NewFileTooLarge creates an error for files exceeding size limit.
func NewFileTooLarge(path string, size, limit int64) *StructuredError {
return New(ErrFileTooBig, "file exceeds maximum size limit").
WithContext("file", path).
WithContext("size_bytes", size).
WithContext("limit_bytes", limit).
WithRemediation(fmt.Sprintf("File size (%d bytes) exceeds limit (%d bytes). Consider processing smaller files or increasing the limit.", size, limit))
}
// NewParseError creates an error for parsing failures.
func NewParseError(language, file string, cause error) *StructuredError {
return Wrap(ErrParseFailed, fmt.Sprintf("failed to parse %s file", language), cause).
WithContext("language", language).
WithContext("file", file).
WithRemediation("Check file syntax and ensure it's valid source code for the specified language")
}
// NewSearchTimeout creates an error for search timeouts.
func NewSearchTimeout(pattern string, duration string) *StructuredError {
return New(ErrRipgrepTimeout, "search operation timed out").
WithContext("pattern", pattern).
WithContext("duration", duration).
WithRemediation("Try narrowing the search scope, using more specific patterns, or increasing the timeout limit")
}
// NewEditValidationError creates an error for edit validation failures.
func NewEditValidationError(file string, cause error) *StructuredError {
return Wrap(ErrValidationFailed, "edit validation failed - syntax errors detected", cause).
WithContext("file", file).
WithRemediation("Review the edit operation and ensure it produces valid syntax. The file was not modified.")
}
// NewFileNotFoundError creates an error for missing files.
func NewFileNotFoundError(file string) *StructuredError {
return New(ErrFileNotFound, fmt.Sprintf("file not found: %s", file)).
WithContext("file", file).
WithRemediation("Verify the file path is correct and the file exists")
}
// NewFileNotReadableError creates an error for unreadable files.
func NewFileNotReadableError(file string, cause error) *StructuredError {
return Wrap(ErrFileNotReadable, fmt.Sprintf("cannot read file: %s", file), cause).
WithContext("file", file).
WithRemediation("Check file permissions and ensure the file is not locked by another process")
}
// NewFileNotWritableError creates an error for write failures.
func NewFileNotWritableError(file string, cause error) *StructuredError {
return Wrap(ErrFileNotWritable, fmt.Sprintf("cannot write to file: %s", file), cause).
WithContext("file", file).
WithRemediation("Check file permissions, disk space, and ensure the file is not locked by another process")
}
// NewNodeNotFoundError creates an error when AST node selector finds no matches.
func NewNodeNotFoundError(selector string) *StructuredError {
return New(ErrNodeNotFound, "no AST nodes match the selector criteria").
WithContext("selector", selector).
WithRemediation("Verify the selector criteria (kind, name, pattern, line) match an existing code structure")
}
// NewInvalidSelectionError creates an error for ambiguous or invalid selectors.
func NewInvalidSelectionError(message string) *StructuredError {
return New(ErrInvalidSelection, message).
WithRemediation("Refine the selector to be more specific or provide a selector_index to choose between multiple matches")
}
// NewInvalidEditError creates an error for invalid edit operations.
func NewInvalidEditError(message string) *StructuredError {
return New(ErrInvalidEdit, message).
WithRemediation("Review the edit request and ensure all required fields are provided with valid values")
}
+375
View File
@@ -0,0 +1,375 @@
// Package fuzzy provides fuzzy string matching using Levenshtein distance.
package fuzzy
import (
"sort"
"strings"
"unicode"
)
// Match represents a fuzzy match result.
type Match struct {
Text string
Distance int
Similarity float64
Score float64
}
// Matcher provides fuzzy matching capabilities.
type Matcher struct {
threshold int
}
// New creates a new fuzzy matcher with the given threshold.
// Threshold is the maximum edit distance to consider a match (typically 1-3).
func New(threshold int) *Matcher {
return &Matcher{
threshold: threshold,
}
}
// Match performs fuzzy matching of query against candidates.
func (m *Matcher) Match(query string, candidates []string) []Match {
if query == "" {
return nil
}
matches := make([]Match, 0, len(candidates)/10)
queryLower := strings.ToLower(query)
for _, candidate := range candidates {
candidateLower := strings.ToLower(candidate)
// Calculate Levenshtein distance
dist := levenshteinDistance(queryLower, candidateLower)
// Skip if distance exceeds threshold
if dist > m.threshold {
// Check if it's a substring match (important for identifiers)
if !strings.Contains(candidateLower, queryLower) {
continue
}
// Allow substring matches even if edit distance is high
}
// Calculate similarity (0.0 to 1.0)
maxLen := max(len(query), len(candidate))
similarity := 1.0 - float64(dist)/float64(maxLen)
// Calculate composite score
score := m.calculateScore(queryLower, candidateLower, dist, similarity)
matches = append(matches, Match{
Text: candidate,
Distance: dist,
Similarity: similarity,
Score: score,
})
}
// Sort by score descending
sort.Slice(matches, func(i, j int) bool {
return matches[i].Score > matches[j].Score
})
return matches
}
// calculateScore computes a composite score considering multiple factors.
func (m *Matcher) calculateScore(query, candidate string, dist int, similarity float64) float64 {
score := similarity
// Bonus for exact match
if query == candidate {
score += 2.0
}
// Bonus for prefix match (important for identifier search)
if strings.HasPrefix(candidate, query) {
score += 1.0
}
// Bonus for word boundary matches (e.g., "getName" matches "get")
if containsWordBoundary(candidate, query) {
score += 0.5
}
// Penalty for length difference (prefer similar-length matches)
lenDiff := abs(len(candidate) - len(query))
score -= float64(lenDiff) * 0.01
// Penalty for edit distance
score -= float64(dist) * 0.1
return score
}
// levenshteinDistance computes the Levenshtein distance between two strings.
// Uses the Wagner-Fischer algorithm with space optimization O(min(m,n)).
func levenshteinDistance(s1, s2 string) int {
if s1 == s2 {
return 0
}
if len(s1) == 0 {
return len(s2)
}
if len(s2) == 0 {
return len(s1)
}
// Ensure s1 is the shorter string for space optimization
if len(s1) > len(s2) {
s1, s2 = s2, s1
}
// Use rune slices to handle Unicode properly
r1 := []rune(s1)
r2 := []rune(s2)
len1 := len(r1)
len2 := len(r2)
// Only need two rows of the matrix
previous := make([]int, len2+1)
current := make([]int, len2+1)
// Initialize first row
for j := 0; j <= len2; j++ {
previous[j] = j
}
// Calculate edit distance
for i := 1; i <= len1; i++ {
current[0] = i
for j := 1; j <= len2; j++ {
cost := 1
if r1[i-1] == r2[j-1] {
cost = 0
}
current[j] = min(
previous[j]+1, // deletion
current[j-1]+1, // insertion
previous[j-1]+cost, // substitution
)
}
// Swap rows
previous, current = current, previous
}
return previous[len2]
}
// DamerauLevenshteinDistance computes Damerau-Levenshtein distance (includes transpositions).
// This is more accurate for typos where adjacent characters are swapped.
func DamerauLevenshteinDistance(s1, s2 string) int {
if s1 == s2 {
return 0
}
if len(s1) == 0 {
return len(s2)
}
if len(s2) == 0 {
return len(s1)
}
r1 := []rune(s1)
r2 := []rune(s2)
len1 := len(r1)
len2 := len(r2)
// Create distance matrix
d := make([][]int, len1+1)
for i := range d {
d[i] = make([]int, len2+1)
}
// Initialize first row and column
for i := 0; i <= len1; i++ {
d[i][0] = i
}
for j := 0; j <= len2; j++ {
d[0][j] = j
}
// Calculate distances
for i := 1; i <= len1; i++ {
for j := 1; j <= len2; j++ {
cost := 1
if r1[i-1] == r2[j-1] {
cost = 0
}
d[i][j] = min(
d[i-1][j]+1, // deletion
d[i][j-1]+1, // insertion
d[i-1][j-1]+cost, // substitution
)
// Check for transposition
if i > 1 && j > 1 && r1[i-1] == r2[j-2] && r1[i-2] == r2[j-1] {
d[i][j] = min(d[i][j], d[i-2][j-2]+cost)
}
}
}
return d[len1][len2]
}
// JaroWinklerSimilarity computes Jaro-Winkler similarity (0.0 to 1.0).
// Better for short strings and names.
func JaroWinklerSimilarity(s1, s2 string) float64 {
if s1 == s2 {
return 1.0
}
r1 := []rune(s1)
r2 := []rune(s2)
if len(r1) == 0 || len(r2) == 0 {
return 0.0
}
// Calculate Jaro similarity first
jaro := jaroSimilarity(r1, r2)
// Calculate common prefix length (up to 4 characters)
prefixLen := 0
for i := 0; i < min(min(len(r1), len(r2)), 4); i++ {
if r1[i] == r2[i] {
prefixLen++
} else {
break
}
}
// Jaro-Winkler adds bonus for common prefix
const p = 0.1
return jaro + float64(prefixLen)*p*(1.0-jaro)
}
// jaroSimilarity computes Jaro similarity.
func jaroSimilarity(r1, r2 []rune) float64 {
len1 := len(r1)
len2 := len(r2)
// Maximum allowed distance
matchDist := max(len1, len2)/2 - 1
if matchDist < 0 {
matchDist = 0
}
matched1 := make([]bool, len1)
matched2 := make([]bool, len2)
matches := 0
transpositions := 0
// Find matches
for i := range len1 {
start := max(0, i-matchDist)
end := min(i+matchDist+1, len2)
for j := start; j < end; j++ {
if matched2[j] || r1[i] != r2[j] {
continue
}
matched1[i] = true
matched2[j] = true
matches++
break
}
}
if matches == 0 {
return 0.0
}
// Count transpositions
k := 0
for i := range len1 {
if !matched1[i] {
continue
}
for !matched2[k] {
k++
}
if r1[i] != r2[k] {
transpositions++
}
k++
}
return (float64(matches)/float64(len1) +
float64(matches)/float64(len2) +
float64(matches-transpositions/2)/float64(matches)) / 3.0
}
// containsWordBoundary checks if query appears at word boundaries in text.
func containsWordBoundary(text, query string) bool {
textLower := strings.ToLower(text)
queryLower := strings.ToLower(query)
idx := strings.Index(textLower, queryLower)
if idx == -1 {
return false
}
// Check if match is at start
if idx == 0 {
return true
}
// Check for underscore or non-alphanumeric boundary
prevRune := rune(text[idx-1])
if !unicode.IsLetter(prevRune) && !unicode.IsDigit(prevRune) {
return true
}
// Check for camelCase boundary (lowercase before uppercase)
if idx > 0 && len(text) > idx {
curr := rune(text[idx])
prev := rune(text[idx-1])
if unicode.IsLower(prev) && unicode.IsUpper(curr) {
return true
}
}
return false
}
// Helper functions
func min(values ...int) int {
if len(values) == 0 {
return 0
}
m := values[0]
for _, v := range values[1:] {
if v < m {
m = v
}
}
return m
}
func max(values ...int) int {
if len(values) == 0 {
return 0
}
m := values[0]
for _, v := range values[1:] {
if v > m {
m = v
}
}
return m
}
func abs(x int) int {
if x < 0 {
return -x
}
return x
}
+275
View File
@@ -0,0 +1,275 @@
package fuzzy
import (
"testing"
)
func TestLevenshteinDistance(t *testing.T) {
tests := []struct {
s1 string
s2 string
expected int
}{
{"", "", 0},
{"", "abc", 3},
{"abc", "", 3},
{"abc", "abc", 0},
{"abc", "abd", 1},
{"kitten", "sitting", 3},
{"saturday", "sunday", 3},
{"book", "back", 2},
{"café", "cafe", 1}, // Unicode handling
}
for _, tt := range tests {
got := levenshteinDistance(tt.s1, tt.s2)
if got != tt.expected {
t.Errorf("levenshteinDistance(%q, %q) = %d, want %d", tt.s1, tt.s2, got, tt.expected)
}
}
}
func TestDamerauLevenshteinDistance(t *testing.T) {
tests := []struct {
s1 string
s2 string
expected int
}{
{"abc", "abc", 0},
{"abc", "acb", 1}, // Transposition
{"ca", "abc", 3}, // Delete a, delete b, insert c = 3 operations
{"", "abc", 3},
}
for _, tt := range tests {
got := DamerauLevenshteinDistance(tt.s1, tt.s2)
if got != tt.expected {
t.Errorf("DamerauLevenshteinDistance(%q, %q) = %d, want %d", tt.s1, tt.s2, got, tt.expected)
}
}
}
func TestJaroWinklerSimilarity(t *testing.T) {
tests := []struct {
s1 string
s2 string
minScore float64 // Minimum expected similarity
}{
{"", "", 1.0},
{"abc", "abc", 1.0},
{"martha", "marhta", 0.96}, // High similarity for transposition
{"dixon", "dicksonx", 0.76}, // Moderate similarity
{"", "abc", 0.0},
}
for _, tt := range tests {
got := JaroWinklerSimilarity(tt.s1, tt.s2)
if got < tt.minScore {
t.Errorf("JaroWinklerSimilarity(%q, %q) = %.2f, want >= %.2f", tt.s1, tt.s2, got, tt.minScore)
}
}
}
func TestMatcher_Match(t *testing.T) {
m := New(2) // Allow edit distance up to 2
candidates := []string{
"getUserName",
"getUsername",
"get_user_name",
"getUserId",
"setUserName",
"findUser",
"userName",
"usernameField",
}
tests := []struct {
query string
topMatch string
expectMin int
}{
{
query: "getUserName",
expectMin: 3, // Exact + similar variants
topMatch: "getUserName",
},
{
query: "getuser",
expectMin: 2, // Should match getUserName, getUsername at minimum
topMatch: "getUserName",
},
{
query: "username",
expectMin: 2, // Case-insensitive matches
topMatch: "userName",
},
}
for _, tt := range tests {
matches := m.Match(tt.query, candidates)
if len(matches) < tt.expectMin {
t.Errorf("Match(%q) returned %d matches, want at least %d", tt.query, len(matches), tt.expectMin)
}
if len(matches) > 0 {
// Top match should have highest score
if matches[0].Score < matches[len(matches)-1].Score {
t.Errorf("Match(%q) results not sorted by score", tt.query)
}
}
}
}
func TestMatcher_EmptyQuery(t *testing.T) {
m := New(2)
candidates := []string{"test", "example"}
matches := m.Match("", candidates)
if matches != nil {
t.Errorf("Match with empty query should return nil, got %v", matches)
}
}
func TestMatcher_PrefixBonus(t *testing.T) {
m := New(2)
candidates := []string{
"getUserName", // prefix match
"findUserName", // contains but not prefix
}
matches := m.Match("get", candidates)
if len(matches) < 1 {
t.Fatal("Expected at least one match")
}
// Prefix match should score higher
if matches[0].Text != "getUserName" {
t.Errorf("Expected prefix match to rank first, got %q", matches[0].Text)
}
}
func TestMatcher_ExactMatchBonus(t *testing.T) {
m := New(2)
candidates := []string{
"test",
"testing",
"tester",
}
matches := m.Match("test", candidates)
if len(matches) < 1 {
t.Fatal("Expected at least one match")
}
// Exact match should rank first
if matches[0].Text != "test" {
t.Errorf("Expected exact match to rank first, got %q", matches[0].Text)
}
// Exact match should have highest score
if matches[0].Score < 2.0 { // Should have exact match bonus
t.Errorf("Exact match score too low: %.2f", matches[0].Score)
}
}
func TestContainsWordBoundary(t *testing.T) {
tests := []struct {
text string
query string
expected bool
}{
{"getUserName", "get", true}, // At start
{"getUserName", "user", true}, // After lowercase->uppercase boundary
{"get_user_name", "user", true}, // After underscore
{"getUserName", "Name", true}, // After lowercase->uppercase
{"getUserName", "ser", false}, // Middle of word
{"", "test", false}, // Empty text
}
for _, tt := range tests {
got := containsWordBoundary(tt.text, tt.query)
if got != tt.expected {
t.Errorf("containsWordBoundary(%q, %q) = %v, want %v", tt.text, tt.query, got, tt.expected)
}
}
}
func TestMatcher_UnicodeHandling(t *testing.T) {
m := New(2)
candidates := []string{
"café",
"resume",
"naïve",
}
// Test with Unicode characters
matches := m.Match("cafe", candidates)
if len(matches) == 0 {
t.Error("Expected matches for Unicode strings")
}
// Should find café with small edit distance
found := false
for _, match := range matches {
if match.Text == "café" && match.Distance <= 2 {
found = true
break
}
}
if !found {
t.Error("Failed to fuzzy match Unicode string 'café'")
}
}
func BenchmarkLevenshteinDistance(b *testing.B) {
s1 := "the quick brown fox jumps over the lazy dog"
s2 := "the quikc brown fox jumps ovver the lazy dog"
b.ResetTimer()
for i := range b.N {
_ = levenshteinDistance(s1, s2)
_ = i // use i to avoid unused warning
}
}
func BenchmarkDamerauLevenshteinDistance(b *testing.B) {
s1 := "the quick brown fox jumps over the lazy dog"
s2 := "the quikc brown fox jumps ovver the lazy dog"
b.ResetTimer()
for i := range b.N {
_ = DamerauLevenshteinDistance(s1, s2)
_ = i
}
}
func BenchmarkJaroWinklerSimilarity(b *testing.B) {
s1 := "martha"
s2 := "marhta"
b.ResetTimer()
for i := range b.N {
_ = JaroWinklerSimilarity(s1, s2)
_ = i
}
}
func BenchmarkMatcher_Match(b *testing.B) {
m := New(2)
candidates := []string{
"getUserName", "getUsername", "get_user_name", "getUserId",
"setUserName", "findUser", "userName", "usernameField",
"userAccount", "accountUser", "userProfile", "profileUser",
}
b.ResetTimer()
for i := range b.N {
_ = m.Match("getuser", candidates)
_ = i
}
}
+105
View File
@@ -0,0 +1,105 @@
// Package protocol defines shared types used across the MCP file operations server.
package protocol
// Location represents a position in a file.
type Location struct {
File string `json:"file"`
Line int `json:"line"`
Column int `json:"column"`
}
// Range represents a range in a file.
type Range struct {
Start Location `json:"start"`
End Location `json:"end"`
}
// SymbolKind represents the kind of a symbol.
type SymbolKind string
const (
SymbolFunction SymbolKind = "function"
SymbolMethod SymbolKind = "method"
SymbolClass SymbolKind = "class"
SymbolStruct SymbolKind = "struct"
SymbolInterface SymbolKind = "interface"
SymbolVariable SymbolKind = "variable"
SymbolConstant SymbolKind = "constant"
SymbolType SymbolKind = "type"
SymbolField SymbolKind = "field"
SymbolProperty SymbolKind = "property"
SymbolModule SymbolKind = "module"
SymbolPackage SymbolKind = "package"
)
// Symbol represents a code symbol (function, class, variable, etc.).
type Symbol struct {
Name string `json:"name"`
Kind SymbolKind `json:"kind"`
Doc string `json:"doc,omitempty"`
Location Location `json:"location"`
}
// SyntaxError represents a syntax error in a file.
type SyntaxError struct {
Message string `json:"message"`
Location Location `json:"location"`
}
// Language represents a programming language.
type Language string
const (
LangGo Language = "go"
LangTypeScript Language = "typescript"
LangJavaScript Language = "javascript"
LangPython Language = "python"
LangC Language = "c"
LangCpp Language = "cpp"
LangHTML Language = "html"
LangVue Language = "vue"
LangJSON Language = "json"
LangYAML Language = "yaml"
LangUnknown Language = "unknown"
)
// DetectLanguage detects the language from a filename.
func DetectLanguage(filename string) Language {
ext := getExtension(filename)
switch ext {
case ".go":
return LangGo
case ".ts", ".tsx":
return LangTypeScript
case ".js", ".jsx", ".mjs", ".cjs":
return LangJavaScript
case ".py", ".pyw":
return LangPython
case ".c", ".h":
return LangC
case ".cpp", ".cc", ".cxx", ".hpp", ".hxx":
return LangCpp
case ".html", ".htm":
return LangHTML
case ".vue":
return LangVue
case ".json":
return LangJSON
case ".yaml", ".yml":
return LangYAML
default:
return LangUnknown
}
}
func getExtension(filename string) string {
for i := len(filename) - 1; i >= 0; i-- {
if filename[i] == '.' {
return filename[i:]
}
if filename[i] == '/' || filename[i] == '\\' {
break
}
}
return ""
}
+69
View File
@@ -0,0 +1,69 @@
package protocol
import "testing"
func TestDetectLanguage(t *testing.T) {
tests := []struct {
filename string
expected Language
}{
{"main.go", LangGo},
{"server.go", LangGo},
{"index.ts", LangTypeScript},
{"component.tsx", LangTypeScript},
{"Button.tsx", LangTypeScript},
{"app.js", LangJavaScript},
{"component.jsx", LangJavaScript},
{"Component.jsx", LangJavaScript},
{"module.mjs", LangJavaScript},
{"common.cjs", LangJavaScript},
{"script.py", LangPython},
{"app.pyw", LangPython},
{"main.c", LangC},
{"header.h", LangC},
{"main.cpp", LangCpp},
{"main.cc", LangCpp},
{"main.cxx", LangCpp},
{"header.hpp", LangCpp},
{"header.hxx", LangCpp},
{"index.html", LangHTML},
{"page.htm", LangHTML},
{"Component.vue", LangVue},
{"unknown.txt", LangUnknown},
{"README", LangUnknown},
{"path/to/file.go", LangGo},
{"path/to/file.ts", LangTypeScript},
}
for _, tt := range tests {
t.Run(tt.filename, func(t *testing.T) {
result := DetectLanguage(tt.filename)
if result != tt.expected {
t.Errorf("DetectLanguage(%q) = %q, want %q", tt.filename, result, tt.expected)
}
})
}
}
func TestGetExtension(t *testing.T) {
tests := []struct {
filename string
expected string
}{
{"file.go", ".go"},
{"file.test.go", ".go"},
{"path/to/file.ts", ".ts"},
{"noextension", ""},
{".hidden", ".hidden"},
{"file.", "."},
}
for _, tt := range tests {
t.Run(tt.filename, func(t *testing.T) {
result := getExtension(tt.filename)
if result != tt.expected {
t.Errorf("getExtension(%q) = %q, want %q", tt.filename, result, tt.expected)
}
})
}
}
+55
View File
@@ -0,0 +1,55 @@
/**
* @file header.h
* @brief Sample header file for testing.
*/
#ifndef HEADER_H
#define HEADER_H
/**
* @brief Maximum buffer size.
*/
#define MAX_BUFFER_SIZE 1024
/**
* @brief Status codes for operations.
*/
typedef enum {
STATUS_OK = 0,
STATUS_ERROR = 1,
STATUS_NOT_FOUND = 2
} Status;
/**
* @brief Buffer structure for data storage.
*/
typedef struct {
char data[MAX_BUFFER_SIZE];
int length;
} Buffer;
/**
* @brief Initialize a buffer.
* @param buf Pointer to the buffer
*/
void buffer_init(Buffer* buf);
/**
* @brief Write data to the buffer.
* @param buf Pointer to the buffer
* @param data Data to write
* @param len Length of data
* @return Status code
*/
Status buffer_write(Buffer* buf, const char* data, int len);
/**
* @brief Read data from the buffer.
* @param buf Pointer to the buffer
* @param out Output buffer
* @param max_len Maximum length to read
* @return Number of bytes read
*/
int buffer_read(Buffer* buf, char* out, int max_len);
#endif /* HEADER_H */
+57
View File
@@ -0,0 +1,57 @@
/**
* @file valid.c
* @brief Sample C file for testing.
*/
#include <stdio.h>
#include <stdlib.h>
/**
* @brief A simple point structure.
*/
struct Point {
int x;
int y;
};
/**
* @brief Creates a new point.
* @param x The x coordinate
* @param y The y coordinate
* @return A new Point structure
*/
struct Point create_point(int x, int y) {
struct Point p;
p.x = x;
p.y = y;
return p;
}
/**
* @brief Calculates the distance from origin.
* @param p The point
* @return The squared distance from origin
*/
int distance_squared(struct Point p) {
return p.x * p.x + p.y * p.y;
}
/**
* @brief Prints a point to stdout.
* @param p The point to print
*/
void print_point(struct Point p) {
printf("Point(%d, %d)\n", p.x, p.y);
}
// Simple helper function
int add(int a, int b) {
return a + b;
}
int main(void) {
struct Point p = create_point(3, 4);
print_point(p);
printf("Distance squared: %d\n", distance_squared(p));
return 0;
}
+11
View File
@@ -0,0 +1,11 @@
package main
// This file contains intentional syntax errors for testing.
func broken( {
return
}
type Incomplete struct {
Name string
// Missing closing brace
+44
View File
@@ -0,0 +1,44 @@
package main
import "fmt"
// Server represents the main application server.
type Server struct {
Name string
Port int
}
// NewServer creates a new Server instance.
func NewServer(name string, port int) *Server {
return &Server{
Name: name,
Port: port,
}
}
// Start starts the server.
func (s *Server) Start() error {
fmt.Printf("Starting server %s on port %d\n", s.Name, s.Port)
return nil
}
// Config holds application configuration.
type Config struct {
Debug bool
Timeout int
}
const (
// DefaultPort is the default server port.
DefaultPort = 8080
)
var (
// Version is the application version.
Version = "1.0.0"
)
func main() {
srv := NewServer("main", DefaultPort)
srv.Start()
}
+29
View File
@@ -0,0 +1,29 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Test HTML File</title>
</head>
<body>
<div class="container mx-auto px-4">
<h1 class="text-3xl font-bold text-blue-600">Hello World</h1>
<p class="text-gray-700 mt-4">This is a test HTML file with Tailwind CSS classes.</p>
<div class="flex gap-4 mt-8">
<button class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded">
Primary Button
</button>
<button class="bg-gray-500 hover:bg-gray-700 text-white font-bold py-2 px-4 rounded">
Secondary Button
</button>
</div>
<ul class="list-disc list-inside mt-4">
<li>First item</li>
<li>Second item</li>
<li>Third item</li>
</ul>
</div>
</body>
</html>
+62
View File
@@ -0,0 +1,62 @@
"""
Sample Python module for testing.
"""
from typing import List, Optional
class DataProcessor:
"""Processes data records."""
def __init__(self, name: str):
"""
Initialize the processor.
Args:
name: The processor name
"""
self.name = name
self._records: List[dict] = []
def add_record(self, record: dict) -> None:
"""
Add a record to the processor.
Args:
record: The record to add
"""
self._records.append(record)
def process(self) -> List[dict]:
"""
Process all records.
Returns:
The processed records
"""
return [self._transform(r) for r in self._records]
def _transform(self, record: dict) -> dict:
"""Transform a single record."""
return {k.upper(): v for k, v in record.items()}
def calculate_sum(numbers: List[int]) -> int:
"""
Calculate the sum of numbers.
Args:
numbers: List of integers to sum
Returns:
The sum of all numbers
"""
return sum(numbers)
def find_maximum(values: List[int]) -> Optional[int]:
"""Find the maximum value in a list."""
return max(values) if values else None
DEFAULT_BATCH_SIZE = 100
+129
View File
@@ -0,0 +1,129 @@
import React, { useState, useEffect } from 'react';
interface ButtonProps {
variant?: 'primary' | 'secondary';
disabled?: boolean;
onClick?: () => void;
children: React.ReactNode;
}
/**
* A reusable button component with Tailwind CSS styling
*/
export const Button: React.FC<ButtonProps> = ({
variant = 'primary',
disabled = false,
onClick,
children
}) => {
const baseClasses = 'font-bold py-2 px-4 rounded transition-colors duration-200';
const variantClasses = {
primary: 'bg-blue-500 hover:bg-blue-700 text-white',
secondary: 'bg-gray-500 hover:bg-gray-700 text-white'
};
return (
<button
className={`${baseClasses} ${variantClasses[variant]} ${disabled ? 'opacity-50 cursor-not-allowed' : ''}`}
disabled={disabled}
onClick={onClick}
>
{children}
</button>
);
};
interface TodoItem {
id: number;
text: string;
completed: boolean;
}
/**
* Todo list component demonstrating React hooks and Tailwind
*/
export const TodoList: React.FC = () => {
const [todos, setTodos] = useState<TodoItem[]>([
{ id: 1, text: 'Learn React', completed: true },
{ id: 2, text: 'Learn TypeScript', completed: true },
{ id: 3, text: 'Build amazing apps', completed: false }
]);
const [inputValue, setInputValue] = useState('');
useEffect(() => {
console.log('Todos updated:', todos);
}, [todos]);
const addTodo = () => {
if (inputValue.trim()) {
const newTodo: TodoItem = {
id: Date.now(),
text: inputValue,
completed: false
};
setTodos([...todos, newTodo]);
setInputValue('');
}
};
const toggleTodo = (id: number) => {
setTodos(todos.map(todo =>
todo.id === id ? { ...todo, completed: !todo.completed } : todo
));
};
const deleteTodo = (id: number) => {
setTodos(todos.filter(todo => todo.id !== id));
};
return (
<div className="container mx-auto px-4 py-8 max-w-2xl">
<h1 className="text-3xl font-bold text-gray-800 mb-6">
My Todo List
</h1>
<div className="flex gap-2 mb-6">
<input
type="text"
value={inputValue}
onChange={(e) => setInputValue(e.target.value)}
onKeyPress={(e) => e.key === 'Enter' && addTodo()}
placeholder="Add a new todo..."
className="flex-1 px-4 py-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
<Button onClick={addTodo}>Add</Button>
</div>
<ul className="space-y-2">
{todos.map(todo => (
<li
key={todo.id}
className="flex items-center gap-3 p-4 bg-white rounded-lg shadow-sm hover:shadow-md transition-shadow"
>
<input
type="checkbox"
checked={todo.completed}
onChange={() => toggleTodo(todo.id)}
className="w-5 h-5 text-blue-600 rounded focus:ring-2 focus:ring-blue-500"
/>
<span className={`flex-1 ${todo.completed ? 'line-through text-gray-400' : 'text-gray-700'}`}>
{todo.text}
</span>
<Button
variant="secondary"
onClick={() => deleteTodo(todo.id)}
>
Delete
</Button>
</li>
))}
</ul>
{todos.length === 0 && (
<div className="text-center py-12 text-gray-400">
No todos yet. Add one above!
</div>
)}
</div>
);
};
+53
View File
@@ -0,0 +1,53 @@
/**
* Represents a user in the system.
*/
interface User {
id: number;
name: string;
email: string;
}
/**
* Configuration options for the application.
*/
type Config = {
debug: boolean;
timeout: number;
};
/**
* Greeting service for handling user greetings.
*/
class GreetingService {
private prefix: string;
/**
* Creates a new GreetingService.
* @param prefix The greeting prefix
*/
constructor(prefix: string) {
this.prefix = prefix;
}
/**
* Greets a user.
* @param user The user to greet
* @returns The greeting message
*/
greet(user: User): string {
return `${this.prefix}, ${user.name}!`;
}
}
/**
* Formats a user for display.
* @param user The user to format
* @returns Formatted string
*/
function formatUser(user: User): string {
return `${user.name} <${user.email}>`;
}
const DEFAULT_TIMEOUT = 5000;
export { User, Config, GreetingService, formatUser, DEFAULT_TIMEOUT };
+76
View File
@@ -0,0 +1,76 @@
<template>
<div class="container mx-auto px-4 py-8">
<h1 class="text-3xl font-bold text-blue-600 mb-4">
{{ title }}
</h1>
<div v-if="showContent" class="bg-white shadow-md rounded-lg p-6">
<p class="text-gray-700 mb-4">{{ description }}</p>
<div class="flex gap-4 mt-4">
<button
@click="handlePrimary"
:class="['bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded', { 'opacity-50': isLoading }]"
:disabled="isLoading"
>
{{ primaryButtonText }}
</button>
<button
@click="handleSecondary"
class="bg-gray-500 hover:bg-gray-700 text-white font-bold py-2 px-4 rounded"
>
{{ secondaryButtonText }}
</button>
</div>
<ul v-for="item in items" :key="item.id" class="list-disc list-inside mt-4">
<li class="text-gray-600">{{ item.name }}</li>
</ul>
</div>
<div v-else class="text-center text-gray-500">
<p>No content to display</p>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue';
interface Item {
id: number;
name: string;
}
const title = ref('Vue Component with Tailwind');
const description = ref('This is a sample Vue 3 component using Composition API and Tailwind CSS');
const showContent = ref(true);
const isLoading = ref(false);
const items = ref<Item[]>([
{ id: 1, name: 'First item' },
{ id: 2, name: 'Second item' },
{ id: 3, name: 'Third item' },
]);
const primaryButtonText = computed(() => isLoading.value ? 'Loading...' : 'Primary Action');
const secondaryButtonText = ref('Secondary Action');
const handlePrimary = () => {
isLoading.value = true;
setTimeout(() => {
isLoading.value = false;
}, 2000);
};
const handleSecondary = () => {
console.log('Secondary button clicked');
};
</script>
<style scoped>
.container {
max-width: 1200px;
}
</style>