mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-05 22:23:50 +00:00
Ho hum.
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
name: Test, build, release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '**/release.yaml'
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24"
|
||||
docker-enabled: false
|
||||
rolling-release-tag: "v1"
|
||||
secrets: inherit
|
||||
@@ -0,0 +1,3 @@
|
||||
TODO.md
|
||||
bin/mcp-filepuff
|
||||
mcp-filepuff
|
||||
@@ -0,0 +1,122 @@
|
||||
# yaml-language-server: $schema=https://goreleaser.com/static/schema.json
|
||||
# vim: set ts=2 sw=2 tw=0 fo=cnqoj
|
||||
|
||||
version: 2
|
||||
|
||||
project_name: mcp-filepuff
|
||||
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy
|
||||
- go generate ./...
|
||||
|
||||
builds:
|
||||
- id: mcp-filepuff
|
||||
main: ./cmd/mcp-filepuff
|
||||
binary: mcp-filepuff
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
flags:
|
||||
- -trimpath
|
||||
ldflags:
|
||||
- -s -w
|
||||
- -X main.version={{.Version}}
|
||||
- -X main.commit={{.Commit}}
|
||||
- -X main.date={{.Date}}
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
|
||||
archives:
|
||||
- id: default
|
||||
formats:
|
||||
- tar.gz
|
||||
name_template: >-
|
||||
{{ .ProjectName }}_
|
||||
{{- .Version }}_
|
||||
{{- .Os }}_
|
||||
{{- .Arch }}
|
||||
files:
|
||||
- LICENSE
|
||||
- README.md
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
formats:
|
||||
- zip
|
||||
|
||||
checksum:
|
||||
name_template: 'checksums.txt'
|
||||
|
||||
snapshot:
|
||||
version_template: "{{ incpatch .Version }}-next"
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- '^docs:'
|
||||
- '^test:'
|
||||
- '^chore:'
|
||||
- Merge pull request
|
||||
- Merge branch
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: lukaszraczylo
|
||||
name: filepuff-mcp
|
||||
draft: false
|
||||
prerelease: auto
|
||||
name_template: "v{{.Version}}"
|
||||
header: |
|
||||
## MCP Filepuff v{{.Version}}
|
||||
|
||||
AST-aware file operations and LSP integration for Claude Code.
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
curl -sSL https://raw.githubusercontent.com/lukaszraczylo/filepuff-mcp/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
dockers_v2:
|
||||
- images:
|
||||
- "ghcr.io/lukaszraczylo/filepuff-mcp"
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "latest"
|
||||
- "v1"
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
build_flag_templates:
|
||||
- "--pull"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/lukaszraczylo/filepuff-mcp"
|
||||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
signature: "${artifact}.sigstore.json"
|
||||
args:
|
||||
- sign-blob
|
||||
- "--bundle=${signature}"
|
||||
- "${artifact}"
|
||||
- "--yes"
|
||||
artifacts: checksum
|
||||
output: true
|
||||
|
||||
docker_signs:
|
||||
- cmd: cosign
|
||||
artifacts: manifests
|
||||
output: true
|
||||
args:
|
||||
- sign
|
||||
- "${artifact}@${digest}"
|
||||
- "--yes"
|
||||
@@ -0,0 +1,13 @@
|
||||
FROM alpine:3.21
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata git ripgrep
|
||||
|
||||
COPY ${TARGETPLATFORM}/mcp-filepuff /usr/local/bin/mcp-filepuff
|
||||
|
||||
RUN chmod +x /usr/local/bin/mcp-filepuff
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/mcp-filepuff"]
|
||||
CMD ["-workspace", "/workspace"]
|
||||
@@ -0,0 +1,80 @@
|
||||
.PHONY: build test lint clean install run deps
|
||||
|
||||
# Binary name
|
||||
BINARY_NAME=mcp-filepuff
|
||||
# Build directory
|
||||
BUILD_DIR=bin
|
||||
# Main package
|
||||
MAIN_PKG=./cmd/mcp-filepuff
|
||||
|
||||
# Go parameters
|
||||
GOCMD=go
|
||||
GOBUILD=$(GOCMD) build
|
||||
GOTEST=$(GOCMD) test
|
||||
GOGET=$(GOCMD) get
|
||||
GOMOD=$(GOCMD) mod
|
||||
GOFMT=$(GOCMD) fmt
|
||||
|
||||
# Build flags
|
||||
LDFLAGS=-ldflags "-s -w" -buildvcs=false
|
||||
|
||||
# Default target
|
||||
all: deps test build
|
||||
|
||||
# Install dependencies
|
||||
deps:
|
||||
$(GOMOD) download
|
||||
$(GOMOD) tidy
|
||||
|
||||
# Build the binary
|
||||
build:
|
||||
mkdir -p $(BUILD_DIR)
|
||||
$(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PKG)
|
||||
|
||||
# Build for all platforms
|
||||
build-all:
|
||||
mkdir -p $(BUILD_DIR)
|
||||
GOOS=darwin GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 $(MAIN_PKG)
|
||||
GOOS=darwin GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 $(MAIN_PKG)
|
||||
GOOS=linux GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 $(MAIN_PKG)
|
||||
GOOS=linux GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 $(MAIN_PKG)
|
||||
GOOS=windows GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe $(MAIN_PKG)
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
$(GOTEST) -v -race -coverprofile=coverage.out ./...
|
||||
|
||||
# Run tests with short flag
|
||||
test-short:
|
||||
$(GOTEST) -v -short ./...
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
rm -rf $(BUILD_DIR)
|
||||
rm -f coverage.out
|
||||
|
||||
# Install binary to GOPATH/bin
|
||||
install: build
|
||||
cp $(BUILD_DIR)/$(BINARY_NAME) $(GOPATH)/bin/
|
||||
|
||||
# Run the server (for development)
|
||||
run: build
|
||||
./$(BUILD_DIR)/$(BINARY_NAME) -log-level debug
|
||||
|
||||
# Run with specific workspace
|
||||
run-workspace: build
|
||||
./$(BUILD_DIR)/$(BINARY_NAME) -workspace $(WORKSPACE) -log-level debug
|
||||
|
||||
# Show help
|
||||
help:
|
||||
@echo "Available targets:"
|
||||
@echo " deps - Download and tidy dependencies"
|
||||
@echo " build - Build the binary"
|
||||
@echo " build-all - Build for all platforms"
|
||||
@echo " test - Run tests with coverage"
|
||||
@echo " test-short - Run short tests"
|
||||
@echo " lint - Run linters"
|
||||
@echo " clean - Clean build artifacts"
|
||||
@echo " install - Install binary to GOPATH/bin"
|
||||
@echo " run - Build and run the server"
|
||||
@echo " run-workspace - Run with specific workspace (WORKSPACE=/path)"
|
||||
@@ -0,0 +1,572 @@
|
||||
# mcp-filepuff
|
||||
|
||||
A Go-based MCP (Model Context Protocol) server for Claude Code providing intelligent file operations with fast search, AST-aware querying, LSP integration, and safe editing capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **Fast Text Search**: Powered by ripgrep for blazing-fast code search with regex support
|
||||
- **AST-Aware File Reading**: Read files with symbol extraction using Tree-sitter
|
||||
- **Code Pattern Matching**: Query code using patterns with capture placeholders
|
||||
- **LSP Integration**: Go-to-definition, find references, and symbol info via language servers
|
||||
- **Safe Editing**: AST-aware file editing with syntax validation and preview
|
||||
- **Multi-Language Support**: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue, React
|
||||
- **Token Efficient**: Optimized for minimal token usage with symbols-only mode and output limiting
|
||||
|
||||
## Installation
|
||||
|
||||
### Binary Releases (Recommended)
|
||||
|
||||
Download pre-built binaries from the [releases page](https://github.com/lukaszraczylo/filepuff-mcp/releases):
|
||||
|
||||
```bash
|
||||
# macOS (ARM64)
|
||||
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-darwin-arm64.tar.gz | tar xz
|
||||
sudo mv mcp-filepuff /usr/local/bin/
|
||||
|
||||
# macOS (AMD64)
|
||||
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-darwin-amd64.tar.gz | tar xz
|
||||
sudo mv mcp-filepuff /usr/local/bin/
|
||||
|
||||
# Linux (ARM64)
|
||||
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-linux-arm64.tar.gz | tar xz
|
||||
sudo mv mcp-filepuff /usr/local/bin/
|
||||
|
||||
# Linux (AMD64)
|
||||
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-linux-amd64.tar.gz | tar xz
|
||||
sudo mv mcp-filepuff /usr/local/bin/
|
||||
|
||||
# Windows (PowerShell)
|
||||
Invoke-WebRequest -Uri "https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-windows-amd64.zip" -OutFile mcp-filepuff.zip
|
||||
Expand-Archive mcp-filepuff.zip -DestinationPath .
|
||||
Move-Item mcp-filepuff.exe C:\Windows\System32\
|
||||
```
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- [ripgrep](https://github.com/BurntSushi/ripgrep) (`rg`) installed and in PATH
|
||||
|
||||
### Optional Dependencies (for LSP features)
|
||||
|
||||
- `gopls` - Go language server
|
||||
- `typescript-language-server` - TypeScript/JavaScript language server
|
||||
- `pylsp` - Python language server
|
||||
- `clangd` - C/C++ language server
|
||||
|
||||
### Build from Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/lukaszraczylo/filepuff-mcp.git
|
||||
cd filepuff-mcp
|
||||
make build
|
||||
```
|
||||
|
||||
The binary will be available at `bin/mcp-filepuff`.
|
||||
|
||||
### Install via Claude Code
|
||||
|
||||
After downloading or building the binary, configure it in Claude Code:
|
||||
|
||||
1. **Create or edit `~/.config/claude-code/claude_desktop_config.json`**:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"filepuff": {
|
||||
"command": "/usr/local/bin/mcp-filepuff",
|
||||
"args": ["-workspace", "/path/to/your/workspace"],
|
||||
"env": {
|
||||
"MCP_LOG_LEVEL": "info"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Restart Claude Code** to load the MCP server
|
||||
|
||||
3. **Verify** by asking Claude: "Can you ping the filepuff server?"
|
||||
|
||||
See the [Claude Code MCP documentation](https://code.claude.com/docs/en/mcp) for more details.
|
||||
|
||||
## Usage
|
||||
|
||||
### Running the Server (Standalone)
|
||||
|
||||
```bash
|
||||
./bin/mcp-filepuff -workspace /path/to/workspace
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
- `-workspace string`: Workspace root directory (default: current directory)
|
||||
- `-log-level string`: Log level - debug, info, warn, error (default: "info")
|
||||
- `-log-file string`: Log file path (default: stderr)
|
||||
|
||||
### Configuration
|
||||
|
||||
The server can be configured via:
|
||||
|
||||
1. **Environment Variables**:
|
||||
- `MCP_WORKSPACE_ROOT`: Workspace root directory
|
||||
- `MCP_LSP_TIMEOUT`: LSP timeout duration (e.g., "10m")
|
||||
- `MCP_SEARCH_TIMEOUT`: Search timeout duration (e.g., "1m")
|
||||
- `MCP_ENABLE_LSP`: Enable LSP features ("true"/"false")
|
||||
- `MCP_FOLLOW_SYMLINKS`: Follow symbolic links ("true"/"false")
|
||||
- `MCP_RESPECT_GITIGNORE`: Respect .gitignore files ("true"/"false")
|
||||
|
||||
2. **Config File**: Create `.mcp-filepuff.json` in the workspace root:
|
||||
```json
|
||||
{
|
||||
"enable_lsp": true,
|
||||
"follow_symlinks": true,
|
||||
"respect_gitignore": true
|
||||
}
|
||||
```
|
||||
|
||||
### Claude Code Integration
|
||||
|
||||
To use mcp-filepuff with Claude Code, add it to your MCP server configuration:
|
||||
|
||||
1. **Global Configuration** (`~/.config/claude-code/mcp_servers.json`):
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"filepuff": {
|
||||
"command": "/path/to/mcp-filepuff",
|
||||
"args": ["-workspace", "/path/to/your/workspace"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Project-specific Configuration** (`.claude/mcp_servers.json` in your project):
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"filepuff": {
|
||||
"command": "mcp-filepuff",
|
||||
"args": ["-workspace", "."]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
After configuration, Claude Code will have access to all mcp-filepuff tools for enhanced file operations.
|
||||
|
||||
### Making Claude Code Prefer Filepuff Tools
|
||||
|
||||
By default, Claude Code uses its built-in file operation tools. To make it prefer filepuff's enhanced tools, add instructions to your `CLAUDE.md` file:
|
||||
|
||||
**Global Configuration** (`~/.claude/CLAUDE.md`):
|
||||
```markdown
|
||||
# MCP Tool Preferences
|
||||
|
||||
When performing file operations, prefer filepuff MCP tools over built-in equivalents:
|
||||
|
||||
| Operation | Use This | Instead Of |
|
||||
|-----------|----------|------------|
|
||||
| Read files | `mcp__filepuff__file_read` | Read |
|
||||
| Search content | `mcp__filepuff__file_search` | Grep |
|
||||
| AST pattern search | `mcp__filepuff__ast_query` | Grep/Glob |
|
||||
| Edit files | `mcp__filepuff__edit_preview` + `mcp__filepuff__edit_apply` | Edit |
|
||||
| Find definitions | `mcp__filepuff__find_definition` | Grep |
|
||||
| Find references | `mcp__filepuff__find_references` | Grep |
|
||||
| Symbol info | `mcp__filepuff__symbol_at` | - |
|
||||
|
||||
Benefits of filepuff tools:
|
||||
- AST-aware operations that understand code structure
|
||||
- LSP integration for accurate symbol navigation
|
||||
- Syntax validation before applying edits
|
||||
```
|
||||
|
||||
You can also place this in a project-specific `CLAUDE.md` or `.claude/CLAUDE.md` file.
|
||||
|
||||
**Optional: Restrict Built-in Tools**
|
||||
|
||||
To enforce filepuff usage, add permission restrictions in `.claude/settings.json`:
|
||||
```json
|
||||
{
|
||||
"permissions": {
|
||||
"deny": ["Read", "Edit", "Grep"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Available Tools
|
||||
|
||||
### `ping`
|
||||
Health check tool to verify the server is running.
|
||||
|
||||
**Returns**: "pong"
|
||||
|
||||
---
|
||||
|
||||
### `file_search`
|
||||
Search for text patterns in files using ripgrep.
|
||||
|
||||
**Parameters**:
|
||||
- `pattern` (required): The search pattern (regex by default)
|
||||
- `paths`: Paths to search in (defaults to workspace root)
|
||||
- `file_types`: File types to search (e.g., ["go", "ts", "py"])
|
||||
- `ignore_case`: Case insensitive search
|
||||
- `regex`: Treat pattern as regex (default: true)
|
||||
- `context_lines`: Number of context lines around matches (default: 2)
|
||||
- `max_results`: Maximum number of results to return
|
||||
|
||||
---
|
||||
|
||||
### `file_read`
|
||||
Read a file's contents with optional line range and AST symbol summary. Supports token-efficient modes for AI assistants.
|
||||
|
||||
**Parameters**:
|
||||
- `path` (required): Path to the file to read
|
||||
- `line_start`: Starting line number (1-indexed)
|
||||
- `line_end`: Ending line number (inclusive)
|
||||
- `include_ast`: Include AST symbol summary (functions, classes, types, etc.)
|
||||
- `symbols_only`: **[Token Efficient]** Return only symbol summary without file content. Requires `include_ast=true`. Reduces token usage by ~90-98%.
|
||||
- `max_lines`: **[Token Efficient]** Maximum number of lines to return. Useful for large files where you only need a preview.
|
||||
|
||||
**Example Output with AST**:
|
||||
```
|
||||
**server.go** (245 lines, go)
|
||||
|
||||
Symbols:
|
||||
func NewServer L12
|
||||
func (Server).Start L45
|
||||
struct Server L5
|
||||
type Config L150
|
||||
|
||||
---
|
||||
|
||||
12│ func NewServer(config Config) *Server {
|
||||
13│ return &Server{config: config}
|
||||
14│ }
|
||||
```
|
||||
|
||||
**Token-Efficient Example (symbols_only)**:
|
||||
```json
|
||||
{"path": "server.go", "include_ast": true, "symbols_only": true}
|
||||
```
|
||||
Returns only the symbol summary (~500 tokens instead of ~8,000 tokens for the full file):
|
||||
```
|
||||
**server.go** (245 lines, go)
|
||||
|
||||
Symbols:
|
||||
func NewServer L12
|
||||
func (Server).Start L45
|
||||
struct Server L5
|
||||
type Config L150
|
||||
```
|
||||
|
||||
**Token-Efficient Example (max_lines)**:
|
||||
```json
|
||||
{"path": "server.go", "max_lines": 50}
|
||||
```
|
||||
Returns first 50 lines with a truncation notice if the file is longer.
|
||||
|
||||
---
|
||||
|
||||
### `ast_query`
|
||||
Search for AST patterns in code files using structural pattern matching.
|
||||
|
||||
**Parameters**:
|
||||
- `pattern` (required): Code pattern with placeholders
|
||||
- `$NAME` - capture single node
|
||||
- `$$$ARGS` - capture multiple nodes
|
||||
- `$_` - wildcard (match but don't capture)
|
||||
- `language` (required): Target language (go, typescript, javascript, python, c, cpp)
|
||||
- `paths`: Paths to search in
|
||||
- `name_matches`: Regex pattern to filter by name
|
||||
- `name_exact`: Exact name to match
|
||||
- `kind_in`: Node types to match (e.g., function_declaration)
|
||||
- `max_results`: Maximum number of results (default: 100)
|
||||
|
||||
**Examples**:
|
||||
```json
|
||||
// Find all Go functions returning error
|
||||
{"pattern": "func $NAME($$$ARGS) error", "language": "go"}
|
||||
|
||||
// Find all Python classes
|
||||
{"pattern": "class $NAME: $$$BODY", "language": "python"}
|
||||
|
||||
// Find React components (functions starting with uppercase)
|
||||
{"pattern": "function $NAME($PROPS) { $$$BODY }", "language": "javascript", "name_matches": "^[A-Z]"}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `symbol_at`
|
||||
Get information about the symbol at a specific position. Uses LSP when available, falls back to AST.
|
||||
|
||||
**Parameters**:
|
||||
- `file` (required): Path to the file
|
||||
- `line` (required): Line number (1-indexed)
|
||||
- `column` (required): Column number (1-indexed)
|
||||
|
||||
---
|
||||
|
||||
### `find_definition`
|
||||
Find the definition of the symbol at a specific position.
|
||||
|
||||
**Parameters**:
|
||||
- `file` (required): Path to the file
|
||||
- `line` (required): Line number (1-indexed)
|
||||
- `column` (required): Column number (1-indexed)
|
||||
|
||||
---
|
||||
|
||||
### `find_references`
|
||||
Find all references to the symbol at a specific position.
|
||||
|
||||
**Parameters**:
|
||||
- `file` (required): Path to the file
|
||||
- `line` (required): Line number (1-indexed)
|
||||
- `column` (required): Column number (1-indexed)
|
||||
- `include_declaration`: Include the declaration in results (default: true)
|
||||
|
||||
---
|
||||
|
||||
### `edit_preview`
|
||||
Preview an edit without applying it. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++), and text-based editing for other files (Markdown, JSON, YAML, config files, etc.).
|
||||
|
||||
**Parameters**:
|
||||
- `file` (required): Path to the file to edit
|
||||
- `operation` (required): Edit operation (replace, insert_before, insert_after, delete)
|
||||
- `new_content`: New content (required for replace/insert operations)
|
||||
|
||||
**AST-mode selectors** (for code files):
|
||||
- `selector_kind`: Node type to match (e.g., function_declaration)
|
||||
- `selector_name`: Name of the symbol to match
|
||||
|
||||
**Shared selectors**:
|
||||
- `selector_line`: Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range.
|
||||
- `selector_index`: Index of the match to use if multiple matches found (default: 0)
|
||||
|
||||
**Text-mode selectors** (for non-code files or explicit text matching):
|
||||
- `selector_line_end`: End line number for range selection
|
||||
- `selector_text`: Exact text to match (must be unique or use selector_index)
|
||||
- `selector_pattern`: Regex pattern to match
|
||||
|
||||
---
|
||||
|
||||
### `edit_apply`
|
||||
Apply an edit to a file. Uses AST-aware editing for code files with syntax validation, and text-based editing for other files.
|
||||
|
||||
**Parameters**: Same as `edit_preview`
|
||||
|
||||
**Example (AST mode - Go file)**:
|
||||
```json
|
||||
{
|
||||
"file": "server.go",
|
||||
"operation": "replace",
|
||||
"selector_kind": "function_declaration",
|
||||
"selector_name": "Hello",
|
||||
"new_content": "func Hello() {\n\tprintln(\"New Hello\")\n}"
|
||||
}
|
||||
```
|
||||
|
||||
**Example (Text mode - Markdown file)**:
|
||||
```json
|
||||
{
|
||||
"file": "README.md",
|
||||
"operation": "replace",
|
||||
"selector_text": "## Installation",
|
||||
"new_content": "## Getting Started"
|
||||
}
|
||||
```
|
||||
|
||||
**Example (Text mode - JSON with regex)**:
|
||||
```json
|
||||
{
|
||||
"file": "package.json",
|
||||
"operation": "replace",
|
||||
"selector_pattern": "\"version\":\\s*\"[^\"]+\"",
|
||||
"new_content": "\"version\": \"2.0.0\""
|
||||
}
|
||||
```
|
||||
|
||||
**Example (Text mode - Line range)**:
|
||||
```json
|
||||
{
|
||||
"file": "config.yaml",
|
||||
"operation": "replace",
|
||||
"selector_line": 5,
|
||||
"selector_line_end": 10,
|
||||
"new_content": "database:\n host: production.db.example.com\n port: 5432"
|
||||
}
|
||||
```
|
||||
|
||||
## Supported Languages
|
||||
|
||||
| Language | Extensions | Search | AST | LSP | Edit |
|
||||
|----------|-----------|--------|-----|-----|------|
|
||||
| Go | .go | Yes | Yes | gopls | Yes |
|
||||
| TypeScript | .ts, .tsx | Yes | Yes | typescript-language-server | Yes |
|
||||
| JavaScript | .js, .jsx, .mjs, .cjs | Yes | Yes | typescript-language-server | Yes |
|
||||
| Python | .py, .pyw | Yes | Yes | pylsp | Yes |
|
||||
| C | .c, .h | Yes | Yes | clangd | Yes |
|
||||
| C++ | .cpp, .cc, .cxx, .hpp, .hxx | Yes | Yes | clangd | Yes |
|
||||
| HTML | .html, .htm | Yes | Yes | - | Yes |
|
||||
| Vue | .vue | Yes | Yes* | - | Yes |
|
||||
| React | .jsx, .tsx | Yes | Yes | typescript-language-server | Yes |
|
||||
|
||||
\* Vue uses HTML parser for template sections
|
||||
|
||||
## Development
|
||||
|
||||
### Build
|
||||
|
||||
```bash
|
||||
make build
|
||||
```
|
||||
|
||||
### Run Tests
|
||||
|
||||
```bash
|
||||
make test
|
||||
```
|
||||
|
||||
### Lint
|
||||
|
||||
```bash
|
||||
make lint
|
||||
```
|
||||
|
||||
### Clean
|
||||
|
||||
```bash
|
||||
make clean
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
.
|
||||
├── cmd/
|
||||
│ └── mcp-filepuff/ # Main entry point
|
||||
├── internal/
|
||||
│ ├── config/ # Configuration management
|
||||
│ ├── edit/ # AST-aware editing engine
|
||||
│ ├── lsp/ # LSP client and manager
|
||||
│ ├── parser/ # Tree-sitter integration
|
||||
│ ├── query/ # AST pattern matching
|
||||
│ ├── search/ # Ripgrep wrapper
|
||||
│ └── server/ # MCP server implementation
|
||||
├── pkg/
|
||||
│ └── protocol/ # Shared types
|
||||
├── .github/
|
||||
│ └── workflows/ # CI configuration
|
||||
├── Makefile # Build automation
|
||||
├── .goreleaser.yaml # Release configuration
|
||||
└── TODO.md # Implementation roadmap
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ MCP Server │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Tools: file_search, file_read, ast_query, symbol_at, │
|
||||
│ find_definition, find_references, │
|
||||
│ edit_preview, edit_apply, ping │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Core Engines │
|
||||
├───────────┬─────────────┬────────────┬─────────────────┤
|
||||
│ Search │ Parser │ LSP │ Edit │
|
||||
│ (ripgrep) │(tree-sitter)│ Manager │ Engine │
|
||||
└───────────┴─────────────┴────────────┴─────────────────┘
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### "ripgrep not found" Error
|
||||
The `file_search` tool requires ripgrep (`rg`) to be installed and in your PATH.
|
||||
|
||||
**Solution**: Install ripgrep:
|
||||
```bash
|
||||
# macOS
|
||||
brew install ripgrep
|
||||
|
||||
# Ubuntu/Debian
|
||||
sudo apt install ripgrep
|
||||
|
||||
# Windows (with Chocolatey)
|
||||
choco install ripgrep
|
||||
```
|
||||
|
||||
#### LSP Features Not Working
|
||||
LSP features (go-to-definition, find-references, symbol-at) require language servers to be installed.
|
||||
|
||||
**Solution**: Install the appropriate language server:
|
||||
```bash
|
||||
# Go
|
||||
go install golang.org/x/tools/gopls@latest
|
||||
|
||||
# TypeScript/JavaScript
|
||||
npm install -g typescript-language-server typescript
|
||||
|
||||
# Python
|
||||
pip install python-lsp-server
|
||||
|
||||
# C/C++
|
||||
# macOS: brew install llvm
|
||||
# Ubuntu: sudo apt install clangd
|
||||
```
|
||||
|
||||
#### AST Parsing Fails for Valid Code
|
||||
If AST parsing fails for code that compiles correctly, it may be a Tree-sitter grammar limitation.
|
||||
|
||||
**Solution**:
|
||||
- Ensure the file has the correct extension for its language
|
||||
- Check for unusual syntax that may not be supported by the Tree-sitter grammar
|
||||
- Try using the `file_search` tool instead for text-based operations
|
||||
|
||||
#### Edit Operations Fail with "syntax error"
|
||||
The edit engine validates syntax before and after edits.
|
||||
|
||||
**Solution**:
|
||||
- Ensure `new_content` is syntactically valid for the target language
|
||||
- Use `edit_preview` first to see the proposed changes
|
||||
- Check that the selector matches exactly one node
|
||||
|
||||
#### Timeout Errors
|
||||
Long-running operations may timeout.
|
||||
|
||||
**Solution**: Configure timeout values via environment variables:
|
||||
```bash
|
||||
export MCP_LSP_TIMEOUT="10m" # LSP operations (default: 5m)
|
||||
export MCP_SEARCH_TIMEOUT="2m" # Search operations (default: 30s)
|
||||
```
|
||||
|
||||
#### Permission Denied Errors
|
||||
The server needs read/write access to workspace files.
|
||||
|
||||
**Solution**:
|
||||
- Ensure the user running the server has appropriate file permissions
|
||||
- Check that the workspace path is correct and accessible
|
||||
- On macOS, grant terminal/IDE full disk access if needed
|
||||
|
||||
### Debug Logging
|
||||
|
||||
Enable debug logging to troubleshoot issues:
|
||||
|
||||
```bash
|
||||
./bin/mcp-filepuff -workspace /path/to/workspace -log-level debug -log-file /tmp/mcp-filepuff.log
|
||||
```
|
||||
|
||||
### Verifying Installation
|
||||
|
||||
Use the `ping` tool to verify the server is running correctly:
|
||||
|
||||
```json
|
||||
{"tool": "ping"}
|
||||
```
|
||||
|
||||
Expected response: `"pong"`
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
@@ -0,0 +1,84 @@
|
||||
// Package main is the entry point for the MCP file operations server.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
var (
|
||||
workspaceRoot = flag.String("workspace", "", "Workspace root directory (default: current directory)")
|
||||
logLevel = flag.String("log-level", "info", "Log level (debug, info, warn, error)")
|
||||
logFile = flag.String("log-file", "", "Log file path (default: stderr)")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
// Set up logging
|
||||
logger := setupLogger(*logLevel, *logFile)
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.Load(*workspaceRoot)
|
||||
if err != nil {
|
||||
logger.Error("failed to load configuration", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.Info("configuration loaded",
|
||||
"workspace_root", cfg.WorkspaceRoot,
|
||||
"lsp_enabled", cfg.EnableLSP,
|
||||
)
|
||||
|
||||
// Create and run server
|
||||
srv, err := server.New(cfg, logger)
|
||||
if err != nil {
|
||||
logger.Error("failed to create server", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := srv.Run(ctx); err != nil {
|
||||
logger.Error("server error", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogger(level string, logFile string) *slog.Logger {
|
||||
var logLevel slog.Level
|
||||
switch level {
|
||||
case "debug":
|
||||
logLevel = slog.LevelDebug
|
||||
case "warn":
|
||||
logLevel = slog.LevelWarn
|
||||
case "error":
|
||||
logLevel = slog.LevelError
|
||||
default:
|
||||
logLevel = slog.LevelInfo
|
||||
}
|
||||
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: logLevel,
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
if logFile != "" {
|
||||
f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
// Fallback to stderr
|
||||
handler = slog.NewJSONHandler(os.Stderr, opts)
|
||||
} else {
|
||||
handler = slog.NewJSONHandler(f, opts)
|
||||
}
|
||||
} else {
|
||||
// Use stderr for MCP servers (stdout is for protocol messages)
|
||||
handler = slog.NewJSONHandler(os.Stderr, opts)
|
||||
}
|
||||
|
||||
return slog.New(handler)
|
||||
}
|
||||
+1479
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
module github.com/lukaszraczylo/mcp-filepuff
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/mark3labs/mcp-go v0.43.2
|
||||
github.com/sergi/go-diff v1.4.0
|
||||
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||
github.com/mailru/easyjson v0.9.1 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
|
||||
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I=
|
||||
github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4=
|
||||
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,174 @@
|
||||
// Package config provides configuration management for the MCP file operations server.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
// Config holds all configuration options for the MCP server.
|
||||
type Config struct {
|
||||
Formatters map[string]string `json:"formatters"`
|
||||
WorkspaceRoot string `json:"workspace_root"`
|
||||
LSPTimeout time.Duration `json:"lsp_timeout"`
|
||||
SearchTimeout time.Duration `json:"search_timeout"`
|
||||
MaxFileSize int64 `json:"max_file_size"`
|
||||
MaxSearchResults int `json:"max_search_results"`
|
||||
MaxEditSize int64 `json:"max_edit_size"`
|
||||
EnableLSP bool `json:"enable_lsp"`
|
||||
FollowSymlinks bool `json:"follow_symlinks"`
|
||||
RespectGitignore bool `json:"respect_gitignore"`
|
||||
}
|
||||
|
||||
// Default values for configuration.
|
||||
const (
|
||||
DefaultLSPTimeout = 5 * time.Minute
|
||||
DefaultSearchTimeout = 30 * time.Second
|
||||
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
|
||||
DefaultMaxSearchResults = 1000
|
||||
DefaultMaxEditSize = 100 * 1024 // 100 KB
|
||||
)
|
||||
|
||||
// Default returns a Config with default values.
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
WorkspaceRoot: ".",
|
||||
LSPTimeout: DefaultLSPTimeout,
|
||||
SearchTimeout: DefaultSearchTimeout,
|
||||
MaxFileSize: DefaultMaxFileSize,
|
||||
MaxSearchResults: DefaultMaxSearchResults,
|
||||
MaxEditSize: DefaultMaxEditSize,
|
||||
EnableLSP: true,
|
||||
Formatters: make(map[string]string),
|
||||
FollowSymlinks: true,
|
||||
RespectGitignore: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads configuration from environment variables and optional config file.
|
||||
// Priority: CLI flags > environment variables > config file > defaults.
|
||||
func Load(workspaceRoot string) (*Config, error) {
|
||||
cfg := Default()
|
||||
|
||||
// Set workspace root
|
||||
if workspaceRoot != "" {
|
||||
absPath, err := filepath.Abs(workspaceRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.WorkspaceRoot = absPath
|
||||
} else if cwd, err := os.Getwd(); err == nil {
|
||||
cfg.WorkspaceRoot = cwd
|
||||
}
|
||||
|
||||
// Try to load from config file in workspace root
|
||||
configPath := filepath.Join(cfg.WorkspaceRoot, ".mcp-filepuff.json")
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := json.Unmarshal(data, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Override from environment variables
|
||||
cfg.loadFromEnv()
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) loadFromEnv() {
|
||||
if v := os.Getenv("MCP_WORKSPACE_ROOT"); v != "" {
|
||||
if absPath, err := filepath.Abs(v); err == nil {
|
||||
c.WorkspaceRoot = absPath
|
||||
}
|
||||
}
|
||||
|
||||
if v := os.Getenv("MCP_LSP_TIMEOUT"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
c.LSPTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
if v := os.Getenv("MCP_SEARCH_TIMEOUT"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
c.SearchTimeout = d
|
||||
}
|
||||
}
|
||||
|
||||
if v := os.Getenv("MCP_ENABLE_LSP"); v == "false" || v == "0" {
|
||||
c.EnableLSP = false
|
||||
}
|
||||
|
||||
if v := os.Getenv("MCP_FOLLOW_SYMLINKS"); v == "false" || v == "0" {
|
||||
c.FollowSymlinks = false
|
||||
}
|
||||
|
||||
if v := os.Getenv("MCP_RESPECT_GITIGNORE"); v == "false" || v == "0" {
|
||||
c.RespectGitignore = false
|
||||
}
|
||||
}
|
||||
|
||||
// IsPathAllowed checks if a path is within the workspace root.
|
||||
// It resolves symlinks to prevent path traversal attacks.
|
||||
func (c *Config) IsPathAllowed(path string) bool {
|
||||
// Get absolute path of the target
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get absolute path of workspace root
|
||||
absRoot, err := filepath.Abs(c.WorkspaceRoot)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Always try to resolve workspace root symlinks for consistent comparison
|
||||
evalRoot, evalErr := filepath.EvalSymlinks(absRoot)
|
||||
if evalErr == nil {
|
||||
absRoot = evalRoot
|
||||
}
|
||||
|
||||
// For the target path, try to resolve symlinks
|
||||
evalPath, evalErr := filepath.EvalSymlinks(absPath)
|
||||
if evalErr == nil {
|
||||
// File exists and was resolved
|
||||
absPath = evalPath
|
||||
} else {
|
||||
// File doesn't exist - resolve parent directories to match workspace root resolution
|
||||
// Walk up the tree until we find an existing directory
|
||||
dir := filepath.Dir(absPath)
|
||||
remaining := filepath.Base(absPath)
|
||||
|
||||
for dir != "." && dir != "/" && dir != absPath {
|
||||
evalDir, evalErr := filepath.EvalSymlinks(dir)
|
||||
if evalErr == nil {
|
||||
// Found an existing directory, reconstruct the path
|
||||
absPath = filepath.Join(evalDir, remaining)
|
||||
break
|
||||
}
|
||||
// Move up one level
|
||||
newDir := filepath.Dir(dir)
|
||||
if newDir == dir {
|
||||
// Reached the root without finding an existing directory
|
||||
break
|
||||
}
|
||||
remaining = filepath.Join(filepath.Base(dir), remaining)
|
||||
dir = newDir
|
||||
}
|
||||
}
|
||||
|
||||
// Compute relative path
|
||||
rel, err := filepath.Rel(absRoot, absPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the path is within workspace (doesn't start with ..)
|
||||
// This prevents both "../" attacks and symlink bypasses
|
||||
// Also reject empty relative path (which means it's the workspace root itself)
|
||||
return rel != "." && !strings.HasPrefix(rel, "..")
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDefault(t *testing.T) {
|
||||
cfg := Default()
|
||||
|
||||
if cfg.WorkspaceRoot != "." {
|
||||
t.Errorf("expected default workspace root '.', got %q", cfg.WorkspaceRoot)
|
||||
}
|
||||
if cfg.LSPTimeout != DefaultLSPTimeout {
|
||||
t.Errorf("expected default LSP timeout %v, got %v", DefaultLSPTimeout, cfg.LSPTimeout)
|
||||
}
|
||||
if cfg.SearchTimeout != DefaultSearchTimeout {
|
||||
t.Errorf("expected default search timeout %v, got %v", DefaultSearchTimeout, cfg.SearchTimeout)
|
||||
}
|
||||
if cfg.MaxFileSize != DefaultMaxFileSize {
|
||||
t.Errorf("expected default max file size %d, got %d", DefaultMaxFileSize, cfg.MaxFileSize)
|
||||
}
|
||||
if cfg.MaxSearchResults != DefaultMaxSearchResults {
|
||||
t.Errorf("expected default max search results %d, got %d", DefaultMaxSearchResults, cfg.MaxSearchResults)
|
||||
}
|
||||
if cfg.MaxEditSize != DefaultMaxEditSize {
|
||||
t.Errorf("expected default max edit size %d, got %d", DefaultMaxEditSize, cfg.MaxEditSize)
|
||||
}
|
||||
if !cfg.EnableLSP {
|
||||
t.Error("expected EnableLSP to be true by default")
|
||||
}
|
||||
if !cfg.FollowSymlinks {
|
||||
t.Error("expected FollowSymlinks to be true by default")
|
||||
}
|
||||
if !cfg.RespectGitignore {
|
||||
t.Error("expected RespectGitignore to be true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
// Create a temporary directory for workspace
|
||||
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg, err := Load(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
|
||||
absPath, _ := filepath.Abs(tmpDir)
|
||||
if cfg.WorkspaceRoot != absPath {
|
||||
t.Errorf("expected workspace root %q, got %q", absPath, cfg.WorkspaceRoot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromEnv(t *testing.T) {
|
||||
// Save original env values
|
||||
origLSPTimeout := os.Getenv("MCP_LSP_TIMEOUT")
|
||||
origSearchTimeout := os.Getenv("MCP_SEARCH_TIMEOUT")
|
||||
origEnableLSP := os.Getenv("MCP_ENABLE_LSP")
|
||||
origFollowSymlinks := os.Getenv("MCP_FOLLOW_SYMLINKS")
|
||||
origRespectGitignore := os.Getenv("MCP_RESPECT_GITIGNORE")
|
||||
|
||||
// Restore env after test
|
||||
t.Cleanup(func() {
|
||||
_ = os.Setenv("MCP_LSP_TIMEOUT", origLSPTimeout)
|
||||
_ = os.Setenv("MCP_SEARCH_TIMEOUT", origSearchTimeout)
|
||||
_ = os.Setenv("MCP_ENABLE_LSP", origEnableLSP)
|
||||
_ = os.Setenv("MCP_FOLLOW_SYMLINKS", origFollowSymlinks)
|
||||
_ = os.Setenv("MCP_RESPECT_GITIGNORE", origRespectGitignore)
|
||||
})
|
||||
|
||||
// Set test env values
|
||||
_ = os.Setenv("MCP_LSP_TIMEOUT", "10m")
|
||||
_ = os.Setenv("MCP_SEARCH_TIMEOUT", "1m")
|
||||
_ = os.Setenv("MCP_ENABLE_LSP", "false")
|
||||
_ = os.Setenv("MCP_FOLLOW_SYMLINKS", "0")
|
||||
_ = os.Setenv("MCP_RESPECT_GITIGNORE", "false")
|
||||
|
||||
cfg := Default()
|
||||
cfg.loadFromEnv()
|
||||
|
||||
if cfg.LSPTimeout != 10*time.Minute {
|
||||
t.Errorf("expected LSP timeout 10m, got %v", cfg.LSPTimeout)
|
||||
}
|
||||
if cfg.SearchTimeout != 1*time.Minute {
|
||||
t.Errorf("expected search timeout 1m, got %v", cfg.SearchTimeout)
|
||||
}
|
||||
if cfg.EnableLSP {
|
||||
t.Error("expected EnableLSP to be false")
|
||||
}
|
||||
if cfg.FollowSymlinks {
|
||||
t.Error("expected FollowSymlinks to be false")
|
||||
}
|
||||
if cfg.RespectGitignore {
|
||||
t.Error("expected RespectGitignore to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPathAllowed(t *testing.T) {
|
||||
// Create a temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := Default()
|
||||
cfg.WorkspaceRoot = tmpDir
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
allowed bool
|
||||
}{
|
||||
{
|
||||
name: "file in workspace",
|
||||
path: filepath.Join(tmpDir, "test.go"),
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "nested file in workspace",
|
||||
path: filepath.Join(tmpDir, "subdir", "test.go"),
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "path outside workspace",
|
||||
path: "/etc/passwd",
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "relative path traversal",
|
||||
path: filepath.Join(tmpDir, "..", "outside.txt"),
|
||||
allowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := cfg.IsPathAllowed(tt.path)
|
||||
if result != tt.allowed {
|
||||
t.Errorf("IsPathAllowed(%q) = %v, want %v", tt.path, result, tt.allowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadWithConfigFile(t *testing.T) {
|
||||
// Create a temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Write config file
|
||||
configPath := filepath.Join(tmpDir, ".mcp-filepuff.json")
|
||||
configContent := `{
|
||||
"enable_lsp": false,
|
||||
"follow_symlinks": false
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configContent), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
var cfg *Config
|
||||
cfg, err = Load(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.EnableLSP {
|
||||
t.Error("expected EnableLSP to be false from config file")
|
||||
}
|
||||
if cfg.FollowSymlinks {
|
||||
t.Error("expected FollowSymlinks to be false from config file")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIsPathAllowed_SymlinkSecurity tests the symlink security fix.
|
||||
func TestIsPathAllowed_SymlinkSecurity(t *testing.T) {
|
||||
// Create a temporary workspace
|
||||
tmpDir := t.TempDir()
|
||||
workspace := filepath.Join(tmpDir, "workspace")
|
||||
outside := filepath.Join(tmpDir, "outside")
|
||||
|
||||
if err := os.MkdirAll(workspace, 0700); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(outside, 0700); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a file outside the workspace
|
||||
outsideFile := filepath.Join(outside, "secret.txt")
|
||||
if err := os.WriteFile(outsideFile, []byte("secret data"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
WorkspaceRoot: workspace,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
setup func() string
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "regular file inside workspace",
|
||||
setup: func() string {
|
||||
return filepath.Join(workspace, "file.txt")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "file with parent directory traversal",
|
||||
setup: func() string {
|
||||
return filepath.Join(workspace, "../outside/secret.txt")
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "symlink pointing outside workspace",
|
||||
setup: func() string {
|
||||
symlink := filepath.Join(workspace, "link.txt")
|
||||
_ = os.Symlink(outsideFile, symlink)
|
||||
return symlink
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "symlink pointing inside workspace",
|
||||
setup: func() string {
|
||||
inside := filepath.Join(workspace, "inside.txt")
|
||||
_ = os.WriteFile(inside, []byte("ok"), 0600)
|
||||
symlink := filepath.Join(workspace, "link_inside.txt")
|
||||
_ = os.Symlink(inside, symlink)
|
||||
return symlink
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "dotfile inside workspace",
|
||||
setup: func() string {
|
||||
return filepath.Join(workspace, ".gitignore")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "hidden directory inside workspace",
|
||||
setup: func() string {
|
||||
return filepath.Join(workspace, ".git/config")
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := tt.setup()
|
||||
result := cfg.IsPathAllowed(path)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsPathAllowed(%q) = %v, want %v", path, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPathAllowed_BasicCases tests basic path validation.
|
||||
func TestIsPathAllowed_BasicCases(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &Config{
|
||||
WorkspaceRoot: tmpDir,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "path inside workspace",
|
||||
path: filepath.Join(tmpDir, "file.txt"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "path outside workspace",
|
||||
path: "/etc/passwd",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "parent directory reference",
|
||||
path: filepath.Join(tmpDir, "../../../etc/passwd"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "workspace root itself",
|
||||
path: tmpDir,
|
||||
expected: false, // Empty relative path
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := cfg.IsPathAllowed(tt.path)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsPathAllowed(%q) = %v, want %v", tt.path, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
)
|
||||
|
||||
// TestConcurrentEditLocking tests that concurrent edits to the same file are properly serialized.
|
||||
func TestConcurrentEditLocking(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
// Create initial file
|
||||
initialContent := `package main
|
||||
|
||||
func main() {
|
||||
println("hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
registry := parser.NewRegistry()
|
||||
engine := NewEngine(registry)
|
||||
|
||||
// Run 10 concurrent edits
|
||||
const numEdits = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numEdits)
|
||||
|
||||
errors := make(chan error, numEdits)
|
||||
|
||||
for i := 0; i < numEdits; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: testFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Kind: "function_declaration",
|
||||
Name: "main",
|
||||
},
|
||||
NewContent: `func main() {
|
||||
println("edit ` + string(rune(i)) + `")
|
||||
}`,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent edit failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify file wasn't corrupted
|
||||
finalContent, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Parse to ensure it's still valid Go
|
||||
_, err = registry.Parse(context.Background(), testFile, finalContent)
|
||||
if err != nil {
|
||||
t.Errorf("File corrupted after concurrent edits: %v\nContent:\n%s", err, string(finalContent))
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentEditDifferentFiles tests that concurrent edits to different files don't block each other.
|
||||
func TestConcurrentEditDifferentFiles(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
registry := parser.NewRegistry()
|
||||
engine := NewEngine(registry)
|
||||
|
||||
const numFiles = 5
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numFiles)
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < numFiles; i++ {
|
||||
i := i
|
||||
testFile := filepath.Join(tmpDir, fmt.Sprintf("test%d.go", i))
|
||||
|
||||
// Create initial file
|
||||
initialContent := `package main
|
||||
|
||||
func test() {
|
||||
println("initial")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Wait for all goroutines to be ready
|
||||
<-startBarrier
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: testFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Kind: "function_declaration",
|
||||
Name: "test",
|
||||
},
|
||||
NewContent: `func test() {
|
||||
println("modified")
|
||||
}`,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Errorf("Edit failed for %s: %v", testFile, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Errorf("Edit unsuccessful for %s: %s", testFile, result.Error)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Release all goroutines simultaneously
|
||||
close(startBarrier)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestFileLockRelease tests that file locks are properly released after edits.
|
||||
func TestFileLockRelease(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
initialContent := `package main
|
||||
|
||||
func test() {}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
registry := parser.NewRegistry()
|
||||
engine := NewEngine(registry)
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: testFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Kind: "function_declaration",
|
||||
Name: "test",
|
||||
},
|
||||
NewContent: `func test() { println("updated") }`,
|
||||
}
|
||||
|
||||
// First edit
|
||||
ctx := context.Background()
|
||||
result1, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result1.Success {
|
||||
t.Fatalf("First edit failed: %s", result1.Error)
|
||||
}
|
||||
|
||||
// Second edit should succeed (lock was released)
|
||||
result2, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result2.Success {
|
||||
t.Fatalf("Second edit failed (lock not released?): %s", result2.Error)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,757 @@
|
||||
// Package edit provides AST-aware file editing capabilities.
|
||||
package edit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// Global regex cache for compiled patterns (thread-safe)
|
||||
var regexCache sync.Map // string -> *regexp.Regexp
|
||||
|
||||
// compileRegex compiles a regex pattern with caching for performance.
|
||||
func compileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
// Check cache first
|
||||
if cached, ok := regexCache.Load(pattern); ok {
|
||||
return cached.(*regexp.Regexp), nil
|
||||
}
|
||||
|
||||
// Compile and cache
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
regexCache.Store(pattern, re)
|
||||
return re, nil
|
||||
}
|
||||
|
||||
// EditOperation defines the type of edit operation.
|
||||
type EditOperation string
|
||||
|
||||
const (
|
||||
EditReplace EditOperation = "replace"
|
||||
EditInsertBefore EditOperation = "insert_before"
|
||||
EditInsertAfter EditOperation = "insert_after"
|
||||
EditDelete EditOperation = "delete"
|
||||
)
|
||||
|
||||
// ASTEdit represents an AST-aware edit request.
|
||||
type ASTEdit struct {
|
||||
File string `json:"file"`
|
||||
Operation EditOperation `json:"operation"`
|
||||
NewContent string `json:"new_content,omitempty"`
|
||||
Selector ASTSelector `json:"selector"`
|
||||
}
|
||||
|
||||
// ASTSelector specifies how to find the target node.
|
||||
type ASTSelector struct {
|
||||
Kind string `json:"kind,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
TextPattern string `json:"text_pattern,omitempty"`
|
||||
AtLine int `json:"at_line,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
LineEnd int `json:"line_end,omitempty"`
|
||||
}
|
||||
|
||||
// EditResult contains the result of an edit operation.
|
||||
type EditResult struct {
|
||||
Diff string `json:"diff,omitempty"`
|
||||
OriginalContent string `json:"original_content,omitempty"`
|
||||
NewContent string `json:"new_content,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
Applied bool `json:"applied"`
|
||||
}
|
||||
|
||||
// Engine performs AST-aware edits.
|
||||
type Engine struct {
|
||||
registry *parser.Registry
|
||||
fileLocks sync.Map // map[string]*sync.Mutex for per-file locking
|
||||
}
|
||||
|
||||
// NewEngine creates a new edit engine.
|
||||
func NewEngine(registry *parser.Registry) *Engine {
|
||||
return &Engine{
|
||||
registry: registry,
|
||||
fileLocks: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// lockFile acquires a lock for the specified file and returns an unlock function.
|
||||
// This prevents concurrent edits to the same file which could cause corruption.
|
||||
func (e *Engine) lockFile(filePath string) func() {
|
||||
// Get or create mutex for this file
|
||||
actual, _ := e.fileLocks.LoadOrStore(filePath, &sync.Mutex{})
|
||||
mu := actual.(*sync.Mutex)
|
||||
mu.Lock()
|
||||
return mu.Unlock
|
||||
}
|
||||
|
||||
// Preview generates a preview of an edit without applying it.
|
||||
func (e *Engine) Preview(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
|
||||
return e.performEdit(ctx, edit, false)
|
||||
}
|
||||
|
||||
// Apply performs an edit and writes the result to disk.
|
||||
// Uses file locking to prevent concurrent edits to the same file.
|
||||
func (e *Engine) Apply(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
|
||||
unlock := e.lockFile(edit.File)
|
||||
defer unlock()
|
||||
return e.performEdit(ctx, edit, true)
|
||||
}
|
||||
|
||||
// performEdit executes an edit operation.
|
||||
func (e *Engine) performEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||
// Determine if we should use text mode
|
||||
useTextMode := e.shouldUseTextMode(edit)
|
||||
|
||||
if useTextMode {
|
||||
return e.performTextEdit(ctx, edit, apply)
|
||||
}
|
||||
return e.performASTEdit(ctx, edit, apply)
|
||||
}
|
||||
|
||||
// shouldUseTextMode determines if text-based editing should be used.
|
||||
func (e *Engine) shouldUseTextMode(edit *ASTEdit) bool {
|
||||
// Use text mode if text-specific selectors are provided
|
||||
if edit.Selector.Text != "" || edit.Selector.TextPattern != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Use text mode if line range is specified without AST selectors
|
||||
if edit.Selector.AtLine > 0 && edit.Selector.LineEnd > 0 &&
|
||||
edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Use text mode if language is not supported for AST
|
||||
lang := protocol.DetectLanguage(edit.File)
|
||||
return lang == protocol.LangUnknown
|
||||
}
|
||||
|
||||
// performASTEdit executes an AST-aware edit operation.
|
||||
func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||
// Validate operation
|
||||
if err := e.validateASTEdit(edit); err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(edit.File)
|
||||
if err != nil {
|
||||
structuredErr := errors.NewFileNotReadableError(edit.File, err)
|
||||
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
|
||||
}
|
||||
|
||||
// Parse file
|
||||
parseResult, err := e.registry.Parse(ctx, edit.File, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Find target node
|
||||
node, err := e.resolveSelector(edit.Selector, parseResult.Tree, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Apply edit
|
||||
newContent, err := e.applyEdit(edit, node, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Validate new content (re-parse)
|
||||
_, err = e.registry.Parse(ctx, edit.File, newContent)
|
||||
if err != nil {
|
||||
structuredErr := errors.NewEditValidationError(edit.File, err)
|
||||
return &EditResult{
|
||||
Success: false,
|
||||
Error: structuredErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate diff
|
||||
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||
|
||||
result := &EditResult{
|
||||
Success: true,
|
||||
Diff: diff,
|
||||
OriginalContent: string(content),
|
||||
NewContent: string(newContent),
|
||||
Applied: false,
|
||||
}
|
||||
|
||||
// Apply changes if requested
|
||||
if apply {
|
||||
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
|
||||
structuredErr := errors.NewFileNotWritableError(edit.File, err)
|
||||
return &EditResult{
|
||||
Success: false,
|
||||
Error: structuredErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
result.Applied = true
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// performTextEdit executes a text-based edit operation for non-AST files.
|
||||
func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||
// Validate operation
|
||||
if err := e.validateTextEdit(edit); err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(edit.File)
|
||||
if err != nil {
|
||||
structuredErr := errors.NewFileNotReadableError(edit.File, err)
|
||||
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
|
||||
}
|
||||
|
||||
// Find the text selection (byte range)
|
||||
start, end, err := e.resolveTextSelector(edit.Selector, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Apply edit
|
||||
newContent, err := e.applyTextEditOperation(edit.Operation, content, start, end, edit.NewContent)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Generate diff
|
||||
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||
|
||||
result := &EditResult{
|
||||
Success: true,
|
||||
Diff: diff,
|
||||
OriginalContent: string(content),
|
||||
NewContent: string(newContent),
|
||||
Applied: false,
|
||||
}
|
||||
|
||||
// Apply changes if requested
|
||||
if apply {
|
||||
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
|
||||
structuredErr := errors.NewFileNotWritableError(edit.File, err)
|
||||
return &EditResult{
|
||||
Success: false,
|
||||
Error: structuredErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
result.Applied = true
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// validateBaseEdit checks common edit request fields.
|
||||
func (e *Engine) validateBaseEdit(edit *ASTEdit) error {
|
||||
if edit.File == "" {
|
||||
return errors.NewInvalidEditError("file is required")
|
||||
}
|
||||
|
||||
if edit.Operation == "" {
|
||||
return errors.NewInvalidEditError("operation is required")
|
||||
}
|
||||
|
||||
// Validate operation type
|
||||
switch edit.Operation {
|
||||
case EditReplace, EditInsertBefore, EditInsertAfter:
|
||||
if edit.NewContent == "" {
|
||||
return errors.NewInvalidEditError(fmt.Sprintf("new_content is required for %s operation", edit.Operation))
|
||||
}
|
||||
case EditDelete:
|
||||
// new_content not required
|
||||
default:
|
||||
return errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateASTEdit checks if an AST edit request is valid.
|
||||
func (e *Engine) validateASTEdit(edit *ASTEdit) error {
|
||||
if err := e.validateBaseEdit(edit); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate AST selector
|
||||
if edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" && edit.Selector.AtLine == 0 {
|
||||
return errors.NewInvalidEditError("AST selector must specify at least one of: kind, name, pattern, or at_line")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTextEdit checks if a text edit request is valid.
|
||||
func (e *Engine) validateTextEdit(edit *ASTEdit) error {
|
||||
if err := e.validateBaseEdit(edit); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate text selector - need at least one text selection method
|
||||
hasTextSelector := edit.Selector.Text != "" ||
|
||||
edit.Selector.TextPattern != "" ||
|
||||
edit.Selector.AtLine > 0
|
||||
|
||||
if !hasTextSelector {
|
||||
return errors.NewInvalidEditError("text selector must specify at least one of: text, text_pattern, or at_line")
|
||||
}
|
||||
|
||||
// Validate regex pattern if provided (uses cached compilation)
|
||||
if edit.Selector.TextPattern != "" {
|
||||
if _, err := compileRegex(edit.Selector.TextPattern); err != nil {
|
||||
return errors.Wrap(errors.ErrInvalidEdit, "invalid text_pattern regex", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSelector finds the target node based on the selector.
|
||||
func (e *Engine) resolveSelector(sel ASTSelector, tree *sitter.Tree, content []byte) (*sitter.Node, error) {
|
||||
if tree == nil {
|
||||
return nil, errors.NewNodeNotFoundError("no AST tree available")
|
||||
}
|
||||
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return nil, errors.NewNodeNotFoundError("empty AST tree")
|
||||
}
|
||||
|
||||
var matches []*sitter.Node
|
||||
|
||||
parser.WalkTree(root, func(n *sitter.Node) bool {
|
||||
if e.matchesSelector(sel, n, content) {
|
||||
matches = append(matches, n)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(matches) == 0 {
|
||||
selectorDesc := fmt.Sprintf("kind=%s name=%s pattern=%s line=%d", sel.Kind, sel.Name, sel.Pattern, sel.AtLine)
|
||||
return nil, errors.NewNodeNotFoundError(selectorDesc)
|
||||
}
|
||||
|
||||
// Use index to select specific match
|
||||
index := sel.Index
|
||||
if index < 0 || index >= len(matches) {
|
||||
return nil, errors.NewInvalidSelectionError(fmt.Sprintf("selector matched %d nodes, but index %d is out of range", len(matches), index))
|
||||
}
|
||||
|
||||
return matches[index], nil
|
||||
}
|
||||
|
||||
// matchesSelector checks if a node matches the selector criteria.
|
||||
func (e *Engine) matchesSelector(sel ASTSelector, n *sitter.Node, content []byte) bool {
|
||||
// Check kind
|
||||
if sel.Kind != "" && n.Type() != sel.Kind {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check name (look for identifier in the node)
|
||||
if sel.Name != "" {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
// Also try to find an identifier child
|
||||
found := false
|
||||
for i := 0; i < int(n.NamedChildCount()); i++ {
|
||||
child := n.NamedChild(i)
|
||||
if child != nil && child.Type() == "identifier" {
|
||||
if parser.GetNodeText(child, content) == sel.Name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
} else if parser.GetNodeText(nameNode, content) != sel.Name {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check line
|
||||
if sel.AtLine > 0 {
|
||||
startLine := int(n.StartPoint().Row) + 1
|
||||
endLine := int(n.EndPoint().Row) + 1
|
||||
if sel.AtLine < startLine || sel.AtLine > endLine {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern matching is handled separately (simplified here)
|
||||
if sel.Pattern != "" {
|
||||
nodeText := parser.GetNodeText(n, content)
|
||||
if !strings.Contains(nodeText, sel.Pattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// applyEdit applies the edit operation to the content.
|
||||
func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]byte, error) {
|
||||
startByte := node.StartByte()
|
||||
endByte := node.EndByte()
|
||||
|
||||
// Detect and preserve indentation
|
||||
indentation := detectIndentation(content, startByte)
|
||||
newContent := indentContent(edit.NewContent, indentation)
|
||||
|
||||
var result []byte
|
||||
|
||||
switch edit.Operation {
|
||||
case EditReplace:
|
||||
result = append(result, content[:startByte]...)
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
case EditInsertBefore:
|
||||
result = append(result, content[:startByte]...)
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, content[startByte:]...)
|
||||
|
||||
case EditInsertAfter:
|
||||
result = append(result, content[:endByte]...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
case EditDelete:
|
||||
result = append(result, content[:startByte]...)
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
default:
|
||||
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// detectIndentation detects the indentation at a given byte position.
|
||||
func detectIndentation(content []byte, bytePos uint32) string {
|
||||
// Find the start of the line
|
||||
lineStart := int(bytePos)
|
||||
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||
lineStart--
|
||||
}
|
||||
|
||||
// Extract leading whitespace
|
||||
var indent strings.Builder
|
||||
for i := lineStart; i < int(bytePos) && i < len(content); i++ {
|
||||
c := content[i]
|
||||
if c == ' ' || c == '\t' {
|
||||
indent.WriteByte(c)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return indent.String()
|
||||
}
|
||||
|
||||
// indentContent applies indentation to multi-line content.
|
||||
func indentContent(content string, indent string) string {
|
||||
if indent == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
for i, line := range lines {
|
||||
if i > 0 && line != "" {
|
||||
lines[i] = indent + line
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// generateDiff creates a unified diff between original and modified content.
|
||||
// Uses Myers diff algorithm for accurate and readable diffs.
|
||||
func generateDiff(original, modified, filename string) string {
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(original, modified, false)
|
||||
|
||||
// Cleanup for readability
|
||||
diffs = dmp.DiffCleanupSemantic(diffs)
|
||||
|
||||
// Convert to unified diff format
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(fmt.Sprintf("--- %s\n", filename))
|
||||
buf.WriteString(fmt.Sprintf("+++ %s\n", filename))
|
||||
|
||||
// Group diffs into hunks
|
||||
lineNum := 1
|
||||
for _, diff := range diffs {
|
||||
lines := strings.Split(diff.Text, "\n")
|
||||
for i, line := range lines {
|
||||
// Skip empty last line from split
|
||||
if i == len(lines)-1 && line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
switch diff.Type {
|
||||
case diffmatchpatch.DiffDelete:
|
||||
buf.WriteString(fmt.Sprintf("-%s\n", line))
|
||||
case diffmatchpatch.DiffInsert:
|
||||
buf.WriteString(fmt.Sprintf("+%s\n", line))
|
||||
case diffmatchpatch.DiffEqual:
|
||||
buf.WriteString(fmt.Sprintf(" %s\n", line))
|
||||
lineNum++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// resolveTextSelector finds the byte range for a text-based selection.
|
||||
func (e *Engine) resolveTextSelector(sel ASTSelector, content []byte) (start, end int, err error) {
|
||||
switch {
|
||||
case sel.Text != "":
|
||||
return e.findExactText(content, sel.Text, sel.Index)
|
||||
case sel.TextPattern != "":
|
||||
return e.findRegexPattern(content, sel.TextPattern, sel.Index)
|
||||
case sel.AtLine > 0:
|
||||
return e.findLineRange(content, sel.AtLine, sel.LineEnd)
|
||||
default:
|
||||
return 0, 0, errors.NewInvalidEditError("text selector requires text, text_pattern, or at_line")
|
||||
}
|
||||
}
|
||||
|
||||
// findExactText finds an exact text match in content.
|
||||
func (e *Engine) findExactText(content []byte, text string, index int) (start, end int, err error) {
|
||||
if text == "" {
|
||||
return 0, 0, errors.NewInvalidEditError("text selector cannot be empty")
|
||||
}
|
||||
|
||||
textBytes := []byte(text)
|
||||
type match struct{ start, end int }
|
||||
var matches []match
|
||||
|
||||
offset := 0
|
||||
for {
|
||||
idx := bytes.Index(content[offset:], textBytes)
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
matches = append(matches, match{
|
||||
start: offset + idx,
|
||||
end: offset + idx + len(textBytes),
|
||||
})
|
||||
offset += idx + 1
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text not found: %q", truncateString(text, 50)))
|
||||
}
|
||||
|
||||
// If multiple matches and no index specified, require explicit selection
|
||||
if len(matches) > 1 && index == 0 {
|
||||
// Check if index was explicitly set to 0 or just defaulted
|
||||
// Since we can't distinguish, we'll allow index 0 but warn about multiple matches
|
||||
// Actually, let's be strict and require explicit index for multiple matches
|
||||
locations := make([]string, 0, min(len(matches), 5))
|
||||
for i, m := range matches {
|
||||
if i >= 5 {
|
||||
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||
break
|
||||
}
|
||||
line := countLines(content[:m.start]) + 1
|
||||
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||
}
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||
}
|
||||
|
||||
if index >= len(matches) {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||
}
|
||||
|
||||
return matches[index].start, matches[index].end, nil
|
||||
}
|
||||
|
||||
// findRegexPattern finds a regex pattern match in content.
|
||||
func (e *Engine) findRegexPattern(content []byte, pattern string, index int) (start, end int, err error) {
|
||||
re, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(errors.ErrInvalidEdit, "invalid regex pattern", err)
|
||||
}
|
||||
|
||||
matches := re.FindAllIndex(content, -1)
|
||||
if len(matches) == 0 {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern not found: %q", truncateString(pattern, 50)))
|
||||
}
|
||||
|
||||
// If multiple matches and index is 0 (default), show error with locations
|
||||
if len(matches) > 1 && index == 0 {
|
||||
locations := make([]string, 0, min(len(matches), 5))
|
||||
for i, m := range matches {
|
||||
if i >= 5 {
|
||||
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||
break
|
||||
}
|
||||
line := countLines(content[:m[0]]) + 1
|
||||
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||
}
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||
}
|
||||
|
||||
if index >= len(matches) {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||
}
|
||||
|
||||
return matches[index][0], matches[index][1], nil
|
||||
}
|
||||
|
||||
// findLineRange finds the byte range for a line range selection.
|
||||
func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, end int, err error) {
|
||||
if lineEnd == 0 {
|
||||
lineEnd = lineStart
|
||||
}
|
||||
|
||||
if lineStart < 1 {
|
||||
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line number must be >= 1, got %d", lineStart))
|
||||
}
|
||||
|
||||
if lineEnd < lineStart {
|
||||
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line_end (%d) must be >= line (%d)", lineEnd, lineStart))
|
||||
}
|
||||
|
||||
lines := bytes.Split(content, []byte("\n"))
|
||||
totalLines := len(lines)
|
||||
|
||||
// Convert to 0-indexed
|
||||
startIdx := lineStart - 1
|
||||
endIdx := lineEnd - 1
|
||||
|
||||
if startIdx >= totalLines {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line %d out of range (file has %d lines)", lineStart, totalLines))
|
||||
}
|
||||
if endIdx >= totalLines {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line_end %d out of range (file has %d lines)", lineEnd, totalLines))
|
||||
}
|
||||
|
||||
// Calculate byte positions
|
||||
start = 0
|
||||
for i := 0; i < startIdx; i++ {
|
||||
start += len(lines[i]) + 1 // +1 for newline
|
||||
}
|
||||
|
||||
end = start
|
||||
for i := startIdx; i <= endIdx; i++ {
|
||||
end += len(lines[i])
|
||||
if i < totalLines-1 {
|
||||
end += 1 // newline
|
||||
}
|
||||
}
|
||||
|
||||
return start, end, nil
|
||||
}
|
||||
|
||||
// applyTextEditOperation applies a text edit operation.
|
||||
func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start, end int, newContent string) ([]byte, error) {
|
||||
// Detect indentation at the selection point
|
||||
indentation := detectIndentationAtByte(content, start)
|
||||
indentedContent := indentContent(newContent, indentation)
|
||||
|
||||
var result []byte
|
||||
|
||||
switch op {
|
||||
case EditReplace:
|
||||
result = append(result, content[:start]...)
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
case EditInsertBefore:
|
||||
result = append(result, content[:start]...)
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, content[start:]...)
|
||||
|
||||
case EditInsertAfter:
|
||||
result = append(result, content[:end]...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
case EditDelete:
|
||||
result = append(result, content[:start]...)
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
default:
|
||||
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", op))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// detectIndentationAtByte detects indentation at a byte position.
|
||||
func detectIndentationAtByte(content []byte, bytePos int) string {
|
||||
// Find the start of the line
|
||||
lineStart := bytePos
|
||||
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||
lineStart--
|
||||
}
|
||||
|
||||
// Extract leading whitespace
|
||||
var indent strings.Builder
|
||||
for i := lineStart; i < bytePos && i < len(content); i++ {
|
||||
c := content[i]
|
||||
if c == ' ' || c == '\t' {
|
||||
indent.WriteByte(c)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return indent.String()
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen with ellipsis.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// countLines counts the number of newlines in content.
|
||||
func countLines(content []byte) int {
|
||||
return bytes.Count(content, []byte("\n"))
|
||||
}
|
||||
|
||||
// ValidateLanguage checks if AST editing is supported for a file.
|
||||
// Returns nil for supported languages, error for unsupported.
|
||||
// Note: Text-based editing is always available regardless of this check.
|
||||
func ValidateLanguage(filename string) error {
|
||||
lang := protocol.DetectLanguage(filename)
|
||||
if lang == protocol.LangUnknown {
|
||||
return fmt.Errorf("unsupported file type for AST editing: %s (text-based editing is available)", filename)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,836 @@
|
||||
package edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
)
|
||||
|
||||
func TestValidateEdit(t *testing.T) {
|
||||
e := NewEngine(parser.NewRegistry())
|
||||
|
||||
tests := []struct {
|
||||
edit *ASTEdit
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid replace",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid delete",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditDelete,
|
||||
Selector: ASTSelector{Name: "oldFunc"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing file",
|
||||
edit: &ASTEdit{
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing operation",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "replace without content",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown operation",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: "unknown",
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := e.validateASTEdit(tt.edit)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSelector(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
content := []byte(`package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
|
||||
func Goodbye() {
|
||||
println("goodbye")
|
||||
}
|
||||
`)
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := registry.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sel ASTSelector
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "by kind",
|
||||
sel: ASTSelector{Kind: "function_declaration"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "by name",
|
||||
sel: ASTSelector{Name: "Hello"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "by kind and name",
|
||||
sel: ASTSelector{Kind: "function_declaration", Name: "Goodbye"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "by line",
|
||||
sel: ASTSelector{AtLine: 3},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
sel: ASTSelector{Name: "NonExistent"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "index out of range",
|
||||
sel: ASTSelector{Kind: "function_declaration", Index: 10},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
node, err := e.resolveSelector(tt.sel, result.Tree, content)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if node == nil {
|
||||
t.Error("expected node")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEdit(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
content := []byte(`package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
`)
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := registry.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
operation EditOperation
|
||||
newCode string
|
||||
wantIn string // substring that should be in result
|
||||
}{
|
||||
{
|
||||
name: "replace",
|
||||
operation: EditReplace,
|
||||
newCode: "func NewHello() {}",
|
||||
wantIn: "NewHello",
|
||||
},
|
||||
{
|
||||
name: "insert after",
|
||||
operation: EditInsertAfter,
|
||||
newCode: "func After() {}",
|
||||
wantIn: "After",
|
||||
},
|
||||
{
|
||||
name: "insert before",
|
||||
operation: EditInsertBefore,
|
||||
newCode: "func Before() {}",
|
||||
wantIn: "Before",
|
||||
},
|
||||
{
|
||||
name: "delete",
|
||||
operation: EditDelete,
|
||||
newCode: "",
|
||||
wantIn: "package main", // Should still have package declaration
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Find the function node
|
||||
node, err := e.resolveSelector(ASTSelector{Kind: "function_declaration"}, result.Tree, content)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: tt.operation,
|
||||
NewContent: tt.newCode,
|
||||
}
|
||||
|
||||
newContent, err := e.applyEdit(edit, node, content)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(newContent), tt.wantIn) {
|
||||
t.Errorf("result does not contain %q:\n%s", tt.wantIn, string(newContent))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreview(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
|
||||
}
|
||||
|
||||
result, err := e.Preview(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("preview failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("preview was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
if result.Applied {
|
||||
t.Error("preview should not apply changes")
|
||||
}
|
||||
|
||||
if result.Diff == "" {
|
||||
t.Error("expected diff in result")
|
||||
}
|
||||
|
||||
// Verify original file is unchanged
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if string(fileContent) != content {
|
||||
t.Error("original file was modified during preview")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyToFile(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
if !result.Applied {
|
||||
t.Error("apply should set Applied=true")
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "NewHello") {
|
||||
t.Error("file was not modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectIndentation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
pos uint32
|
||||
}{
|
||||
{
|
||||
name: "no indent",
|
||||
content: "func main() {}",
|
||||
pos: 0,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "tab indent",
|
||||
content: "func main() {\n\tprintln(\"hello\")\n}",
|
||||
pos: 15,
|
||||
want: "\t",
|
||||
},
|
||||
{
|
||||
name: "space indent",
|
||||
content: "func main() {\n println(\"hello\")\n}",
|
||||
pos: 18,
|
||||
want: " ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := detectIndentation([]byte(tt.content), tt.pos)
|
||||
if got != tt.want {
|
||||
t.Errorf("detectIndentation() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDiff(t *testing.T) {
|
||||
original := "line1\nline2\nline3"
|
||||
modified := "line1\nmodified\nline3"
|
||||
filename := "test.txt"
|
||||
|
||||
diff := generateDiff(original, modified, filename)
|
||||
|
||||
if !strings.Contains(diff, "---") {
|
||||
t.Error("diff should contain --- header")
|
||||
}
|
||||
if !strings.Contains(diff, "+++") {
|
||||
t.Error("diff should contain +++ header")
|
||||
}
|
||||
if !strings.Contains(diff, "-line2") {
|
||||
t.Error("diff should show removed line")
|
||||
}
|
||||
if !strings.Contains(diff, "+modified") {
|
||||
t.Error("diff should show added line")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Text-based editing tests ====================
|
||||
|
||||
func TestTextEditWithExactText(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp markdown file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "README.md")
|
||||
|
||||
content := `# My Project
|
||||
|
||||
## Installation
|
||||
|
||||
Run the following command:
|
||||
|
||||
## Usage
|
||||
|
||||
See the docs.
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Text: "## Installation"},
|
||||
NewContent: "## Getting Started",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "## Getting Started") {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
if strings.Contains(string(fileContent), "## Installation") {
|
||||
t.Error("old text should be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditWithLineRange(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp config file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `name: myapp
|
||||
version: 1.0.0
|
||||
database:
|
||||
host: localhost
|
||||
port: 5432
|
||||
logging:
|
||||
level: debug
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
AtLine: 3,
|
||||
LineEnd: 5,
|
||||
},
|
||||
NewContent: "database:\n host: production.db.example.com\n port: 5433",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "production.db.example.com") {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditWithRegexPattern(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp JSON file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "package.json")
|
||||
|
||||
content := `{
|
||||
"name": "my-package",
|
||||
"version": "1.0.0",
|
||||
"description": "A test package"
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{TextPattern: `"version":\s*"[^"]+"`},
|
||||
NewContent: `"version": "2.0.0"`,
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), `"version": "2.0.0"`) {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditInsertAfter(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp env file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
content := `DATABASE_URL=postgres://localhost/mydb
|
||||
SECRET_KEY=abc123
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditInsertAfter,
|
||||
Selector: ASTSelector{Text: "DATABASE_URL=postgres://localhost/mydb"},
|
||||
NewContent: "REDIS_URL=redis://localhost:6379",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "REDIS_URL=redis://localhost:6379") {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditMultipleMatchesError(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file with repeated text
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := `TODO: fix this
|
||||
some code here
|
||||
TODO: also fix this
|
||||
more code
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Text: "TODO"},
|
||||
NewContent: "DONE",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
// Should fail because of multiple matches
|
||||
if result.Success {
|
||||
t.Error("expected error for multiple matches without index")
|
||||
}
|
||||
if !strings.Contains(result.Error, "matches") {
|
||||
t.Errorf("error should mention multiple matches: %s", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditWithIndex(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file with repeated text
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := `TODO: fix this
|
||||
some code here
|
||||
TODO: also fix this
|
||||
more code
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Text: "TODO",
|
||||
Index: 1, // Select second match
|
||||
},
|
||||
NewContent: "DONE",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify only second TODO was replaced
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
contentStr := string(fileContent)
|
||||
if !strings.Contains(contentStr, "TODO: fix this") {
|
||||
t.Error("first TODO should not be replaced")
|
||||
}
|
||||
if !strings.Contains(contentStr, "DONE: also fix this") {
|
||||
t.Error("second TODO should be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTextEdit(t *testing.T) {
|
||||
e := NewEngine(parser.NewRegistry())
|
||||
|
||||
tests := []struct {
|
||||
edit *ASTEdit
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid text selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Text: "some text"},
|
||||
NewContent: "new text",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid pattern selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{TextPattern: "\\d+"},
|
||||
NewContent: "replaced",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid line selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{AtLine: 5},
|
||||
NewContent: "new line",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{},
|
||||
NewContent: "new text",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid regex pattern",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{TextPattern: "[invalid"},
|
||||
NewContent: "new text",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := e.validateTextEdit(tt.edit)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindLineRange(t *testing.T) {
|
||||
e := NewEngine(parser.NewRegistry())
|
||||
|
||||
content := []byte("line1\nline2\nline3\nline4\nline5")
|
||||
|
||||
// Content: "line1\nline2\nline3\nline4\nline5" (no trailing newline)
|
||||
// Positions: line1=0-5, \n=5, line2=6-10, \n=11, line3=12-16, \n=17, line4=18-22, \n=23, line5=24-28
|
||||
tests := []struct {
|
||||
name string
|
||||
lineStart int
|
||||
lineEnd int
|
||||
wantStart int
|
||||
wantEnd int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single line",
|
||||
lineStart: 2,
|
||||
lineEnd: 0, // defaults to lineStart
|
||||
wantStart: 6,
|
||||
wantEnd: 12, // includes trailing newline
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "range of lines",
|
||||
lineStart: 2,
|
||||
lineEnd: 4,
|
||||
wantStart: 6,
|
||||
wantEnd: 24, // through end of line4 including newline
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "first line",
|
||||
lineStart: 1,
|
||||
lineEnd: 1,
|
||||
wantStart: 0,
|
||||
wantEnd: 6, // includes trailing newline
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "line out of range",
|
||||
lineStart: 10,
|
||||
lineEnd: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid line number",
|
||||
lineStart: 0,
|
||||
lineEnd: 1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "end before start",
|
||||
lineStart: 3,
|
||||
lineEnd: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
start, end, err := e.findLineRange(content, tt.lineStart, tt.lineEnd)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if start != tt.wantStart {
|
||||
t.Errorf("start = %d, want %d", start, tt.wantStart)
|
||||
}
|
||||
if end != tt.wantEnd {
|
||||
t.Errorf("end = %d, want %d", end, tt.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,310 @@
|
||||
// Package lsp provides a generic LSP client implementation.
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
// Client represents an LSP client connection.
|
||||
type Client struct {
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
stderr io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
pending map[int64]chan *Response
|
||||
done chan struct{}
|
||||
notifications chan *Notification
|
||||
requestID atomic.Int64
|
||||
runningMu sync.RWMutex
|
||||
stopOnce sync.Once
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// Request represents a JSON-RPC request.
|
||||
type Request struct {
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
// Response represents a JSON-RPC response.
|
||||
type Response struct {
|
||||
Error *ResponseError `json:"error,omitempty"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
// ResponseError represents a JSON-RPC error.
|
||||
type ResponseError struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
func (e *ResponseError) Error() string {
|
||||
return fmt.Sprintf("LSP error %d: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Notification represents a JSON-RPC notification.
|
||||
type Notification struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// NewClient creates a new LSP client from a command.
|
||||
func NewClient(cmd *exec.Cmd) (*Client, error) {
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = stdin.Close()
|
||||
return nil, fmt.Errorf("failed to get stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
_ = stdin.Close()
|
||||
_ = stdout.Close()
|
||||
return nil, fmt.Errorf("failed to get stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
_ = stdin.Close()
|
||||
_ = stdout.Close()
|
||||
_ = stderr.Close()
|
||||
return nil, fmt.Errorf("failed to start LSP server: %w", err)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
pending: make(map[int64]chan *Response),
|
||||
done: make(chan struct{}),
|
||||
running: true,
|
||||
notifications: make(chan *Notification, 100),
|
||||
}
|
||||
|
||||
// Start reader goroutine
|
||||
go c.readLoop()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Call sends a request and waits for a response.
|
||||
func (c *Client) Call(ctx context.Context, method string, params interface{}) (*Response, error) {
|
||||
c.runningMu.RLock()
|
||||
if !c.running {
|
||||
c.runningMu.RUnlock()
|
||||
return nil, fmt.Errorf("client is not running")
|
||||
}
|
||||
c.runningMu.RUnlock()
|
||||
|
||||
id := c.requestID.Add(1)
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
respChan := make(chan *Response, 1)
|
||||
c.mu.Lock()
|
||||
c.pending[id] = respChan
|
||||
c.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Send request
|
||||
if err := c.send(req); err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
// Wait for response
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-c.done:
|
||||
return nil, fmt.Errorf("client closed")
|
||||
case resp := <-respChan:
|
||||
if resp.Error != nil {
|
||||
return nil, resp.Error
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Notify sends a notification (no response expected).
|
||||
func (c *Client) Notify(method string, params interface{}) error {
|
||||
c.runningMu.RLock()
|
||||
if !c.running {
|
||||
c.runningMu.RUnlock()
|
||||
return fmt.Errorf("client is not running")
|
||||
}
|
||||
c.runningMu.RUnlock()
|
||||
|
||||
notif := struct {
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
}{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
return c.send(notif)
|
||||
}
|
||||
|
||||
// Notifications returns a channel for receiving server notifications.
|
||||
func (c *Client) Notifications() <-chan *Notification {
|
||||
return c.notifications
|
||||
}
|
||||
|
||||
// Close shuts down the client and the LSP server.
|
||||
func (c *Client) Close() error {
|
||||
var err error
|
||||
c.stopOnce.Do(func() {
|
||||
c.runningMu.Lock()
|
||||
c.running = false
|
||||
c.runningMu.Unlock()
|
||||
|
||||
close(c.done)
|
||||
|
||||
// Close stdin to signal the server
|
||||
_ = c.stdin.Close()
|
||||
|
||||
// Wait for process to exit with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = c.cmd.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Clean exit
|
||||
case <-time.After(5 * time.Second):
|
||||
// Force kill
|
||||
_ = c.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
close(c.notifications)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// send writes a JSON-RPC message to the server.
|
||||
func (c *Client) send(msg interface{}) error {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Format with Content-Length header
|
||||
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
|
||||
_, err = c.stdin.Write([]byte(header))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
|
||||
_, err = c.stdin.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write body: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readLoop reads and dispatches messages from the server.
|
||||
func (c *Client) readLoop() {
|
||||
reader := bufio.NewReader(c.stdout)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read headers
|
||||
contentLength := -1
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
break
|
||||
}
|
||||
if strings.HasPrefix(line, "Content-Length:") {
|
||||
lengthStr := strings.TrimSpace(strings.TrimPrefix(line, "Content-Length:"))
|
||||
contentLength, _ = strconv.Atoi(lengthStr)
|
||||
}
|
||||
}
|
||||
|
||||
if contentLength <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read body
|
||||
body := make([]byte, contentLength)
|
||||
_, err := io.ReadFull(reader, body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to parse as response first
|
||||
var resp Response
|
||||
if err := json.Unmarshal(body, &resp); err == nil && resp.ID != 0 {
|
||||
c.mu.Lock()
|
||||
if ch, ok := c.pending[resp.ID]; ok {
|
||||
ch <- &resp
|
||||
}
|
||||
c.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to parse as notification
|
||||
var notif Notification
|
||||
if err := json.Unmarshal(body, ¬if); err == nil && notif.Method != "" {
|
||||
select {
|
||||
case c.notifications <- ¬if:
|
||||
default:
|
||||
// Drop notification if channel is full
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsRunning returns whether the client is running.
|
||||
func (c *Client) IsRunning() bool {
|
||||
c.runningMu.RLock()
|
||||
defer c.runningMu.RUnlock()
|
||||
return c.running
|
||||
}
|
||||
@@ -0,0 +1,535 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// Manager manages LSP servers for different languages.
|
||||
type Manager struct {
|
||||
servers map[protocol.Language]*ManagedServer
|
||||
logger *slog.Logger
|
||||
stopReaper chan struct{}
|
||||
workspaceRoot string
|
||||
timeout time.Duration
|
||||
idleTimeout time.Duration
|
||||
mu sync.RWMutex
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// ManagedServer represents a managed LSP server instance.
|
||||
type ManagedServer struct {
|
||||
lastUsed time.Time
|
||||
initErr error
|
||||
client *Client
|
||||
openDocs map[string]int
|
||||
language protocol.Language
|
||||
capabilities ServerCapabilities
|
||||
mu sync.Mutex
|
||||
ready bool
|
||||
}
|
||||
|
||||
// ServerConfig contains the configuration for an LSP server.
|
||||
type ServerConfig struct {
|
||||
Command []string
|
||||
Args []string
|
||||
}
|
||||
|
||||
// DefaultServerConfigs contains default configurations for LSP servers.
|
||||
var DefaultServerConfigs = map[protocol.Language]ServerConfig{
|
||||
protocol.LangGo: {
|
||||
Command: []string{"gopls"},
|
||||
Args: []string{"serve"},
|
||||
},
|
||||
protocol.LangTypeScript: {
|
||||
Command: []string{"typescript-language-server"},
|
||||
Args: []string{"--stdio"},
|
||||
},
|
||||
protocol.LangJavaScript: {
|
||||
Command: []string{"typescript-language-server"},
|
||||
Args: []string{"--stdio"},
|
||||
},
|
||||
protocol.LangPython: {
|
||||
Command: []string{"pylsp"},
|
||||
},
|
||||
protocol.LangC: {
|
||||
Command: []string{"clangd"},
|
||||
},
|
||||
protocol.LangCpp: {
|
||||
Command: []string{"clangd"},
|
||||
},
|
||||
}
|
||||
|
||||
// NewManager creates a new LSP manager.
|
||||
func NewManager(workspaceRoot string, logger *slog.Logger) *Manager {
|
||||
m := &Manager{
|
||||
servers: make(map[protocol.Language]*ManagedServer),
|
||||
timeout: 10 * time.Second,
|
||||
idleTimeout: 5 * time.Minute,
|
||||
workspaceRoot: workspaceRoot,
|
||||
logger: logger,
|
||||
stopReaper: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start idle reaper
|
||||
go m.reapIdleServers()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// GetServer returns or creates an LSP server for the given language.
|
||||
func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*ManagedServer, error) {
|
||||
m.mu.RLock()
|
||||
srv, exists := m.servers[lang]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists && srv.ready {
|
||||
// Update lastUsed with server's own lock to avoid race condition
|
||||
srv.mu.Lock()
|
||||
srv.lastUsed = time.Now()
|
||||
srv.mu.Unlock()
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// Create new server
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if srv, ok := m.servers[lang]; ok && srv.ready {
|
||||
srv.mu.Lock()
|
||||
srv.lastUsed = time.Now()
|
||||
srv.mu.Unlock()
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// Check if server config exists
|
||||
config, ok := DefaultServerConfigs[lang]
|
||||
if !ok {
|
||||
return nil, errors.New(errors.ErrLSPServerNotFound, fmt.Sprintf("no LSP server configured for language: %s", lang)).
|
||||
WithContext("language", string(lang)).
|
||||
WithRemediation("Configure an LSP server for this language or use a supported language")
|
||||
}
|
||||
|
||||
// Check if command is available
|
||||
cmdPath, err := exec.LookPath(config.Command[0])
|
||||
if err != nil {
|
||||
return nil, errors.NewLSPServerNotFound(string(lang), config.Command[0])
|
||||
}
|
||||
|
||||
// Create command
|
||||
args := append(config.Command[1:], config.Args...)
|
||||
cmd := exec.CommandContext(ctx, cmdPath, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Dir = m.workspaceRoot
|
||||
|
||||
// Create client
|
||||
client, err := NewClient(cmd)
|
||||
if err != nil {
|
||||
// Ensure process is killed if client creation fails
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
return nil, errors.Wrap(errors.ErrLSPCommunication, "failed to create LSP client", err).
|
||||
WithContext("language", string(lang)).
|
||||
WithContext("command", config.Command[0]).
|
||||
WithRemediation("Ensure the LSP server binary is executable and compatible with your system")
|
||||
}
|
||||
|
||||
newSrv := &ManagedServer{
|
||||
client: client,
|
||||
language: lang,
|
||||
lastUsed: time.Now(),
|
||||
openDocs: make(map[string]int),
|
||||
}
|
||||
|
||||
// Initialize server
|
||||
if err := m.initializeServer(ctx, newSrv); err != nil {
|
||||
_ = client.Close()
|
||||
// Ensure process is killed on initialization failure
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
newSrv.initErr = err
|
||||
return nil, errors.Wrap(errors.ErrLSPInitFailed, "LSP server initialization failed", err).
|
||||
WithContext("language", string(lang)).
|
||||
WithContext("command", config.Command[0]).
|
||||
WithRemediation("Check LSP server logs for initialization errors")
|
||||
}
|
||||
|
||||
newSrv.ready = true
|
||||
m.servers[lang] = newSrv
|
||||
m.logger.Info("started LSP server", "language", lang, "command", config.Command[0])
|
||||
|
||||
return newSrv, nil
|
||||
}
|
||||
|
||||
// initializeServer performs the LSP initialization handshake.
|
||||
func (m *Manager) initializeServer(ctx context.Context, srv *ManagedServer) error {
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build root URI
|
||||
rootURI := "file://" + m.workspaceRoot
|
||||
|
||||
// Send initialize request
|
||||
params := InitializeParams{
|
||||
ProcessID: os.Getpid(),
|
||||
RootURI: rootURI,
|
||||
Capabilities: Capabilities{
|
||||
TextDocument: TextDocumentClientCapabilities{
|
||||
Hover: HoverCapability{
|
||||
ContentFormat: []string{"markdown", "plaintext"},
|
||||
},
|
||||
Definition: DefinitionCapability{
|
||||
LinkSupport: true,
|
||||
},
|
||||
References: ReferencesCapability{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := srv.client.Call(ctx, "initialize", params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse capabilities
|
||||
var result InitializeResult
|
||||
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||
return fmt.Errorf("failed to parse initialize result: %w", err)
|
||||
}
|
||||
srv.capabilities = result.Capabilities
|
||||
|
||||
// Send initialized notification
|
||||
if err := srv.client.Notify("initialized", struct{}{}); err != nil {
|
||||
return fmt.Errorf("initialized notification failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hover performs a hover request at the given position.
|
||||
func (m *Manager) Hover(ctx context.Context, file string, line, col int) (*HoverResult, error) {
|
||||
lang := protocol.DetectLanguage(file)
|
||||
srv, err := m.GetServer(ctx, lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure document is open
|
||||
err = m.ensureDocumentOpen(ctx, srv, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := HoverParams{
|
||||
TextDocumentPositionParams: TextDocumentPositionParams{
|
||||
TextDocument: TextDocumentIdentifier{
|
||||
URI: fileToURI(file),
|
||||
},
|
||||
Position: Position{
|
||||
Line: line - 1, // Convert to 0-indexed
|
||||
Character: col - 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := srv.client.Call(ctx, "textDocument/hover", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hover request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.Result == nil || string(resp.Result) == "null" {
|
||||
return nil, nil // No hover info
|
||||
}
|
||||
|
||||
var result HoverResult
|
||||
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse hover result: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Definition finds the definition of the symbol at the given position.
|
||||
func (m *Manager) Definition(ctx context.Context, file string, line, col int) ([]Location, error) {
|
||||
lang := protocol.DetectLanguage(file)
|
||||
srv, err := m.GetServer(ctx, lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure document is open
|
||||
err = m.ensureDocumentOpen(ctx, srv, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := DefinitionParams{
|
||||
TextDocumentPositionParams: TextDocumentPositionParams{
|
||||
TextDocument: TextDocumentIdentifier{
|
||||
URI: fileToURI(file),
|
||||
},
|
||||
Position: Position{
|
||||
Line: line - 1,
|
||||
Character: col - 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := srv.client.Call(ctx, "textDocument/definition", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("definition request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.Result == nil || string(resp.Result) == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Result can be Location, []Location, or []LocationLink
|
||||
var locations []Location
|
||||
if err := json.Unmarshal(resp.Result, &locations); err != nil {
|
||||
// Try single location
|
||||
var single Location
|
||||
if err := json.Unmarshal(resp.Result, &single); err == nil {
|
||||
locations = []Location{single}
|
||||
}
|
||||
}
|
||||
|
||||
return locations, nil
|
||||
}
|
||||
|
||||
// References finds all references to the symbol at the given position.
|
||||
func (m *Manager) References(ctx context.Context, file string, line, col int, includeDeclaration bool) ([]Location, error) {
|
||||
lang := protocol.DetectLanguage(file)
|
||||
srv, err := m.GetServer(ctx, lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure document is open
|
||||
err = m.ensureDocumentOpen(ctx, srv, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := ReferenceParams{
|
||||
TextDocumentPositionParams: TextDocumentPositionParams{
|
||||
TextDocument: TextDocumentIdentifier{
|
||||
URI: fileToURI(file),
|
||||
},
|
||||
Position: Position{
|
||||
Line: line - 1,
|
||||
Character: col - 1,
|
||||
},
|
||||
},
|
||||
Context: ReferenceContext{
|
||||
IncludeDeclaration: includeDeclaration,
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := srv.client.Call(ctx, "textDocument/references", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("references request failed: %w", err)
|
||||
}
|
||||
|
||||
if resp.Result == nil || string(resp.Result) == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var locations []Location
|
||||
if err := json.Unmarshal(resp.Result, &locations); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse references result: %w", err)
|
||||
}
|
||||
|
||||
return locations, nil
|
||||
}
|
||||
|
||||
// ensureDocumentOpen opens a document if not already open.
|
||||
func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, file string) error {
|
||||
uri := fileToURI(file)
|
||||
|
||||
srv.mu.Lock()
|
||||
if _, ok := srv.openDocs[uri]; ok {
|
||||
srv.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
|
||||
// Read file content
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
// Get language ID
|
||||
langID := languageToLSPID(srv.language)
|
||||
|
||||
params := DidOpenTextDocumentParams{
|
||||
TextDocument: TextDocumentItem{
|
||||
URI: uri,
|
||||
LanguageID: langID,
|
||||
Version: 1,
|
||||
Text: string(content),
|
||||
},
|
||||
}
|
||||
|
||||
if err := srv.client.Notify("textDocument/didOpen", params); err != nil {
|
||||
return fmt.Errorf("didOpen failed: %w", err)
|
||||
}
|
||||
|
||||
srv.mu.Lock()
|
||||
srv.openDocs[uri] = 1
|
||||
srv.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseDocument closes a document in the server.
|
||||
func (m *Manager) CloseDocument(_ context.Context, lang protocol.Language, file string) error {
|
||||
m.mu.RLock()
|
||||
srv, ok := m.servers[lang]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok || !srv.ready {
|
||||
return nil
|
||||
}
|
||||
|
||||
uri := fileToURI(file)
|
||||
|
||||
srv.mu.Lock()
|
||||
if _, ok := srv.openDocs[uri]; !ok {
|
||||
srv.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
delete(srv.openDocs, uri)
|
||||
srv.mu.Unlock()
|
||||
|
||||
params := DidCloseTextDocumentParams{
|
||||
TextDocument: TextDocumentIdentifier{
|
||||
URI: uri,
|
||||
},
|
||||
}
|
||||
|
||||
return srv.client.Notify("textDocument/didClose", params)
|
||||
}
|
||||
|
||||
// reapIdleServers periodically closes idle servers.
|
||||
func (m *Manager) reapIdleServers() {
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopReaper:
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.mu.Lock()
|
||||
for lang, srv := range m.servers {
|
||||
// Check lastUsed with server's lock to avoid race condition
|
||||
srv.mu.Lock()
|
||||
idle := time.Since(srv.lastUsed) > m.idleTimeout
|
||||
srv.mu.Unlock()
|
||||
|
||||
if idle {
|
||||
m.logger.Info("closing idle LSP server", "language", lang)
|
||||
_ = srv.client.Close()
|
||||
delete(m.servers, lang)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down all LSP servers.
|
||||
func (m *Manager) Close() error {
|
||||
close(m.stopReaper)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.stopped = true
|
||||
|
||||
for lang, srv := range m.servers {
|
||||
m.logger.Info("shutting down LSP server", "language", lang)
|
||||
// Try graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
_, _ = srv.client.Call(ctx, "shutdown", nil)
|
||||
cancel()
|
||||
_ = srv.client.Notify("exit", nil)
|
||||
_ = srv.client.Close()
|
||||
}
|
||||
|
||||
m.servers = make(map[protocol.Language]*ManagedServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAvailable checks if an LSP server is available for the given language.
|
||||
func (m *Manager) IsAvailable(lang protocol.Language) bool {
|
||||
config, ok := DefaultServerConfigs[lang]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := exec.LookPath(config.Command[0])
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// fileToURI converts a file path to a file URI.
|
||||
func fileToURI(file string) string {
|
||||
absPath, err := filepath.Abs(file)
|
||||
if err != nil {
|
||||
return "file://" + file
|
||||
}
|
||||
return "file://" + absPath
|
||||
}
|
||||
|
||||
// URIToFile converts a file URI to a file path.
|
||||
func URIToFile(uri string) string {
|
||||
if len(uri) > 7 && uri[:7] == "file://" {
|
||||
return uri[7:]
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// languageToLSPID converts a language to LSP language ID.
|
||||
func languageToLSPID(lang protocol.Language) string {
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return "go"
|
||||
case protocol.LangTypeScript:
|
||||
return "typescript"
|
||||
case protocol.LangJavaScript:
|
||||
return "javascript"
|
||||
case protocol.LangPython:
|
||||
return "python"
|
||||
case protocol.LangC:
|
||||
return "c"
|
||||
case protocol.LangCpp:
|
||||
return "cpp"
|
||||
default:
|
||||
return string(lang)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package lsp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestFileToURI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
file string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "absolute path",
|
||||
file: "/Users/test/file.go",
|
||||
want: "file:///Users/test/file.go",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := fileToURI(tt.file)
|
||||
if got != tt.want {
|
||||
t.Errorf("fileToURI() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestURIToFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "file uri",
|
||||
uri: "file:///Users/test/file.go",
|
||||
want: "/Users/test/file.go",
|
||||
},
|
||||
{
|
||||
name: "not a file uri",
|
||||
uri: "/Users/test/file.go",
|
||||
want: "/Users/test/file.go",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := URIToFile(tt.uri)
|
||||
if got != tt.want {
|
||||
t.Errorf("URIToFile() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLanguageToLSPID(t *testing.T) {
|
||||
tests := []struct {
|
||||
lang protocol.Language
|
||||
want string
|
||||
}{
|
||||
{protocol.LangGo, "go"},
|
||||
{protocol.LangTypeScript, "typescript"},
|
||||
{protocol.LangJavaScript, "javascript"},
|
||||
{protocol.LangPython, "python"},
|
||||
{protocol.LangC, "c"},
|
||||
{protocol.LangCpp, "cpp"},
|
||||
{protocol.LangUnknown, "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.lang), func(t *testing.T) {
|
||||
got := languageToLSPID(tt.lang)
|
||||
if got != tt.want {
|
||||
t.Errorf("languageToLSPID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAvailable(t *testing.T) {
|
||||
// This tests the structure of the manager without actually spawning servers
|
||||
// which requires the actual LSP servers to be installed
|
||||
|
||||
// Just verify the DefaultServerConfigs structure
|
||||
expectedLanguages := []protocol.Language{
|
||||
protocol.LangGo,
|
||||
protocol.LangTypeScript,
|
||||
protocol.LangJavaScript,
|
||||
protocol.LangPython,
|
||||
protocol.LangC,
|
||||
protocol.LangCpp,
|
||||
}
|
||||
|
||||
for _, lang := range expectedLanguages {
|
||||
if _, ok := DefaultServerConfigs[lang]; !ok {
|
||||
t.Errorf("missing server config for language: %s", lang)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultServerConfigs(t *testing.T) {
|
||||
// Verify the command structure
|
||||
for lang, config := range DefaultServerConfigs {
|
||||
if len(config.Command) == 0 {
|
||||
t.Errorf("language %s has empty command", lang)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package lsp
|
||||
|
||||
// InitializeParams are the parameters for the initialize request.
|
||||
type InitializeParams struct {
|
||||
RootURI string `json:"rootUri"`
|
||||
Capabilities Capabilities `json:"capabilities"`
|
||||
ProcessID int `json:"processId"`
|
||||
}
|
||||
|
||||
// Capabilities represents client capabilities.
|
||||
type Capabilities struct {
|
||||
TextDocument TextDocumentClientCapabilities `json:"textDocument"`
|
||||
}
|
||||
|
||||
// TextDocumentClientCapabilities represents text document capabilities.
|
||||
type TextDocumentClientCapabilities struct {
|
||||
Hover HoverCapability `json:"hover,omitempty"`
|
||||
Definition DefinitionCapability `json:"definition,omitempty"`
|
||||
References ReferencesCapability `json:"references,omitempty"`
|
||||
}
|
||||
|
||||
// HoverCapability represents hover capabilities.
|
||||
type HoverCapability struct {
|
||||
ContentFormat []string `json:"contentFormat,omitempty"`
|
||||
}
|
||||
|
||||
// DefinitionCapability represents definition capabilities.
|
||||
type DefinitionCapability struct {
|
||||
LinkSupport bool `json:"linkSupport,omitempty"`
|
||||
}
|
||||
|
||||
// ReferencesCapability represents references capabilities.
|
||||
type ReferencesCapability struct{}
|
||||
|
||||
// InitializeResult is the result of the initialize request.
|
||||
type InitializeResult struct {
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
}
|
||||
|
||||
// ServerCapabilities represents server capabilities.
|
||||
type ServerCapabilities struct {
|
||||
HoverProvider bool `json:"hoverProvider,omitempty"`
|
||||
DefinitionProvider bool `json:"definitionProvider,omitempty"`
|
||||
ReferencesProvider bool `json:"referencesProvider,omitempty"`
|
||||
DocumentSymbolProvider bool `json:"documentSymbolProvider,omitempty"`
|
||||
TextDocumentSync int `json:"textDocumentSync,omitempty"`
|
||||
}
|
||||
|
||||
// Position represents a position in a document.
|
||||
type Position struct {
|
||||
Line int `json:"line"` // 0-indexed
|
||||
Character int `json:"character"` // 0-indexed
|
||||
}
|
||||
|
||||
// Range represents a range in a document.
|
||||
type Range struct {
|
||||
Start Position `json:"start"`
|
||||
End Position `json:"end"`
|
||||
}
|
||||
|
||||
// Location represents a location in a document.
|
||||
type Location struct {
|
||||
URI string `json:"uri"`
|
||||
Range Range `json:"range"`
|
||||
}
|
||||
|
||||
// TextDocumentIdentifier identifies a text document.
|
||||
type TextDocumentIdentifier struct {
|
||||
URI string `json:"uri"`
|
||||
}
|
||||
|
||||
// TextDocumentPositionParams represents position parameters.
|
||||
type TextDocumentPositionParams struct {
|
||||
TextDocument TextDocumentIdentifier `json:"textDocument"`
|
||||
Position Position `json:"position"`
|
||||
}
|
||||
|
||||
// HoverParams are the parameters for the hover request.
|
||||
type HoverParams struct {
|
||||
TextDocumentPositionParams
|
||||
}
|
||||
|
||||
// HoverResult is the result of the hover request.
|
||||
type HoverResult struct {
|
||||
Range *Range `json:"range,omitempty"`
|
||||
Contents MarkupContent `json:"contents"`
|
||||
}
|
||||
|
||||
// MarkupContent represents markup content.
|
||||
type MarkupContent struct {
|
||||
Kind string `json:"kind"` // "plaintext" or "markdown"
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// DefinitionParams are the parameters for the definition request.
|
||||
type DefinitionParams struct {
|
||||
TextDocumentPositionParams
|
||||
}
|
||||
|
||||
// ReferenceParams are the parameters for the references request.
|
||||
type ReferenceParams struct {
|
||||
TextDocumentPositionParams
|
||||
Context ReferenceContext `json:"context"`
|
||||
}
|
||||
|
||||
// ReferenceContext represents reference context.
|
||||
type ReferenceContext struct {
|
||||
IncludeDeclaration bool `json:"includeDeclaration"`
|
||||
}
|
||||
|
||||
// TextDocumentItem represents a text document.
|
||||
type TextDocumentItem struct {
|
||||
URI string `json:"uri"`
|
||||
LanguageID string `json:"languageId"`
|
||||
Text string `json:"text"`
|
||||
Version int `json:"version"`
|
||||
}
|
||||
|
||||
// DidOpenTextDocumentParams are the parameters for didOpen.
|
||||
type DidOpenTextDocumentParams struct {
|
||||
TextDocument TextDocumentItem `json:"textDocument"`
|
||||
}
|
||||
|
||||
// DidCloseTextDocumentParams are the parameters for didClose.
|
||||
type DidCloseTextDocumentParams struct {
|
||||
TextDocument TextDocumentIdentifier `json:"textDocument"`
|
||||
}
|
||||
|
||||
// DocumentSymbol represents a symbol in a document.
|
||||
type DocumentSymbol struct {
|
||||
Name string `json:"name"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
Children []DocumentSymbol `json:"children,omitempty"`
|
||||
Range Range `json:"range"`
|
||||
SelectionRange Range `json:"selectionRange"`
|
||||
Kind int `json:"kind"`
|
||||
}
|
||||
|
||||
// SymbolInformation represents symbol information.
|
||||
type SymbolInformation struct {
|
||||
Name string `json:"name"`
|
||||
ContainerName string `json:"containerName,omitempty"`
|
||||
Location Location `json:"location"`
|
||||
Kind int `json:"kind"`
|
||||
}
|
||||
|
||||
// DocumentSymbolParams are the parameters for documentSymbol.
|
||||
type DocumentSymbolParams struct {
|
||||
TextDocument TextDocumentIdentifier `json:"textDocument"`
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// FindNodeAtPosition finds the node at the given line and column.
|
||||
func FindNodeAtPosition(tree *sitter.Tree, line, col int) *sitter.Node {
|
||||
if tree == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert to 0-indexed
|
||||
point := sitter.Point{
|
||||
Row: uint32(line - 1), // #nosec G115 - line numbers are bounded by file size
|
||||
Column: uint32(col - 1), // #nosec G115 - column numbers are bounded by line length
|
||||
}
|
||||
|
||||
return findNodeAtPoint(root, point)
|
||||
}
|
||||
|
||||
// findNodeAtPoint recursively finds the smallest node containing the point.
|
||||
func findNodeAtPoint(node *sitter.Node, point sitter.Point) *sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
startPoint := node.StartPoint()
|
||||
endPoint := node.EndPoint()
|
||||
|
||||
// Check if point is within this node
|
||||
if !pointInRange(point, startPoint, endPoint) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to find a more specific child node
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
if result := findNodeAtPoint(child, point); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// No child contains the point, return this node
|
||||
return node
|
||||
}
|
||||
|
||||
// pointInRange checks if a point is within a range.
|
||||
func pointInRange(point, start, end sitter.Point) bool {
|
||||
// Before start?
|
||||
if point.Row < start.Row || (point.Row == start.Row && point.Column < start.Column) {
|
||||
return false
|
||||
}
|
||||
// After end?
|
||||
if point.Row > end.Row || (point.Row == end.Row && point.Column >= end.Column) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// FindParentOfKind finds the nearest ancestor of the given node type.
|
||||
func FindParentOfKind(node *sitter.Node, kind string) *sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
current := node.Parent()
|
||||
for current != nil {
|
||||
if current.Type() == kind {
|
||||
return current
|
||||
}
|
||||
current = current.Parent()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNodeText returns the text content of a node.
|
||||
func GetNodeText(node *sitter.Node, content []byte) string {
|
||||
if node == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
start := node.StartByte()
|
||||
end := node.EndByte()
|
||||
|
||||
if int(start) >= len(content) || int(end) > len(content) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(content[start:end])
|
||||
}
|
||||
|
||||
// WalkTree walks the tree calling fn for each node.
|
||||
// If fn returns false, the walk stops.
|
||||
func WalkTree(node *sitter.Node, fn func(*sitter.Node) bool) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !fn(node) {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
WalkTree(node.Child(i), fn)
|
||||
}
|
||||
}
|
||||
|
||||
// FindNodesByKind finds all nodes of a given kind.
|
||||
func FindNodesByKind(root *sitter.Node, kind string) []*sitter.Node {
|
||||
var nodes []*sitter.Node
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
if n.Type() == kind {
|
||||
nodes = append(nodes, n)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
// FindNamedChildren returns all named (non-anonymous) children of a node.
|
||||
func FindNamedChildren(node *sitter.Node) []*sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var children []*sitter.Node
|
||||
for i := 0; i < int(node.NamedChildCount()); i++ {
|
||||
if child := node.NamedChild(i); child != nil {
|
||||
children = append(children, child)
|
||||
}
|
||||
}
|
||||
return children
|
||||
}
|
||||
|
||||
// GetChildByFieldName returns the child node with the given field name.
|
||||
func GetChildByFieldName(node *sitter.Node, fieldName string) *sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return node.ChildByFieldName(fieldName)
|
||||
}
|
||||
|
||||
// NodeLocation returns the location of a node.
|
||||
func NodeLocation(node *sitter.Node, filename string) protocol.Location {
|
||||
if node == nil {
|
||||
return protocol.Location{}
|
||||
}
|
||||
|
||||
startPoint := node.StartPoint()
|
||||
return protocol.Location{
|
||||
File: filename,
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// NodeRange returns the range of a node.
|
||||
func NodeRange(node *sitter.Node, filename string) protocol.Range {
|
||||
if node == nil {
|
||||
return protocol.Range{}
|
||||
}
|
||||
|
||||
startPoint := node.StartPoint()
|
||||
endPoint := node.EndPoint()
|
||||
|
||||
return protocol.Range{
|
||||
Start: protocol.Location{
|
||||
File: filename,
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
},
|
||||
End: protocol.Location{
|
||||
File: filename,
|
||||
Line: int(endPoint.Row) + 1,
|
||||
Column: int(endPoint.Column) + 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLRUCacheEviction tests that the LRU cache properly evicts old entries.
|
||||
func TestLRUCacheEviction(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 101 unique Go files (cache size is 100)
|
||||
for i := 0; i < 101; i++ {
|
||||
content := []byte(fmt.Sprintf("package main\n\nfunc test%d() {}\n", i))
|
||||
filename := "test.go"
|
||||
|
||||
_, err := registry.Parse(ctx, filename, content)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed for iteration %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// The LRU cache should have evicted the oldest entry
|
||||
// Verify cache size is capped at 100
|
||||
cacheLen := registry.cache.Len()
|
||||
if cacheLen > 100 {
|
||||
t.Errorf("Cache size %d exceeds max size 100", cacheLen)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheHit tests that repeated parsing of the same content uses cache.
|
||||
func TestCacheHit(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
ctx := context.Background()
|
||||
|
||||
content := []byte("package main\n\nfunc test() {}\n")
|
||||
filename := "test.go"
|
||||
|
||||
// First parse
|
||||
result1, err := registry.Parse(ctx, filename, content)
|
||||
if err != nil {
|
||||
t.Fatalf("First parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Second parse should use cache
|
||||
result2, err := registry.Parse(ctx, filename, content)
|
||||
if err != nil {
|
||||
t.Fatalf("Second parse failed: %v", err)
|
||||
}
|
||||
|
||||
// The tree should be the same object (cached)
|
||||
if result1.Tree != result2.Tree {
|
||||
t.Error("Expected cached tree to be reused, but got different tree objects")
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentHashCollisionResistance tests that different content produces different hashes.
|
||||
func TestContentHashCollisionResistance(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
content1 []byte
|
||||
content2 []byte
|
||||
}{
|
||||
{
|
||||
name: "different content",
|
||||
content1: []byte("package main"),
|
||||
content2: []byte("package test"),
|
||||
},
|
||||
{
|
||||
name: "same prefix different suffix",
|
||||
content1: []byte("package main\nfunc a() {}"),
|
||||
content2: []byte("package main\nfunc b() {}"),
|
||||
},
|
||||
{
|
||||
name: "different length",
|
||||
content1: []byte("short"),
|
||||
content2: []byte("much longer content here"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hash1 := contentHash(tc.content1)
|
||||
hash2 := contentHash(tc.content2)
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("Hash collision: %s == %s for different content", hash1, hash2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentHashConsistency tests that the same content always produces the same hash.
|
||||
func TestContentHashConsistency(t *testing.T) {
|
||||
content := []byte("package main\n\nfunc test() {}\n")
|
||||
|
||||
hash1 := contentHash(content)
|
||||
hash2 := contentHash(content)
|
||||
hash3 := contentHash(content)
|
||||
|
||||
if hash1 != hash2 || hash2 != hash3 {
|
||||
t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkContentHash_xxHash benchmarks the xxHash implementation.
|
||||
func BenchmarkContentHash_xxHash(b *testing.B) {
|
||||
// Typical file content size (10KB)
|
||||
content := make([]byte, 10*1024)
|
||||
for i := range content {
|
||||
content[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = contentHash(content)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCacheHitRate benchmarks cache performance with realistic workload.
|
||||
func BenchmarkCacheHitRate(b *testing.B) {
|
||||
registry := NewRegistry()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a set of common files that get parsed repeatedly
|
||||
files := [][]byte{
|
||||
[]byte("package main\n\nfunc main() {}\n"),
|
||||
[]byte("package test\n\nimport \"testing\"\n"),
|
||||
[]byte("package util\n\nfunc helper() string { return \"\" }\n"),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate realistic access pattern with cache hits
|
||||
content := files[i%len(files)]
|
||||
_, _ = registry.Parse(ctx, "test.go", content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,550 @@
|
||||
// Package parser provides documentation extraction for multiple languages.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// DocComment represents an extracted documentation comment.
|
||||
type DocComment struct {
|
||||
Tags map[string]string
|
||||
Text string
|
||||
Raw string
|
||||
Style CommentStyle
|
||||
StartLine int
|
||||
EndLine int
|
||||
}
|
||||
|
||||
// CommentStyle indicates the type of comment.
|
||||
type CommentStyle string
|
||||
|
||||
const (
|
||||
CommentStyleLine CommentStyle = "line" // // comment
|
||||
CommentStyleBlock CommentStyle = "block" // /* comment */
|
||||
CommentStyleJSDoc CommentStyle = "jsdoc" // /** comment */
|
||||
CommentStyleDoxygen CommentStyle = "doxygen" // /** comment */ or /// comment
|
||||
CommentStyleDocstring CommentStyle = "docstring" // """comment""" or '''comment'''
|
||||
CommentStyleHash CommentStyle = "hash" // # comment (Python)
|
||||
)
|
||||
|
||||
// ExtractDocComment extracts the documentation comment for a node.
|
||||
func ExtractDocComment(n *sitter.Node, content []byte, lang protocol.Language) *DocComment {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return extractGoDocComment(n, content)
|
||||
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||
return extractJSDocComment(n, content)
|
||||
case protocol.LangPython:
|
||||
return extractPythonDocComment(n, content)
|
||||
case protocol.LangC, protocol.LangCpp:
|
||||
return extractCDocComment(n, content)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// extractGoDocComment extracts Go documentation comments.
|
||||
// Go uses // or /* */ comments immediately preceding a declaration.
|
||||
func extractGoDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
cleaned := cleanGoComment(text)
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: detectCommentStyle(raw[0]),
|
||||
Tags: nil, // Go doesn't use JSDoc-style tags
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// extractJSDocComment extracts JSDoc-style documentation comments.
|
||||
func extractJSDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// JSDoc prefers the last comment block if it's a JSDoc comment
|
||||
var jsDocComment *sitter.Node
|
||||
for i := len(comments) - 1; i >= 0; i-- {
|
||||
text := GetNodeText(comments[i], content)
|
||||
if strings.HasPrefix(strings.TrimSpace(text), "/**") {
|
||||
jsDocComment = comments[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if jsDocComment != nil {
|
||||
text := GetNodeText(jsDocComment, content)
|
||||
cleaned, tags := parseJSDoc(text)
|
||||
return &DocComment{
|
||||
Text: cleaned,
|
||||
Raw: text,
|
||||
Style: CommentStyleJSDoc,
|
||||
Tags: tags,
|
||||
StartLine: int(jsDocComment.StartPoint().Row) + 1,
|
||||
EndLine: int(jsDocComment.EndPoint().Row) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to regular comments
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
cleaned := cleanJSComment(text)
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: CommentStyleLine,
|
||||
Tags: nil,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// extractPythonDocComment extracts Python docstrings.
|
||||
// Python docstrings are triple-quoted strings inside the function/class body.
|
||||
func extractPythonDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
// Python docstrings are inside the body, not before
|
||||
body := n.ChildByFieldName("body")
|
||||
if body == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// First statement should be the docstring if present
|
||||
if body.NamedChildCount() > 0 {
|
||||
first := body.NamedChild(0)
|
||||
if first != nil && first.Type() == "expression_statement" {
|
||||
if first.NamedChildCount() > 0 {
|
||||
expr := first.NamedChild(0)
|
||||
if expr != nil && expr.Type() == "string" {
|
||||
text := GetNodeText(expr, content)
|
||||
cleaned := cleanPythonDocstring(text)
|
||||
return &DocComment{
|
||||
Text: cleaned,
|
||||
Raw: text,
|
||||
Style: CommentStyleDocstring,
|
||||
Tags: nil,
|
||||
StartLine: int(expr.StartPoint().Row) + 1,
|
||||
EndLine: int(expr.EndPoint().Row) + 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for # comments before the definition
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
// Clean # comment
|
||||
cleaned := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(text), "#"))
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: CommentStyleHash,
|
||||
Tags: nil,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// extractCDocComment extracts C/C++ documentation comments (Doxygen style).
|
||||
func extractCDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Look for Doxygen-style comment
|
||||
var doxyComment *sitter.Node
|
||||
for i := len(comments) - 1; i >= 0; i-- {
|
||||
text := GetNodeText(comments[i], content)
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if strings.HasPrefix(trimmed, "/**") || strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
|
||||
doxyComment = comments[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if doxyComment != nil {
|
||||
text := GetNodeText(doxyComment, content)
|
||||
cleaned, tags := parseDoxygen(text)
|
||||
return &DocComment{
|
||||
Text: cleaned,
|
||||
Raw: text,
|
||||
Style: CommentStyleDoxygen,
|
||||
Tags: tags,
|
||||
StartLine: int(doxyComment.StartPoint().Row) + 1,
|
||||
EndLine: int(doxyComment.EndPoint().Row) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to regular comments
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
cleaned := cleanCComment(text)
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: detectCommentStyle(raw[0]),
|
||||
Tags: nil,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// collectPrecedingComments collects all comment nodes immediately before a node.
|
||||
func collectPrecedingComments(n *sitter.Node, _ []byte, commentTypes []string) []*sitter.Node {
|
||||
var comments []*sitter.Node
|
||||
|
||||
// Walk backwards through siblings
|
||||
prev := n.PrevSibling()
|
||||
lastCommentLine := int(n.StartPoint().Row)
|
||||
|
||||
for prev != nil {
|
||||
isComment := false
|
||||
nodeType := prev.Type()
|
||||
for _, ct := range commentTypes {
|
||||
if nodeType == ct {
|
||||
isComment = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isComment {
|
||||
break
|
||||
}
|
||||
|
||||
commentEndLine := int(prev.EndPoint().Row)
|
||||
|
||||
// Check if there's a blank line gap
|
||||
if lastCommentLine-commentEndLine > 1 {
|
||||
break
|
||||
}
|
||||
|
||||
comments = append([]*sitter.Node{prev}, comments...)
|
||||
lastCommentLine = int(prev.StartPoint().Row)
|
||||
prev = prev.PrevSibling()
|
||||
}
|
||||
|
||||
return comments
|
||||
}
|
||||
|
||||
// detectCommentStyle determines the style of a comment.
|
||||
func detectCommentStyle(comment string) CommentStyle {
|
||||
trimmed := strings.TrimSpace(comment)
|
||||
if strings.HasPrefix(trimmed, "/**") {
|
||||
return CommentStyleJSDoc
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
|
||||
return CommentStyleDoxygen
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "/*") {
|
||||
return CommentStyleBlock
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "//") {
|
||||
return CommentStyleLine
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
return CommentStyleHash
|
||||
}
|
||||
if strings.HasPrefix(trimmed, `"""`) || strings.HasPrefix(trimmed, `'''`) {
|
||||
return CommentStyleDocstring
|
||||
}
|
||||
return CommentStyleLine
|
||||
}
|
||||
|
||||
// cleanGoComment cleans a Go comment.
|
||||
func cleanGoComment(comment string) string {
|
||||
comment = strings.TrimSpace(comment)
|
||||
|
||||
// Handle // comments
|
||||
if after, found := strings.CutPrefix(comment, "//"); found {
|
||||
return strings.TrimSpace(after)
|
||||
}
|
||||
|
||||
// Handle /* */ comments
|
||||
if strings.HasPrefix(comment, "/*") && strings.HasSuffix(comment, "*/") {
|
||||
comment = strings.TrimPrefix(comment, "/*")
|
||||
comment = strings.TrimSuffix(comment, "*/")
|
||||
return cleanBlockComment(comment)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(comment)
|
||||
}
|
||||
|
||||
// cleanJSComment cleans a JavaScript/TypeScript comment.
|
||||
func cleanJSComment(comment string) string {
|
||||
return cleanGoComment(comment) // Same rules
|
||||
}
|
||||
|
||||
// cleanCComment cleans a C/C++ comment.
|
||||
func cleanCComment(comment string) string {
|
||||
return cleanGoComment(comment) // Same rules
|
||||
}
|
||||
|
||||
// cleanBlockComment cleans the content of a block comment.
|
||||
func cleanBlockComment(comment string) string {
|
||||
lines := strings.Split(comment, "\n")
|
||||
var cleaned []string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
// Remove leading * from each line (common in block comments)
|
||||
line = strings.TrimPrefix(line, "*")
|
||||
line = strings.TrimSpace(line)
|
||||
cleaned = append(cleaned, line)
|
||||
}
|
||||
|
||||
// Remove empty leading/trailing lines
|
||||
for len(cleaned) > 0 && cleaned[0] == "" {
|
||||
cleaned = cleaned[1:]
|
||||
}
|
||||
for len(cleaned) > 0 && cleaned[len(cleaned)-1] == "" {
|
||||
cleaned = cleaned[:len(cleaned)-1]
|
||||
}
|
||||
|
||||
return strings.Join(cleaned, "\n")
|
||||
}
|
||||
|
||||
// parseJSDoc parses a JSDoc comment and extracts tags.
|
||||
func parseJSDoc(comment string) (string, map[string]string) {
|
||||
comment = strings.TrimSpace(comment)
|
||||
|
||||
// Remove /** and */
|
||||
comment = strings.TrimPrefix(comment, "/**")
|
||||
comment = strings.TrimSuffix(comment, "*/")
|
||||
|
||||
lines := strings.Split(comment, "\n")
|
||||
var descLines []string
|
||||
tags := make(map[string]string)
|
||||
|
||||
// Regex for JSDoc tags
|
||||
tagPattern := regexp.MustCompile(`^\s*\*?\s*@(\w+)\s*(.*)$`)
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimPrefix(line, "*")
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
|
||||
tagName := matches[1]
|
||||
tagValue := strings.TrimSpace(matches[2])
|
||||
if existing, ok := tags[tagName]; ok {
|
||||
tags[tagName] = existing + "\n" + tagValue
|
||||
} else {
|
||||
tags[tagName] = tagValue
|
||||
}
|
||||
} else if line != "" {
|
||||
descLines = append(descLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(descLines, "\n"), tags
|
||||
}
|
||||
|
||||
// parseDoxygen parses a Doxygen comment and extracts tags.
|
||||
func parseDoxygen(comment string) (string, map[string]string) {
|
||||
comment = strings.TrimSpace(comment)
|
||||
|
||||
// Handle /// and //! style comments
|
||||
comment = strings.TrimPrefix(comment, "///")
|
||||
comment = strings.TrimPrefix(comment, "//!")
|
||||
|
||||
// Handle /** */ style comments
|
||||
comment = strings.TrimPrefix(comment, "/**")
|
||||
comment = strings.TrimSuffix(comment, "*/")
|
||||
|
||||
lines := strings.Split(comment, "\n")
|
||||
var descLines []string
|
||||
tags := make(map[string]string)
|
||||
|
||||
// Regex for Doxygen tags (@param, @return, \param, \return, etc.)
|
||||
tagPattern := regexp.MustCompile(`^\s*\*?\s*[@\\](\w+)\s*(.*)$`)
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimPrefix(line, "*")
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
|
||||
tagName := matches[1]
|
||||
tagValue := strings.TrimSpace(matches[2])
|
||||
if existing, ok := tags[tagName]; ok {
|
||||
tags[tagName] = existing + "\n" + tagValue
|
||||
} else {
|
||||
tags[tagName] = tagValue
|
||||
}
|
||||
} else if line != "" {
|
||||
descLines = append(descLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(descLines, "\n"), tags
|
||||
}
|
||||
|
||||
// FormatDocComment formats a DocComment for display.
|
||||
func FormatDocComment(doc *DocComment) string {
|
||||
if doc == nil || doc.Text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(doc.Text)
|
||||
|
||||
if len(doc.Tags) > 0 {
|
||||
sb.WriteString("\n\n")
|
||||
// Order: description, params, returns, other
|
||||
paramOrder := []string{"param", "parameter", "arg", "argument"}
|
||||
returnOrder := []string{"return", "returns", "retval"}
|
||||
|
||||
// Write params first
|
||||
for _, tagName := range paramOrder {
|
||||
if val, ok := doc.Tags[tagName]; ok {
|
||||
for _, line := range strings.Split(val, "\n") {
|
||||
sb.WriteString("@" + tagName + " " + line + "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write returns
|
||||
for _, tagName := range returnOrder {
|
||||
if val, ok := doc.Tags[tagName]; ok {
|
||||
sb.WriteString("@" + tagName + " " + val + "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Write remaining tags
|
||||
written := make(map[string]bool)
|
||||
for _, t := range paramOrder {
|
||||
written[t] = true
|
||||
}
|
||||
for _, t := range returnOrder {
|
||||
written[t] = true
|
||||
}
|
||||
|
||||
for tagName, val := range doc.Tags {
|
||||
if !written[tagName] {
|
||||
sb.WriteString("@" + tagName + " " + val + "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// cleanPythonDocstring cleans a Python docstring.
|
||||
func cleanPythonDocstring(doc string) string {
|
||||
doc = strings.TrimSpace(doc)
|
||||
|
||||
// Remove triple quotes
|
||||
doc = strings.TrimPrefix(doc, `"""`)
|
||||
doc = strings.TrimSuffix(doc, `"""`)
|
||||
doc = strings.TrimPrefix(doc, `'''`)
|
||||
doc = strings.TrimSuffix(doc, `'''`)
|
||||
|
||||
return strings.TrimSpace(doc)
|
||||
}
|
||||
@@ -0,0 +1,630 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
func TestExtractGoDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "single line comment",
|
||||
code: `package main
|
||||
|
||||
// Hello says hello
|
||||
func Hello() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "Hello says hello",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
{
|
||||
name: "multi-line comments",
|
||||
code: `package main
|
||||
|
||||
// This is a function
|
||||
// that does something
|
||||
// important
|
||||
func DoSomething() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a function\nthat does something\nimportant",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
{
|
||||
name: "block comment",
|
||||
code: `package main
|
||||
|
||||
/* This is a block comment
|
||||
describing the function */
|
||||
func BlockCommented() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a block comment\ndescribing the function",
|
||||
wantStyle: CommentStyleBlock,
|
||||
},
|
||||
{
|
||||
name: "doc comment with asterisks",
|
||||
code: `package main
|
||||
|
||||
/*
|
||||
* This is a properly formatted
|
||||
* block comment with asterisks
|
||||
*/
|
||||
func FormattedBlock() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a properly formatted\nblock comment with asterisks",
|
||||
wantStyle: CommentStyleBlock,
|
||||
},
|
||||
{
|
||||
name: "no comment",
|
||||
code: `package main
|
||||
|
||||
func NoComment() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.go", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Find the target node
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangGo)
|
||||
|
||||
if tt.wantText == "" {
|
||||
if doc != nil && doc.Text != "" {
|
||||
t.Errorf("expected no doc, got %q", doc.Text)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "JSDoc comment",
|
||||
code: `/**
|
||||
* Adds two numbers together.
|
||||
* @param a The first number
|
||||
* @param b The second number
|
||||
* @returns The sum of a and b
|
||||
*/
|
||||
function add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "Adds two numbers together.",
|
||||
wantStyle: CommentStyleJSDoc,
|
||||
wantTags: map[string]string{
|
||||
"param": "a The first number\nb The second number",
|
||||
"returns": "The sum of a and b",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple line comment",
|
||||
code: `// This is a simple function
|
||||
function simple() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a simple function",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
{
|
||||
name: "JSDoc with types",
|
||||
code: `/**
|
||||
* @param {string} name - The name
|
||||
* @returns {boolean} True if valid
|
||||
*/
|
||||
function validate(name) {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "",
|
||||
wantStyle: CommentStyleJSDoc,
|
||||
wantTags: map[string]string{
|
||||
"param": "{string} name - The name",
|
||||
"returns": "{boolean} True if valid",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.js", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangJavaScript)
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
|
||||
if tt.wantTags != nil {
|
||||
for k, want := range tt.wantTags {
|
||||
if got := doc.Tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPythonDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "docstring",
|
||||
code: `def greet(name):
|
||||
"""Greet a person by name."""
|
||||
print(f"Hello, {name}!")
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Greet a person by name.",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
{
|
||||
name: "multi-line docstring",
|
||||
code: `def calculate(x, y):
|
||||
"""
|
||||
Calculate the sum of two numbers.
|
||||
|
||||
Args:
|
||||
x: First number
|
||||
y: Second number
|
||||
|
||||
Returns:
|
||||
The sum of x and y
|
||||
"""
|
||||
return x + y
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Calculate the sum of two numbers.\n\n Args:\n x: First number\n y: Second number\n\n Returns:\n The sum of x and y",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
{
|
||||
name: "class docstring",
|
||||
code: `class MyClass:
|
||||
"""This is a class description."""
|
||||
pass
|
||||
`,
|
||||
nodeKind: "class_definition",
|
||||
wantText: "This is a class description.",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
{
|
||||
name: "single quote docstring",
|
||||
code: `def func():
|
||||
'''Single quote docstring'''
|
||||
pass
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Single quote docstring",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.py", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangPython)
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "Doxygen block comment",
|
||||
code: `/**
|
||||
* Adds two numbers.
|
||||
* @param a First number
|
||||
* @param b Second number
|
||||
* @return Sum of a and b
|
||||
*/
|
||||
int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Adds two numbers.",
|
||||
wantStyle: CommentStyleDoxygen,
|
||||
wantTags: map[string]string{
|
||||
"param": "a First number\nb Second number",
|
||||
"return": "Sum of a and b",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "regular block comment",
|
||||
code: `/* This is a regular comment */
|
||||
int regular() { return 0; }
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "This is a regular comment",
|
||||
wantStyle: CommentStyleBlock,
|
||||
},
|
||||
{
|
||||
name: "line comment",
|
||||
code: `// Simple function
|
||||
int simple() { return 1; }
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Simple function",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.c", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangC)
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
|
||||
if tt.wantTags != nil {
|
||||
for k, want := range tt.wantTags {
|
||||
if got := doc.Tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSDoc(t *testing.T) {
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
input string
|
||||
wantText string
|
||||
}{
|
||||
{
|
||||
name: "complete jsdoc",
|
||||
input: `/**
|
||||
* This is a description.
|
||||
* Multiple lines.
|
||||
* @param {string} name The name
|
||||
* @returns {boolean} Result
|
||||
*/`,
|
||||
wantText: "This is a description.\nMultiple lines.",
|
||||
wantTags: map[string]string{
|
||||
"param": "{string} name The name",
|
||||
"returns": "{boolean} Result",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty jsdoc",
|
||||
input: `/** */`,
|
||||
wantText: "",
|
||||
wantTags: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "only description",
|
||||
input: `/** Simple description */`,
|
||||
wantText: "Simple description",
|
||||
wantTags: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
text, tags := parseJSDoc(tt.input)
|
||||
|
||||
if text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
|
||||
}
|
||||
|
||||
if len(tags) != len(tt.wantTags) {
|
||||
t.Errorf("tag count mismatch: got %d, want %d", len(tags), len(tt.wantTags))
|
||||
}
|
||||
|
||||
for k, want := range tt.wantTags {
|
||||
if got := tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDoxygen(t *testing.T) {
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
input string
|
||||
wantText string
|
||||
}{
|
||||
{
|
||||
name: "doxygen with @ tags",
|
||||
input: `/**
|
||||
* Brief description.
|
||||
* @param x Value
|
||||
* @return Result
|
||||
*/`,
|
||||
wantText: "Brief description.",
|
||||
wantTags: map[string]string{
|
||||
"param": "x Value",
|
||||
"return": "Result",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "doxygen with backslash tags",
|
||||
input: `/**
|
||||
* Description.
|
||||
* \param y Input
|
||||
* \retval Output value
|
||||
*/`,
|
||||
wantText: "Description.",
|
||||
wantTags: map[string]string{
|
||||
"param": "y Input",
|
||||
"retval": "Output value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "triple slash",
|
||||
input: `/// Simple description`,
|
||||
wantText: "Simple description",
|
||||
wantTags: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
text, tags := parseDoxygen(tt.input)
|
||||
|
||||
if text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
|
||||
}
|
||||
|
||||
for k, want := range tt.wantTags {
|
||||
if got := tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatDocComment(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
doc *DocComment
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "with tags",
|
||||
doc: &DocComment{
|
||||
Text: "This is a function.",
|
||||
Tags: map[string]string{
|
||||
"param": "x The value",
|
||||
"returns": "The result",
|
||||
},
|
||||
},
|
||||
want: "This is a function.\n\n@param x The value\n@returns The result",
|
||||
},
|
||||
{
|
||||
name: "no tags",
|
||||
doc: &DocComment{
|
||||
Text: "Simple description.",
|
||||
Tags: nil,
|
||||
},
|
||||
want: "Simple description.",
|
||||
},
|
||||
{
|
||||
name: "nil doc",
|
||||
doc: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty text",
|
||||
doc: &DocComment{
|
||||
Text: "",
|
||||
Tags: nil,
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FormatDocComment(tt.doc)
|
||||
if got != tt.want {
|
||||
t.Errorf("mismatch:\ngot: %q\nwant: %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectCommentStyle(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want CommentStyle
|
||||
}{
|
||||
{"/** JSDoc */", CommentStyleJSDoc},
|
||||
{"/// Doxygen", CommentStyleDoxygen},
|
||||
{"//! Doxygen", CommentStyleDoxygen},
|
||||
{"/* block */", CommentStyleBlock},
|
||||
{"// line", CommentStyleLine},
|
||||
{"# hash", CommentStyleHash},
|
||||
{`"""docstring"""`, CommentStyleDocstring},
|
||||
{`'''docstring'''`, CommentStyleDocstring},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := detectCommentStyle(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// findNodeByKind finds the first node of the given kind.
|
||||
func findNodeByKind(root *sitter.Node, nodeType string) *sitter.Node {
|
||||
if root == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result *sitter.Node
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
if n.Type() == nodeType {
|
||||
result = n
|
||||
return false // stop walking
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func TestCleanBlockComment(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
input: "\n * Line 1\n * Line 2\n ",
|
||||
want: "Line 1\nLine 2",
|
||||
},
|
||||
{
|
||||
input: "Simple",
|
||||
want: "Simple",
|
||||
},
|
||||
{
|
||||
input: "\n\nWith blank lines\n\n",
|
||||
want: "With blank lines",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input[:min(10, len(tt.input))], func(t *testing.T) {
|
||||
got := cleanBlockComment(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
// Package parser provides Tree-sitter based parsing for multiple languages.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
"github.com/smacker/go-tree-sitter/c"
|
||||
"github.com/smacker/go-tree-sitter/cpp"
|
||||
"github.com/smacker/go-tree-sitter/golang"
|
||||
"github.com/smacker/go-tree-sitter/html"
|
||||
"github.com/smacker/go-tree-sitter/javascript"
|
||||
"github.com/smacker/go-tree-sitter/python"
|
||||
"github.com/smacker/go-tree-sitter/typescript/typescript"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// MaxFileSize is the maximum file size we'll parse (10MB).
|
||||
const MaxFileSize = 10 * 1024 * 1024
|
||||
|
||||
// Registry manages Tree-sitter parsers for different languages.
|
||||
type Registry struct {
|
||||
parsers map[protocol.Language]*sitter.Parser
|
||||
cache *lru.Cache[string, *CachedTree]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// CachedTree stores a parsed tree with its metadata.
|
||||
// Content is not stored to reduce memory usage.
|
||||
type CachedTree struct {
|
||||
Tree *sitter.Tree
|
||||
Language protocol.Language
|
||||
}
|
||||
|
||||
// ParseResult contains the result of parsing a file.
|
||||
type ParseResult struct {
|
||||
Tree *sitter.Tree
|
||||
Language protocol.Language
|
||||
Errors []SyntaxError
|
||||
Content []byte
|
||||
}
|
||||
|
||||
// SyntaxError represents a syntax error found during parsing.
|
||||
type SyntaxError struct {
|
||||
Message string
|
||||
NodeType string
|
||||
Location protocol.Location
|
||||
}
|
||||
|
||||
// NewRegistry creates a new parser registry.
|
||||
func NewRegistry() *Registry {
|
||||
// Create LRU cache with capacity of 100 trees
|
||||
cache, err := lru.New[string, *CachedTree](100)
|
||||
if err != nil {
|
||||
// LRU.New only errors if size <= 0, which won't happen here
|
||||
panic(fmt.Sprintf("failed to create LRU cache: %v", err))
|
||||
}
|
||||
|
||||
return &Registry{
|
||||
parsers: make(map[protocol.Language]*sitter.Parser),
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// getLanguage returns the Tree-sitter language for a given language.
|
||||
func getLanguage(lang protocol.Language) (*sitter.Language, error) {
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return golang.GetLanguage(), nil
|
||||
case protocol.LangTypeScript:
|
||||
return typescript.GetLanguage(), nil
|
||||
case protocol.LangJavaScript:
|
||||
return javascript.GetLanguage(), nil
|
||||
case protocol.LangPython:
|
||||
return python.GetLanguage(), nil
|
||||
case protocol.LangC:
|
||||
return c.GetLanguage(), nil
|
||||
case protocol.LangCpp:
|
||||
return cpp.GetLanguage(), nil
|
||||
case protocol.LangHTML:
|
||||
return html.GetLanguage(), nil
|
||||
case protocol.LangVue:
|
||||
// Vue SFC files use HTML-like template syntax, so we use the HTML parser
|
||||
return html.GetLanguage(), nil
|
||||
default:
|
||||
return nil, errors.New(errors.ErrInvalidLanguage, fmt.Sprintf("language %s is not supported", lang)).
|
||||
WithContext("language", string(lang)).
|
||||
WithRemediation("Supported languages: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue")
|
||||
}
|
||||
}
|
||||
|
||||
// GetParser returns a parser for the given language.
|
||||
func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
|
||||
r.mu.RLock()
|
||||
if p, ok := r.parsers[lang]; ok {
|
||||
r.mu.RUnlock()
|
||||
return p, nil
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
// Create new parser
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if p, ok := r.parsers[lang]; ok {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
sitterLang, err := getLanguage(lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parser := sitter.NewParser()
|
||||
parser.SetLanguage(sitterLang)
|
||||
r.parsers[lang] = parser
|
||||
|
||||
return parser, nil
|
||||
}
|
||||
|
||||
// Parse parses the given content for the specified language.
|
||||
func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
}
|
||||
|
||||
// Detect binary files
|
||||
if isBinary(content) {
|
||||
return nil, errors.New(errors.ErrParseFailed, "binary file detected").
|
||||
WithContext("file", filename).
|
||||
WithRemediation("This appears to be a binary file and cannot be parsed as source code")
|
||||
}
|
||||
|
||||
// Detect language
|
||||
lang := protocol.DetectLanguage(filename)
|
||||
if lang == protocol.LangUnknown {
|
||||
return nil, errors.New(errors.ErrInvalidLanguage, "could not detect language from filename").
|
||||
WithContext("file", filename).
|
||||
WithRemediation("Ensure file has a recognized extension (e.g., .go, .ts, .py, .c, .cpp, .html, .vue, .json, .yaml)")
|
||||
}
|
||||
|
||||
// Handle YAML and JSON separately (they don't use tree-sitter)
|
||||
switch lang {
|
||||
case protocol.LangYAML:
|
||||
return r.ParseYAML(ctx, filename, content)
|
||||
case protocol.LangJSON:
|
||||
return r.ParseJSON(ctx, filename, content)
|
||||
}
|
||||
|
||||
// Check cache (LRU cache is thread-safe)
|
||||
hash := contentHash(content)
|
||||
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
|
||||
errors := extractErrors(cached.Tree.RootNode(), content)
|
||||
return &ParseResult{
|
||||
Tree: cached.Tree,
|
||||
Language: lang,
|
||||
Errors: errors,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get parser
|
||||
parser, err := r.GetParser(lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse content - tree-sitter parsers are not thread-safe,
|
||||
// so we need to hold the lock during parsing
|
||||
r.mu.Lock()
|
||||
tree, err := parser.ParseCtx(ctx, nil, content)
|
||||
r.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.NewParseError(string(lang), filename, err)
|
||||
}
|
||||
|
||||
// Extract syntax errors
|
||||
errors := extractErrors(tree.RootNode(), content)
|
||||
|
||||
// Cache result (LRU cache handles eviction automatically)
|
||||
r.cache.Add(hash, &CachedTree{
|
||||
Tree: tree,
|
||||
Language: lang,
|
||||
})
|
||||
|
||||
return &ParseResult{
|
||||
Tree: tree,
|
||||
Language: lang,
|
||||
Errors: errors,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractErrors finds all error nodes in the tree.
|
||||
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
|
||||
var errors []SyntaxError
|
||||
|
||||
var walk func(n *sitter.Node)
|
||||
walk = func(n *sitter.Node) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if n.IsError() || n.IsMissing() {
|
||||
startPoint := n.StartPoint()
|
||||
nodeType := "ERROR"
|
||||
if n.IsMissing() {
|
||||
nodeType = "MISSING"
|
||||
}
|
||||
|
||||
errors = append(errors, SyntaxError{
|
||||
Location: protocol.Location{
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
},
|
||||
Message: fmt.Sprintf("syntax error: unexpected %s", n.Type()),
|
||||
NodeType: nodeType,
|
||||
})
|
||||
}
|
||||
|
||||
for i := 0; i < int(n.ChildCount()); i++ {
|
||||
walk(n.Child(i))
|
||||
}
|
||||
}
|
||||
|
||||
walk(node)
|
||||
return errors
|
||||
}
|
||||
|
||||
// contentHash returns a fast hash of the content for caching.
|
||||
// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes.
|
||||
func contentHash(content []byte) string {
|
||||
h := xxhash.Sum64(content)
|
||||
return fmt.Sprintf("%016x", h)
|
||||
}
|
||||
|
||||
// isBinary checks if content appears to be binary.
|
||||
func isBinary(content []byte) bool {
|
||||
// Check first 8000 bytes for null bytes
|
||||
checkLen := min(8000, len(content))
|
||||
|
||||
for i := range checkLen {
|
||||
if content[i] == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Close closes all parsers and clears the cache.
|
||||
func (r *Registry) Close() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for _, p := range r.parsers {
|
||||
p.Close()
|
||||
}
|
||||
r.parsers = make(map[protocol.Language]*sitter.Parser)
|
||||
|
||||
// Purge LRU cache
|
||||
r.cache.Purge()
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
if r == nil {
|
||||
t.Fatal("expected non-nil registry")
|
||||
}
|
||||
defer r.Close()
|
||||
}
|
||||
|
||||
func TestGetParser(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
lang protocol.Language
|
||||
wantErr bool
|
||||
}{
|
||||
{protocol.LangGo, false},
|
||||
{protocol.LangTypeScript, false},
|
||||
{protocol.LangJavaScript, false},
|
||||
{protocol.LangPython, false},
|
||||
{protocol.LangC, false},
|
||||
{protocol.LangCpp, false},
|
||||
{protocol.LangHTML, false},
|
||||
{protocol.LangVue, false},
|
||||
{protocol.LangUnknown, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.lang), func(t *testing.T) {
|
||||
parser, err := r.GetParser(tt.lang)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if parser == nil {
|
||||
t.Error("expected non-nil parser")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
content string
|
||||
wantLang protocol.Language
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "go file",
|
||||
filename: "test.go",
|
||||
content: "package main\n\nfunc main() {}\n",
|
||||
wantLang: protocol.LangGo,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "typescript file",
|
||||
filename: "test.ts",
|
||||
content: "function hello(): void {}\n",
|
||||
wantLang: protocol.LangTypeScript,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "react tsx file",
|
||||
filename: "Component.tsx",
|
||||
content: `import React from 'react';\n\nexport const Button: React.FC = () => <button className="btn">Click</button>;`,
|
||||
wantLang: protocol.LangTypeScript,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "react jsx file",
|
||||
filename: "Component.jsx",
|
||||
content: `import React from 'react';\n\nexport const Button = () => <button className="btn">Click</button>;`,
|
||||
wantLang: protocol.LangJavaScript,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "python file",
|
||||
filename: "test.py",
|
||||
content: "def hello():\n pass\n",
|
||||
wantLang: protocol.LangPython,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "html file",
|
||||
filename: "test.html",
|
||||
content: `<!DOCTYPE html><html><head><title>Test</title></head><body><h1 class="text-xl">Hello</h1></body></html>`,
|
||||
wantLang: protocol.LangHTML,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "vue file",
|
||||
filename: "Component.vue",
|
||||
content: `<template><div class="container"><h1>{{ title }}</h1></div></template>`,
|
||||
wantLang: protocol.LangVue,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unknown file",
|
||||
filename: "test.txt",
|
||||
content: "hello world",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, tt.filename, []byte(tt.content))
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Language != tt.wantLang {
|
||||
t.Errorf("expected language %s, got %s", tt.wantLang, result.Language)
|
||||
}
|
||||
|
||||
if result.Tree == nil {
|
||||
t.Error("expected non-nil tree")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWithSyntaxErrors(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
// Invalid Go code
|
||||
content := "package main\n\nfunc main( {}\n" // Missing closing paren
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := r.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have parsed (tree-sitter is error-tolerant)
|
||||
if result.Tree == nil {
|
||||
t.Error("expected non-nil tree")
|
||||
}
|
||||
|
||||
// Should have detected errors
|
||||
if len(result.Errors) == 0 {
|
||||
t.Error("expected syntax errors to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBinary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "text file",
|
||||
content: []byte("hello world"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "binary with null byte",
|
||||
content: []byte{0x68, 0x65, 0x6c, 0x00, 0x6f},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
content: []byte{},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isBinary(tt.content); got != tt.want {
|
||||
t.Errorf("isBinary() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := []byte("package main\n\nfunc main() {}\n")
|
||||
ctx := context.Background()
|
||||
|
||||
// Parse once
|
||||
result1, err := r.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Parse again with same content
|
||||
result2, err := r.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Should return cached tree (same pointer)
|
||||
if result1.Tree != result2.Tree {
|
||||
t.Error("expected cached tree to be returned")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// ExtractSymbols extracts symbols from a parsed tree.
|
||||
func ExtractSymbols(tree *sitter.Tree, content []byte, lang protocol.Language, filename string) []protocol.Symbol {
|
||||
if tree == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return extractGoSymbols(root, content, filename)
|
||||
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||
return extractJSSymbols(root, content, filename)
|
||||
case protocol.LangPython:
|
||||
return extractPythonSymbols(root, content, filename)
|
||||
case protocol.LangC, protocol.LangCpp:
|
||||
return extractCSymbols(root, content, filename)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// extractGoSymbols extracts symbols from Go code.
|
||||
func extractGoSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_declaration":
|
||||
symbol = extractGoFunction(n, content, filename)
|
||||
case "method_declaration":
|
||||
symbol = extractGoMethod(n, content, filename)
|
||||
case "type_declaration":
|
||||
symbol = extractGoType(n, content, filename)
|
||||
case "const_declaration", "var_declaration":
|
||||
syms := extractGoVarConst(n, content, filename)
|
||||
symbols = append(symbols, syms...)
|
||||
return true
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangGo); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractGoFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractGoMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get receiver type
|
||||
receiver := n.ChildByFieldName("receiver")
|
||||
receiverType := ""
|
||||
if receiver != nil {
|
||||
// Find the type in the receiver
|
||||
WalkTree(receiver, func(node *sitter.Node) bool {
|
||||
if node.Type() == "type_identifier" {
|
||||
receiverType = GetNodeText(node, content)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
name := GetNodeText(nameNode, content)
|
||||
if receiverType != "" {
|
||||
name = "(" + receiverType + ")." + name
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: name,
|
||||
Kind: protocol.SymbolMethod,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractGoType(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
// Find type_spec child
|
||||
for i := 0; i < int(n.NamedChildCount()); i++ {
|
||||
child := n.NamedChild(i)
|
||||
if child != nil && child.Type() == "type_spec" {
|
||||
nameNode := child.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kind := protocol.SymbolType
|
||||
typeNode := child.ChildByFieldName("type")
|
||||
if typeNode != nil {
|
||||
switch typeNode.Type() {
|
||||
case "struct_type":
|
||||
kind = protocol.SymbolStruct
|
||||
case "interface_type":
|
||||
kind = protocol.SymbolInterface
|
||||
}
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: kind,
|
||||
Location: NodeLocation(child, filename),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractGoVarConst(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
kind := protocol.SymbolVariable
|
||||
if n.Type() == "const_declaration" {
|
||||
kind = protocol.SymbolConstant
|
||||
}
|
||||
|
||||
WalkTree(n, func(node *sitter.Node) bool {
|
||||
if node.Type() == "const_spec" || node.Type() == "var_spec" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode != nil {
|
||||
symbols = append(symbols, protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: kind,
|
||||
Location: NodeLocation(node, filename),
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
// extractJSSymbols extracts symbols from JavaScript/TypeScript code.
|
||||
func extractJSSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_declaration":
|
||||
symbol = extractJSFunction(n, content, filename)
|
||||
case "class_declaration":
|
||||
symbol = extractJSClass(n, content, filename)
|
||||
case "method_definition":
|
||||
symbol = extractJSMethod(n, content, filename)
|
||||
case "lexical_declaration", "variable_declaration":
|
||||
syms := extractJSVariable(n, content, filename)
|
||||
symbols = append(symbols, syms...)
|
||||
return true
|
||||
case "interface_declaration":
|
||||
symbol = extractTSInterface(n, content, filename)
|
||||
case "type_alias_declaration":
|
||||
symbol = extractTSTypeAlias(n, content, filename)
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangJavaScript); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractJSFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolClass,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolMethod,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSVariable(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(n, func(node *sitter.Node) bool {
|
||||
if node.Type() == "variable_declarator" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode != nil {
|
||||
symbols = append(symbols, protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolVariable,
|
||||
Location: NodeLocation(node, filename),
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractTSInterface(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolInterface,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractTSTypeAlias(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolType,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
// extractPythonSymbols extracts symbols from Python code.
|
||||
func extractPythonSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_definition":
|
||||
symbol = extractPythonFunction(n, content, filename)
|
||||
case "class_definition":
|
||||
symbol = extractPythonClass(n, content, filename)
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangPython); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractPythonFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this is a method (inside a class)
|
||||
parent := n.Parent()
|
||||
kind := protocol.SymbolFunction
|
||||
if parent != nil && parent.Type() == "block" {
|
||||
grandparent := parent.Parent()
|
||||
if grandparent != nil && grandparent.Type() == "class_definition" {
|
||||
kind = protocol.SymbolMethod
|
||||
}
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: kind,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractPythonClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolClass,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
// extractCSymbols extracts symbols from C/C++ code.
|
||||
func extractCSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_definition":
|
||||
symbol = extractCFunction(n, content, filename)
|
||||
case "struct_specifier":
|
||||
symbol = extractCStruct(n, content, filename)
|
||||
case "class_specifier":
|
||||
symbol = extractCppClass(n, content, filename)
|
||||
case "declaration":
|
||||
// Could be function declaration or variable
|
||||
if hasFunctionDeclarator(n) {
|
||||
symbol = extractCFunctionDecl(n, content, filename)
|
||||
}
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangC); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractCFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
declarator := n.ChildByFieldName("declarator")
|
||||
if declarator == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find the function name within the declarator
|
||||
var name string
|
||||
WalkTree(declarator, func(node *sitter.Node) bool {
|
||||
if node.Type() == "identifier" {
|
||||
name = GetNodeText(node, content)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: name,
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractCStruct(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolStruct,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractCppClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolClass,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractCFunctionDecl(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
declarator := n.ChildByFieldName("declarator")
|
||||
if declarator == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var name string
|
||||
WalkTree(declarator, func(node *sitter.Node) bool {
|
||||
if node.Type() == "identifier" {
|
||||
name = GetNodeText(node, content)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: name,
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func hasFunctionDeclarator(n *sitter.Node) bool {
|
||||
found := false
|
||||
WalkTree(n, func(node *sitter.Node) bool {
|
||||
if node.Type() == "function_declarator" {
|
||||
found = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestExtractGoSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `package main
|
||||
|
||||
// Hello prints a greeting
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
|
||||
// Server handles requests
|
||||
type Server struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const MaxConnections = 100
|
||||
var globalVar = "test"
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangGo, "test.go")
|
||||
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"Hello": protocol.SymbolFunction,
|
||||
"Server": protocol.SymbolStruct,
|
||||
"(Server).Start": protocol.SymbolMethod,
|
||||
"MaxConnections": protocol.SymbolConstant,
|
||||
"globalVar": protocol.SymbolVariable,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `
|
||||
function greet(name) {
|
||||
console.log("Hello, " + name);
|
||||
}
|
||||
|
||||
class User {
|
||||
constructor(name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
getName() {
|
||||
return this.name;
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_USERS = 100;
|
||||
let currentUser = null;
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.js", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangJavaScript, "test.js")
|
||||
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"greet": protocol.SymbolFunction,
|
||||
"User": protocol.SymbolClass,
|
||||
"MAX_USERS": protocol.SymbolVariable,
|
||||
"currentUser": protocol.SymbolVariable,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPythonSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `
|
||||
def greet(name):
|
||||
"""Greet a person by name."""
|
||||
print(f"Hello, {name}")
|
||||
|
||||
class User:
|
||||
"""Represents a user."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.py", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangPython, "test.py")
|
||||
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"greet": protocol.SymbolFunction,
|
||||
"User": protocol.SymbolClass,
|
||||
"__init__": protocol.SymbolMethod,
|
||||
"get_name": protocol.SymbolMethod,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `
|
||||
#include <stdio.h>
|
||||
|
||||
struct Point {
|
||||
int x;
|
||||
int y;
|
||||
};
|
||||
|
||||
void print_point(struct Point p) {
|
||||
printf("(%d, %d)\n", p.x, p.y);
|
||||
}
|
||||
|
||||
int main() {
|
||||
struct Point p = {1, 2};
|
||||
print_point(p);
|
||||
return 0;
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.c", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangC, "test.c")
|
||||
|
||||
// Note: C symbol extraction is complex, checking for at least main and Point
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"Point": protocol.SymbolStruct,
|
||||
"main": protocol.SymbolFunction,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
// Package parser provides YAML and JSON parsing with AST support.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// YAMLNode wraps yaml.Node to provide tree-sitter-like interface
|
||||
type YAMLNode struct {
|
||||
*yaml.Node
|
||||
Content []byte
|
||||
}
|
||||
|
||||
// JSONNode represents a JSON AST node
|
||||
type JSONNode struct {
|
||||
Value any
|
||||
Type string
|
||||
Children []*JSONNode
|
||||
Line int
|
||||
Column int
|
||||
}
|
||||
|
||||
// ParseYAML parses YAML content and returns a tree-sitter-compatible result
|
||||
func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
}
|
||||
|
||||
// Parse YAML
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||
return nil, errors.NewParseError("yaml", filename, err)
|
||||
}
|
||||
|
||||
// Extract syntax errors from YAML parse
|
||||
syntaxErrors := extractYAMLErrors()
|
||||
|
||||
// Create a pseudo tree-sitter tree for compatibility
|
||||
// We'll use nil for the tree since YAML doesn't use tree-sitter
|
||||
return &ParseResult{
|
||||
Tree: nil, // YAML uses yaml.Node instead
|
||||
Language: protocol.LangYAML,
|
||||
Errors: syntaxErrors,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseJSON parses JSON content and returns a tree-sitter-compatible result
|
||||
func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
}
|
||||
|
||||
// Parse JSON to validate syntax
|
||||
var jsonData any
|
||||
if err := json.Unmarshal(content, &jsonData); err != nil {
|
||||
return nil, errors.NewParseError("json", filename, err)
|
||||
}
|
||||
|
||||
// JSON parsing succeeded, no syntax errors
|
||||
return &ParseResult{
|
||||
Tree: nil, // JSON uses native Go structures
|
||||
Language: protocol.LangJSON,
|
||||
Errors: []SyntaxError{},
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractYAMLErrors extracts errors from YAML nodes
|
||||
func extractYAMLErrors() []SyntaxError {
|
||||
// YAML parser already validates during unmarshal
|
||||
// If we got here, there are no syntax errors
|
||||
// However, we could add semantic validation here in the future
|
||||
return []SyntaxError{}
|
||||
}
|
||||
|
||||
// WalkYAML walks a YAML AST and calls fn for each node
|
||||
func WalkYAML(node *yaml.Node, fn func(*yaml.Node) bool) {
|
||||
if node == nil || !fn(node) {
|
||||
return
|
||||
}
|
||||
|
||||
for _, child := range node.Content {
|
||||
WalkYAML(child, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// GetYAMLNodeText returns the text representation of a YAML node
|
||||
func GetYAMLNodeText(node *yaml.Node) string {
|
||||
if node == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case yaml.DocumentNode:
|
||||
if len(node.Content) > 0 {
|
||||
return GetYAMLNodeText(node.Content[0])
|
||||
}
|
||||
return ""
|
||||
case yaml.MappingNode:
|
||||
return node.Value
|
||||
case yaml.SequenceNode:
|
||||
return node.Value
|
||||
case yaml.ScalarNode:
|
||||
return node.Value
|
||||
case yaml.AliasNode:
|
||||
return node.Value
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetYAMLNodeLocation returns the location of a YAML node
|
||||
func GetYAMLNodeLocation(node *yaml.Node) protocol.Location {
|
||||
if node == nil {
|
||||
return protocol.Location{Line: 1, Column: 1}
|
||||
}
|
||||
|
||||
return protocol.Location{
|
||||
Line: node.Line,
|
||||
Column: node.Column,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryYAML performs a simple query on YAML content
|
||||
// Example: "$.metadata.name" to find the name field in metadata
|
||||
func QueryYAML(content []byte, query string) ([]*yaml.Node, error) {
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML: %w", err)
|
||||
}
|
||||
|
||||
// Simple path-based query implementation
|
||||
// This is a basic implementation - can be extended with more sophisticated queries
|
||||
var results []*yaml.Node
|
||||
|
||||
WalkYAML(&root, func(node *yaml.Node) bool {
|
||||
if node.Value == query || node.Tag == query {
|
||||
results = append(results, node)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// QueryJSON performs a simple query on JSON content
|
||||
func QueryJSON(content []byte, query string) ([]any, error) {
|
||||
var data any
|
||||
if err := json.Unmarshal(content, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
|
||||
// Basic implementation - can be extended with JSONPath support
|
||||
var results []any
|
||||
|
||||
// For now, just validate that it's valid JSON
|
||||
results = append(results, data)
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ValidateYAML validates YAML content without parsing to full AST
|
||||
func ValidateYAML(content []byte) error {
|
||||
var node yaml.Node
|
||||
if err := yaml.Unmarshal(content, &node); err != nil {
|
||||
return fmt.Errorf("YAML validation failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateJSON validates JSON content
|
||||
func ValidateJSON(content []byte) error {
|
||||
var data any
|
||||
if err := json.Unmarshal(content, &data); err != nil {
|
||||
return fmt.Errorf("JSON validation failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToSitterTree is a placeholder that returns nil for YAML/JSON
|
||||
// These formats don't use tree-sitter, but we keep this for interface compatibility
|
||||
func (yn *YAMLNode) ToSitterTree() *sitter.Tree {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestParseYAML(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid simple YAML",
|
||||
content: `name: test
|
||||
version: 1.0.0
|
||||
enabled: true`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid nested YAML",
|
||||
content: `metadata:
|
||||
name: test-app
|
||||
namespace: default
|
||||
spec:
|
||||
replicas: 3
|
||||
selector:
|
||||
matchLabels:
|
||||
app: test`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid list YAML",
|
||||
content: `items:
|
||||
- name: item1
|
||||
value: 100
|
||||
- name: item2
|
||||
value: 200`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid YAML - bad syntax",
|
||||
content: `name: test\n bad: indent\n wrong: [unclosed`,
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.ParseYAML(context.Background(), "test.yaml", []byte(tt.content))
|
||||
|
||||
if tt.shouldError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Error("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if result.Language != protocol.LangYAML {
|
||||
t.Errorf("expected language YAML, got %s", result.Language)
|
||||
}
|
||||
|
||||
if len(result.Errors) > 0 {
|
||||
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSON(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid simple JSON",
|
||||
content: `{"name": "test", "version": "1.0.0", "enabled": true}`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid nested JSON",
|
||||
content: `{
|
||||
"metadata": {
|
||||
"name": "test-app",
|
||||
"namespace": "default"
|
||||
},
|
||||
"spec": {
|
||||
"replicas": 3,
|
||||
"selector": {
|
||||
"matchLabels": {
|
||||
"app": "test"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid array JSON",
|
||||
content: `[{"name": "item1", "value": 100}, {"name": "item2", "value": 200}]`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON - unclosed brace",
|
||||
content: `{"name": "test", "value": 100`,
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON - trailing comma",
|
||||
content: `{"name": "test", "value": 100,}`,
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.ParseJSON(context.Background(), "test.json", []byte(tt.content))
|
||||
|
||||
if tt.shouldError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Error("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if result.Language != protocol.LangJSON {
|
||||
t.Errorf("expected language JSON, got %s", result.Language)
|
||||
}
|
||||
|
||||
if len(result.Errors) > 0 {
|
||||
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryParse_YAML_JSON(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
yamlContent := []byte(`name: test
|
||||
version: 1.0.0`)
|
||||
|
||||
jsonContent := []byte(`{"name": "test", "version": "1.0.0"}`)
|
||||
|
||||
// Test YAML through main Parse method
|
||||
yamlResult, err := registry.Parse(context.Background(), "config.yaml", yamlContent)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse YAML: %v", err)
|
||||
}
|
||||
if yamlResult.Language != protocol.LangYAML {
|
||||
t.Errorf("expected YAML language, got %s", yamlResult.Language)
|
||||
}
|
||||
|
||||
// Test JSON through main Parse method
|
||||
jsonResult, err := registry.Parse(context.Background(), "config.json", jsonContent)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse JSON: %v", err)
|
||||
}
|
||||
if jsonResult.Language != protocol.LangJSON {
|
||||
t.Errorf("expected JSON language, got %s", jsonResult.Language)
|
||||
}
|
||||
|
||||
// Test .yml extension
|
||||
ymlResult, err := registry.Parse(context.Background(), "config.yml", yamlContent)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse .yml: %v", err)
|
||||
}
|
||||
if ymlResult.Language != protocol.LangYAML {
|
||||
t.Errorf("expected YAML language for .yml extension, got %s", ymlResult.Language)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalkYAML(t *testing.T) {
|
||||
content := []byte(`metadata:
|
||||
name: test
|
||||
labels:
|
||||
app: myapp
|
||||
env: prod`)
|
||||
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||
t.Fatalf("failed to parse YAML: %v", err)
|
||||
}
|
||||
|
||||
nodeCount := 0
|
||||
WalkYAML(&root, func(node *yaml.Node) bool {
|
||||
nodeCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if nodeCount == 0 {
|
||||
t.Error("expected to visit nodes, but count is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid YAML",
|
||||
content: []byte("name: test\nvalue: 100"),
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid YAML",
|
||||
content: []byte("name: test\n bad:\n[unclosed"),
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateYAML(tt.content)
|
||||
if (err != nil) != tt.shouldError {
|
||||
t.Errorf("ValidateYAML() error = %v, shouldError = %v", err, tt.shouldError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid JSON",
|
||||
content: []byte(`{"name": "test", "value": 100}`),
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
content: []byte(`{"name": "test", "value": 100`),
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateJSON(tt.content)
|
||||
if (err != nil) != tt.shouldError {
|
||||
t.Errorf("ValidateJSON() error = %v, shouldError = %v", err, tt.shouldError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,538 @@
|
||||
// Package query implements a hybrid AST query language with pattern matching.
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// Global regex cache for compiled patterns (thread-safe)
|
||||
var regexCache sync.Map // string -> *regexp.Regexp
|
||||
|
||||
// compileRegex compiles a regex pattern with caching for performance.
|
||||
// Cached patterns avoid repeated compilation overhead (10-50x speedup).
|
||||
// Thread-safe: uses LoadOrStore to prevent race conditions.
|
||||
func compileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
// Check cache first
|
||||
if cached, ok := regexCache.Load(pattern); ok {
|
||||
return cached.(*regexp.Regexp), nil
|
||||
}
|
||||
|
||||
// Compile regex
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to store - if another goroutine already stored it, use theirs
|
||||
// This prevents race conditions where multiple goroutines compile the same pattern
|
||||
actual, _ := regexCache.LoadOrStore(pattern, re)
|
||||
return actual.(*regexp.Regexp), nil
|
||||
}
|
||||
|
||||
// ASTQuery defines a query for matching AST patterns.
|
||||
type ASTQuery struct {
|
||||
Pattern string `json:"pattern"` // code pattern with $VAR placeholders
|
||||
Language string `json:"language"` // required
|
||||
Filters QueryFilters `json:"filters,omitempty"`
|
||||
}
|
||||
|
||||
// QueryFilters provide additional filtering criteria.
|
||||
type QueryFilters struct {
|
||||
HasChild *ASTQuery `json:"has_child,omitempty"`
|
||||
HasParent *ASTQuery `json:"has_parent,omitempty"`
|
||||
NameMatches string `json:"name_matches,omitempty"`
|
||||
NameExact string `json:"name_exact,omitempty"`
|
||||
InFile string `json:"in_file,omitempty"`
|
||||
NotInFile string `json:"not_in_file,omitempty"`
|
||||
KindIn []string `json:"kind_in,omitempty"`
|
||||
}
|
||||
|
||||
// MatchResult represents a single match from a query.
|
||||
type MatchResult struct {
|
||||
Node *sitter.Node
|
||||
Captures map[string]CapturedNode
|
||||
File string
|
||||
Text string
|
||||
Location protocol.Location
|
||||
}
|
||||
|
||||
// CapturedNode represents a captured node or nodes.
|
||||
type CapturedNode struct {
|
||||
Text string
|
||||
Nodes []*sitter.Node
|
||||
}
|
||||
|
||||
// CaptureType indicates the type of capture.
|
||||
type CaptureType int
|
||||
|
||||
const (
|
||||
CaptureSingle CaptureType = iota // $NAME - single node
|
||||
CaptureMultiple // $$$NAME - multiple nodes
|
||||
CaptureWildcard // $_ - wildcard (don't capture)
|
||||
)
|
||||
|
||||
// Capture represents a placeholder in a pattern.
|
||||
type Capture struct {
|
||||
Name string
|
||||
Type CaptureType
|
||||
Position int // position in the pattern
|
||||
}
|
||||
|
||||
// ParsedPattern represents a parsed code pattern.
|
||||
type ParsedPattern struct {
|
||||
Original string
|
||||
Template string
|
||||
Captures []Capture
|
||||
}
|
||||
|
||||
// Matcher performs AST pattern matching.
|
||||
type Matcher struct {
|
||||
registry *parser.Registry
|
||||
}
|
||||
|
||||
// NewMatcher creates a new pattern matcher.
|
||||
func NewMatcher(registry *parser.Registry) *Matcher {
|
||||
return &Matcher{registry: registry}
|
||||
}
|
||||
|
||||
// ParsePattern parses a pattern string and extracts captures.
|
||||
func ParsePattern(pattern string) (*ParsedPattern, error) {
|
||||
if pattern == "" {
|
||||
return nil, fmt.Errorf("empty pattern")
|
||||
}
|
||||
|
||||
var captures []Capture
|
||||
template := pattern
|
||||
captureID := 0
|
||||
|
||||
// Find all captures: $$$ (multi), $_ (wildcard), $NAME (single)
|
||||
// Order matters: check $$$ first
|
||||
multiRe := regexp.MustCompile(`\$\$\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||
wildcardRe := regexp.MustCompile(`\$_`)
|
||||
singleRe := regexp.MustCompile(`\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||
|
||||
// Extract multi-node captures ($$$NAME)
|
||||
for _, match := range multiRe.FindAllStringSubmatchIndex(pattern, -1) {
|
||||
name := pattern[match[2]:match[3]]
|
||||
captures = append(captures, Capture{
|
||||
Name: name,
|
||||
Type: CaptureMultiple,
|
||||
Position: match[0],
|
||||
})
|
||||
}
|
||||
|
||||
// Replace multi-captures with placeholder identifiers
|
||||
template = multiRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||
captureID++
|
||||
return fmt.Sprintf("__multi_%d__", captureID)
|
||||
})
|
||||
|
||||
// Extract wildcards ($_)
|
||||
for _, match := range wildcardRe.FindAllStringIndex(pattern, -1) {
|
||||
captures = append(captures, Capture{
|
||||
Name: "_",
|
||||
Type: CaptureWildcard,
|
||||
Position: match[0],
|
||||
})
|
||||
}
|
||||
|
||||
// Replace wildcards with placeholder identifiers
|
||||
template = wildcardRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||
captureID++
|
||||
return fmt.Sprintf("__wild_%d__", captureID)
|
||||
})
|
||||
|
||||
// Extract single-node captures ($NAME) - exclude those that are part of $$$NAME
|
||||
// Check which $NAME patterns are not preceded by $$
|
||||
remaining := template
|
||||
for _, match := range singleRe.FindAllStringSubmatchIndex(remaining, -1) {
|
||||
name := remaining[match[2]:match[3]]
|
||||
// Skip if this looks like our placeholder
|
||||
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
|
||||
continue
|
||||
}
|
||||
captures = append(captures, Capture{
|
||||
Name: name,
|
||||
Type: CaptureSingle,
|
||||
Position: match[0],
|
||||
})
|
||||
}
|
||||
|
||||
// Replace single captures with placeholder identifiers
|
||||
template = singleRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||
name := strings.TrimPrefix(s, "$")
|
||||
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
|
||||
return s // keep our placeholders as is
|
||||
}
|
||||
captureID++
|
||||
return fmt.Sprintf("__single_%d__", captureID)
|
||||
})
|
||||
|
||||
return &ParsedPattern{
|
||||
Original: pattern,
|
||||
Captures: captures,
|
||||
Template: template,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Match executes a query against a parsed tree.
|
||||
func (m *Matcher) Match(ctx context.Context, query *ASTQuery, tree *sitter.Tree, content []byte, filename string) ([]MatchResult, error) {
|
||||
if query.Pattern == "" {
|
||||
return nil, fmt.Errorf("query pattern is required")
|
||||
}
|
||||
|
||||
lang := protocol.Language(query.Language)
|
||||
if lang == "" || lang == protocol.LangUnknown {
|
||||
return nil, fmt.Errorf("valid language is required")
|
||||
}
|
||||
|
||||
// Parse the pattern
|
||||
parsed, err := ParsePattern(query.Pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pattern: %w", err)
|
||||
}
|
||||
|
||||
var results []MatchResult
|
||||
|
||||
// Walk the tree and find matches
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
parser.WalkTree(root, func(n *sitter.Node) bool {
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
// Try to match this node against the pattern
|
||||
if matched, captures := matchNode(n, parsed, content); matched {
|
||||
// Apply filters
|
||||
if !passesFilters(n, query.Filters, content) {
|
||||
return true // continue walking
|
||||
}
|
||||
|
||||
startPoint := n.StartPoint()
|
||||
results = append(results, MatchResult{
|
||||
Node: n,
|
||||
Captures: captures,
|
||||
File: filename,
|
||||
Location: protocol.Location{
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
},
|
||||
Text: parser.GetNodeText(n, content),
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// matchNode attempts to match a node against a parsed pattern.
|
||||
// This is a simplified matcher that looks for structural similarity.
|
||||
func matchNode(node *sitter.Node, pattern *ParsedPattern, content []byte) (bool, map[string]CapturedNode) {
|
||||
if node == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
captures := make(map[string]CapturedNode)
|
||||
|
||||
// Use pattern keyword matching as a heuristic to find matching nodes
|
||||
// A full implementation would parse both pattern and node and compare AST structure
|
||||
matched := matchPatternHeuristic(node, pattern, content, captures)
|
||||
|
||||
return matched, captures
|
||||
}
|
||||
|
||||
// matchPatternHeuristic uses heuristics to match patterns.
|
||||
// This is a simplified implementation that matches based on node type and structure.
|
||||
func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content []byte, captures map[string]CapturedNode) bool {
|
||||
patternLower := strings.ToLower(pattern.Original)
|
||||
nodeType := node.Type()
|
||||
|
||||
// Match function patterns
|
||||
if strings.Contains(patternLower, "func ") || strings.Contains(patternLower, "function ") {
|
||||
if nodeType != "function_declaration" && nodeType != "method_declaration" && nodeType != "function_definition" {
|
||||
return false
|
||||
}
|
||||
extractFunctionCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
// Match class patterns
|
||||
if strings.Contains(patternLower, "class ") {
|
||||
if nodeType != "class_declaration" && nodeType != "class_definition" {
|
||||
return false
|
||||
}
|
||||
extractClassCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
// Match struct patterns (Go, C, C++)
|
||||
if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") {
|
||||
if nodeType != "type_declaration" && nodeType != "struct_specifier" {
|
||||
return false
|
||||
}
|
||||
extractStructCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
// Match interface patterns (Go, TypeScript)
|
||||
if strings.Contains(patternLower, "interface ") {
|
||||
if nodeType != "interface_declaration" && nodeType != "type_declaration" {
|
||||
return false
|
||||
}
|
||||
extractInterfaceCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// extractFunctionCaptures extracts captures from a function node.
|
||||
func extractFunctionCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "ARGS", "args", "PARAMS", "params":
|
||||
if paramsNode := node.ChildByFieldName("parameters"); paramsNode != nil {
|
||||
var paramNodes []*sitter.Node
|
||||
for i := 0; i < int(paramsNode.NamedChildCount()); i++ {
|
||||
paramNodes = append(paramNodes, paramsNode.NamedChild(i))
|
||||
}
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: paramNodes,
|
||||
Text: parser.GetNodeText(paramsNode, content),
|
||||
}
|
||||
}
|
||||
case "BODY", "body":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
case "RETURN", "return", "RESULT", "result":
|
||||
if resultNode := node.ChildByFieldName("result"); resultNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{resultNode},
|
||||
Text: parser.GetNodeText(resultNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractClassCaptures extracts captures from a class node.
|
||||
func extractClassCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "BODY", "body":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
case "EXTENDS", "extends", "SUPERCLASS", "superclass":
|
||||
if extendsNode := node.ChildByFieldName("superclass"); extendsNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{extendsNode},
|
||||
Text: parser.GetNodeText(extendsNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractStructCaptures extracts captures from a struct node.
|
||||
func extractStructCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
// For Go type_declaration, we need to look at the type_spec child
|
||||
if node.Type() == "type_declaration" {
|
||||
for i := 0; i < int(node.NamedChildCount()); i++ {
|
||||
child := node.NamedChild(i)
|
||||
if child != nil && child.Type() == "type_spec" {
|
||||
if nameNode := child.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "FIELDS", "fields", "BODY", "body":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractInterfaceCaptures extracts captures from an interface node.
|
||||
func extractInterfaceCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "BODY", "body", "METHODS", "methods":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// passesFilters checks if a node passes all the specified filters.
|
||||
func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool {
|
||||
// Name regex filter (uses cached compilation)
|
||||
if filters.NameMatches != "" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return false
|
||||
}
|
||||
name := parser.GetNodeText(nameNode, content)
|
||||
re, err := compileRegex(filters.NameMatches)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !re.MatchString(name) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Exact name filter
|
||||
if filters.NameExact != "" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return false
|
||||
}
|
||||
name := parser.GetNodeText(nameNode, content)
|
||||
if name != filters.NameExact {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Kind filter
|
||||
if len(filters.KindIn) > 0 {
|
||||
nodeType := node.Type()
|
||||
found := false
|
||||
for _, kind := range filters.KindIn {
|
||||
if nodeType == kind {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// FormatResults formats match results for display.
|
||||
func FormatResults(results []MatchResult, maxResults int) string {
|
||||
if len(results) == 0 {
|
||||
return "No matches found."
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results)))
|
||||
|
||||
displayCount := len(results)
|
||||
truncated := false
|
||||
if maxResults > 0 && displayCount > maxResults {
|
||||
displayCount = maxResults
|
||||
truncated = true
|
||||
}
|
||||
|
||||
for i := 0; i < displayCount; i++ {
|
||||
r := results[i]
|
||||
nodeType := "unknown"
|
||||
if r.Node != nil {
|
||||
nodeType = r.Node.Type()
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
|
||||
|
||||
// Truncate very long text
|
||||
text := r.Text
|
||||
if len(text) > 500 {
|
||||
text = text[:500] + "..."
|
||||
}
|
||||
sb.WriteString("```\n")
|
||||
sb.WriteString(text)
|
||||
sb.WriteString("\n```\n")
|
||||
|
||||
// Show captures
|
||||
if len(r.Captures) > 0 {
|
||||
sb.WriteString("Captures: ")
|
||||
first := true
|
||||
for name, cap := range r.Captures {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
first = false
|
||||
capText := cap.Text
|
||||
if len(capText) > 50 {
|
||||
capText = capText[:50] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if truncated {
|
||||
sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,559 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestParsePattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
captureNames []string
|
||||
captureTypes []CaptureType
|
||||
wantCaptures int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty pattern",
|
||||
pattern: "",
|
||||
wantErr: true,
|
||||
wantCaptures: 0,
|
||||
},
|
||||
{
|
||||
name: "single capture",
|
||||
pattern: "func $NAME() {}",
|
||||
wantErr: false,
|
||||
wantCaptures: 1,
|
||||
captureNames: []string{"NAME"},
|
||||
captureTypes: []CaptureType{CaptureSingle},
|
||||
},
|
||||
{
|
||||
name: "multiple single captures",
|
||||
pattern: "func $NAME($ARGS) $RETURN",
|
||||
wantErr: false,
|
||||
wantCaptures: 3,
|
||||
captureNames: []string{"NAME", "ARGS", "RETURN"},
|
||||
captureTypes: []CaptureType{CaptureSingle, CaptureSingle, CaptureSingle},
|
||||
},
|
||||
{
|
||||
name: "multi-node capture",
|
||||
pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
||||
wantErr: false,
|
||||
wantCaptures: 3,
|
||||
captureNames: []string{"ARGS", "BODY", "NAME"},
|
||||
captureTypes: []CaptureType{CaptureMultiple, CaptureMultiple, CaptureSingle},
|
||||
},
|
||||
{
|
||||
name: "wildcard capture",
|
||||
pattern: "func $NAME($_) {}",
|
||||
wantErr: false,
|
||||
wantCaptures: 2,
|
||||
captureNames: []string{"NAME", "_"},
|
||||
captureTypes: []CaptureType{CaptureSingle, CaptureWildcard},
|
||||
},
|
||||
{
|
||||
name: "no captures",
|
||||
pattern: "func main() {}",
|
||||
wantErr: false,
|
||||
wantCaptures: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsed, err := ParsePattern(tt.pattern)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(parsed.Captures) != tt.wantCaptures {
|
||||
t.Errorf("expected %d captures, got %d", tt.wantCaptures, len(parsed.Captures))
|
||||
}
|
||||
|
||||
// Check capture names (order may vary)
|
||||
if tt.captureNames != nil {
|
||||
captureMap := make(map[string]CaptureType)
|
||||
for _, cap := range parsed.Captures {
|
||||
captureMap[cap.Name] = cap.Type
|
||||
}
|
||||
|
||||
for i, name := range tt.captureNames {
|
||||
if _, ok := captureMap[name]; !ok {
|
||||
t.Errorf("expected capture %s not found", name)
|
||||
}
|
||||
if captureMap[name] != tt.captureTypes[i] {
|
||||
t.Errorf("capture %s: expected type %v, got %v", name, tt.captureTypes[i], captureMap[name])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchGoFunctions(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
|
||||
func Greet(name string) error {
|
||||
println("hello", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "match all functions",
|
||||
query: &ASTQuery{
|
||||
Pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
||||
Language: "go",
|
||||
},
|
||||
wantMatches: 3, // Hello, Greet, Start
|
||||
},
|
||||
{
|
||||
name: "match functions starting with H",
|
||||
query: &ASTQuery{
|
||||
Pattern: "func $NAME() {}",
|
||||
Language: "go",
|
||||
Filters: QueryFilters{
|
||||
NameMatches: "^H",
|
||||
},
|
||||
},
|
||||
wantMatches: 1, // Hello
|
||||
},
|
||||
{
|
||||
name: "match specific function",
|
||||
query: &ASTQuery{
|
||||
Pattern: "func $NAME() {}",
|
||||
Language: "go",
|
||||
Filters: QueryFilters{
|
||||
NameExact: "Hello",
|
||||
},
|
||||
},
|
||||
wantMatches: 1, // Hello
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
for i, r := range results {
|
||||
t.Logf("match %d: %s at line %d", i, r.Node.Type(), r.Location.Line)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchGoStructs(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `package main
|
||||
|
||||
type Server struct {
|
||||
Port int
|
||||
Host string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Timeout int
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Log(msg string)
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMinimum int
|
||||
}{
|
||||
{
|
||||
name: "match all structs",
|
||||
query: &ASTQuery{
|
||||
Pattern: "type $NAME struct { $$$FIELDS }",
|
||||
Language: "go",
|
||||
},
|
||||
wantMinimum: 2, // Server, Config (may also match interface as type_declaration)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) < tt.wantMinimum {
|
||||
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchJSFunctions(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `
|
||||
function greet(name) {
|
||||
console.log("Hello, " + name);
|
||||
}
|
||||
|
||||
function sayHello() {
|
||||
console.log("Hello!");
|
||||
}
|
||||
|
||||
class User {
|
||||
constructor(name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
getName() {
|
||||
return this.name;
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.js", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "match all functions",
|
||||
query: &ASTQuery{
|
||||
Pattern: "function $NAME($$$ARGS) { $$$BODY }",
|
||||
Language: "javascript",
|
||||
},
|
||||
wantMatches: 2, // greet, sayHello
|
||||
},
|
||||
{
|
||||
name: "match classes",
|
||||
query: &ASTQuery{
|
||||
Pattern: "class $NAME { $$$BODY }",
|
||||
Language: "javascript",
|
||||
},
|
||||
wantMatches: 1, // User
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.js")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchPythonSymbols(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `
|
||||
def greet(name):
|
||||
print(f"Hello, {name}")
|
||||
|
||||
def calculate(a, b):
|
||||
return a + b
|
||||
|
||||
class User:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.py", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMinimum int
|
||||
}{
|
||||
{
|
||||
name: "match classes",
|
||||
query: &ASTQuery{
|
||||
Pattern: "class $NAME: $$$BODY",
|
||||
Language: "python",
|
||||
},
|
||||
wantMinimum: 1, // User
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.py")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) < tt.wantMinimum {
|
||||
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryFilters(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `package main
|
||||
|
||||
func HelloWorld() {}
|
||||
func helloWorld() {}
|
||||
func GoodbyeWorld() {}
|
||||
func Main() {}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filters QueryFilters
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "regex filter - starts with H",
|
||||
filters: QueryFilters{
|
||||
NameMatches: "^[Hh]ello",
|
||||
},
|
||||
wantMatches: 2, // HelloWorld, helloWorld
|
||||
},
|
||||
{
|
||||
name: "exact name filter",
|
||||
filters: QueryFilters{
|
||||
NameExact: "Main",
|
||||
},
|
||||
wantMatches: 1, // Main
|
||||
},
|
||||
{
|
||||
name: "kind filter",
|
||||
filters: QueryFilters{
|
||||
KindIn: []string{"function_declaration"},
|
||||
},
|
||||
wantMatches: 4, // all functions
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query := &ASTQuery{
|
||||
Pattern: "func $NAME() {}",
|
||||
Language: "go",
|
||||
Filters: tt.filters,
|
||||
}
|
||||
|
||||
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
for _, r := range results {
|
||||
if nameNode := r.Node.ChildByFieldName("name"); nameNode != nil {
|
||||
t.Logf("matched: %s", parser.GetNodeText(nameNode, []byte(content)))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
results []MatchResult
|
||||
maxResults int
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "empty results",
|
||||
results: []MatchResult{},
|
||||
maxResults: 100,
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "single result",
|
||||
results: []MatchResult{
|
||||
{
|
||||
File: "test.go",
|
||||
Location: protocol.Location{Line: 10, Column: 1},
|
||||
Text: "func Hello() {}",
|
||||
Captures: map[string]CapturedNode{
|
||||
"NAME": {Text: "Hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
maxResults: 100,
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "truncated results",
|
||||
results: []MatchResult{
|
||||
{File: "a.go", Location: protocol.Location{Line: 1}, Text: "func A() {}"},
|
||||
{File: "b.go", Location: protocol.Location{Line: 1}, Text: "func B() {}"},
|
||||
{File: "c.go", Location: protocol.Location{Line: 1}, Text: "func C() {}"},
|
||||
},
|
||||
maxResults: 2,
|
||||
wantEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
output := FormatResults(tt.results, tt.maxResults)
|
||||
|
||||
if tt.wantEmpty {
|
||||
if output != "No matches found." {
|
||||
t.Errorf("expected 'No matches found.', got: %s", output)
|
||||
}
|
||||
} else {
|
||||
if output == "No matches found." {
|
||||
t.Error("expected results, got 'No matches found.'")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryValidation(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
ctx := context.Background()
|
||||
|
||||
// Parse some valid content
|
||||
content := `package main
|
||||
func main() {}
|
||||
`
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty pattern",
|
||||
query: &ASTQuery{Pattern: "", Language: "go"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing language",
|
||||
query: &ASTQuery{Pattern: "func $NAME() {}", Language: ""},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown language",
|
||||
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "unknown"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid query",
|
||||
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "go"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCompileRegexCaching tests that regex compilation is cached.
|
||||
func TestCompileRegexCaching(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
pattern := `^test_\w+$`
|
||||
|
||||
// First compilation
|
||||
re1, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("First compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Second compilation should return cached version
|
||||
re2, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Second compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be the exact same object
|
||||
if re1 != re2 {
|
||||
t.Error("Expected cached regex to be reused, got different objects")
|
||||
}
|
||||
|
||||
// Verify it's in the cache
|
||||
cached, ok := regexCache.Load(pattern)
|
||||
if !ok {
|
||||
t.Error("Pattern not found in cache")
|
||||
}
|
||||
|
||||
if cached.(*regexp.Regexp) != re1 {
|
||||
t.Error("Cached regex doesn't match returned regex")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompileRegexConcurrent tests concurrent regex compilation.
|
||||
func TestCompileRegexConcurrent(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
pattern := `[a-z]+_\d+`
|
||||
const numGoroutines = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
results := make([]*regexp.Regexp, numGoroutines)
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
re, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
results[i] = re
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent compile failed: %v", err)
|
||||
}
|
||||
|
||||
// All results should be the same object (cached)
|
||||
for i := 1; i < numGoroutines; i++ {
|
||||
if results[i] != results[0] {
|
||||
t.Errorf("Result %d is different from result 0 (cache not working)", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompileRegexInvalidPattern tests error handling for invalid patterns.
|
||||
func TestCompileRegexInvalidPattern(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
invalidPattern := `[invalid(`
|
||||
|
||||
_, err := compileRegex(invalidPattern)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid pattern, got nil")
|
||||
}
|
||||
|
||||
// Invalid patterns should not be cached
|
||||
_, ok := regexCache.Load(invalidPattern)
|
||||
if ok {
|
||||
t.Error("Invalid pattern should not be cached")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompileRegexMultiplePatterns tests that different patterns are cached separately.
|
||||
func TestCompileRegexMultiplePatterns(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
patterns := []string{
|
||||
`^test_\w+$`,
|
||||
`^\d{4}-\d{2}-\d{2}$`,
|
||||
`^[A-Z][a-z]+$`,
|
||||
`\b\w+@\w+\.\w+\b`,
|
||||
}
|
||||
|
||||
compiled := make([]*regexp.Regexp, len(patterns))
|
||||
|
||||
// Compile all patterns
|
||||
for i, pattern := range patterns {
|
||||
re, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Compile failed for pattern %s: %v", pattern, err)
|
||||
}
|
||||
compiled[i] = re
|
||||
}
|
||||
|
||||
// Verify all are cached
|
||||
for i, pattern := range patterns {
|
||||
cached, ok := regexCache.Load(pattern)
|
||||
if !ok {
|
||||
t.Errorf("Pattern %s not in cache", pattern)
|
||||
}
|
||||
|
||||
if cached.(*regexp.Regexp) != compiled[i] {
|
||||
t.Errorf("Cached regex for %s doesn't match compiled version", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// All should be different objects
|
||||
for i := 0; i < len(compiled); i++ {
|
||||
for j := i + 1; j < len(compiled); j++ {
|
||||
if compiled[i] == compiled[j] {
|
||||
t.Errorf("Pattern %d and %d have same regex object", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegex_Uncached benchmarks regex compilation without caching.
|
||||
func BenchmarkCompileRegex_Uncached(b *testing.B) {
|
||||
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = regexp.Compile(pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegex_Cached benchmarks regex compilation with caching.
|
||||
func BenchmarkCompileRegex_Cached(b *testing.B) {
|
||||
// Clear cache
|
||||
regexCache = sync.Map{}
|
||||
|
||||
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
|
||||
|
||||
// Pre-populate cache
|
||||
_, _ = compileRegex(pattern)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = compileRegex(pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegex_MixedPatterns benchmarks realistic workload with multiple patterns.
|
||||
func BenchmarkCompileRegex_MixedPatterns(b *testing.B) {
|
||||
// Clear cache
|
||||
regexCache = sync.Map{}
|
||||
|
||||
patterns := []string{
|
||||
`^test_\w+$`,
|
||||
`^\d{4}-\d{2}-\d{2}$`,
|
||||
`^[A-Z][a-z]+$`,
|
||||
`\b\w+@\w+\.\w+\b`,
|
||||
`^func\s+\w+\(`,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate realistic access pattern
|
||||
pattern := patterns[i%len(patterns)]
|
||||
_, _ = compileRegex(pattern)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,401 @@
|
||||
// Package search provides text search functionality using ripgrep.
|
||||
package search
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// Searcher provides text search functionality using ripgrep.
|
||||
type Searcher struct {
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
rgPath string
|
||||
}
|
||||
|
||||
// Request represents a search request.
|
||||
type Request struct {
|
||||
Pattern string
|
||||
Paths []string
|
||||
FileTypes []string
|
||||
ContextLines int
|
||||
MaxResults int
|
||||
IgnoreCase bool
|
||||
Regex bool
|
||||
IncludeHidden bool
|
||||
FollowSymlinks bool
|
||||
}
|
||||
|
||||
// Result represents a single search result.
|
||||
type Result struct {
|
||||
File string `json:"file"`
|
||||
MatchText string `json:"match_text"`
|
||||
Language protocol.Language `json:"language"`
|
||||
Context ContextLines `json:"context"`
|
||||
Line int `json:"line"`
|
||||
Column int `json:"column"`
|
||||
}
|
||||
|
||||
// ContextLines holds lines before and after a match.
|
||||
type ContextLines struct {
|
||||
Before []string `json:"before"`
|
||||
After []string `json:"after"`
|
||||
}
|
||||
|
||||
// SearchResults holds the complete search results.
|
||||
type SearchResults struct {
|
||||
Results []Result `json:"results"`
|
||||
Truncated bool `json:"truncated"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// ripgrep JSON output types
|
||||
type rgMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type rgMatch struct {
|
||||
Path struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"path"`
|
||||
Lines struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"lines"`
|
||||
Submatches []struct {
|
||||
Match struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"match"`
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
} `json:"submatches"`
|
||||
LineNumber int `json:"line_number"`
|
||||
AbsoluteOffset int `json:"absolute_offset"`
|
||||
}
|
||||
|
||||
type rgContext struct {
|
||||
Path struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"path"`
|
||||
Lines struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"lines"`
|
||||
LineNumber int `json:"line_number"`
|
||||
}
|
||||
|
||||
type rgSummary struct {
|
||||
ElapsedTotal struct {
|
||||
Secs int `json:"secs"`
|
||||
Nanos int `json:"nanos"`
|
||||
} `json:"elapsed_total"`
|
||||
Stats struct {
|
||||
Searches int `json:"searches"`
|
||||
SearchesWithMatch int `json:"searches_with_match"`
|
||||
BytesSearched int64 `json:"bytes_searched"`
|
||||
BytesPrinted int64 `json:"bytes_printed"`
|
||||
MatchedLines int `json:"matched_lines"`
|
||||
Matches int `json:"matches"`
|
||||
} `json:"stats"`
|
||||
}
|
||||
|
||||
// New creates a new Searcher instance.
|
||||
func New(cfg *config.Config, logger *slog.Logger) (*Searcher, error) {
|
||||
// Detect ripgrep binary
|
||||
rgPath, err := exec.LookPath("rg")
|
||||
if err != nil {
|
||||
return nil, errors.NewRipgrepNotFound()
|
||||
}
|
||||
|
||||
return &Searcher{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
rgPath: rgPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Search executes a search and returns results.
|
||||
func (s *Searcher) Search(ctx context.Context, req *Request) (*SearchResults, error) {
|
||||
if req.Pattern == "" {
|
||||
return nil, errors.New(errors.ErrInvalidPattern, "pattern cannot be empty").
|
||||
WithRemediation("Provide a non-empty search pattern")
|
||||
}
|
||||
|
||||
// Build ripgrep command
|
||||
args := s.buildArgs(req)
|
||||
|
||||
s.logger.Debug("executing ripgrep", "args", args)
|
||||
|
||||
// Create command with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, s.cfg.SearchTimeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, s.rgPath, args...) // #nosec G204 - rgPath is validated at initialization
|
||||
|
||||
// Set working directory to workspace root
|
||||
cmd.Dir = s.cfg.WorkspaceRoot
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// Run command - ripgrep returns exit code 1 for no matches, which is not an error
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return nil, errors.NewSearchTimeout(req.Pattern, s.cfg.SearchTimeout.String())
|
||||
}
|
||||
// Exit code 1 means no matches, which is fine
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||
return &SearchResults{Results: []Result{}, Total: 0}, nil
|
||||
}
|
||||
// Exit code 2 means error
|
||||
if stderr.Len() > 0 {
|
||||
return nil, errors.Wrap(errors.ErrSearchFailed, "ripgrep search failed", err).
|
||||
WithContext("pattern", req.Pattern).
|
||||
WithContext("stderr", stderr.String()).
|
||||
WithRemediation("Check search pattern syntax and ensure files are readable")
|
||||
}
|
||||
return nil, errors.Wrap(errors.ErrSearchFailed, "ripgrep search failed", err).
|
||||
WithContext("pattern", req.Pattern).
|
||||
WithRemediation("Check search pattern syntax and ensure ripgrep is functioning correctly")
|
||||
}
|
||||
|
||||
// Parse JSON output
|
||||
return s.parseOutput(&stdout, req.MaxResults)
|
||||
}
|
||||
|
||||
// buildArgs builds the ripgrep command arguments.
|
||||
func (s *Searcher) buildArgs(req *Request) []string {
|
||||
args := []string{"--json"}
|
||||
|
||||
// Add context lines
|
||||
if req.ContextLines > 0 {
|
||||
args = append(args, fmt.Sprintf("--context=%d", req.ContextLines))
|
||||
}
|
||||
|
||||
// File type filtering
|
||||
for _, ft := range req.FileTypes {
|
||||
args = append(args, "--type", ft)
|
||||
}
|
||||
|
||||
// Case sensitivity
|
||||
if req.IgnoreCase {
|
||||
args = append(args, "--ignore-case")
|
||||
}
|
||||
|
||||
// Fixed strings (non-regex)
|
||||
if !req.Regex {
|
||||
args = append(args, "--fixed-strings")
|
||||
}
|
||||
|
||||
// Follow symlinks
|
||||
if req.FollowSymlinks || s.cfg.FollowSymlinks {
|
||||
args = append(args, "--follow")
|
||||
}
|
||||
|
||||
// Include hidden files
|
||||
if req.IncludeHidden {
|
||||
args = append(args, "--hidden")
|
||||
}
|
||||
|
||||
// Respect .gitignore (default behavior for rg)
|
||||
if !s.cfg.RespectGitignore {
|
||||
args = append(args, "--no-ignore")
|
||||
}
|
||||
|
||||
// Max count per file to limit results
|
||||
if req.MaxResults > 0 {
|
||||
args = append(args, fmt.Sprintf("--max-count=%d", req.MaxResults))
|
||||
}
|
||||
|
||||
// Add pattern
|
||||
args = append(args, "--", req.Pattern)
|
||||
|
||||
// Add paths (default to current directory which is workspace root)
|
||||
if len(req.Paths) > 0 {
|
||||
for _, p := range req.Paths {
|
||||
// Validate path is within workspace
|
||||
if s.cfg.IsPathAllowed(p) {
|
||||
args = append(args, p)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
args = append(args, ".")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// parseOutput parses ripgrep JSON output.
|
||||
func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchResults, error) {
|
||||
results := &SearchResults{
|
||||
Results: []Result{},
|
||||
}
|
||||
|
||||
// Track context by file and line
|
||||
contextBefore := make(map[string][]string) // file -> lines before current match
|
||||
currentFile := ""
|
||||
|
||||
scanner := bufio.NewScanner(output)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var msg rgMessage
|
||||
if err := json.Unmarshal(line, &msg); err != nil {
|
||||
s.logger.Debug("failed to parse ripgrep output line", "error", err, "line", string(line))
|
||||
continue
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case "match":
|
||||
var match rgMatch
|
||||
if err := json.Unmarshal(msg.Data, &match); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check max results
|
||||
if maxResults > 0 && len(results.Results) >= maxResults {
|
||||
results.Truncated = true
|
||||
continue
|
||||
}
|
||||
|
||||
result := Result{
|
||||
File: match.Path.Text,
|
||||
Line: match.LineNumber,
|
||||
MatchText: strings.TrimRight(match.Lines.Text, "\n\r"),
|
||||
Language: protocol.DetectLanguage(match.Path.Text),
|
||||
}
|
||||
|
||||
// Add column from first submatch
|
||||
if len(match.Submatches) > 0 {
|
||||
result.Column = match.Submatches[0].Start + 1 // 1-indexed
|
||||
}
|
||||
|
||||
// Add context before
|
||||
if ctx, ok := contextBefore[match.Path.Text]; ok {
|
||||
result.Context.Before = ctx
|
||||
delete(contextBefore, match.Path.Text)
|
||||
}
|
||||
|
||||
results.Results = append(results.Results, result)
|
||||
currentFile = match.Path.Text
|
||||
|
||||
case "context":
|
||||
var ctx rgContext
|
||||
if err := json.Unmarshal(msg.Data, &ctx); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lineText := strings.TrimRight(ctx.Lines.Text, "\n\r")
|
||||
|
||||
// Determine if this is before or after context
|
||||
if len(results.Results) > 0 {
|
||||
lastResult := &results.Results[len(results.Results)-1]
|
||||
if lastResult.File == ctx.Path.Text && ctx.LineNumber > lastResult.Line {
|
||||
// This is after context
|
||||
lastResult.Context.After = append(lastResult.Context.After, lineText)
|
||||
} else if ctx.Path.Text == currentFile || currentFile == "" {
|
||||
// This is before context for a potential upcoming match
|
||||
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
|
||||
}
|
||||
} else {
|
||||
// Before any match - store as potential before context
|
||||
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
|
||||
}
|
||||
|
||||
case "summary":
|
||||
var summary rgSummary
|
||||
if err := json.Unmarshal(msg.Data, &summary); err != nil {
|
||||
continue
|
||||
}
|
||||
results.Total = summary.Stats.Matches
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading ripgrep output: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// FormatResults formats search results for display.
|
||||
func (s *Searcher) FormatResults(results *SearchResults) string {
|
||||
if len(results.Results) == 0 {
|
||||
return "No matches found."
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Group results by file
|
||||
fileResults := make(map[string][]Result)
|
||||
var fileOrder []string
|
||||
for _, r := range results.Results {
|
||||
if _, exists := fileResults[r.File]; !exists {
|
||||
fileOrder = append(fileOrder, r.File)
|
||||
}
|
||||
fileResults[r.File] = append(fileResults[r.File], r)
|
||||
}
|
||||
|
||||
// Write summary
|
||||
totalMatches := len(results.Results)
|
||||
fileCount := len(fileResults)
|
||||
sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount))
|
||||
if results.Truncated {
|
||||
sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total))
|
||||
}
|
||||
sb.WriteString(":\n\n")
|
||||
|
||||
// Write results grouped by file
|
||||
for _, file := range fileOrder {
|
||||
// Make path relative to workspace root if possible
|
||||
relPath := file
|
||||
if absPath, err := filepath.Abs(file); err == nil {
|
||||
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
|
||||
relPath = rel
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("**%s**\n", relPath))
|
||||
|
||||
for _, r := range fileResults[file] {
|
||||
// Write context before
|
||||
for _, ctx := range r.Context.Before {
|
||||
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
|
||||
}
|
||||
|
||||
// Write match line
|
||||
sb.WriteString(fmt.Sprintf("L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200)))
|
||||
|
||||
// Write context after
|
||||
for _, ctx := range r.Context.After {
|
||||
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// truncateLine truncates a line if it exceeds maxLen.
|
||||
func truncateLine(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
cfg := config.Default()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
searcher, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
// ripgrep might not be installed
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
t.Skip("ripgrep not installed, skipping test")
|
||||
}
|
||||
t.Fatalf("New failed: %v", err)
|
||||
}
|
||||
|
||||
if searcher == nil {
|
||||
t.Fatal("expected non-nil searcher")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildArgs(t *testing.T) {
|
||||
cfg := config.Default()
|
||||
cfg.WorkspaceRoot = "/test/workspace"
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
// Create searcher without checking for rg binary
|
||||
s := &Searcher{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
rgPath: "/usr/bin/rg",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *Request
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "basic search",
|
||||
req: &Request{
|
||||
Pattern: "test",
|
||||
ContextLines: 2,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--context=2", "--", "test", "."},
|
||||
},
|
||||
{
|
||||
name: "ignore case",
|
||||
req: &Request{
|
||||
Pattern: "test",
|
||||
IgnoreCase: true,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--ignore-case", "--", "test", "."},
|
||||
},
|
||||
{
|
||||
name: "fixed strings",
|
||||
req: &Request{
|
||||
Pattern: "test",
|
||||
Regex: false,
|
||||
},
|
||||
expected: []string{"--json", "--fixed-strings", "--", "test", "."},
|
||||
},
|
||||
{
|
||||
name: "with file types",
|
||||
req: &Request{
|
||||
Pattern: "test",
|
||||
FileTypes: []string{"go", "ts"},
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."},
|
||||
},
|
||||
{
|
||||
name: "with max results",
|
||||
req: &Request{
|
||||
Pattern: "test",
|
||||
MaxResults: 10,
|
||||
Regex: true,
|
||||
},
|
||||
expected: []string{"--json", "--max-count=10", "--", "test", "."},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := s.buildArgs(tt.req)
|
||||
|
||||
// Check that all expected args are present
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, arg := range args {
|
||||
if arg == exp {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected arg %q not found in %v", exp, args)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResults(t *testing.T) {
|
||||
cfg := config.Default()
|
||||
cfg.WorkspaceRoot = "/test/workspace"
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
s := &Searcher{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
rgPath: "/usr/bin/rg",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
results *SearchResults
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "empty results",
|
||||
results: &SearchResults{
|
||||
Results: []Result{},
|
||||
},
|
||||
contains: []string{"No matches found"},
|
||||
},
|
||||
{
|
||||
name: "single result",
|
||||
results: &SearchResults{
|
||||
Results: []Result{
|
||||
{
|
||||
File: "test.go",
|
||||
Line: 10,
|
||||
Column: 5,
|
||||
MatchText: "func TestSomething()",
|
||||
},
|
||||
},
|
||||
Total: 1,
|
||||
},
|
||||
contains: []string{"test.go", "L10", "TestSomething"},
|
||||
},
|
||||
{
|
||||
name: "truncated results",
|
||||
results: &SearchResults{
|
||||
Results: []Result{
|
||||
{
|
||||
File: "test.go",
|
||||
Line: 10,
|
||||
MatchText: "match",
|
||||
},
|
||||
},
|
||||
Truncated: true,
|
||||
Total: 100,
|
||||
},
|
||||
contains: []string{"truncated", "100"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
output := s.FormatResults(tt.results)
|
||||
|
||||
for _, exp := range tt.contains {
|
||||
if !strings.Contains(output, exp) {
|
||||
t.Errorf("expected output to contain %q, got:\n%s", exp, output)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOutput(t *testing.T) {
|
||||
cfg := config.Default()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
s := &Searcher{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
rgPath: "/usr/bin/rg",
|
||||
}
|
||||
|
||||
// Sample ripgrep JSON output
|
||||
jsonOutput := `{"type":"begin","data":{"path":{"text":"test.go"}}}
|
||||
{"type":"match","data":{"path":{"text":"test.go"},"lines":{"text":"func TestSomething() {\n"},"line_number":10,"absolute_offset":100,"submatches":[{"match":{"text":"Test"},"start":5,"end":9}]}}
|
||||
{"type":"end","data":{"path":{"text":"test.go"},"stats":{"bytes_searched":1000}}}
|
||||
{"type":"summary","data":{"elapsed_total":{"secs":0,"nanos":1000000},"stats":{"searches":1,"searches_with_match":1,"bytes_searched":1000,"bytes_printed":100,"matched_lines":1,"matches":1}}}
|
||||
`
|
||||
buf := bytes.NewBufferString(jsonOutput)
|
||||
|
||||
results, err := s.parseOutput(buf, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("parseOutput failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results.Results) != 1 {
|
||||
t.Errorf("expected 1 result, got %d", len(results.Results))
|
||||
}
|
||||
|
||||
if results.Results[0].File != "test.go" {
|
||||
t.Errorf("expected file 'test.go', got %q", results.Results[0].File)
|
||||
}
|
||||
|
||||
if results.Results[0].Line != 10 {
|
||||
t.Errorf("expected line 10, got %d", results.Results[0].Line)
|
||||
}
|
||||
|
||||
if results.Results[0].Column != 6 { // 1-indexed
|
||||
t.Errorf("expected column 6, got %d", results.Results[0].Column)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
input: "short",
|
||||
maxLen: 10,
|
||||
expected: "short",
|
||||
},
|
||||
{
|
||||
input: "this is a very long line that should be truncated",
|
||||
maxLen: 20,
|
||||
expected: "this is a very lo...",
|
||||
},
|
||||
{
|
||||
input: "exact",
|
||||
maxLen: 5,
|
||||
expected: "exact",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := truncateLine(tt.input, tt.maxLen)
|
||||
if result != tt.expected {
|
||||
t.Errorf("truncateLine(%q, %d) = %q, want %q", tt.input, tt.maxLen, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchIntegration(t *testing.T) {
|
||||
// Create a temporary directory with test files
|
||||
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-search-test")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test files
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
func main() {
|
||||
println("Hello, World!")
|
||||
}
|
||||
`
|
||||
err = os.WriteFile(testFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.WorkspaceRoot = tmpDir
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
searcher, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Skip("ripgrep not installed, skipping integration test")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := &Request{
|
||||
Pattern: "Hello",
|
||||
ContextLines: 1,
|
||||
Regex: false,
|
||||
}
|
||||
|
||||
results, err := searcher.Search(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results.Results) != 1 {
|
||||
t.Errorf("expected 1 result, got %d", len(results.Results))
|
||||
}
|
||||
|
||||
if len(results.Results) > 0 && !strings.Contains(results.Results[0].MatchText, "Hello") {
|
||||
t.Errorf("expected match to contain 'Hello', got %q", results.Results[0].MatchText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchEmptyPattern(t *testing.T) {
|
||||
cfg := config.Default()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||
|
||||
s := &Searcher{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
rgPath: "/usr/bin/rg",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := &Request{
|
||||
Pattern: "",
|
||||
}
|
||||
|
||||
_, err := s.Search(ctx, req)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty pattern")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,993 @@
|
||||
// Package server implements the MCP server for file operations.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/edit"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/lsp"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// Server represents the MCP file operations server.
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
mcp *server.MCPServer
|
||||
searcher *search.Searcher
|
||||
parser *parser.Registry
|
||||
matcher *query.Matcher
|
||||
lspManager *lsp.Manager
|
||||
editor *edit.Engine
|
||||
}
|
||||
|
||||
// New creates a new MCP server instance.
|
||||
func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
|
||||
parserRegistry := parser.NewRegistry()
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
parser: parserRegistry,
|
||||
matcher: query.NewMatcher(parserRegistry),
|
||||
editor: edit.NewEngine(parserRegistry),
|
||||
}
|
||||
|
||||
// Initialize searcher
|
||||
searcher, err := search.New(cfg, logger)
|
||||
if err != nil {
|
||||
logger.Warn("ripgrep not available, search functionality disabled", "error", err)
|
||||
}
|
||||
s.searcher = searcher
|
||||
|
||||
// Initialize LSP manager if enabled
|
||||
if cfg.EnableLSP {
|
||||
s.lspManager = lsp.NewManager(cfg.WorkspaceRoot, logger)
|
||||
}
|
||||
|
||||
// Create MCP server
|
||||
mcpServer := server.NewMCPServer(
|
||||
"mcp-filepuff",
|
||||
"1.0.0",
|
||||
server.WithLogging(),
|
||||
)
|
||||
s.mcp = mcpServer
|
||||
|
||||
// Register tools
|
||||
s.registerTools()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// registerTools registers all available tools with the MCP server.
|
||||
func (s *Server) registerTools() {
|
||||
// Register ping tool for health checks
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("ping",
|
||||
mcp.WithDescription("Health check - returns pong to verify the server is running"),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
),
|
||||
s.handlePing,
|
||||
)
|
||||
|
||||
// Register file_search tool
|
||||
if s.searcher != nil {
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("file_search",
|
||||
mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("pattern",
|
||||
mcp.Required(),
|
||||
mcp.Description("The search pattern (regex by default)"),
|
||||
),
|
||||
mcp.WithArray("paths",
|
||||
mcp.Description("Paths to search in (defaults to workspace root)"),
|
||||
mcp.WithStringItems(),
|
||||
),
|
||||
mcp.WithArray("file_types",
|
||||
mcp.Description("File types to search (e.g., ['go', 'ts', 'py'])"),
|
||||
mcp.WithStringItems(),
|
||||
),
|
||||
mcp.WithBoolean("ignore_case",
|
||||
mcp.Description("Case insensitive search"),
|
||||
),
|
||||
mcp.WithBoolean("regex",
|
||||
mcp.Description("Treat pattern as regex (default: true)"),
|
||||
),
|
||||
mcp.WithNumber("context_lines",
|
||||
mcp.Description("Number of context lines around matches (default: 2)"),
|
||||
),
|
||||
mcp.WithNumber("max_results",
|
||||
mcp.Description("Maximum number of results to return"),
|
||||
),
|
||||
),
|
||||
s.handleFileSearch,
|
||||
)
|
||||
}
|
||||
|
||||
// Register file_read tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("file_read",
|
||||
mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary"),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("path",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file to read"),
|
||||
),
|
||||
mcp.WithNumber("line_start",
|
||||
mcp.Description("Starting line number (1-indexed)"),
|
||||
),
|
||||
mcp.WithNumber("line_end",
|
||||
mcp.Description("Ending line number (inclusive)"),
|
||||
),
|
||||
mcp.WithBoolean("include_ast",
|
||||
mcp.Description("Include AST symbol summary (functions, classes, types, etc.)"),
|
||||
),
|
||||
mcp.WithBoolean("symbols_only",
|
||||
mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true."),
|
||||
),
|
||||
mcp.WithNumber("max_lines",
|
||||
mcp.Description("Maximum number of lines to return (for token efficiency). Applied after line_start/line_end."),
|
||||
),
|
||||
),
|
||||
s.handleFileRead,
|
||||
)
|
||||
|
||||
// Register ast_query tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("ast_query",
|
||||
mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("pattern",
|
||||
mcp.Required(),
|
||||
mcp.Description("Code pattern with placeholders: $NAME (single), $$$ARGS (multiple), $_ (wildcard). Examples: 'func $NAME($$$ARGS) error', 'class $NAME { $$$BODY }'"),
|
||||
),
|
||||
mcp.WithString("language",
|
||||
mcp.Required(),
|
||||
mcp.Description("Target language: go, typescript, javascript, python, c, cpp"),
|
||||
),
|
||||
mcp.WithArray("paths",
|
||||
mcp.Description("Paths to search in (defaults to workspace root)"),
|
||||
mcp.WithStringItems(),
|
||||
),
|
||||
mcp.WithString("name_matches",
|
||||
mcp.Description("Regex pattern to filter by name"),
|
||||
),
|
||||
mcp.WithString("name_exact",
|
||||
mcp.Description("Exact name to match"),
|
||||
),
|
||||
mcp.WithArray("kind_in",
|
||||
mcp.Description("Node types to match (e.g., function_declaration, class_declaration)"),
|
||||
mcp.WithStringItems(),
|
||||
),
|
||||
mcp.WithNumber("max_results",
|
||||
mcp.Description("Maximum number of results to return (default: 100)"),
|
||||
),
|
||||
),
|
||||
s.handleASTQuery,
|
||||
)
|
||||
|
||||
// Register LSP-based tools if LSP is enabled
|
||||
if s.lspManager != nil {
|
||||
// Register symbol_at tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("symbol_at",
|
||||
mcp.WithDescription("Get information about the symbol at a specific position in a file. Returns type, documentation, and definition location using LSP when available."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file"),
|
||||
),
|
||||
mcp.WithNumber("line",
|
||||
mcp.Required(),
|
||||
mcp.Description("Line number (1-indexed)"),
|
||||
),
|
||||
mcp.WithNumber("column",
|
||||
mcp.Required(),
|
||||
mcp.Description("Column number (1-indexed)"),
|
||||
),
|
||||
),
|
||||
s.handleSymbolAt,
|
||||
)
|
||||
|
||||
// Register find_definition tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("find_definition",
|
||||
mcp.WithDescription("Find the definition of the symbol at a specific position. Uses LSP to locate where a function, variable, type, etc. is defined."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file"),
|
||||
),
|
||||
mcp.WithNumber("line",
|
||||
mcp.Required(),
|
||||
mcp.Description("Line number (1-indexed)"),
|
||||
),
|
||||
mcp.WithNumber("column",
|
||||
mcp.Required(),
|
||||
mcp.Description("Column number (1-indexed)"),
|
||||
),
|
||||
),
|
||||
s.handleFindDefinition,
|
||||
)
|
||||
|
||||
// Register find_references tool
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("find_references",
|
||||
mcp.WithDescription("Find all references to the symbol at a specific position. Uses LSP to locate all usages of a function, variable, type, etc."),
|
||||
mcp.WithReadOnlyHintAnnotation(true),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file"),
|
||||
),
|
||||
mcp.WithNumber("line",
|
||||
mcp.Required(),
|
||||
mcp.Description("Line number (1-indexed)"),
|
||||
),
|
||||
mcp.WithNumber("column",
|
||||
mcp.Required(),
|
||||
mcp.Description("Column number (1-indexed)"),
|
||||
),
|
||||
mcp.WithBoolean("include_declaration",
|
||||
mcp.Description("Include the declaration in results (default: true)"),
|
||||
),
|
||||
),
|
||||
s.handleFindReferences,
|
||||
)
|
||||
}
|
||||
|
||||
// Register edit tools
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("edit_preview",
|
||||
mcp.WithDescription("Preview an edit without applying it. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++), and text-based editing for other files (Markdown, JSON, YAML, config files, etc.)."),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file to edit"),
|
||||
),
|
||||
mcp.WithString("operation",
|
||||
mcp.Required(),
|
||||
mcp.Description("Edit operation: replace, insert_before, insert_after, delete"),
|
||||
),
|
||||
mcp.WithString("new_content",
|
||||
mcp.Description("New content (required for replace/insert operations)"),
|
||||
),
|
||||
// AST-mode selectors (for code files)
|
||||
mcp.WithString("selector_kind",
|
||||
mcp.Description("AST node type to match (e.g., function_declaration, class_declaration). For code files only."),
|
||||
),
|
||||
mcp.WithString("selector_name",
|
||||
mcp.Description("Name of the symbol to match. For code files only."),
|
||||
),
|
||||
// Shared selectors
|
||||
mcp.WithNumber("selector_line",
|
||||
mcp.Description("Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range."),
|
||||
),
|
||||
mcp.WithNumber("selector_index",
|
||||
mcp.Description("Index of the match to use if multiple matches found (default: 0)"),
|
||||
),
|
||||
// Text-mode selectors (for non-code files or explicit text matching)
|
||||
mcp.WithNumber("selector_line_end",
|
||||
mcp.Description("End line number for range selection (text mode). Used with selector_line."),
|
||||
),
|
||||
mcp.WithString("selector_text",
|
||||
mcp.Description("Exact text to match (text mode). Must be unique or use selector_index."),
|
||||
),
|
||||
mcp.WithString("selector_pattern",
|
||||
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
|
||||
),
|
||||
),
|
||||
s.handleEditPreview,
|
||||
)
|
||||
|
||||
s.mcp.AddTool(
|
||||
mcp.NewTool("edit_apply",
|
||||
mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.)."),
|
||||
mcp.WithString("file",
|
||||
mcp.Required(),
|
||||
mcp.Description("Path to the file to edit"),
|
||||
),
|
||||
mcp.WithString("operation",
|
||||
mcp.Required(),
|
||||
mcp.Description("Edit operation: replace, insert_before, insert_after, delete"),
|
||||
),
|
||||
mcp.WithString("new_content",
|
||||
mcp.Description("New content (required for replace/insert operations)"),
|
||||
),
|
||||
// AST-mode selectors (for code files)
|
||||
mcp.WithString("selector_kind",
|
||||
mcp.Description("AST node type to match (e.g., function_declaration, class_declaration). For code files only."),
|
||||
),
|
||||
mcp.WithString("selector_name",
|
||||
mcp.Description("Name of the symbol to match. For code files only."),
|
||||
),
|
||||
// Shared selectors
|
||||
mcp.WithNumber("selector_line",
|
||||
mcp.Description("Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range."),
|
||||
),
|
||||
mcp.WithNumber("selector_index",
|
||||
mcp.Description("Index of the match to use if multiple matches found (default: 0)"),
|
||||
),
|
||||
// Text-mode selectors (for non-code files or explicit text matching)
|
||||
mcp.WithNumber("selector_line_end",
|
||||
mcp.Description("End line number for range selection (text mode). Used with selector_line."),
|
||||
),
|
||||
mcp.WithString("selector_text",
|
||||
mcp.Description("Exact text to match (text mode). Must be unique or use selector_index."),
|
||||
),
|
||||
mcp.WithString("selector_pattern",
|
||||
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
|
||||
),
|
||||
),
|
||||
s.handleEditApply,
|
||||
)
|
||||
}
|
||||
|
||||
// handlePing handles the ping health check tool.
|
||||
func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("pong"), nil
|
||||
}
|
||||
|
||||
// handleFileSearch handles the file_search tool.
|
||||
func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
s.logger.Debug("file_search completed",
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
}()
|
||||
|
||||
if s.searcher == nil {
|
||||
return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil
|
||||
}
|
||||
|
||||
// Parse request arguments using SDK helpers
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
}
|
||||
|
||||
req := &search.Request{
|
||||
Pattern: pattern,
|
||||
Paths: request.GetStringSlice("paths", nil),
|
||||
FileTypes: request.GetStringSlice("file_types", nil),
|
||||
IgnoreCase: request.GetBool("ignore_case", false),
|
||||
Regex: request.GetBool("regex", true),
|
||||
ContextLines: request.GetInt("context_lines", 2),
|
||||
MaxResults: request.GetInt("max_results", 0),
|
||||
}
|
||||
|
||||
// Execute search
|
||||
results, err := s.searcher.Search(ctx, req)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("search error: %v", err)), nil
|
||||
}
|
||||
|
||||
s.logger.Info("search completed",
|
||||
"pattern", pattern,
|
||||
"results_count", len(results.Results),
|
||||
"truncated", results.Truncated,
|
||||
)
|
||||
|
||||
// Format results
|
||||
output := s.searcher.FormatResults(results)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
// handleFileRead handles the file_read tool.
|
||||
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
path, err := request.RequireString("path")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("path is required"), nil
|
||||
}
|
||||
|
||||
// Validate path is within workspace
|
||||
if !s.cfg.IsPathAllowed(path) {
|
||||
return mcp.NewToolResultError("path is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
|
||||
}
|
||||
if os.IsPermission(err) {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
|
||||
}
|
||||
return mcp.NewToolResultError(fmt.Sprintf("error reading file: %v", err)), nil
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
|
||||
}
|
||||
|
||||
// Handle line range
|
||||
lines := splitLines(string(content))
|
||||
lineStart := request.GetInt("line_start", 1)
|
||||
lineEnd := request.GetInt("line_end", len(lines))
|
||||
|
||||
// Clamp to valid range
|
||||
if lineStart < 1 {
|
||||
lineStart = 1
|
||||
}
|
||||
if lineEnd > len(lines) {
|
||||
lineEnd = len(lines)
|
||||
}
|
||||
if lineStart > lineEnd {
|
||||
lineStart = lineEnd
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
|
||||
// Include AST summary if requested
|
||||
includeAST := request.GetBool("include_ast", false)
|
||||
symbolsOnly := request.GetBool("symbols_only", false)
|
||||
maxLines := request.GetInt("max_lines", 0)
|
||||
|
||||
// Validate symbols_only requires include_ast
|
||||
if symbolsOnly && !includeAST {
|
||||
return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil
|
||||
}
|
||||
|
||||
if includeAST {
|
||||
astSummary := s.generateASTSummary(ctx, path, content)
|
||||
if astSummary != "" {
|
||||
output.WriteString(astSummary)
|
||||
if !symbolsOnly {
|
||||
output.WriteString("\n---\n\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip file content if symbols_only mode
|
||||
if !symbolsOnly {
|
||||
// Apply max_lines limit if specified
|
||||
effectiveEnd := lineEnd
|
||||
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
|
||||
effectiveEnd = lineStart + maxLines - 1
|
||||
if effectiveEnd < lineEnd {
|
||||
// Add note that output was truncated
|
||||
defer func() {
|
||||
output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Extract requested lines
|
||||
for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ {
|
||||
output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i]))
|
||||
}
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// generateASTSummary generates a summary of symbols in the file.
|
||||
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
|
||||
// Parse the file
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return "" // Silently skip AST if parsing fails
|
||||
}
|
||||
|
||||
// Extract symbols
|
||||
lang := protocol.DetectLanguage(path)
|
||||
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
|
||||
if len(symbols) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Get relative path
|
||||
relPath := path
|
||||
if absPath, err := filepath.Abs(path); err == nil {
|
||||
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
|
||||
relPath = rel
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
|
||||
sb.WriteString("Symbols:\n")
|
||||
|
||||
for _, sym := range symbols {
|
||||
kindStr := symbolKindIcon(sym.Kind)
|
||||
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// symbolKindIcon returns an icon/prefix for a symbol kind.
|
||||
func symbolKindIcon(kind protocol.SymbolKind) string {
|
||||
switch kind {
|
||||
case protocol.SymbolFunction:
|
||||
return "func"
|
||||
case protocol.SymbolMethod:
|
||||
return "meth"
|
||||
case protocol.SymbolClass:
|
||||
return "class"
|
||||
case protocol.SymbolStruct:
|
||||
return "struct"
|
||||
case protocol.SymbolInterface:
|
||||
return "iface"
|
||||
case protocol.SymbolVariable:
|
||||
return "var"
|
||||
case protocol.SymbolConstant:
|
||||
return "const"
|
||||
case protocol.SymbolType:
|
||||
return "type"
|
||||
case protocol.SymbolField:
|
||||
return "field"
|
||||
case protocol.SymbolProperty:
|
||||
return "prop"
|
||||
case protocol.SymbolModule:
|
||||
return "mod"
|
||||
case protocol.SymbolPackage:
|
||||
return "pkg"
|
||||
default:
|
||||
return "sym"
|
||||
}
|
||||
}
|
||||
|
||||
func splitLines(s string) []string {
|
||||
// Use optimized stdlib implementation (2-3x faster than manual loop)
|
||||
return strings.Split(s, "\n")
|
||||
}
|
||||
|
||||
// handleASTQuery handles the ast_query tool.
|
||||
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
pattern, err := request.RequireString("pattern")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("pattern is required"), nil
|
||||
}
|
||||
|
||||
language, err := request.RequireString("language")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("language is required"), nil
|
||||
}
|
||||
|
||||
// Build query
|
||||
astQuery := &query.ASTQuery{
|
||||
Pattern: pattern,
|
||||
Language: language,
|
||||
Filters: query.QueryFilters{
|
||||
NameMatches: request.GetString("name_matches", ""),
|
||||
NameExact: request.GetString("name_exact", ""),
|
||||
KindIn: request.GetStringSlice("kind_in", nil),
|
||||
},
|
||||
}
|
||||
|
||||
maxResults := request.GetInt("max_results", 100)
|
||||
paths := request.GetStringSlice("paths", nil)
|
||||
|
||||
// Default to workspace root if no paths specified
|
||||
if len(paths) == 0 {
|
||||
paths = []string{s.cfg.WorkspaceRoot}
|
||||
}
|
||||
|
||||
// Find files to search based on language
|
||||
ext := languageToExtension(language)
|
||||
if ext == "" {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
|
||||
}
|
||||
|
||||
var allResults []query.MatchResult
|
||||
|
||||
// Walk through paths and find matching files
|
||||
for _, searchPath := range paths {
|
||||
// Validate path is within workspace
|
||||
if !s.cfg.IsPathAllowed(searchPath) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip files with errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Skip hidden directories
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check file extension matches language
|
||||
if !strings.HasSuffix(path, ext) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read and parse file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil // Skip unreadable files
|
||||
}
|
||||
|
||||
// Check file size
|
||||
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||
return nil // Skip large files
|
||||
}
|
||||
|
||||
// Parse file
|
||||
result, err := s.parser.Parse(ctx, path, content)
|
||||
if err != nil {
|
||||
return nil // Skip unparseable files
|
||||
}
|
||||
|
||||
// Run query
|
||||
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
|
||||
if err != nil {
|
||||
return nil // Skip on error
|
||||
}
|
||||
|
||||
allResults = append(allResults, matches...)
|
||||
|
||||
// Stop if we have enough results
|
||||
if maxResults > 0 && len(allResults) >= maxResults {
|
||||
return filepath.SkipAll
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn("error walking path", "path", searchPath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Format and return results
|
||||
output := query.FormatResults(allResults, maxResults)
|
||||
return mcp.NewToolResultText(output), nil
|
||||
}
|
||||
|
||||
// languageToExtension maps language names to file extensions.
|
||||
func languageToExtension(language string) string {
|
||||
switch strings.ToLower(language) {
|
||||
case "go":
|
||||
return ".go"
|
||||
case "typescript":
|
||||
return ".ts"
|
||||
case "javascript":
|
||||
return ".js"
|
||||
case "python":
|
||||
return ".py"
|
||||
case "c":
|
||||
return ".c"
|
||||
case "cpp", "c++":
|
||||
return ".cpp"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// handleSymbolAt handles the symbol_at tool.
|
||||
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Try LSP hover
|
||||
hover, err := s.lspManager.Hover(ctx, file, line, col)
|
||||
if err != nil {
|
||||
// Fall back to AST-based info
|
||||
return s.handleSymbolAtFallback(ctx, file, line, col)
|
||||
}
|
||||
|
||||
if hover == nil {
|
||||
return mcp.NewToolResultText("No symbol information available at this position."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information**\n\n")
|
||||
output.WriteString(hover.Contents.Value)
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
|
||||
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %v", err)), nil
|
||||
}
|
||||
|
||||
result, err := s.parser.Parse(ctx, file, content)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to parse file: %v", err)), nil
|
||||
}
|
||||
|
||||
node := parser.FindNodeAtPosition(result.Tree, line, col)
|
||||
if node == nil {
|
||||
return mcp.NewToolResultText("No symbol at this position."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("**Symbol Information** (AST fallback)\n\n")
|
||||
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
|
||||
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindDefinition handles the find_definition tool.
|
||||
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
locations, err := s.lspManager.Definition(ctx, file, line, col)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %v", err)), nil
|
||||
}
|
||||
|
||||
if len(locations) == 0 {
|
||||
return mcp.NewToolResultText("No definition found."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
|
||||
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
|
||||
// Try to read a preview snippet
|
||||
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
|
||||
if preview != "" {
|
||||
output.WriteString("```\n")
|
||||
output.WriteString(preview)
|
||||
output.WriteString("```\n")
|
||||
}
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// handleFindReferences handles the find_references tool.
|
||||
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
line := request.GetInt("line", 0)
|
||||
if line <= 0 {
|
||||
return mcp.NewToolResultError("line must be positive"), nil
|
||||
}
|
||||
|
||||
col := request.GetInt("column", 0)
|
||||
if col <= 0 {
|
||||
return mcp.NewToolResultError("column must be positive"), nil
|
||||
}
|
||||
|
||||
includeDecl := request.GetBool("include_declaration", true)
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %v", err)), nil
|
||||
}
|
||||
|
||||
if len(locations) == 0 {
|
||||
return mcp.NewToolResultText("No references found."), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
|
||||
|
||||
// Group by file
|
||||
fileGroups := make(map[string][]lsp.Location)
|
||||
for _, loc := range locations {
|
||||
filePath := lsp.URIToFile(loc.URI)
|
||||
fileGroups[filePath] = append(fileGroups[filePath], loc)
|
||||
}
|
||||
|
||||
for filePath, locs := range fileGroups {
|
||||
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
|
||||
for _, loc := range locs {
|
||||
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||
}
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// readFilePreview reads a few lines from a file around the given line.
|
||||
func readFilePreview(file string, line, contextLines int) string {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := splitLines(string(content))
|
||||
startLine := max(1, line-contextLines)
|
||||
endLine := min(line+contextLines, len(lines))
|
||||
|
||||
var preview strings.Builder
|
||||
for i := startLine - 1; i < endLine && i < len(lines); i++ {
|
||||
lineText := lines[i]
|
||||
if len(lineText) > 100 {
|
||||
lineText = lineText[:100] + "..."
|
||||
}
|
||||
prefix := " "
|
||||
if i+1 == line {
|
||||
prefix = "> "
|
||||
}
|
||||
preview.WriteString(fmt.Sprintf("%s%4d: %s\n", prefix, i+1, lineText))
|
||||
}
|
||||
|
||||
return preview.String()
|
||||
}
|
||||
|
||||
// handleEditPreview handles the edit_preview tool.
|
||||
func (s *Server) handleEditPreview(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return s.handleEdit(ctx, request, false)
|
||||
}
|
||||
|
||||
// handleEditApply handles the edit_apply tool.
|
||||
func (s *Server) handleEditApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return s.handleEdit(ctx, request, true)
|
||||
}
|
||||
|
||||
// handleEdit is the shared implementation for edit_preview and edit_apply.
|
||||
func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, apply bool) (*mcp.CallToolResult, error) {
|
||||
file, err := request.RequireString("file")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("file is required"), nil
|
||||
}
|
||||
|
||||
operation, err := request.RequireString("operation")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("operation is required"), nil
|
||||
}
|
||||
|
||||
// Validate path
|
||||
if !s.cfg.IsPathAllowed(file) {
|
||||
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||
}
|
||||
|
||||
// Note: We no longer validate language support here.
|
||||
// The edit engine automatically detects whether to use AST or text mode.
|
||||
|
||||
// Build edit request with both AST and text-mode selectors
|
||||
astEdit := &edit.ASTEdit{
|
||||
File: file,
|
||||
Operation: edit.EditOperation(operation),
|
||||
NewContent: request.GetString("new_content", ""),
|
||||
Selector: edit.ASTSelector{
|
||||
// AST-mode selectors
|
||||
Kind: request.GetString("selector_kind", ""),
|
||||
Name: request.GetString("selector_name", ""),
|
||||
AtLine: request.GetInt("selector_line", 0),
|
||||
Index: request.GetInt("selector_index", 0),
|
||||
// Text-mode selectors
|
||||
LineEnd: request.GetInt("selector_line_end", 0),
|
||||
Text: request.GetString("selector_text", ""),
|
||||
TextPattern: request.GetString("selector_pattern", ""),
|
||||
},
|
||||
}
|
||||
|
||||
// Perform edit
|
||||
var result *edit.EditResult
|
||||
if apply {
|
||||
result, err = s.editor.Apply(ctx, astEdit)
|
||||
} else {
|
||||
result, err = s.editor.Preview(ctx, astEdit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("edit failed: %v", err)), nil
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
return mcp.NewToolResultError(result.Error), nil
|
||||
}
|
||||
|
||||
// Format output
|
||||
var output strings.Builder
|
||||
if apply {
|
||||
output.WriteString("**Edit Applied Successfully**\n\n")
|
||||
} else {
|
||||
output.WriteString("**Edit Preview**\n\n")
|
||||
}
|
||||
|
||||
output.WriteString("Diff:\n```diff\n")
|
||||
output.WriteString(result.Diff)
|
||||
output.WriteString("```\n")
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// Run starts the MCP server and blocks until shutdown.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
// Set up signal handling for graceful shutdown
|
||||
_, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
s.logger.Info("received shutdown signal", "signal", sig)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
s.logger.Info("starting MCP server",
|
||||
"workspace", s.cfg.WorkspaceRoot,
|
||||
"lsp_enabled", s.cfg.EnableLSP,
|
||||
)
|
||||
|
||||
// Start the MCP server with stdio transport
|
||||
return server.ServeStdio(s.mcp)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
s.logger.Info("shutting down MCP server")
|
||||
|
||||
// Close LSP manager
|
||||
if s.lspManager != nil {
|
||||
_ = s.lspManager.Close()
|
||||
}
|
||||
|
||||
// Close parser registry
|
||||
if s.parser != nil {
|
||||
s.parser.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// Create temp directory for testing
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
WorkspaceRoot: tmpDir,
|
||||
EnableLSP: false, // Disable LSP for simpler testing
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
if srv == nil {
|
||||
t.Fatal("New() returned nil server")
|
||||
}
|
||||
|
||||
if srv.cfg != cfg {
|
||||
t.Error("server config mismatch")
|
||||
}
|
||||
|
||||
if srv.parser == nil {
|
||||
t.Error("parser should not be nil")
|
||||
}
|
||||
|
||||
if srv.matcher == nil {
|
||||
t.Error("matcher should not be nil")
|
||||
}
|
||||
|
||||
if srv.editor == nil {
|
||||
t.Error("editor should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
|
||||
result, err := srv.handlePing(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("handlePing() error = %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handlePing() returned nil result")
|
||||
}
|
||||
|
||||
// Check that the result contains "pong"
|
||||
contents := result.Content
|
||||
if len(contents) == 0 {
|
||||
t.Fatal("handlePing() returned empty content")
|
||||
}
|
||||
|
||||
textContent, ok := contents[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("handlePing() did not return text content")
|
||||
}
|
||||
|
||||
if textContent.Text != "pong" {
|
||||
t.Errorf("handlePing() = %v, want 'pong'", textContent.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleFileRead(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test file
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
// Hello says hello
|
||||
func Hello() {
|
||||
println("Hello, World!")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"path": testFile,
|
||||
}
|
||||
|
||||
result, err := srv.handleFileRead(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("handleFileRead() error = %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead() returned nil result")
|
||||
}
|
||||
|
||||
contents := result.Content
|
||||
if len(contents) == 0 {
|
||||
t.Fatal("handleFileRead() returned empty content")
|
||||
}
|
||||
|
||||
textContent, ok := contents[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("handleFileRead() did not return text content")
|
||||
}
|
||||
|
||||
// Should contain the file content
|
||||
if textContent.Text == "" {
|
||||
t.Error("handleFileRead() returned empty text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleFileReadWithAST(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test file
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
// Hello says hello
|
||||
func Hello() {
|
||||
println("Hello, World!")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"path": testFile,
|
||||
"include_ast": true,
|
||||
}
|
||||
|
||||
result, err := srv.handleFileRead(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("handleFileRead() error = %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleFileRead() returned nil result")
|
||||
}
|
||||
|
||||
contents := result.Content
|
||||
if len(contents) == 0 {
|
||||
t.Fatal("handleFileRead() returned empty content")
|
||||
}
|
||||
|
||||
textContent, ok := contents[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatal("handleFileRead() did not return text content")
|
||||
}
|
||||
|
||||
// Should contain "Symbols:" section when include_ast is true
|
||||
if textContent.Text == "" {
|
||||
t.Error("handleFileRead() returned empty text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleFileReadNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"path": filepath.Join(tmpDir, "nonexistent.go"),
|
||||
}
|
||||
|
||||
result, err := srv.handleFileRead(ctx, req)
|
||||
// Should return error for non-existent file
|
||||
if err == nil && result != nil {
|
||||
// Check if result indicates an error
|
||||
contents := result.Content
|
||||
if len(contents) > 0 {
|
||||
textContent, ok := contents[0].(mcp.TextContent)
|
||||
if ok && textContent.Text == "" {
|
||||
t.Log("handleFileRead() returned empty text for non-existent file")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleASTQuery(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test file
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
func Hello() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Goodbye() error {
|
||||
return nil
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"pattern": "func $NAME() error",
|
||||
"language": "go",
|
||||
"paths": []interface{}{tmpDir},
|
||||
}
|
||||
|
||||
result, err := srv.handleASTQuery(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("handleASTQuery() error = %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleASTQuery() returned nil result")
|
||||
}
|
||||
|
||||
contents := result.Content
|
||||
if len(contents) == 0 {
|
||||
t.Fatal("handleASTQuery() returned empty content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleEditPreview(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test file
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("Hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"file": testFile,
|
||||
"operation": "replace",
|
||||
"selector_kind": "function_declaration",
|
||||
"selector_name": "Hello",
|
||||
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
|
||||
}
|
||||
|
||||
result, err := srv.handleEdit(ctx, req, false)
|
||||
if err != nil {
|
||||
t.Errorf("handleEdit(preview) error = %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleEdit(preview) returned nil result")
|
||||
}
|
||||
|
||||
// Verify file was NOT modified (it's just a preview)
|
||||
fileContent, _ := os.ReadFile(testFile)
|
||||
if string(fileContent) != content {
|
||||
t.Error("handleEdit(preview) should not modify the file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleEditApply(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test file
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("Hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
srv, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = map[string]interface{}{
|
||||
"file": testFile,
|
||||
"operation": "replace",
|
||||
"selector_kind": "function_declaration",
|
||||
"selector_name": "Hello",
|
||||
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
|
||||
}
|
||||
|
||||
result, err := srv.handleEdit(ctx, req, true)
|
||||
if err != nil {
|
||||
t.Errorf("handleEdit(apply) error = %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("handleEdit(apply) returned nil result")
|
||||
}
|
||||
|
||||
// Verify file WAS modified
|
||||
fileContent, _ := os.ReadFile(testFile)
|
||||
if string(fileContent) == content {
|
||||
t.Error("handleEdit(apply) should modify the file")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,289 @@
|
||||
// Package errors provides structured error handling with error codes and context.
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrorCode represents a specific error condition.
|
||||
type ErrorCode int
|
||||
|
||||
// Error codes organized by category
|
||||
const (
|
||||
// Search errors (1000-1099)
|
||||
ErrRipgrepNotFound ErrorCode = 1001
|
||||
ErrRipgrepTimeout ErrorCode = 1002
|
||||
ErrInvalidPattern ErrorCode = 1003
|
||||
ErrSearchFailed ErrorCode = 1004
|
||||
ErrNoResults ErrorCode = 1005
|
||||
|
||||
// Parser errors (1100-1199)
|
||||
ErrParserNotFound ErrorCode = 1101
|
||||
ErrParseFailed ErrorCode = 1102
|
||||
ErrInvalidLanguage ErrorCode = 1103
|
||||
ErrFileTooBig ErrorCode = 1104
|
||||
ErrInvalidSyntax ErrorCode = 1105
|
||||
|
||||
// LSP errors (1200-1299)
|
||||
ErrLSPServerNotFound ErrorCode = 1201
|
||||
ErrLSPInitFailed ErrorCode = 1202
|
||||
ErrLSPTimeout ErrorCode = 1203
|
||||
ErrLSPCommunication ErrorCode = 1204
|
||||
ErrNoHoverInfo ErrorCode = 1205
|
||||
ErrNoDefinition ErrorCode = 1206
|
||||
ErrNoReferences ErrorCode = 1207
|
||||
|
||||
// Edit errors (1300-1399)
|
||||
ErrEditFailed ErrorCode = 1301
|
||||
ErrInvalidEdit ErrorCode = 1302
|
||||
ErrFileNotFound ErrorCode = 1303
|
||||
ErrFileNotReadable ErrorCode = 1304
|
||||
ErrFileNotWritable ErrorCode = 1305
|
||||
ErrNodeNotFound ErrorCode = 1306
|
||||
ErrValidationFailed ErrorCode = 1307
|
||||
ErrInvalidSelection ErrorCode = 1308
|
||||
|
||||
// Query errors (1400-1499)
|
||||
ErrInvalidQuery ErrorCode = 1401
|
||||
ErrQueryTimeout ErrorCode = 1402
|
||||
ErrNoMatches ErrorCode = 1403
|
||||
ErrQueryCompile ErrorCode = 1404
|
||||
|
||||
// Config errors (1500-1599)
|
||||
ErrInvalidConfig ErrorCode = 1501
|
||||
ErrPathNotAllowed ErrorCode = 1502
|
||||
ErrWorkspaceNotSet ErrorCode = 1503
|
||||
|
||||
// Internal errors (1900-1999)
|
||||
ErrInternal ErrorCode = 1900
|
||||
ErrCacheFailed ErrorCode = 1901
|
||||
ErrConcurrency ErrorCode = 1902
|
||||
)
|
||||
|
||||
// StructuredError represents an error with rich context and remediation info.
|
||||
type StructuredError struct {
|
||||
Cause error
|
||||
Context map[string]any
|
||||
Message string
|
||||
Remediation string
|
||||
Stack string
|
||||
Code ErrorCode
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *StructuredError) Error() string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Error code and message
|
||||
sb.WriteString(fmt.Sprintf("[%d] %s", e.Code, e.Message))
|
||||
|
||||
// Context if available
|
||||
if len(e.Context) > 0 {
|
||||
sb.WriteString("\nContext:")
|
||||
for k, v := range e.Context {
|
||||
sb.WriteString(fmt.Sprintf("\n %s: %v", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
// Remediation if available
|
||||
if e.Remediation != "" {
|
||||
sb.WriteString(fmt.Sprintf("\nHow to fix: %s", e.Remediation))
|
||||
}
|
||||
|
||||
// Underlying cause if available
|
||||
if e.Cause != nil {
|
||||
sb.WriteString(fmt.Sprintf("\nCaused by: %v", e.Cause))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying cause for error chain support.
|
||||
func (e *StructuredError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// WithContext adds context to the error.
|
||||
func (e *StructuredError) WithContext(key string, value any) *StructuredError {
|
||||
if e.Context == nil {
|
||||
e.Context = make(map[string]any)
|
||||
}
|
||||
e.Context[key] = value
|
||||
return e
|
||||
}
|
||||
|
||||
// WithRemediation sets the remediation message.
|
||||
func (e *StructuredError) WithRemediation(msg string) *StructuredError {
|
||||
e.Remediation = msg
|
||||
return e
|
||||
}
|
||||
|
||||
// New creates a new structured error with stack trace.
|
||||
func New(code ErrorCode, message string) *StructuredError {
|
||||
return &StructuredError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Context: make(map[string]interface{}),
|
||||
Stack: captureStack(2),
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap wraps an existing error with structured error information.
|
||||
func Wrap(code ErrorCode, message string, cause error) *StructuredError {
|
||||
return &StructuredError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Context: make(map[string]interface{}),
|
||||
Cause: cause,
|
||||
Stack: captureStack(2),
|
||||
}
|
||||
}
|
||||
|
||||
// Is checks if an error matches the given error code.
|
||||
func Is(err error, code ErrorCode) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if se, ok := err.(*StructuredError); ok {
|
||||
return se.Code == code
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCode extracts the error code from an error, or returns 0 if not a structured error.
|
||||
func GetCode(err error) ErrorCode {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if se, ok := err.(*StructuredError); ok {
|
||||
return se.Code
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// captureStack captures the stack trace.
|
||||
func captureStack(skip int) string {
|
||||
const depth = 16
|
||||
var pcs [depth]uintptr
|
||||
n := runtime.Callers(skip+1, pcs[:])
|
||||
|
||||
var sb strings.Builder
|
||||
frames := runtime.CallersFrames(pcs[:n])
|
||||
|
||||
for {
|
||||
frame, more := frames.Next()
|
||||
if !strings.Contains(frame.File, "runtime/") {
|
||||
sb.WriteString(fmt.Sprintf("\n %s:%d %s", frame.File, frame.Line, frame.Function))
|
||||
}
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// Common error constructors for convenience
|
||||
|
||||
// NewRipgrepNotFound creates an error for missing ripgrep binary.
|
||||
func NewRipgrepNotFound() *StructuredError {
|
||||
os := runtime.GOOS
|
||||
install := "brew install ripgrep"
|
||||
|
||||
switch os {
|
||||
case "linux":
|
||||
install = "apt-get install ripgrep (Debian/Ubuntu) or yum install ripgrep (RHEL/CentOS)"
|
||||
case "windows":
|
||||
install = "choco install ripgrep or scoop install ripgrep"
|
||||
}
|
||||
|
||||
return New(ErrRipgrepNotFound, "ripgrep (rg) binary not found in system PATH").
|
||||
WithContext("os", os).
|
||||
WithRemediation(fmt.Sprintf("Install ripgrep: %s", install))
|
||||
}
|
||||
|
||||
// NewLSPServerNotFound creates an error for missing LSP server.
|
||||
func NewLSPServerNotFound(language, serverName string) *StructuredError {
|
||||
return New(ErrLSPServerNotFound, fmt.Sprintf("LSP server '%s' not found for language %s", serverName, language)).
|
||||
WithContext("language", language).
|
||||
WithContext("server", serverName).
|
||||
WithRemediation(fmt.Sprintf("Install the %s LSP server to enable IDE features for %s", serverName, language))
|
||||
}
|
||||
|
||||
// NewFileTooLarge creates an error for files exceeding size limit.
|
||||
func NewFileTooLarge(path string, size, limit int64) *StructuredError {
|
||||
return New(ErrFileTooBig, "file exceeds maximum size limit").
|
||||
WithContext("file", path).
|
||||
WithContext("size_bytes", size).
|
||||
WithContext("limit_bytes", limit).
|
||||
WithRemediation(fmt.Sprintf("File size (%d bytes) exceeds limit (%d bytes). Consider processing smaller files or increasing the limit.", size, limit))
|
||||
}
|
||||
|
||||
// NewParseError creates an error for parsing failures.
|
||||
func NewParseError(language, file string, cause error) *StructuredError {
|
||||
return Wrap(ErrParseFailed, fmt.Sprintf("failed to parse %s file", language), cause).
|
||||
WithContext("language", language).
|
||||
WithContext("file", file).
|
||||
WithRemediation("Check file syntax and ensure it's valid source code for the specified language")
|
||||
}
|
||||
|
||||
// NewSearchTimeout creates an error for search timeouts.
|
||||
func NewSearchTimeout(pattern string, duration string) *StructuredError {
|
||||
return New(ErrRipgrepTimeout, "search operation timed out").
|
||||
WithContext("pattern", pattern).
|
||||
WithContext("duration", duration).
|
||||
WithRemediation("Try narrowing the search scope, using more specific patterns, or increasing the timeout limit")
|
||||
}
|
||||
|
||||
// NewEditValidationError creates an error for edit validation failures.
|
||||
func NewEditValidationError(file string, cause error) *StructuredError {
|
||||
return Wrap(ErrValidationFailed, "edit validation failed - syntax errors detected", cause).
|
||||
WithContext("file", file).
|
||||
WithRemediation("Review the edit operation and ensure it produces valid syntax. The file was not modified.")
|
||||
}
|
||||
|
||||
// NewFileNotFoundError creates an error for missing files.
|
||||
func NewFileNotFoundError(file string) *StructuredError {
|
||||
return New(ErrFileNotFound, fmt.Sprintf("file not found: %s", file)).
|
||||
WithContext("file", file).
|
||||
WithRemediation("Verify the file path is correct and the file exists")
|
||||
}
|
||||
|
||||
// NewFileNotReadableError creates an error for unreadable files.
|
||||
func NewFileNotReadableError(file string, cause error) *StructuredError {
|
||||
return Wrap(ErrFileNotReadable, fmt.Sprintf("cannot read file: %s", file), cause).
|
||||
WithContext("file", file).
|
||||
WithRemediation("Check file permissions and ensure the file is not locked by another process")
|
||||
}
|
||||
|
||||
// NewFileNotWritableError creates an error for write failures.
|
||||
func NewFileNotWritableError(file string, cause error) *StructuredError {
|
||||
return Wrap(ErrFileNotWritable, fmt.Sprintf("cannot write to file: %s", file), cause).
|
||||
WithContext("file", file).
|
||||
WithRemediation("Check file permissions, disk space, and ensure the file is not locked by another process")
|
||||
}
|
||||
|
||||
// NewNodeNotFoundError creates an error when AST node selector finds no matches.
|
||||
func NewNodeNotFoundError(selector string) *StructuredError {
|
||||
return New(ErrNodeNotFound, "no AST nodes match the selector criteria").
|
||||
WithContext("selector", selector).
|
||||
WithRemediation("Verify the selector criteria (kind, name, pattern, line) match an existing code structure")
|
||||
}
|
||||
|
||||
// NewInvalidSelectionError creates an error for ambiguous or invalid selectors.
|
||||
func NewInvalidSelectionError(message string) *StructuredError {
|
||||
return New(ErrInvalidSelection, message).
|
||||
WithRemediation("Refine the selector to be more specific or provide a selector_index to choose between multiple matches")
|
||||
}
|
||||
|
||||
// NewInvalidEditError creates an error for invalid edit operations.
|
||||
func NewInvalidEditError(message string) *StructuredError {
|
||||
return New(ErrInvalidEdit, message).
|
||||
WithRemediation("Review the edit request and ensure all required fields are provided with valid values")
|
||||
}
|
||||
@@ -0,0 +1,375 @@
|
||||
// Package fuzzy provides fuzzy string matching using Levenshtein distance.
|
||||
package fuzzy
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Match represents a fuzzy match result.
|
||||
type Match struct {
|
||||
Text string
|
||||
Distance int
|
||||
Similarity float64
|
||||
Score float64
|
||||
}
|
||||
|
||||
// Matcher provides fuzzy matching capabilities.
|
||||
type Matcher struct {
|
||||
threshold int
|
||||
}
|
||||
|
||||
// New creates a new fuzzy matcher with the given threshold.
|
||||
// Threshold is the maximum edit distance to consider a match (typically 1-3).
|
||||
func New(threshold int) *Matcher {
|
||||
return &Matcher{
|
||||
threshold: threshold,
|
||||
}
|
||||
}
|
||||
|
||||
// Match performs fuzzy matching of query against candidates.
|
||||
func (m *Matcher) Match(query string, candidates []string) []Match {
|
||||
if query == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
matches := make([]Match, 0, len(candidates)/10)
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
candidateLower := strings.ToLower(candidate)
|
||||
|
||||
// Calculate Levenshtein distance
|
||||
dist := levenshteinDistance(queryLower, candidateLower)
|
||||
|
||||
// Skip if distance exceeds threshold
|
||||
if dist > m.threshold {
|
||||
// Check if it's a substring match (important for identifiers)
|
||||
if !strings.Contains(candidateLower, queryLower) {
|
||||
continue
|
||||
}
|
||||
// Allow substring matches even if edit distance is high
|
||||
}
|
||||
|
||||
// Calculate similarity (0.0 to 1.0)
|
||||
maxLen := max(len(query), len(candidate))
|
||||
similarity := 1.0 - float64(dist)/float64(maxLen)
|
||||
|
||||
// Calculate composite score
|
||||
score := m.calculateScore(queryLower, candidateLower, dist, similarity)
|
||||
|
||||
matches = append(matches, Match{
|
||||
Text: candidate,
|
||||
Distance: dist,
|
||||
Similarity: similarity,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].Score > matches[j].Score
|
||||
})
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
// calculateScore computes a composite score considering multiple factors.
|
||||
func (m *Matcher) calculateScore(query, candidate string, dist int, similarity float64) float64 {
|
||||
score := similarity
|
||||
|
||||
// Bonus for exact match
|
||||
if query == candidate {
|
||||
score += 2.0
|
||||
}
|
||||
|
||||
// Bonus for prefix match (important for identifier search)
|
||||
if strings.HasPrefix(candidate, query) {
|
||||
score += 1.0
|
||||
}
|
||||
|
||||
// Bonus for word boundary matches (e.g., "getName" matches "get")
|
||||
if containsWordBoundary(candidate, query) {
|
||||
score += 0.5
|
||||
}
|
||||
|
||||
// Penalty for length difference (prefer similar-length matches)
|
||||
lenDiff := abs(len(candidate) - len(query))
|
||||
score -= float64(lenDiff) * 0.01
|
||||
|
||||
// Penalty for edit distance
|
||||
score -= float64(dist) * 0.1
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// levenshteinDistance computes the Levenshtein distance between two strings.
|
||||
// Uses the Wagner-Fischer algorithm with space optimization O(min(m,n)).
|
||||
func levenshteinDistance(s1, s2 string) int {
|
||||
if s1 == s2 {
|
||||
return 0
|
||||
}
|
||||
if len(s1) == 0 {
|
||||
return len(s2)
|
||||
}
|
||||
if len(s2) == 0 {
|
||||
return len(s1)
|
||||
}
|
||||
|
||||
// Ensure s1 is the shorter string for space optimization
|
||||
if len(s1) > len(s2) {
|
||||
s1, s2 = s2, s1
|
||||
}
|
||||
|
||||
// Use rune slices to handle Unicode properly
|
||||
r1 := []rune(s1)
|
||||
r2 := []rune(s2)
|
||||
len1 := len(r1)
|
||||
len2 := len(r2)
|
||||
|
||||
// Only need two rows of the matrix
|
||||
previous := make([]int, len2+1)
|
||||
current := make([]int, len2+1)
|
||||
|
||||
// Initialize first row
|
||||
for j := 0; j <= len2; j++ {
|
||||
previous[j] = j
|
||||
}
|
||||
|
||||
// Calculate edit distance
|
||||
for i := 1; i <= len1; i++ {
|
||||
current[0] = i
|
||||
|
||||
for j := 1; j <= len2; j++ {
|
||||
cost := 1
|
||||
if r1[i-1] == r2[j-1] {
|
||||
cost = 0
|
||||
}
|
||||
|
||||
current[j] = min(
|
||||
previous[j]+1, // deletion
|
||||
current[j-1]+1, // insertion
|
||||
previous[j-1]+cost, // substitution
|
||||
)
|
||||
}
|
||||
|
||||
// Swap rows
|
||||
previous, current = current, previous
|
||||
}
|
||||
|
||||
return previous[len2]
|
||||
}
|
||||
|
||||
// DamerauLevenshteinDistance computes Damerau-Levenshtein distance (includes transpositions).
|
||||
// This is more accurate for typos where adjacent characters are swapped.
|
||||
func DamerauLevenshteinDistance(s1, s2 string) int {
|
||||
if s1 == s2 {
|
||||
return 0
|
||||
}
|
||||
if len(s1) == 0 {
|
||||
return len(s2)
|
||||
}
|
||||
if len(s2) == 0 {
|
||||
return len(s1)
|
||||
}
|
||||
|
||||
r1 := []rune(s1)
|
||||
r2 := []rune(s2)
|
||||
len1 := len(r1)
|
||||
len2 := len(r2)
|
||||
|
||||
// Create distance matrix
|
||||
d := make([][]int, len1+1)
|
||||
for i := range d {
|
||||
d[i] = make([]int, len2+1)
|
||||
}
|
||||
|
||||
// Initialize first row and column
|
||||
for i := 0; i <= len1; i++ {
|
||||
d[i][0] = i
|
||||
}
|
||||
for j := 0; j <= len2; j++ {
|
||||
d[0][j] = j
|
||||
}
|
||||
|
||||
// Calculate distances
|
||||
for i := 1; i <= len1; i++ {
|
||||
for j := 1; j <= len2; j++ {
|
||||
cost := 1
|
||||
if r1[i-1] == r2[j-1] {
|
||||
cost = 0
|
||||
}
|
||||
|
||||
d[i][j] = min(
|
||||
d[i-1][j]+1, // deletion
|
||||
d[i][j-1]+1, // insertion
|
||||
d[i-1][j-1]+cost, // substitution
|
||||
)
|
||||
|
||||
// Check for transposition
|
||||
if i > 1 && j > 1 && r1[i-1] == r2[j-2] && r1[i-2] == r2[j-1] {
|
||||
d[i][j] = min(d[i][j], d[i-2][j-2]+cost)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return d[len1][len2]
|
||||
}
|
||||
|
||||
// JaroWinklerSimilarity computes Jaro-Winkler similarity (0.0 to 1.0).
|
||||
// Better for short strings and names.
|
||||
func JaroWinklerSimilarity(s1, s2 string) float64 {
|
||||
if s1 == s2 {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
r1 := []rune(s1)
|
||||
r2 := []rune(s2)
|
||||
|
||||
if len(r1) == 0 || len(r2) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// Calculate Jaro similarity first
|
||||
jaro := jaroSimilarity(r1, r2)
|
||||
|
||||
// Calculate common prefix length (up to 4 characters)
|
||||
prefixLen := 0
|
||||
for i := 0; i < min(min(len(r1), len(r2)), 4); i++ {
|
||||
if r1[i] == r2[i] {
|
||||
prefixLen++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Jaro-Winkler adds bonus for common prefix
|
||||
const p = 0.1
|
||||
return jaro + float64(prefixLen)*p*(1.0-jaro)
|
||||
}
|
||||
|
||||
// jaroSimilarity computes Jaro similarity.
|
||||
func jaroSimilarity(r1, r2 []rune) float64 {
|
||||
len1 := len(r1)
|
||||
len2 := len(r2)
|
||||
|
||||
// Maximum allowed distance
|
||||
matchDist := max(len1, len2)/2 - 1
|
||||
if matchDist < 0 {
|
||||
matchDist = 0
|
||||
}
|
||||
|
||||
matched1 := make([]bool, len1)
|
||||
matched2 := make([]bool, len2)
|
||||
|
||||
matches := 0
|
||||
transpositions := 0
|
||||
|
||||
// Find matches
|
||||
for i := range len1 {
|
||||
start := max(0, i-matchDist)
|
||||
end := min(i+matchDist+1, len2)
|
||||
|
||||
for j := start; j < end; j++ {
|
||||
if matched2[j] || r1[i] != r2[j] {
|
||||
continue
|
||||
}
|
||||
matched1[i] = true
|
||||
matched2[j] = true
|
||||
matches++
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matches == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// Count transpositions
|
||||
k := 0
|
||||
for i := range len1 {
|
||||
if !matched1[i] {
|
||||
continue
|
||||
}
|
||||
for !matched2[k] {
|
||||
k++
|
||||
}
|
||||
if r1[i] != r2[k] {
|
||||
transpositions++
|
||||
}
|
||||
k++
|
||||
}
|
||||
|
||||
return (float64(matches)/float64(len1) +
|
||||
float64(matches)/float64(len2) +
|
||||
float64(matches-transpositions/2)/float64(matches)) / 3.0
|
||||
}
|
||||
|
||||
// containsWordBoundary checks if query appears at word boundaries in text.
|
||||
func containsWordBoundary(text, query string) bool {
|
||||
textLower := strings.ToLower(text)
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
idx := strings.Index(textLower, queryLower)
|
||||
if idx == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if match is at start
|
||||
if idx == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for underscore or non-alphanumeric boundary
|
||||
prevRune := rune(text[idx-1])
|
||||
if !unicode.IsLetter(prevRune) && !unicode.IsDigit(prevRune) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for camelCase boundary (lowercase before uppercase)
|
||||
if idx > 0 && len(text) > idx {
|
||||
curr := rune(text[idx])
|
||||
prev := rune(text[idx-1])
|
||||
if unicode.IsLower(prev) && unicode.IsUpper(curr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func min(values ...int) int {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
m := values[0]
|
||||
for _, v := range values[1:] {
|
||||
if v < m {
|
||||
m = v
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func max(values ...int) int {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
m := values[0]
|
||||
for _, v := range values[1:] {
|
||||
if v > m {
|
||||
m = v
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
package fuzzy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLevenshteinDistance(t *testing.T) {
|
||||
tests := []struct {
|
||||
s1 string
|
||||
s2 string
|
||||
expected int
|
||||
}{
|
||||
{"", "", 0},
|
||||
{"", "abc", 3},
|
||||
{"abc", "", 3},
|
||||
{"abc", "abc", 0},
|
||||
{"abc", "abd", 1},
|
||||
{"kitten", "sitting", 3},
|
||||
{"saturday", "sunday", 3},
|
||||
{"book", "back", 2},
|
||||
{"café", "cafe", 1}, // Unicode handling
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := levenshteinDistance(tt.s1, tt.s2)
|
||||
if got != tt.expected {
|
||||
t.Errorf("levenshteinDistance(%q, %q) = %d, want %d", tt.s1, tt.s2, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDamerauLevenshteinDistance(t *testing.T) {
|
||||
tests := []struct {
|
||||
s1 string
|
||||
s2 string
|
||||
expected int
|
||||
}{
|
||||
{"abc", "abc", 0},
|
||||
{"abc", "acb", 1}, // Transposition
|
||||
{"ca", "abc", 3}, // Delete a, delete b, insert c = 3 operations
|
||||
{"", "abc", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := DamerauLevenshteinDistance(tt.s1, tt.s2)
|
||||
if got != tt.expected {
|
||||
t.Errorf("DamerauLevenshteinDistance(%q, %q) = %d, want %d", tt.s1, tt.s2, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJaroWinklerSimilarity(t *testing.T) {
|
||||
tests := []struct {
|
||||
s1 string
|
||||
s2 string
|
||||
minScore float64 // Minimum expected similarity
|
||||
}{
|
||||
{"", "", 1.0},
|
||||
{"abc", "abc", 1.0},
|
||||
{"martha", "marhta", 0.96}, // High similarity for transposition
|
||||
{"dixon", "dicksonx", 0.76}, // Moderate similarity
|
||||
{"", "abc", 0.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := JaroWinklerSimilarity(tt.s1, tt.s2)
|
||||
if got < tt.minScore {
|
||||
t.Errorf("JaroWinklerSimilarity(%q, %q) = %.2f, want >= %.2f", tt.s1, tt.s2, got, tt.minScore)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatcher_Match(t *testing.T) {
|
||||
m := New(2) // Allow edit distance up to 2
|
||||
|
||||
candidates := []string{
|
||||
"getUserName",
|
||||
"getUsername",
|
||||
"get_user_name",
|
||||
"getUserId",
|
||||
"setUserName",
|
||||
"findUser",
|
||||
"userName",
|
||||
"usernameField",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query string
|
||||
topMatch string
|
||||
expectMin int
|
||||
}{
|
||||
{
|
||||
query: "getUserName",
|
||||
expectMin: 3, // Exact + similar variants
|
||||
topMatch: "getUserName",
|
||||
},
|
||||
{
|
||||
query: "getuser",
|
||||
expectMin: 2, // Should match getUserName, getUsername at minimum
|
||||
topMatch: "getUserName",
|
||||
},
|
||||
{
|
||||
query: "username",
|
||||
expectMin: 2, // Case-insensitive matches
|
||||
topMatch: "userName",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
matches := m.Match(tt.query, candidates)
|
||||
|
||||
if len(matches) < tt.expectMin {
|
||||
t.Errorf("Match(%q) returned %d matches, want at least %d", tt.query, len(matches), tt.expectMin)
|
||||
}
|
||||
|
||||
if len(matches) > 0 {
|
||||
// Top match should have highest score
|
||||
if matches[0].Score < matches[len(matches)-1].Score {
|
||||
t.Errorf("Match(%q) results not sorted by score", tt.query)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatcher_EmptyQuery(t *testing.T) {
|
||||
m := New(2)
|
||||
candidates := []string{"test", "example"}
|
||||
|
||||
matches := m.Match("", candidates)
|
||||
if matches != nil {
|
||||
t.Errorf("Match with empty query should return nil, got %v", matches)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatcher_PrefixBonus(t *testing.T) {
|
||||
m := New(2)
|
||||
candidates := []string{
|
||||
"getUserName", // prefix match
|
||||
"findUserName", // contains but not prefix
|
||||
}
|
||||
|
||||
matches := m.Match("get", candidates)
|
||||
|
||||
if len(matches) < 1 {
|
||||
t.Fatal("Expected at least one match")
|
||||
}
|
||||
|
||||
// Prefix match should score higher
|
||||
if matches[0].Text != "getUserName" {
|
||||
t.Errorf("Expected prefix match to rank first, got %q", matches[0].Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatcher_ExactMatchBonus(t *testing.T) {
|
||||
m := New(2)
|
||||
candidates := []string{
|
||||
"test",
|
||||
"testing",
|
||||
"tester",
|
||||
}
|
||||
|
||||
matches := m.Match("test", candidates)
|
||||
|
||||
if len(matches) < 1 {
|
||||
t.Fatal("Expected at least one match")
|
||||
}
|
||||
|
||||
// Exact match should rank first
|
||||
if matches[0].Text != "test" {
|
||||
t.Errorf("Expected exact match to rank first, got %q", matches[0].Text)
|
||||
}
|
||||
|
||||
// Exact match should have highest score
|
||||
if matches[0].Score < 2.0 { // Should have exact match bonus
|
||||
t.Errorf("Exact match score too low: %.2f", matches[0].Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsWordBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
text string
|
||||
query string
|
||||
expected bool
|
||||
}{
|
||||
{"getUserName", "get", true}, // At start
|
||||
{"getUserName", "user", true}, // After lowercase->uppercase boundary
|
||||
{"get_user_name", "user", true}, // After underscore
|
||||
{"getUserName", "Name", true}, // After lowercase->uppercase
|
||||
{"getUserName", "ser", false}, // Middle of word
|
||||
{"", "test", false}, // Empty text
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := containsWordBoundary(tt.text, tt.query)
|
||||
if got != tt.expected {
|
||||
t.Errorf("containsWordBoundary(%q, %q) = %v, want %v", tt.text, tt.query, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatcher_UnicodeHandling(t *testing.T) {
|
||||
m := New(2)
|
||||
candidates := []string{
|
||||
"café",
|
||||
"resume",
|
||||
"naïve",
|
||||
}
|
||||
|
||||
// Test with Unicode characters
|
||||
matches := m.Match("cafe", candidates)
|
||||
if len(matches) == 0 {
|
||||
t.Error("Expected matches for Unicode strings")
|
||||
}
|
||||
|
||||
// Should find café with small edit distance
|
||||
found := false
|
||||
for _, match := range matches {
|
||||
if match.Text == "café" && match.Distance <= 2 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Error("Failed to fuzzy match Unicode string 'café'")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLevenshteinDistance(b *testing.B) {
|
||||
s1 := "the quick brown fox jumps over the lazy dog"
|
||||
s2 := "the quikc brown fox jumps ovver the lazy dog"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
_ = levenshteinDistance(s1, s2)
|
||||
_ = i // use i to avoid unused warning
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDamerauLevenshteinDistance(b *testing.B) {
|
||||
s1 := "the quick brown fox jumps over the lazy dog"
|
||||
s2 := "the quikc brown fox jumps ovver the lazy dog"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
_ = DamerauLevenshteinDistance(s1, s2)
|
||||
_ = i
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJaroWinklerSimilarity(b *testing.B) {
|
||||
s1 := "martha"
|
||||
s2 := "marhta"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
_ = JaroWinklerSimilarity(s1, s2)
|
||||
_ = i
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMatcher_Match(b *testing.B) {
|
||||
m := New(2)
|
||||
candidates := []string{
|
||||
"getUserName", "getUsername", "get_user_name", "getUserId",
|
||||
"setUserName", "findUser", "userName", "usernameField",
|
||||
"userAccount", "accountUser", "userProfile", "profileUser",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
_ = m.Match("getuser", candidates)
|
||||
_ = i
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
// Package protocol defines shared types used across the MCP file operations server.
|
||||
package protocol
|
||||
|
||||
// Location represents a position in a file.
|
||||
type Location struct {
|
||||
File string `json:"file"`
|
||||
Line int `json:"line"`
|
||||
Column int `json:"column"`
|
||||
}
|
||||
|
||||
// Range represents a range in a file.
|
||||
type Range struct {
|
||||
Start Location `json:"start"`
|
||||
End Location `json:"end"`
|
||||
}
|
||||
|
||||
// SymbolKind represents the kind of a symbol.
|
||||
type SymbolKind string
|
||||
|
||||
const (
|
||||
SymbolFunction SymbolKind = "function"
|
||||
SymbolMethod SymbolKind = "method"
|
||||
SymbolClass SymbolKind = "class"
|
||||
SymbolStruct SymbolKind = "struct"
|
||||
SymbolInterface SymbolKind = "interface"
|
||||
SymbolVariable SymbolKind = "variable"
|
||||
SymbolConstant SymbolKind = "constant"
|
||||
SymbolType SymbolKind = "type"
|
||||
SymbolField SymbolKind = "field"
|
||||
SymbolProperty SymbolKind = "property"
|
||||
SymbolModule SymbolKind = "module"
|
||||
SymbolPackage SymbolKind = "package"
|
||||
)
|
||||
|
||||
// Symbol represents a code symbol (function, class, variable, etc.).
|
||||
type Symbol struct {
|
||||
Name string `json:"name"`
|
||||
Kind SymbolKind `json:"kind"`
|
||||
Doc string `json:"doc,omitempty"`
|
||||
Location Location `json:"location"`
|
||||
}
|
||||
|
||||
// SyntaxError represents a syntax error in a file.
|
||||
type SyntaxError struct {
|
||||
Message string `json:"message"`
|
||||
Location Location `json:"location"`
|
||||
}
|
||||
|
||||
// Language represents a programming language.
|
||||
type Language string
|
||||
|
||||
const (
|
||||
LangGo Language = "go"
|
||||
LangTypeScript Language = "typescript"
|
||||
LangJavaScript Language = "javascript"
|
||||
LangPython Language = "python"
|
||||
LangC Language = "c"
|
||||
LangCpp Language = "cpp"
|
||||
LangHTML Language = "html"
|
||||
LangVue Language = "vue"
|
||||
LangJSON Language = "json"
|
||||
LangYAML Language = "yaml"
|
||||
LangUnknown Language = "unknown"
|
||||
)
|
||||
|
||||
// DetectLanguage detects the language from a filename.
|
||||
func DetectLanguage(filename string) Language {
|
||||
ext := getExtension(filename)
|
||||
switch ext {
|
||||
case ".go":
|
||||
return LangGo
|
||||
case ".ts", ".tsx":
|
||||
return LangTypeScript
|
||||
case ".js", ".jsx", ".mjs", ".cjs":
|
||||
return LangJavaScript
|
||||
case ".py", ".pyw":
|
||||
return LangPython
|
||||
case ".c", ".h":
|
||||
return LangC
|
||||
case ".cpp", ".cc", ".cxx", ".hpp", ".hxx":
|
||||
return LangCpp
|
||||
case ".html", ".htm":
|
||||
return LangHTML
|
||||
case ".vue":
|
||||
return LangVue
|
||||
case ".json":
|
||||
return LangJSON
|
||||
case ".yaml", ".yml":
|
||||
return LangYAML
|
||||
default:
|
||||
return LangUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func getExtension(filename string) string {
|
||||
for i := len(filename) - 1; i >= 0; i-- {
|
||||
if filename[i] == '.' {
|
||||
return filename[i:]
|
||||
}
|
||||
if filename[i] == '/' || filename[i] == '\\' {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package protocol
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDetectLanguage(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
expected Language
|
||||
}{
|
||||
{"main.go", LangGo},
|
||||
{"server.go", LangGo},
|
||||
{"index.ts", LangTypeScript},
|
||||
{"component.tsx", LangTypeScript},
|
||||
{"Button.tsx", LangTypeScript},
|
||||
{"app.js", LangJavaScript},
|
||||
{"component.jsx", LangJavaScript},
|
||||
{"Component.jsx", LangJavaScript},
|
||||
{"module.mjs", LangJavaScript},
|
||||
{"common.cjs", LangJavaScript},
|
||||
{"script.py", LangPython},
|
||||
{"app.pyw", LangPython},
|
||||
{"main.c", LangC},
|
||||
{"header.h", LangC},
|
||||
{"main.cpp", LangCpp},
|
||||
{"main.cc", LangCpp},
|
||||
{"main.cxx", LangCpp},
|
||||
{"header.hpp", LangCpp},
|
||||
{"header.hxx", LangCpp},
|
||||
{"index.html", LangHTML},
|
||||
{"page.htm", LangHTML},
|
||||
{"Component.vue", LangVue},
|
||||
{"unknown.txt", LangUnknown},
|
||||
{"README", LangUnknown},
|
||||
{"path/to/file.go", LangGo},
|
||||
{"path/to/file.ts", LangTypeScript},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filename, func(t *testing.T) {
|
||||
result := DetectLanguage(tt.filename)
|
||||
if result != tt.expected {
|
||||
t.Errorf("DetectLanguage(%q) = %q, want %q", tt.filename, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExtension(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
expected string
|
||||
}{
|
||||
{"file.go", ".go"},
|
||||
{"file.test.go", ".go"},
|
||||
{"path/to/file.ts", ".ts"},
|
||||
{"noextension", ""},
|
||||
{".hidden", ".hidden"},
|
||||
{"file.", "."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.filename, func(t *testing.T) {
|
||||
result := getExtension(tt.filename)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getExtension(%q) = %q, want %q", tt.filename, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Vendored
+55
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* @file header.h
|
||||
* @brief Sample header file for testing.
|
||||
*/
|
||||
|
||||
#ifndef HEADER_H
|
||||
#define HEADER_H
|
||||
|
||||
/**
|
||||
* @brief Maximum buffer size.
|
||||
*/
|
||||
#define MAX_BUFFER_SIZE 1024
|
||||
|
||||
/**
|
||||
* @brief Status codes for operations.
|
||||
*/
|
||||
typedef enum {
|
||||
STATUS_OK = 0,
|
||||
STATUS_ERROR = 1,
|
||||
STATUS_NOT_FOUND = 2
|
||||
} Status;
|
||||
|
||||
/**
|
||||
* @brief Buffer structure for data storage.
|
||||
*/
|
||||
typedef struct {
|
||||
char data[MAX_BUFFER_SIZE];
|
||||
int length;
|
||||
} Buffer;
|
||||
|
||||
/**
|
||||
* @brief Initialize a buffer.
|
||||
* @param buf Pointer to the buffer
|
||||
*/
|
||||
void buffer_init(Buffer* buf);
|
||||
|
||||
/**
|
||||
* @brief Write data to the buffer.
|
||||
* @param buf Pointer to the buffer
|
||||
* @param data Data to write
|
||||
* @param len Length of data
|
||||
* @return Status code
|
||||
*/
|
||||
Status buffer_write(Buffer* buf, const char* data, int len);
|
||||
|
||||
/**
|
||||
* @brief Read data from the buffer.
|
||||
* @param buf Pointer to the buffer
|
||||
* @param out Output buffer
|
||||
* @param max_len Maximum length to read
|
||||
* @return Number of bytes read
|
||||
*/
|
||||
int buffer_read(Buffer* buf, char* out, int max_len);
|
||||
|
||||
#endif /* HEADER_H */
|
||||
Vendored
+57
@@ -0,0 +1,57 @@
|
||||
/**
|
||||
* @file valid.c
|
||||
* @brief Sample C file for testing.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
/**
|
||||
* @brief A simple point structure.
|
||||
*/
|
||||
struct Point {
|
||||
int x;
|
||||
int y;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Creates a new point.
|
||||
* @param x The x coordinate
|
||||
* @param y The y coordinate
|
||||
* @return A new Point structure
|
||||
*/
|
||||
struct Point create_point(int x, int y) {
|
||||
struct Point p;
|
||||
p.x = x;
|
||||
p.y = y;
|
||||
return p;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the distance from origin.
|
||||
* @param p The point
|
||||
* @return The squared distance from origin
|
||||
*/
|
||||
int distance_squared(struct Point p) {
|
||||
return p.x * p.x + p.y * p.y;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Prints a point to stdout.
|
||||
* @param p The point to print
|
||||
*/
|
||||
void print_point(struct Point p) {
|
||||
printf("Point(%d, %d)\n", p.x, p.y);
|
||||
}
|
||||
|
||||
// Simple helper function
|
||||
int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
struct Point p = create_point(3, 4);
|
||||
print_point(p);
|
||||
printf("Distance squared: %d\n", distance_squared(p));
|
||||
return 0;
|
||||
}
|
||||
Vendored
+11
@@ -0,0 +1,11 @@
|
||||
package main
|
||||
|
||||
// This file contains intentional syntax errors for testing.
|
||||
|
||||
func broken( {
|
||||
return
|
||||
}
|
||||
|
||||
type Incomplete struct {
|
||||
Name string
|
||||
// Missing closing brace
|
||||
Vendored
+44
@@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Server represents the main application server.
|
||||
type Server struct {
|
||||
Name string
|
||||
Port int
|
||||
}
|
||||
|
||||
// NewServer creates a new Server instance.
|
||||
func NewServer(name string, port int) *Server {
|
||||
return &Server{
|
||||
Name: name,
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the server.
|
||||
func (s *Server) Start() error {
|
||||
fmt.Printf("Starting server %s on port %d\n", s.Name, s.Port)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Config holds application configuration.
|
||||
type Config struct {
|
||||
Debug bool
|
||||
Timeout int
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultPort is the default server port.
|
||||
DefaultPort = 8080
|
||||
)
|
||||
|
||||
var (
|
||||
// Version is the application version.
|
||||
Version = "1.0.0"
|
||||
)
|
||||
|
||||
func main() {
|
||||
srv := NewServer("main", DefaultPort)
|
||||
srv.Start()
|
||||
}
|
||||
Vendored
+29
@@ -0,0 +1,29 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Test HTML File</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container mx-auto px-4">
|
||||
<h1 class="text-3xl font-bold text-blue-600">Hello World</h1>
|
||||
<p class="text-gray-700 mt-4">This is a test HTML file with Tailwind CSS classes.</p>
|
||||
|
||||
<div class="flex gap-4 mt-8">
|
||||
<button class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded">
|
||||
Primary Button
|
||||
</button>
|
||||
<button class="bg-gray-500 hover:bg-gray-700 text-white font-bold py-2 px-4 rounded">
|
||||
Secondary Button
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<ul class="list-disc list-inside mt-4">
|
||||
<li>First item</li>
|
||||
<li>Second item</li>
|
||||
<li>Third item</li>
|
||||
</ul>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
Vendored
+62
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Sample Python module for testing.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""Processes data records."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Initialize the processor.
|
||||
|
||||
Args:
|
||||
name: The processor name
|
||||
"""
|
||||
self.name = name
|
||||
self._records: List[dict] = []
|
||||
|
||||
def add_record(self, record: dict) -> None:
|
||||
"""
|
||||
Add a record to the processor.
|
||||
|
||||
Args:
|
||||
record: The record to add
|
||||
"""
|
||||
self._records.append(record)
|
||||
|
||||
def process(self) -> List[dict]:
|
||||
"""
|
||||
Process all records.
|
||||
|
||||
Returns:
|
||||
The processed records
|
||||
"""
|
||||
return [self._transform(r) for r in self._records]
|
||||
|
||||
def _transform(self, record: dict) -> dict:
|
||||
"""Transform a single record."""
|
||||
return {k.upper(): v for k, v in record.items()}
|
||||
|
||||
|
||||
def calculate_sum(numbers: List[int]) -> int:
|
||||
"""
|
||||
Calculate the sum of numbers.
|
||||
|
||||
Args:
|
||||
numbers: List of integers to sum
|
||||
|
||||
Returns:
|
||||
The sum of all numbers
|
||||
"""
|
||||
return sum(numbers)
|
||||
|
||||
|
||||
def find_maximum(values: List[int]) -> Optional[int]:
|
||||
"""Find the maximum value in a list."""
|
||||
return max(values) if values else None
|
||||
|
||||
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
Vendored
+129
@@ -0,0 +1,129 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
|
||||
interface ButtonProps {
|
||||
variant?: 'primary' | 'secondary';
|
||||
disabled?: boolean;
|
||||
onClick?: () => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* A reusable button component with Tailwind CSS styling
|
||||
*/
|
||||
export const Button: React.FC<ButtonProps> = ({
|
||||
variant = 'primary',
|
||||
disabled = false,
|
||||
onClick,
|
||||
children
|
||||
}) => {
|
||||
const baseClasses = 'font-bold py-2 px-4 rounded transition-colors duration-200';
|
||||
const variantClasses = {
|
||||
primary: 'bg-blue-500 hover:bg-blue-700 text-white',
|
||||
secondary: 'bg-gray-500 hover:bg-gray-700 text-white'
|
||||
};
|
||||
|
||||
return (
|
||||
<button
|
||||
className={`${baseClasses} ${variantClasses[variant]} ${disabled ? 'opacity-50 cursor-not-allowed' : ''}`}
|
||||
disabled={disabled}
|
||||
onClick={onClick}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
};
|
||||
|
||||
interface TodoItem {
|
||||
id: number;
|
||||
text: string;
|
||||
completed: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Todo list component demonstrating React hooks and Tailwind
|
||||
*/
|
||||
export const TodoList: React.FC = () => {
|
||||
const [todos, setTodos] = useState<TodoItem[]>([
|
||||
{ id: 1, text: 'Learn React', completed: true },
|
||||
{ id: 2, text: 'Learn TypeScript', completed: true },
|
||||
{ id: 3, text: 'Build amazing apps', completed: false }
|
||||
]);
|
||||
const [inputValue, setInputValue] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
console.log('Todos updated:', todos);
|
||||
}, [todos]);
|
||||
|
||||
const addTodo = () => {
|
||||
if (inputValue.trim()) {
|
||||
const newTodo: TodoItem = {
|
||||
id: Date.now(),
|
||||
text: inputValue,
|
||||
completed: false
|
||||
};
|
||||
setTodos([...todos, newTodo]);
|
||||
setInputValue('');
|
||||
}
|
||||
};
|
||||
|
||||
const toggleTodo = (id: number) => {
|
||||
setTodos(todos.map(todo =>
|
||||
todo.id === id ? { ...todo, completed: !todo.completed } : todo
|
||||
));
|
||||
};
|
||||
|
||||
const deleteTodo = (id: number) => {
|
||||
setTodos(todos.filter(todo => todo.id !== id));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="container mx-auto px-4 py-8 max-w-2xl">
|
||||
<h1 className="text-3xl font-bold text-gray-800 mb-6">
|
||||
My Todo List
|
||||
</h1>
|
||||
|
||||
<div className="flex gap-2 mb-6">
|
||||
<input
|
||||
type="text"
|
||||
value={inputValue}
|
||||
onChange={(e) => setInputValue(e.target.value)}
|
||||
onKeyPress={(e) => e.key === 'Enter' && addTodo()}
|
||||
placeholder="Add a new todo..."
|
||||
className="flex-1 px-4 py-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
/>
|
||||
<Button onClick={addTodo}>Add</Button>
|
||||
</div>
|
||||
|
||||
<ul className="space-y-2">
|
||||
{todos.map(todo => (
|
||||
<li
|
||||
key={todo.id}
|
||||
className="flex items-center gap-3 p-4 bg-white rounded-lg shadow-sm hover:shadow-md transition-shadow"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={todo.completed}
|
||||
onChange={() => toggleTodo(todo.id)}
|
||||
className="w-5 h-5 text-blue-600 rounded focus:ring-2 focus:ring-blue-500"
|
||||
/>
|
||||
<span className={`flex-1 ${todo.completed ? 'line-through text-gray-400' : 'text-gray-700'}`}>
|
||||
{todo.text}
|
||||
</span>
|
||||
<Button
|
||||
variant="secondary"
|
||||
onClick={() => deleteTodo(todo.id)}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
|
||||
{todos.length === 0 && (
|
||||
<div className="text-center py-12 text-gray-400">
|
||||
No todos yet. Add one above!
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
Vendored
+53
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Represents a user in the system.
|
||||
*/
|
||||
interface User {
|
||||
id: number;
|
||||
name: string;
|
||||
email: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration options for the application.
|
||||
*/
|
||||
type Config = {
|
||||
debug: boolean;
|
||||
timeout: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Greeting service for handling user greetings.
|
||||
*/
|
||||
class GreetingService {
|
||||
private prefix: string;
|
||||
|
||||
/**
|
||||
* Creates a new GreetingService.
|
||||
* @param prefix The greeting prefix
|
||||
*/
|
||||
constructor(prefix: string) {
|
||||
this.prefix = prefix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Greets a user.
|
||||
* @param user The user to greet
|
||||
* @returns The greeting message
|
||||
*/
|
||||
greet(user: User): string {
|
||||
return `${this.prefix}, ${user.name}!`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats a user for display.
|
||||
* @param user The user to format
|
||||
* @returns Formatted string
|
||||
*/
|
||||
function formatUser(user: User): string {
|
||||
return `${user.name} <${user.email}>`;
|
||||
}
|
||||
|
||||
const DEFAULT_TIMEOUT = 5000;
|
||||
|
||||
export { User, Config, GreetingService, formatUser, DEFAULT_TIMEOUT };
|
||||
Vendored
+76
@@ -0,0 +1,76 @@
|
||||
<template>
|
||||
<div class="container mx-auto px-4 py-8">
|
||||
<h1 class="text-3xl font-bold text-blue-600 mb-4">
|
||||
{{ title }}
|
||||
</h1>
|
||||
|
||||
<div v-if="showContent" class="bg-white shadow-md rounded-lg p-6">
|
||||
<p class="text-gray-700 mb-4">{{ description }}</p>
|
||||
|
||||
<div class="flex gap-4 mt-4">
|
||||
<button
|
||||
@click="handlePrimary"
|
||||
:class="['bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded', { 'opacity-50': isLoading }]"
|
||||
:disabled="isLoading"
|
||||
>
|
||||
{{ primaryButtonText }}
|
||||
</button>
|
||||
|
||||
<button
|
||||
@click="handleSecondary"
|
||||
class="bg-gray-500 hover:bg-gray-700 text-white font-bold py-2 px-4 rounded"
|
||||
>
|
||||
{{ secondaryButtonText }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<ul v-for="item in items" :key="item.id" class="list-disc list-inside mt-4">
|
||||
<li class="text-gray-600">{{ item.name }}</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div v-else class="text-center text-gray-500">
|
||||
<p>No content to display</p>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue';
|
||||
|
||||
interface Item {
|
||||
id: number;
|
||||
name: string;
|
||||
}
|
||||
|
||||
const title = ref('Vue Component with Tailwind');
|
||||
const description = ref('This is a sample Vue 3 component using Composition API and Tailwind CSS');
|
||||
const showContent = ref(true);
|
||||
const isLoading = ref(false);
|
||||
|
||||
const items = ref<Item[]>([
|
||||
{ id: 1, name: 'First item' },
|
||||
{ id: 2, name: 'Second item' },
|
||||
{ id: 3, name: 'Third item' },
|
||||
]);
|
||||
|
||||
const primaryButtonText = computed(() => isLoading.value ? 'Loading...' : 'Primary Action');
|
||||
const secondaryButtonText = ref('Secondary Action');
|
||||
|
||||
const handlePrimary = () => {
|
||||
isLoading.value = true;
|
||||
setTimeout(() => {
|
||||
isLoading.value = false;
|
||||
}, 2000);
|
||||
};
|
||||
|
||||
const handleSecondary = () => {
|
||||
console.log('Secondary button clicked');
|
||||
};
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
}
|
||||
</style>
|
||||
Reference in New Issue
Block a user