mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-05 22:23:50 +00:00
Ho hum.
This commit is contained in:
@@ -0,0 +1,27 @@
|
|||||||
|
name: Test, build, release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
paths-ignore:
|
||||||
|
- '**.md'
|
||||||
|
- '**/release.yaml'
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
contents: write
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||||
|
with:
|
||||||
|
go-version: "1.24"
|
||||||
|
docker-enabled: false
|
||||||
|
rolling-release-tag: "v1"
|
||||||
|
secrets: inherit
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
TODO.md
|
||||||
|
bin/mcp-filepuff
|
||||||
|
mcp-filepuff
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
# yaml-language-server: $schema=https://goreleaser.com/static/schema.json
|
||||||
|
# vim: set ts=2 sw=2 tw=0 fo=cnqoj
|
||||||
|
|
||||||
|
version: 2
|
||||||
|
|
||||||
|
project_name: mcp-filepuff
|
||||||
|
|
||||||
|
before:
|
||||||
|
hooks:
|
||||||
|
- go mod tidy
|
||||||
|
- go generate ./...
|
||||||
|
|
||||||
|
builds:
|
||||||
|
- id: mcp-filepuff
|
||||||
|
main: ./cmd/mcp-filepuff
|
||||||
|
binary: mcp-filepuff
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=0
|
||||||
|
flags:
|
||||||
|
- -trimpath
|
||||||
|
ldflags:
|
||||||
|
- -s -w
|
||||||
|
- -X main.version={{.Version}}
|
||||||
|
- -X main.commit={{.Commit}}
|
||||||
|
- -X main.date={{.Date}}
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
- darwin
|
||||||
|
- windows
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
|
||||||
|
archives:
|
||||||
|
- id: default
|
||||||
|
formats:
|
||||||
|
- tar.gz
|
||||||
|
name_template: >-
|
||||||
|
{{ .ProjectName }}_
|
||||||
|
{{- .Version }}_
|
||||||
|
{{- .Os }}_
|
||||||
|
{{- .Arch }}
|
||||||
|
files:
|
||||||
|
- LICENSE
|
||||||
|
- README.md
|
||||||
|
format_overrides:
|
||||||
|
- goos: windows
|
||||||
|
formats:
|
||||||
|
- zip
|
||||||
|
|
||||||
|
checksum:
|
||||||
|
name_template: 'checksums.txt'
|
||||||
|
|
||||||
|
snapshot:
|
||||||
|
version_template: "{{ incpatch .Version }}-next"
|
||||||
|
|
||||||
|
changelog:
|
||||||
|
sort: asc
|
||||||
|
filters:
|
||||||
|
exclude:
|
||||||
|
- '^docs:'
|
||||||
|
- '^test:'
|
||||||
|
- '^chore:'
|
||||||
|
- Merge pull request
|
||||||
|
- Merge branch
|
||||||
|
|
||||||
|
release:
|
||||||
|
github:
|
||||||
|
owner: lukaszraczylo
|
||||||
|
name: filepuff-mcp
|
||||||
|
draft: false
|
||||||
|
prerelease: auto
|
||||||
|
name_template: "v{{.Version}}"
|
||||||
|
header: |
|
||||||
|
## MCP Filepuff v{{.Version}}
|
||||||
|
|
||||||
|
AST-aware file operations and LSP integration for Claude Code.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -sSL https://raw.githubusercontent.com/lukaszraczylo/filepuff-mcp/main/scripts/install.sh | bash
|
||||||
|
```
|
||||||
|
|
||||||
|
dockers_v2:
|
||||||
|
- images:
|
||||||
|
- "ghcr.io/lukaszraczylo/filepuff-mcp"
|
||||||
|
tags:
|
||||||
|
- "{{ .Version }}"
|
||||||
|
- "latest"
|
||||||
|
- "v1"
|
||||||
|
platforms:
|
||||||
|
- linux/amd64
|
||||||
|
- linux/arm64
|
||||||
|
dockerfile: Dockerfile.goreleaser
|
||||||
|
build_flag_templates:
|
||||||
|
- "--pull"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/lukaszraczylo/filepuff-mcp"
|
||||||
|
|
||||||
|
signs:
|
||||||
|
- cmd: cosign
|
||||||
|
signature: "${artifact}.sigstore.json"
|
||||||
|
args:
|
||||||
|
- sign-blob
|
||||||
|
- "--bundle=${signature}"
|
||||||
|
- "${artifact}"
|
||||||
|
- "--yes"
|
||||||
|
artifacts: checksum
|
||||||
|
output: true
|
||||||
|
|
||||||
|
docker_signs:
|
||||||
|
- cmd: cosign
|
||||||
|
artifacts: manifests
|
||||||
|
output: true
|
||||||
|
args:
|
||||||
|
- sign
|
||||||
|
- "${artifact}@${digest}"
|
||||||
|
- "--yes"
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
FROM alpine:3.21
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
|
||||||
|
RUN apk add --no-cache ca-certificates tzdata git ripgrep
|
||||||
|
|
||||||
|
COPY ${TARGETPLATFORM}/mcp-filepuff /usr/local/bin/mcp-filepuff
|
||||||
|
|
||||||
|
RUN chmod +x /usr/local/bin/mcp-filepuff
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
ENTRYPOINT ["/usr/local/bin/mcp-filepuff"]
|
||||||
|
CMD ["-workspace", "/workspace"]
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
.PHONY: build test lint clean install run deps
|
||||||
|
|
||||||
|
# Binary name
|
||||||
|
BINARY_NAME=mcp-filepuff
|
||||||
|
# Build directory
|
||||||
|
BUILD_DIR=bin
|
||||||
|
# Main package
|
||||||
|
MAIN_PKG=./cmd/mcp-filepuff
|
||||||
|
|
||||||
|
# Go parameters
|
||||||
|
GOCMD=go
|
||||||
|
GOBUILD=$(GOCMD) build
|
||||||
|
GOTEST=$(GOCMD) test
|
||||||
|
GOGET=$(GOCMD) get
|
||||||
|
GOMOD=$(GOCMD) mod
|
||||||
|
GOFMT=$(GOCMD) fmt
|
||||||
|
|
||||||
|
# Build flags
|
||||||
|
LDFLAGS=-ldflags "-s -w" -buildvcs=false
|
||||||
|
|
||||||
|
# Default target
|
||||||
|
all: deps test build
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
deps:
|
||||||
|
$(GOMOD) download
|
||||||
|
$(GOMOD) tidy
|
||||||
|
|
||||||
|
# Build the binary
|
||||||
|
build:
|
||||||
|
mkdir -p $(BUILD_DIR)
|
||||||
|
$(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PKG)
|
||||||
|
|
||||||
|
# Build for all platforms
|
||||||
|
build-all:
|
||||||
|
mkdir -p $(BUILD_DIR)
|
||||||
|
GOOS=darwin GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 $(MAIN_PKG)
|
||||||
|
GOOS=darwin GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 $(MAIN_PKG)
|
||||||
|
GOOS=linux GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 $(MAIN_PKG)
|
||||||
|
GOOS=linux GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 $(MAIN_PKG)
|
||||||
|
GOOS=windows GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe $(MAIN_PKG)
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
test:
|
||||||
|
$(GOTEST) -v -race -coverprofile=coverage.out ./...
|
||||||
|
|
||||||
|
# Run tests with short flag
|
||||||
|
test-short:
|
||||||
|
$(GOTEST) -v -short ./...
|
||||||
|
|
||||||
|
# Clean build artifacts
|
||||||
|
clean:
|
||||||
|
rm -rf $(BUILD_DIR)
|
||||||
|
rm -f coverage.out
|
||||||
|
|
||||||
|
# Install binary to GOPATH/bin
|
||||||
|
install: build
|
||||||
|
cp $(BUILD_DIR)/$(BINARY_NAME) $(GOPATH)/bin/
|
||||||
|
|
||||||
|
# Run the server (for development)
|
||||||
|
run: build
|
||||||
|
./$(BUILD_DIR)/$(BINARY_NAME) -log-level debug
|
||||||
|
|
||||||
|
# Run with specific workspace
|
||||||
|
run-workspace: build
|
||||||
|
./$(BUILD_DIR)/$(BINARY_NAME) -workspace $(WORKSPACE) -log-level debug
|
||||||
|
|
||||||
|
# Show help
|
||||||
|
help:
|
||||||
|
@echo "Available targets:"
|
||||||
|
@echo " deps - Download and tidy dependencies"
|
||||||
|
@echo " build - Build the binary"
|
||||||
|
@echo " build-all - Build for all platforms"
|
||||||
|
@echo " test - Run tests with coverage"
|
||||||
|
@echo " test-short - Run short tests"
|
||||||
|
@echo " lint - Run linters"
|
||||||
|
@echo " clean - Clean build artifacts"
|
||||||
|
@echo " install - Install binary to GOPATH/bin"
|
||||||
|
@echo " run - Build and run the server"
|
||||||
|
@echo " run-workspace - Run with specific workspace (WORKSPACE=/path)"
|
||||||
@@ -0,0 +1,572 @@
|
|||||||
|
# mcp-filepuff
|
||||||
|
|
||||||
|
A Go-based MCP (Model Context Protocol) server for Claude Code providing intelligent file operations with fast search, AST-aware querying, LSP integration, and safe editing capabilities.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Fast Text Search**: Powered by ripgrep for blazing-fast code search with regex support
|
||||||
|
- **AST-Aware File Reading**: Read files with symbol extraction using Tree-sitter
|
||||||
|
- **Code Pattern Matching**: Query code using patterns with capture placeholders
|
||||||
|
- **LSP Integration**: Go-to-definition, find references, and symbol info via language servers
|
||||||
|
- **Safe Editing**: AST-aware file editing with syntax validation and preview
|
||||||
|
- **Multi-Language Support**: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue, React
|
||||||
|
- **Token Efficient**: Optimized for minimal token usage with symbols-only mode and output limiting
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### Binary Releases (Recommended)
|
||||||
|
|
||||||
|
Download pre-built binaries from the [releases page](https://github.com/lukaszraczylo/filepuff-mcp/releases):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# macOS (ARM64)
|
||||||
|
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-darwin-arm64.tar.gz | tar xz
|
||||||
|
sudo mv mcp-filepuff /usr/local/bin/
|
||||||
|
|
||||||
|
# macOS (AMD64)
|
||||||
|
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-darwin-amd64.tar.gz | tar xz
|
||||||
|
sudo mv mcp-filepuff /usr/local/bin/
|
||||||
|
|
||||||
|
# Linux (ARM64)
|
||||||
|
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-linux-arm64.tar.gz | tar xz
|
||||||
|
sudo mv mcp-filepuff /usr/local/bin/
|
||||||
|
|
||||||
|
# Linux (AMD64)
|
||||||
|
curl -L https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-linux-amd64.tar.gz | tar xz
|
||||||
|
sudo mv mcp-filepuff /usr/local/bin/
|
||||||
|
|
||||||
|
# Windows (PowerShell)
|
||||||
|
Invoke-WebRequest -Uri "https://github.com/lukaszraczylo/filepuff-mcp/releases/latest/download/mcp-filepuff-windows-amd64.zip" -OutFile mcp-filepuff.zip
|
||||||
|
Expand-Archive mcp-filepuff.zip -DestinationPath .
|
||||||
|
Move-Item mcp-filepuff.exe C:\Windows\System32\
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- [ripgrep](https://github.com/BurntSushi/ripgrep) (`rg`) installed and in PATH
|
||||||
|
|
||||||
|
### Optional Dependencies (for LSP features)
|
||||||
|
|
||||||
|
- `gopls` - Go language server
|
||||||
|
- `typescript-language-server` - TypeScript/JavaScript language server
|
||||||
|
- `pylsp` - Python language server
|
||||||
|
- `clangd` - C/C++ language server
|
||||||
|
|
||||||
|
### Build from Source
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/lukaszraczylo/filepuff-mcp.git
|
||||||
|
cd filepuff-mcp
|
||||||
|
make build
|
||||||
|
```
|
||||||
|
|
||||||
|
The binary will be available at `bin/mcp-filepuff`.
|
||||||
|
|
||||||
|
### Install via Claude Code
|
||||||
|
|
||||||
|
After downloading or building the binary, configure it in Claude Code:
|
||||||
|
|
||||||
|
1. **Create or edit `~/.config/claude-code/claude_desktop_config.json`**:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"filepuff": {
|
||||||
|
"command": "/usr/local/bin/mcp-filepuff",
|
||||||
|
"args": ["-workspace", "/path/to/your/workspace"],
|
||||||
|
"env": {
|
||||||
|
"MCP_LOG_LEVEL": "info"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Restart Claude Code** to load the MCP server
|
||||||
|
|
||||||
|
3. **Verify** by asking Claude: "Can you ping the filepuff server?"
|
||||||
|
|
||||||
|
See the [Claude Code MCP documentation](https://code.claude.com/docs/en/mcp) for more details.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Running the Server (Standalone)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./bin/mcp-filepuff -workspace /path/to/workspace
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line Options
|
||||||
|
|
||||||
|
- `-workspace string`: Workspace root directory (default: current directory)
|
||||||
|
- `-log-level string`: Log level - debug, info, warn, error (default: "info")
|
||||||
|
- `-log-file string`: Log file path (default: stderr)
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
The server can be configured via:
|
||||||
|
|
||||||
|
1. **Environment Variables**:
|
||||||
|
- `MCP_WORKSPACE_ROOT`: Workspace root directory
|
||||||
|
- `MCP_LSP_TIMEOUT`: LSP timeout duration (e.g., "10m")
|
||||||
|
- `MCP_SEARCH_TIMEOUT`: Search timeout duration (e.g., "1m")
|
||||||
|
- `MCP_ENABLE_LSP`: Enable LSP features ("true"/"false")
|
||||||
|
- `MCP_FOLLOW_SYMLINKS`: Follow symbolic links ("true"/"false")
|
||||||
|
- `MCP_RESPECT_GITIGNORE`: Respect .gitignore files ("true"/"false")
|
||||||
|
|
||||||
|
2. **Config File**: Create `.mcp-filepuff.json` in the workspace root:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"enable_lsp": true,
|
||||||
|
"follow_symlinks": true,
|
||||||
|
"respect_gitignore": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Code Integration
|
||||||
|
|
||||||
|
To use mcp-filepuff with Claude Code, add it to your MCP server configuration:
|
||||||
|
|
||||||
|
1. **Global Configuration** (`~/.config/claude-code/mcp_servers.json`):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"filepuff": {
|
||||||
|
"command": "/path/to/mcp-filepuff",
|
||||||
|
"args": ["-workspace", "/path/to/your/workspace"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Project-specific Configuration** (`.claude/mcp_servers.json` in your project):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"filepuff": {
|
||||||
|
"command": "mcp-filepuff",
|
||||||
|
"args": ["-workspace", "."]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
After configuration, Claude Code will have access to all mcp-filepuff tools for enhanced file operations.
|
||||||
|
|
||||||
|
### Making Claude Code Prefer Filepuff Tools
|
||||||
|
|
||||||
|
By default, Claude Code uses its built-in file operation tools. To make it prefer filepuff's enhanced tools, add instructions to your `CLAUDE.md` file:
|
||||||
|
|
||||||
|
**Global Configuration** (`~/.claude/CLAUDE.md`):
|
||||||
|
```markdown
|
||||||
|
# MCP Tool Preferences
|
||||||
|
|
||||||
|
When performing file operations, prefer filepuff MCP tools over built-in equivalents:
|
||||||
|
|
||||||
|
| Operation | Use This | Instead Of |
|
||||||
|
|-----------|----------|------------|
|
||||||
|
| Read files | `mcp__filepuff__file_read` | Read |
|
||||||
|
| Search content | `mcp__filepuff__file_search` | Grep |
|
||||||
|
| AST pattern search | `mcp__filepuff__ast_query` | Grep/Glob |
|
||||||
|
| Edit files | `mcp__filepuff__edit_preview` + `mcp__filepuff__edit_apply` | Edit |
|
||||||
|
| Find definitions | `mcp__filepuff__find_definition` | Grep |
|
||||||
|
| Find references | `mcp__filepuff__find_references` | Grep |
|
||||||
|
| Symbol info | `mcp__filepuff__symbol_at` | - |
|
||||||
|
|
||||||
|
Benefits of filepuff tools:
|
||||||
|
- AST-aware operations that understand code structure
|
||||||
|
- LSP integration for accurate symbol navigation
|
||||||
|
- Syntax validation before applying edits
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also place this in a project-specific `CLAUDE.md` or `.claude/CLAUDE.md` file.
|
||||||
|
|
||||||
|
**Optional: Restrict Built-in Tools**
|
||||||
|
|
||||||
|
To enforce filepuff usage, add permission restrictions in `.claude/settings.json`:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"deny": ["Read", "Edit", "Grep"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
|
||||||
|
### `ping`
|
||||||
|
Health check tool to verify the server is running.
|
||||||
|
|
||||||
|
**Returns**: "pong"
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `file_search`
|
||||||
|
Search for text patterns in files using ripgrep.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `pattern` (required): The search pattern (regex by default)
|
||||||
|
- `paths`: Paths to search in (defaults to workspace root)
|
||||||
|
- `file_types`: File types to search (e.g., ["go", "ts", "py"])
|
||||||
|
- `ignore_case`: Case insensitive search
|
||||||
|
- `regex`: Treat pattern as regex (default: true)
|
||||||
|
- `context_lines`: Number of context lines around matches (default: 2)
|
||||||
|
- `max_results`: Maximum number of results to return
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `file_read`
|
||||||
|
Read a file's contents with optional line range and AST symbol summary. Supports token-efficient modes for AI assistants.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `path` (required): Path to the file to read
|
||||||
|
- `line_start`: Starting line number (1-indexed)
|
||||||
|
- `line_end`: Ending line number (inclusive)
|
||||||
|
- `include_ast`: Include AST symbol summary (functions, classes, types, etc.)
|
||||||
|
- `symbols_only`: **[Token Efficient]** Return only symbol summary without file content. Requires `include_ast=true`. Reduces token usage by ~90-98%.
|
||||||
|
- `max_lines`: **[Token Efficient]** Maximum number of lines to return. Useful for large files where you only need a preview.
|
||||||
|
|
||||||
|
**Example Output with AST**:
|
||||||
|
```
|
||||||
|
**server.go** (245 lines, go)
|
||||||
|
|
||||||
|
Symbols:
|
||||||
|
func NewServer L12
|
||||||
|
func (Server).Start L45
|
||||||
|
struct Server L5
|
||||||
|
type Config L150
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
12│ func NewServer(config Config) *Server {
|
||||||
|
13│ return &Server{config: config}
|
||||||
|
14│ }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Token-Efficient Example (symbols_only)**:
|
||||||
|
```json
|
||||||
|
{"path": "server.go", "include_ast": true, "symbols_only": true}
|
||||||
|
```
|
||||||
|
Returns only the symbol summary (~500 tokens instead of ~8,000 tokens for the full file):
|
||||||
|
```
|
||||||
|
**server.go** (245 lines, go)
|
||||||
|
|
||||||
|
Symbols:
|
||||||
|
func NewServer L12
|
||||||
|
func (Server).Start L45
|
||||||
|
struct Server L5
|
||||||
|
type Config L150
|
||||||
|
```
|
||||||
|
|
||||||
|
**Token-Efficient Example (max_lines)**:
|
||||||
|
```json
|
||||||
|
{"path": "server.go", "max_lines": 50}
|
||||||
|
```
|
||||||
|
Returns first 50 lines with a truncation notice if the file is longer.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `ast_query`
|
||||||
|
Search for AST patterns in code files using structural pattern matching.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `pattern` (required): Code pattern with placeholders
|
||||||
|
- `$NAME` - capture single node
|
||||||
|
- `$$$ARGS` - capture multiple nodes
|
||||||
|
- `$_` - wildcard (match but don't capture)
|
||||||
|
- `language` (required): Target language (go, typescript, javascript, python, c, cpp)
|
||||||
|
- `paths`: Paths to search in
|
||||||
|
- `name_matches`: Regex pattern to filter by name
|
||||||
|
- `name_exact`: Exact name to match
|
||||||
|
- `kind_in`: Node types to match (e.g., function_declaration)
|
||||||
|
- `max_results`: Maximum number of results (default: 100)
|
||||||
|
|
||||||
|
**Examples**:
|
||||||
|
```json
|
||||||
|
// Find all Go functions returning error
|
||||||
|
{"pattern": "func $NAME($$$ARGS) error", "language": "go"}
|
||||||
|
|
||||||
|
// Find all Python classes
|
||||||
|
{"pattern": "class $NAME: $$$BODY", "language": "python"}
|
||||||
|
|
||||||
|
// Find React components (functions starting with uppercase)
|
||||||
|
{"pattern": "function $NAME($PROPS) { $$$BODY }", "language": "javascript", "name_matches": "^[A-Z]"}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `symbol_at`
|
||||||
|
Get information about the symbol at a specific position. Uses LSP when available, falls back to AST.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `file` (required): Path to the file
|
||||||
|
- `line` (required): Line number (1-indexed)
|
||||||
|
- `column` (required): Column number (1-indexed)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `find_definition`
|
||||||
|
Find the definition of the symbol at a specific position.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `file` (required): Path to the file
|
||||||
|
- `line` (required): Line number (1-indexed)
|
||||||
|
- `column` (required): Column number (1-indexed)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `find_references`
|
||||||
|
Find all references to the symbol at a specific position.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `file` (required): Path to the file
|
||||||
|
- `line` (required): Line number (1-indexed)
|
||||||
|
- `column` (required): Column number (1-indexed)
|
||||||
|
- `include_declaration`: Include the declaration in results (default: true)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `edit_preview`
|
||||||
|
Preview an edit without applying it. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++), and text-based editing for other files (Markdown, JSON, YAML, config files, etc.).
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `file` (required): Path to the file to edit
|
||||||
|
- `operation` (required): Edit operation (replace, insert_before, insert_after, delete)
|
||||||
|
- `new_content`: New content (required for replace/insert operations)
|
||||||
|
|
||||||
|
**AST-mode selectors** (for code files):
|
||||||
|
- `selector_kind`: Node type to match (e.g., function_declaration)
|
||||||
|
- `selector_name`: Name of the symbol to match
|
||||||
|
|
||||||
|
**Shared selectors**:
|
||||||
|
- `selector_line`: Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range.
|
||||||
|
- `selector_index`: Index of the match to use if multiple matches found (default: 0)
|
||||||
|
|
||||||
|
**Text-mode selectors** (for non-code files or explicit text matching):
|
||||||
|
- `selector_line_end`: End line number for range selection
|
||||||
|
- `selector_text`: Exact text to match (must be unique or use selector_index)
|
||||||
|
- `selector_pattern`: Regex pattern to match
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `edit_apply`
|
||||||
|
Apply an edit to a file. Uses AST-aware editing for code files with syntax validation, and text-based editing for other files.
|
||||||
|
|
||||||
|
**Parameters**: Same as `edit_preview`
|
||||||
|
|
||||||
|
**Example (AST mode - Go file)**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"file": "server.go",
|
||||||
|
"operation": "replace",
|
||||||
|
"selector_kind": "function_declaration",
|
||||||
|
"selector_name": "Hello",
|
||||||
|
"new_content": "func Hello() {\n\tprintln(\"New Hello\")\n}"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example (Text mode - Markdown file)**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"file": "README.md",
|
||||||
|
"operation": "replace",
|
||||||
|
"selector_text": "## Installation",
|
||||||
|
"new_content": "## Getting Started"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example (Text mode - JSON with regex)**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"file": "package.json",
|
||||||
|
"operation": "replace",
|
||||||
|
"selector_pattern": "\"version\":\\s*\"[^\"]+\"",
|
||||||
|
"new_content": "\"version\": \"2.0.0\""
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example (Text mode - Line range)**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"file": "config.yaml",
|
||||||
|
"operation": "replace",
|
||||||
|
"selector_line": 5,
|
||||||
|
"selector_line_end": 10,
|
||||||
|
"new_content": "database:\n host: production.db.example.com\n port: 5432"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Languages
|
||||||
|
|
||||||
|
| Language | Extensions | Search | AST | LSP | Edit |
|
||||||
|
|----------|-----------|--------|-----|-----|------|
|
||||||
|
| Go | .go | Yes | Yes | gopls | Yes |
|
||||||
|
| TypeScript | .ts, .tsx | Yes | Yes | typescript-language-server | Yes |
|
||||||
|
| JavaScript | .js, .jsx, .mjs, .cjs | Yes | Yes | typescript-language-server | Yes |
|
||||||
|
| Python | .py, .pyw | Yes | Yes | pylsp | Yes |
|
||||||
|
| C | .c, .h | Yes | Yes | clangd | Yes |
|
||||||
|
| C++ | .cpp, .cc, .cxx, .hpp, .hxx | Yes | Yes | clangd | Yes |
|
||||||
|
| HTML | .html, .htm | Yes | Yes | - | Yes |
|
||||||
|
| Vue | .vue | Yes | Yes* | - | Yes |
|
||||||
|
| React | .jsx, .tsx | Yes | Yes | typescript-language-server | Yes |
|
||||||
|
|
||||||
|
\* Vue uses HTML parser for template sections
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Build
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lint
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make lint
|
||||||
|
```
|
||||||
|
|
||||||
|
### Clean
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make clean
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
├── cmd/
|
||||||
|
│ └── mcp-filepuff/ # Main entry point
|
||||||
|
├── internal/
|
||||||
|
│ ├── config/ # Configuration management
|
||||||
|
│ ├── edit/ # AST-aware editing engine
|
||||||
|
│ ├── lsp/ # LSP client and manager
|
||||||
|
│ ├── parser/ # Tree-sitter integration
|
||||||
|
│ ├── query/ # AST pattern matching
|
||||||
|
│ ├── search/ # Ripgrep wrapper
|
||||||
|
│ └── server/ # MCP server implementation
|
||||||
|
├── pkg/
|
||||||
|
│ └── protocol/ # Shared types
|
||||||
|
├── .github/
|
||||||
|
│ └── workflows/ # CI configuration
|
||||||
|
├── Makefile # Build automation
|
||||||
|
├── .goreleaser.yaml # Release configuration
|
||||||
|
└── TODO.md # Implementation roadmap
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ MCP Server │
|
||||||
|
├─────────────────────────────────────────────────────────┤
|
||||||
|
│ Tools: file_search, file_read, ast_query, symbol_at, │
|
||||||
|
│ find_definition, find_references, │
|
||||||
|
│ edit_preview, edit_apply, ping │
|
||||||
|
├─────────────────────────────────────────────────────────┤
|
||||||
|
│ Core Engines │
|
||||||
|
├───────────┬─────────────┬────────────┬─────────────────┤
|
||||||
|
│ Search │ Parser │ LSP │ Edit │
|
||||||
|
│ (ripgrep) │(tree-sitter)│ Manager │ Engine │
|
||||||
|
└───────────┴─────────────┴────────────┴─────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### "ripgrep not found" Error
|
||||||
|
The `file_search` tool requires ripgrep (`rg`) to be installed and in your PATH.
|
||||||
|
|
||||||
|
**Solution**: Install ripgrep:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
brew install ripgrep
|
||||||
|
|
||||||
|
# Ubuntu/Debian
|
||||||
|
sudo apt install ripgrep
|
||||||
|
|
||||||
|
# Windows (with Chocolatey)
|
||||||
|
choco install ripgrep
|
||||||
|
```
|
||||||
|
|
||||||
|
#### LSP Features Not Working
|
||||||
|
LSP features (go-to-definition, find-references, symbol-at) require language servers to be installed.
|
||||||
|
|
||||||
|
**Solution**: Install the appropriate language server:
|
||||||
|
```bash
|
||||||
|
# Go
|
||||||
|
go install golang.org/x/tools/gopls@latest
|
||||||
|
|
||||||
|
# TypeScript/JavaScript
|
||||||
|
npm install -g typescript-language-server typescript
|
||||||
|
|
||||||
|
# Python
|
||||||
|
pip install python-lsp-server
|
||||||
|
|
||||||
|
# C/C++
|
||||||
|
# macOS: brew install llvm
|
||||||
|
# Ubuntu: sudo apt install clangd
|
||||||
|
```
|
||||||
|
|
||||||
|
#### AST Parsing Fails for Valid Code
|
||||||
|
If AST parsing fails for code that compiles correctly, it may be a Tree-sitter grammar limitation.
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
- Ensure the file has the correct extension for its language
|
||||||
|
- Check for unusual syntax that may not be supported by the Tree-sitter grammar
|
||||||
|
- Try using the `file_search` tool instead for text-based operations
|
||||||
|
|
||||||
|
#### Edit Operations Fail with "syntax error"
|
||||||
|
The edit engine validates syntax before and after edits.
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
- Ensure `new_content` is syntactically valid for the target language
|
||||||
|
- Use `edit_preview` first to see the proposed changes
|
||||||
|
- Check that the selector matches exactly one node
|
||||||
|
|
||||||
|
#### Timeout Errors
|
||||||
|
Long-running operations may timeout.
|
||||||
|
|
||||||
|
**Solution**: Configure timeout values via environment variables:
|
||||||
|
```bash
|
||||||
|
export MCP_LSP_TIMEOUT="10m" # LSP operations (default: 5m)
|
||||||
|
export MCP_SEARCH_TIMEOUT="2m" # Search operations (default: 30s)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Permission Denied Errors
|
||||||
|
The server needs read/write access to workspace files.
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
- Ensure the user running the server has appropriate file permissions
|
||||||
|
- Check that the workspace path is correct and accessible
|
||||||
|
- On macOS, grant terminal/IDE full disk access if needed
|
||||||
|
|
||||||
|
### Debug Logging
|
||||||
|
|
||||||
|
Enable debug logging to troubleshoot issues:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./bin/mcp-filepuff -workspace /path/to/workspace -log-level debug -log-file /tmp/mcp-filepuff.log
|
||||||
|
```
|
||||||
|
|
||||||
|
### Verifying Installation
|
||||||
|
|
||||||
|
Use the `ping` tool to verify the server is running correctly:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"tool": "ping"}
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected response: `"pong"`
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
// Package main is the entry point for the MCP file operations server.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Parse command line flags
|
||||||
|
var (
|
||||||
|
workspaceRoot = flag.String("workspace", "", "Workspace root directory (default: current directory)")
|
||||||
|
logLevel = flag.String("log-level", "info", "Log level (debug, info, warn, error)")
|
||||||
|
logFile = flag.String("log-file", "", "Log file path (default: stderr)")
|
||||||
|
)
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Set up logging
|
||||||
|
logger := setupLogger(*logLevel, *logFile)
|
||||||
|
|
||||||
|
// Load configuration
|
||||||
|
cfg, err := config.Load(*workspaceRoot)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to load configuration", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("configuration loaded",
|
||||||
|
"workspace_root", cfg.WorkspaceRoot,
|
||||||
|
"lsp_enabled", cfg.EnableLSP,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create and run server
|
||||||
|
srv, err := server.New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to create server", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := srv.Run(ctx); err != nil {
|
||||||
|
logger.Error("server error", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupLogger(level string, logFile string) *slog.Logger {
|
||||||
|
var logLevel slog.Level
|
||||||
|
switch level {
|
||||||
|
case "debug":
|
||||||
|
logLevel = slog.LevelDebug
|
||||||
|
case "warn":
|
||||||
|
logLevel = slog.LevelWarn
|
||||||
|
case "error":
|
||||||
|
logLevel = slog.LevelError
|
||||||
|
default:
|
||||||
|
logLevel = slog.LevelInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &slog.HandlerOptions{
|
||||||
|
Level: logLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
var handler slog.Handler
|
||||||
|
if logFile != "" {
|
||||||
|
f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to stderr
|
||||||
|
handler = slog.NewJSONHandler(os.Stderr, opts)
|
||||||
|
} else {
|
||||||
|
handler = slog.NewJSONHandler(f, opts)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use stderr for MCP servers (stdout is for protocol messages)
|
||||||
|
handler = slog.NewJSONHandler(os.Stderr, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return slog.New(handler)
|
||||||
|
}
|
||||||
+1479
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
|||||||
|
module github.com/lukaszraczylo/mcp-filepuff
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0
|
||||||
|
github.com/goccy/go-json v0.10.5
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||||
|
github.com/mark3labs/mcp-go v0.43.2
|
||||||
|
github.com/sergi/go-diff v1.4.0
|
||||||
|
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||||
|
github.com/buger/jsonparser v1.1.1 // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/invopop/jsonschema v0.13.0 // indirect
|
||||||
|
github.com/mailru/easyjson v0.9.1 // indirect
|
||||||
|
github.com/spf13/cast v1.10.0 // indirect
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
|
)
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||||
|
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||||
|
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
|
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||||
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
|
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
|
||||||
|
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||||
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
|
||||||
|
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||||
|
github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I=
|
||||||
|
github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||||
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
|
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||||
|
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||||
|
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4=
|
||||||
|
github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw=
|
||||||
|
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||||
|
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
// Package config provides configuration management for the MCP file operations server.
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
json "github.com/goccy/go-json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds all configuration options for the MCP server.
|
||||||
|
type Config struct {
|
||||||
|
Formatters map[string]string `json:"formatters"`
|
||||||
|
WorkspaceRoot string `json:"workspace_root"`
|
||||||
|
LSPTimeout time.Duration `json:"lsp_timeout"`
|
||||||
|
SearchTimeout time.Duration `json:"search_timeout"`
|
||||||
|
MaxFileSize int64 `json:"max_file_size"`
|
||||||
|
MaxSearchResults int `json:"max_search_results"`
|
||||||
|
MaxEditSize int64 `json:"max_edit_size"`
|
||||||
|
EnableLSP bool `json:"enable_lsp"`
|
||||||
|
FollowSymlinks bool `json:"follow_symlinks"`
|
||||||
|
RespectGitignore bool `json:"respect_gitignore"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default values for configuration.
|
||||||
|
const (
|
||||||
|
DefaultLSPTimeout = 5 * time.Minute
|
||||||
|
DefaultSearchTimeout = 30 * time.Second
|
||||||
|
DefaultMaxFileSize = 10 * 1024 * 1024 // 10 MB
|
||||||
|
DefaultMaxSearchResults = 1000
|
||||||
|
DefaultMaxEditSize = 100 * 1024 // 100 KB
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default returns a Config with default values.
|
||||||
|
func Default() *Config {
|
||||||
|
return &Config{
|
||||||
|
WorkspaceRoot: ".",
|
||||||
|
LSPTimeout: DefaultLSPTimeout,
|
||||||
|
SearchTimeout: DefaultSearchTimeout,
|
||||||
|
MaxFileSize: DefaultMaxFileSize,
|
||||||
|
MaxSearchResults: DefaultMaxSearchResults,
|
||||||
|
MaxEditSize: DefaultMaxEditSize,
|
||||||
|
EnableLSP: true,
|
||||||
|
Formatters: make(map[string]string),
|
||||||
|
FollowSymlinks: true,
|
||||||
|
RespectGitignore: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads configuration from environment variables and optional config file.
|
||||||
|
// Priority: CLI flags > environment variables > config file > defaults.
|
||||||
|
func Load(workspaceRoot string) (*Config, error) {
|
||||||
|
cfg := Default()
|
||||||
|
|
||||||
|
// Set workspace root
|
||||||
|
if workspaceRoot != "" {
|
||||||
|
absPath, err := filepath.Abs(workspaceRoot)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg.WorkspaceRoot = absPath
|
||||||
|
} else if cwd, err := os.Getwd(); err == nil {
|
||||||
|
cfg.WorkspaceRoot = cwd
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to load from config file in workspace root
|
||||||
|
configPath := filepath.Join(cfg.WorkspaceRoot, ".mcp-filepuff.json")
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
if err := json.Unmarshal(data, cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override from environment variables
|
||||||
|
cfg.loadFromEnv()
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) loadFromEnv() {
|
||||||
|
if v := os.Getenv("MCP_WORKSPACE_ROOT"); v != "" {
|
||||||
|
if absPath, err := filepath.Abs(v); err == nil {
|
||||||
|
c.WorkspaceRoot = absPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv("MCP_LSP_TIMEOUT"); v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil {
|
||||||
|
c.LSPTimeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv("MCP_SEARCH_TIMEOUT"); v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil {
|
||||||
|
c.SearchTimeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv("MCP_ENABLE_LSP"); v == "false" || v == "0" {
|
||||||
|
c.EnableLSP = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv("MCP_FOLLOW_SYMLINKS"); v == "false" || v == "0" {
|
||||||
|
c.FollowSymlinks = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := os.Getenv("MCP_RESPECT_GITIGNORE"); v == "false" || v == "0" {
|
||||||
|
c.RespectGitignore = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPathAllowed checks if a path is within the workspace root.
|
||||||
|
// It resolves symlinks to prevent path traversal attacks.
|
||||||
|
func (c *Config) IsPathAllowed(path string) bool {
|
||||||
|
// Get absolute path of the target
|
||||||
|
absPath, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get absolute path of workspace root
|
||||||
|
absRoot, err := filepath.Abs(c.WorkspaceRoot)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always try to resolve workspace root symlinks for consistent comparison
|
||||||
|
evalRoot, evalErr := filepath.EvalSymlinks(absRoot)
|
||||||
|
if evalErr == nil {
|
||||||
|
absRoot = evalRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the target path, try to resolve symlinks
|
||||||
|
evalPath, evalErr := filepath.EvalSymlinks(absPath)
|
||||||
|
if evalErr == nil {
|
||||||
|
// File exists and was resolved
|
||||||
|
absPath = evalPath
|
||||||
|
} else {
|
||||||
|
// File doesn't exist - resolve parent directories to match workspace root resolution
|
||||||
|
// Walk up the tree until we find an existing directory
|
||||||
|
dir := filepath.Dir(absPath)
|
||||||
|
remaining := filepath.Base(absPath)
|
||||||
|
|
||||||
|
for dir != "." && dir != "/" && dir != absPath {
|
||||||
|
evalDir, evalErr := filepath.EvalSymlinks(dir)
|
||||||
|
if evalErr == nil {
|
||||||
|
// Found an existing directory, reconstruct the path
|
||||||
|
absPath = filepath.Join(evalDir, remaining)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Move up one level
|
||||||
|
newDir := filepath.Dir(dir)
|
||||||
|
if newDir == dir {
|
||||||
|
// Reached the root without finding an existing directory
|
||||||
|
break
|
||||||
|
}
|
||||||
|
remaining = filepath.Join(filepath.Base(dir), remaining)
|
||||||
|
dir = newDir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute relative path
|
||||||
|
rel, err := filepath.Rel(absRoot, absPath)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the path is within workspace (doesn't start with ..)
|
||||||
|
// This prevents both "../" attacks and symlink bypasses
|
||||||
|
// Also reject empty relative path (which means it's the workspace root itself)
|
||||||
|
return rel != "." && !strings.HasPrefix(rel, "..")
|
||||||
|
}
|
||||||
@@ -0,0 +1,184 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefault(t *testing.T) {
|
||||||
|
cfg := Default()
|
||||||
|
|
||||||
|
if cfg.WorkspaceRoot != "." {
|
||||||
|
t.Errorf("expected default workspace root '.', got %q", cfg.WorkspaceRoot)
|
||||||
|
}
|
||||||
|
if cfg.LSPTimeout != DefaultLSPTimeout {
|
||||||
|
t.Errorf("expected default LSP timeout %v, got %v", DefaultLSPTimeout, cfg.LSPTimeout)
|
||||||
|
}
|
||||||
|
if cfg.SearchTimeout != DefaultSearchTimeout {
|
||||||
|
t.Errorf("expected default search timeout %v, got %v", DefaultSearchTimeout, cfg.SearchTimeout)
|
||||||
|
}
|
||||||
|
if cfg.MaxFileSize != DefaultMaxFileSize {
|
||||||
|
t.Errorf("expected default max file size %d, got %d", DefaultMaxFileSize, cfg.MaxFileSize)
|
||||||
|
}
|
||||||
|
if cfg.MaxSearchResults != DefaultMaxSearchResults {
|
||||||
|
t.Errorf("expected default max search results %d, got %d", DefaultMaxSearchResults, cfg.MaxSearchResults)
|
||||||
|
}
|
||||||
|
if cfg.MaxEditSize != DefaultMaxEditSize {
|
||||||
|
t.Errorf("expected default max edit size %d, got %d", DefaultMaxEditSize, cfg.MaxEditSize)
|
||||||
|
}
|
||||||
|
if !cfg.EnableLSP {
|
||||||
|
t.Error("expected EnableLSP to be true by default")
|
||||||
|
}
|
||||||
|
if !cfg.FollowSymlinks {
|
||||||
|
t.Error("expected FollowSymlinks to be true by default")
|
||||||
|
}
|
||||||
|
if !cfg.RespectGitignore {
|
||||||
|
t.Error("expected RespectGitignore to be true by default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
// Create a temporary directory for workspace
|
||||||
|
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
cfg, err := Load(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
absPath, _ := filepath.Abs(tmpDir)
|
||||||
|
if cfg.WorkspaceRoot != absPath {
|
||||||
|
t.Errorf("expected workspace root %q, got %q", absPath, cfg.WorkspaceRoot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadFromEnv(t *testing.T) {
|
||||||
|
// Save original env values
|
||||||
|
origLSPTimeout := os.Getenv("MCP_LSP_TIMEOUT")
|
||||||
|
origSearchTimeout := os.Getenv("MCP_SEARCH_TIMEOUT")
|
||||||
|
origEnableLSP := os.Getenv("MCP_ENABLE_LSP")
|
||||||
|
origFollowSymlinks := os.Getenv("MCP_FOLLOW_SYMLINKS")
|
||||||
|
origRespectGitignore := os.Getenv("MCP_RESPECT_GITIGNORE")
|
||||||
|
|
||||||
|
// Restore env after test
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.Setenv("MCP_LSP_TIMEOUT", origLSPTimeout)
|
||||||
|
_ = os.Setenv("MCP_SEARCH_TIMEOUT", origSearchTimeout)
|
||||||
|
_ = os.Setenv("MCP_ENABLE_LSP", origEnableLSP)
|
||||||
|
_ = os.Setenv("MCP_FOLLOW_SYMLINKS", origFollowSymlinks)
|
||||||
|
_ = os.Setenv("MCP_RESPECT_GITIGNORE", origRespectGitignore)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Set test env values
|
||||||
|
_ = os.Setenv("MCP_LSP_TIMEOUT", "10m")
|
||||||
|
_ = os.Setenv("MCP_SEARCH_TIMEOUT", "1m")
|
||||||
|
_ = os.Setenv("MCP_ENABLE_LSP", "false")
|
||||||
|
_ = os.Setenv("MCP_FOLLOW_SYMLINKS", "0")
|
||||||
|
_ = os.Setenv("MCP_RESPECT_GITIGNORE", "false")
|
||||||
|
|
||||||
|
cfg := Default()
|
||||||
|
cfg.loadFromEnv()
|
||||||
|
|
||||||
|
if cfg.LSPTimeout != 10*time.Minute {
|
||||||
|
t.Errorf("expected LSP timeout 10m, got %v", cfg.LSPTimeout)
|
||||||
|
}
|
||||||
|
if cfg.SearchTimeout != 1*time.Minute {
|
||||||
|
t.Errorf("expected search timeout 1m, got %v", cfg.SearchTimeout)
|
||||||
|
}
|
||||||
|
if cfg.EnableLSP {
|
||||||
|
t.Error("expected EnableLSP to be false")
|
||||||
|
}
|
||||||
|
if cfg.FollowSymlinks {
|
||||||
|
t.Error("expected FollowSymlinks to be false")
|
||||||
|
}
|
||||||
|
if cfg.RespectGitignore {
|
||||||
|
t.Error("expected RespectGitignore to be false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPathAllowed(t *testing.T) {
|
||||||
|
// Create a temporary directory
|
||||||
|
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
cfg := Default()
|
||||||
|
cfg.WorkspaceRoot = tmpDir
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
allowed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "file in workspace",
|
||||||
|
path: filepath.Join(tmpDir, "test.go"),
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested file in workspace",
|
||||||
|
path: filepath.Join(tmpDir, "subdir", "test.go"),
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path outside workspace",
|
||||||
|
path: "/etc/passwd",
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relative path traversal",
|
||||||
|
path: filepath.Join(tmpDir, "..", "outside.txt"),
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := cfg.IsPathAllowed(tt.path)
|
||||||
|
if result != tt.allowed {
|
||||||
|
t.Errorf("IsPathAllowed(%q) = %v, want %v", tt.path, result, tt.allowed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWithConfigFile(t *testing.T) {
|
||||||
|
// Create a temporary directory
|
||||||
|
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
// Write config file
|
||||||
|
configPath := filepath.Join(tmpDir, ".mcp-filepuff.json")
|
||||||
|
configContent := `{
|
||||||
|
"enable_lsp": false,
|
||||||
|
"follow_symlinks": false
|
||||||
|
}`
|
||||||
|
err = os.WriteFile(configPath, []byte(configContent), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg *Config
|
||||||
|
cfg, err = Load(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.EnableLSP {
|
||||||
|
t.Error("expected EnableLSP to be false from config file")
|
||||||
|
}
|
||||||
|
if cfg.FollowSymlinks {
|
||||||
|
t.Error("expected FollowSymlinks to be false from config file")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestIsPathAllowed_SymlinkSecurity tests the symlink security fix.
|
||||||
|
func TestIsPathAllowed_SymlinkSecurity(t *testing.T) {
|
||||||
|
// Create a temporary workspace
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
workspace := filepath.Join(tmpDir, "workspace")
|
||||||
|
outside := filepath.Join(tmpDir, "outside")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(workspace, 0700); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(outside, 0700); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a file outside the workspace
|
||||||
|
outsideFile := filepath.Join(outside, "secret.txt")
|
||||||
|
if err := os.WriteFile(outsideFile, []byte("secret data"), 0600); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &Config{
|
||||||
|
WorkspaceRoot: workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
setup func() string
|
||||||
|
name string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "regular file inside workspace",
|
||||||
|
setup: func() string {
|
||||||
|
return filepath.Join(workspace, "file.txt")
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file with parent directory traversal",
|
||||||
|
setup: func() string {
|
||||||
|
return filepath.Join(workspace, "../outside/secret.txt")
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "symlink pointing outside workspace",
|
||||||
|
setup: func() string {
|
||||||
|
symlink := filepath.Join(workspace, "link.txt")
|
||||||
|
_ = os.Symlink(outsideFile, symlink)
|
||||||
|
return symlink
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "symlink pointing inside workspace",
|
||||||
|
setup: func() string {
|
||||||
|
inside := filepath.Join(workspace, "inside.txt")
|
||||||
|
_ = os.WriteFile(inside, []byte("ok"), 0600)
|
||||||
|
symlink := filepath.Join(workspace, "link_inside.txt")
|
||||||
|
_ = os.Symlink(inside, symlink)
|
||||||
|
return symlink
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dotfile inside workspace",
|
||||||
|
setup: func() string {
|
||||||
|
return filepath.Join(workspace, ".gitignore")
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hidden directory inside workspace",
|
||||||
|
setup: func() string {
|
||||||
|
return filepath.Join(workspace, ".git/config")
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
path := tt.setup()
|
||||||
|
result := cfg.IsPathAllowed(path)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsPathAllowed(%q) = %v, want %v", path, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsPathAllowed_BasicCases tests basic path validation.
|
||||||
|
func TestIsPathAllowed_BasicCases(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
cfg := &Config{
|
||||||
|
WorkspaceRoot: tmpDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "path inside workspace",
|
||||||
|
path: filepath.Join(tmpDir, "file.txt"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path outside workspace",
|
||||||
|
path: "/etc/passwd",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parent directory reference",
|
||||||
|
path: filepath.Join(tmpDir, "../../../etc/passwd"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "workspace root itself",
|
||||||
|
path: tmpDir,
|
||||||
|
expected: false, // Empty relative path
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := cfg.IsPathAllowed(tt.path)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsPathAllowed(%q) = %v, want %v", tt.path, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,203 @@
|
|||||||
|
package edit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConcurrentEditLocking tests that concurrent edits to the same file are properly serialized.
|
||||||
|
func TestConcurrentEditLocking(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
|
||||||
|
// Create initial file
|
||||||
|
initialContent := `package main
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
println("hello")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
engine := NewEngine(registry)
|
||||||
|
|
||||||
|
// Run 10 concurrent edits
|
||||||
|
const numEdits = 10
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numEdits)
|
||||||
|
|
||||||
|
errors := make(chan error, numEdits)
|
||||||
|
|
||||||
|
for i := 0; i < numEdits; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: testFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{
|
||||||
|
Kind: "function_declaration",
|
||||||
|
Name: "main",
|
||||||
|
},
|
||||||
|
NewContent: `func main() {
|
||||||
|
println("edit ` + string(rune(i)) + `")
|
||||||
|
}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := engine.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
for err := range errors {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Concurrent edit failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file wasn't corrupted
|
||||||
|
finalContent, err := os.ReadFile(testFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse to ensure it's still valid Go
|
||||||
|
_, err = registry.Parse(context.Background(), testFile, finalContent)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("File corrupted after concurrent edits: %v\nContent:\n%s", err, string(finalContent))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrentEditDifferentFiles tests that concurrent edits to different files don't block each other.
|
||||||
|
func TestConcurrentEditDifferentFiles(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
engine := NewEngine(registry)
|
||||||
|
|
||||||
|
const numFiles = 5
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numFiles)
|
||||||
|
|
||||||
|
startBarrier := make(chan struct{})
|
||||||
|
|
||||||
|
for i := 0; i < numFiles; i++ {
|
||||||
|
i := i
|
||||||
|
testFile := filepath.Join(tmpDir, fmt.Sprintf("test%d.go", i))
|
||||||
|
|
||||||
|
// Create initial file
|
||||||
|
initialContent := `package main
|
||||||
|
|
||||||
|
func test() {
|
||||||
|
println("initial")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Wait for all goroutines to be ready
|
||||||
|
<-startBarrier
|
||||||
|
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: testFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{
|
||||||
|
Kind: "function_declaration",
|
||||||
|
Name: "test",
|
||||||
|
},
|
||||||
|
NewContent: `func test() {
|
||||||
|
println("modified")
|
||||||
|
}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := engine.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Edit failed for %s: %v", testFile, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Errorf("Edit unsuccessful for %s: %s", testFile, result.Error)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release all goroutines simultaneously
|
||||||
|
close(startBarrier)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFileLockRelease tests that file locks are properly released after edits.
|
||||||
|
func TestFileLockRelease(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
|
||||||
|
initialContent := `package main
|
||||||
|
|
||||||
|
func test() {}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
engine := NewEngine(registry)
|
||||||
|
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: testFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{
|
||||||
|
Kind: "function_declaration",
|
||||||
|
Name: "test",
|
||||||
|
},
|
||||||
|
NewContent: `func test() { println("updated") }`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First edit
|
||||||
|
ctx := context.Background()
|
||||||
|
result1, err := engine.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !result1.Success {
|
||||||
|
t.Fatalf("First edit failed: %s", result1.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second edit should succeed (lock was released)
|
||||||
|
result2, err := engine.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !result2.Success {
|
||||||
|
t.Fatalf("Second edit failed (lock not released?): %s", result2.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,757 @@
|
|||||||
|
// Package edit provides AST-aware file editing capabilities.
|
||||||
|
package edit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
"github.com/sergi/go-diff/diffmatchpatch"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Global regex cache for compiled patterns (thread-safe)
|
||||||
|
var regexCache sync.Map // string -> *regexp.Regexp
|
||||||
|
|
||||||
|
// compileRegex compiles a regex pattern with caching for performance.
|
||||||
|
func compileRegex(pattern string) (*regexp.Regexp, error) {
|
||||||
|
// Check cache first
|
||||||
|
if cached, ok := regexCache.Load(pattern); ok {
|
||||||
|
return cached.(*regexp.Regexp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile and cache
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
regexCache.Store(pattern, re)
|
||||||
|
return re, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EditOperation defines the type of edit operation.
|
||||||
|
type EditOperation string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EditReplace EditOperation = "replace"
|
||||||
|
EditInsertBefore EditOperation = "insert_before"
|
||||||
|
EditInsertAfter EditOperation = "insert_after"
|
||||||
|
EditDelete EditOperation = "delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ASTEdit represents an AST-aware edit request.
|
||||||
|
type ASTEdit struct {
|
||||||
|
File string `json:"file"`
|
||||||
|
Operation EditOperation `json:"operation"`
|
||||||
|
NewContent string `json:"new_content,omitempty"`
|
||||||
|
Selector ASTSelector `json:"selector"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASTSelector specifies how to find the target node.
|
||||||
|
type ASTSelector struct {
|
||||||
|
Kind string `json:"kind,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Pattern string `json:"pattern,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
TextPattern string `json:"text_pattern,omitempty"`
|
||||||
|
AtLine int `json:"at_line,omitempty"`
|
||||||
|
Index int `json:"index,omitempty"`
|
||||||
|
LineEnd int `json:"line_end,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EditResult contains the result of an edit operation.
|
||||||
|
type EditResult struct {
|
||||||
|
Diff string `json:"diff,omitempty"`
|
||||||
|
OriginalContent string `json:"original_content,omitempty"`
|
||||||
|
NewContent string `json:"new_content,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Applied bool `json:"applied"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Engine performs AST-aware edits.
|
||||||
|
type Engine struct {
|
||||||
|
registry *parser.Registry
|
||||||
|
fileLocks sync.Map // map[string]*sync.Mutex for per-file locking
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEngine creates a new edit engine.
|
||||||
|
func NewEngine(registry *parser.Registry) *Engine {
|
||||||
|
return &Engine{
|
||||||
|
registry: registry,
|
||||||
|
fileLocks: sync.Map{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lockFile acquires a lock for the specified file and returns an unlock function.
|
||||||
|
// This prevents concurrent edits to the same file which could cause corruption.
|
||||||
|
func (e *Engine) lockFile(filePath string) func() {
|
||||||
|
// Get or create mutex for this file
|
||||||
|
actual, _ := e.fileLocks.LoadOrStore(filePath, &sync.Mutex{})
|
||||||
|
mu := actual.(*sync.Mutex)
|
||||||
|
mu.Lock()
|
||||||
|
return mu.Unlock
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preview generates a preview of an edit without applying it.
|
||||||
|
func (e *Engine) Preview(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
|
||||||
|
return e.performEdit(ctx, edit, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply performs an edit and writes the result to disk.
|
||||||
|
// Uses file locking to prevent concurrent edits to the same file.
|
||||||
|
func (e *Engine) Apply(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
|
||||||
|
unlock := e.lockFile(edit.File)
|
||||||
|
defer unlock()
|
||||||
|
return e.performEdit(ctx, edit, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// performEdit executes an edit operation.
|
||||||
|
func (e *Engine) performEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||||
|
// Determine if we should use text mode
|
||||||
|
useTextMode := e.shouldUseTextMode(edit)
|
||||||
|
|
||||||
|
if useTextMode {
|
||||||
|
return e.performTextEdit(ctx, edit, apply)
|
||||||
|
}
|
||||||
|
return e.performASTEdit(ctx, edit, apply)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldUseTextMode determines if text-based editing should be used.
|
||||||
|
func (e *Engine) shouldUseTextMode(edit *ASTEdit) bool {
|
||||||
|
// Use text mode if text-specific selectors are provided
|
||||||
|
if edit.Selector.Text != "" || edit.Selector.TextPattern != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use text mode if line range is specified without AST selectors
|
||||||
|
if edit.Selector.AtLine > 0 && edit.Selector.LineEnd > 0 &&
|
||||||
|
edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use text mode if language is not supported for AST
|
||||||
|
lang := protocol.DetectLanguage(edit.File)
|
||||||
|
return lang == protocol.LangUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// performASTEdit executes an AST-aware edit operation.
|
||||||
|
func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||||
|
// Validate operation
|
||||||
|
if err := e.validateASTEdit(edit); err != nil {
|
||||||
|
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read file
|
||||||
|
content, err := os.ReadFile(edit.File)
|
||||||
|
if err != nil {
|
||||||
|
structuredErr := errors.NewFileNotReadableError(edit.File, err)
|
||||||
|
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse file
|
||||||
|
parseResult, err := e.registry.Parse(ctx, edit.File, content)
|
||||||
|
if err != nil {
|
||||||
|
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find target node
|
||||||
|
node, err := e.resolveSelector(edit.Selector, parseResult.Tree, content)
|
||||||
|
if err != nil {
|
||||||
|
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply edit
|
||||||
|
newContent, err := e.applyEdit(edit, node, content)
|
||||||
|
if err != nil {
|
||||||
|
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate new content (re-parse)
|
||||||
|
_, err = e.registry.Parse(ctx, edit.File, newContent)
|
||||||
|
if err != nil {
|
||||||
|
structuredErr := errors.NewEditValidationError(edit.File, err)
|
||||||
|
return &EditResult{
|
||||||
|
Success: false,
|
||||||
|
Error: structuredErr.Error(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate diff
|
||||||
|
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||||
|
|
||||||
|
result := &EditResult{
|
||||||
|
Success: true,
|
||||||
|
Diff: diff,
|
||||||
|
OriginalContent: string(content),
|
||||||
|
NewContent: string(newContent),
|
||||||
|
Applied: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply changes if requested
|
||||||
|
if apply {
|
||||||
|
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
|
||||||
|
structuredErr := errors.NewFileNotWritableError(edit.File, err)
|
||||||
|
return &EditResult{
|
||||||
|
Success: false,
|
||||||
|
Error: structuredErr.Error(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
result.Applied = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// performTextEdit executes a text-based edit operation for non-AST files.
|
||||||
|
func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||||
|
// Validate operation
|
||||||
|
if err := e.validateTextEdit(edit); err != nil {
|
||||||
|
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read file
|
||||||
|
content, err := os.ReadFile(edit.File)
|
||||||
|
if err != nil {
|
||||||
|
structuredErr := errors.NewFileNotReadableError(edit.File, err)
|
||||||
|
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the text selection (byte range)
|
||||||
|
start, end, err := e.resolveTextSelector(edit.Selector, content)
|
||||||
|
if err != nil {
|
||||||
|
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply edit
|
||||||
|
newContent, err := e.applyTextEditOperation(edit.Operation, content, start, end, edit.NewContent)
|
||||||
|
if err != nil {
|
||||||
|
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate diff
|
||||||
|
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||||
|
|
||||||
|
result := &EditResult{
|
||||||
|
Success: true,
|
||||||
|
Diff: diff,
|
||||||
|
OriginalContent: string(content),
|
||||||
|
NewContent: string(newContent),
|
||||||
|
Applied: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply changes if requested
|
||||||
|
if apply {
|
||||||
|
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
|
||||||
|
structuredErr := errors.NewFileNotWritableError(edit.File, err)
|
||||||
|
return &EditResult{
|
||||||
|
Success: false,
|
||||||
|
Error: structuredErr.Error(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
result.Applied = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateBaseEdit checks common edit request fields.
|
||||||
|
func (e *Engine) validateBaseEdit(edit *ASTEdit) error {
|
||||||
|
if edit.File == "" {
|
||||||
|
return errors.NewInvalidEditError("file is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if edit.Operation == "" {
|
||||||
|
return errors.NewInvalidEditError("operation is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate operation type
|
||||||
|
switch edit.Operation {
|
||||||
|
case EditReplace, EditInsertBefore, EditInsertAfter:
|
||||||
|
if edit.NewContent == "" {
|
||||||
|
return errors.NewInvalidEditError(fmt.Sprintf("new_content is required for %s operation", edit.Operation))
|
||||||
|
}
|
||||||
|
case EditDelete:
|
||||||
|
// new_content not required
|
||||||
|
default:
|
||||||
|
return errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateASTEdit checks if an AST edit request is valid.
|
||||||
|
func (e *Engine) validateASTEdit(edit *ASTEdit) error {
|
||||||
|
if err := e.validateBaseEdit(edit); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate AST selector
|
||||||
|
if edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" && edit.Selector.AtLine == 0 {
|
||||||
|
return errors.NewInvalidEditError("AST selector must specify at least one of: kind, name, pattern, or at_line")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTextEdit checks if a text edit request is valid.
|
||||||
|
func (e *Engine) validateTextEdit(edit *ASTEdit) error {
|
||||||
|
if err := e.validateBaseEdit(edit); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate text selector - need at least one text selection method
|
||||||
|
hasTextSelector := edit.Selector.Text != "" ||
|
||||||
|
edit.Selector.TextPattern != "" ||
|
||||||
|
edit.Selector.AtLine > 0
|
||||||
|
|
||||||
|
if !hasTextSelector {
|
||||||
|
return errors.NewInvalidEditError("text selector must specify at least one of: text, text_pattern, or at_line")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate regex pattern if provided (uses cached compilation)
|
||||||
|
if edit.Selector.TextPattern != "" {
|
||||||
|
if _, err := compileRegex(edit.Selector.TextPattern); err != nil {
|
||||||
|
return errors.Wrap(errors.ErrInvalidEdit, "invalid text_pattern regex", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveSelector finds the target node based on the selector.
|
||||||
|
func (e *Engine) resolveSelector(sel ASTSelector, tree *sitter.Tree, content []byte) (*sitter.Node, error) {
|
||||||
|
if tree == nil {
|
||||||
|
return nil, errors.NewNodeNotFoundError("no AST tree available")
|
||||||
|
}
|
||||||
|
|
||||||
|
root := tree.RootNode()
|
||||||
|
if root == nil {
|
||||||
|
return nil, errors.NewNodeNotFoundError("empty AST tree")
|
||||||
|
}
|
||||||
|
|
||||||
|
var matches []*sitter.Node
|
||||||
|
|
||||||
|
parser.WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
if e.matchesSelector(sel, n, content) {
|
||||||
|
matches = append(matches, n)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(matches) == 0 {
|
||||||
|
selectorDesc := fmt.Sprintf("kind=%s name=%s pattern=%s line=%d", sel.Kind, sel.Name, sel.Pattern, sel.AtLine)
|
||||||
|
return nil, errors.NewNodeNotFoundError(selectorDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use index to select specific match
|
||||||
|
index := sel.Index
|
||||||
|
if index < 0 || index >= len(matches) {
|
||||||
|
return nil, errors.NewInvalidSelectionError(fmt.Sprintf("selector matched %d nodes, but index %d is out of range", len(matches), index))
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches[index], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesSelector checks if a node matches the selector criteria.
|
||||||
|
func (e *Engine) matchesSelector(sel ASTSelector, n *sitter.Node, content []byte) bool {
|
||||||
|
// Check kind
|
||||||
|
if sel.Kind != "" && n.Type() != sel.Kind {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check name (look for identifier in the node)
|
||||||
|
if sel.Name != "" {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
// Also try to find an identifier child
|
||||||
|
found := false
|
||||||
|
for i := 0; i < int(n.NamedChildCount()); i++ {
|
||||||
|
child := n.NamedChild(i)
|
||||||
|
if child != nil && child.Type() == "identifier" {
|
||||||
|
if parser.GetNodeText(child, content) == sel.Name {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else if parser.GetNodeText(nameNode, content) != sel.Name {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check line
|
||||||
|
if sel.AtLine > 0 {
|
||||||
|
startLine := int(n.StartPoint().Row) + 1
|
||||||
|
endLine := int(n.EndPoint().Row) + 1
|
||||||
|
if sel.AtLine < startLine || sel.AtLine > endLine {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern matching is handled separately (simplified here)
|
||||||
|
if sel.Pattern != "" {
|
||||||
|
nodeText := parser.GetNodeText(n, content)
|
||||||
|
if !strings.Contains(nodeText, sel.Pattern) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyEdit applies the edit operation to the content.
|
||||||
|
func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]byte, error) {
|
||||||
|
startByte := node.StartByte()
|
||||||
|
endByte := node.EndByte()
|
||||||
|
|
||||||
|
// Detect and preserve indentation
|
||||||
|
indentation := detectIndentation(content, startByte)
|
||||||
|
newContent := indentContent(edit.NewContent, indentation)
|
||||||
|
|
||||||
|
var result []byte
|
||||||
|
|
||||||
|
switch edit.Operation {
|
||||||
|
case EditReplace:
|
||||||
|
result = append(result, content[:startByte]...)
|
||||||
|
result = append(result, []byte(newContent)...)
|
||||||
|
result = append(result, content[endByte:]...)
|
||||||
|
|
||||||
|
case EditInsertBefore:
|
||||||
|
result = append(result, content[:startByte]...)
|
||||||
|
result = append(result, []byte(newContent)...)
|
||||||
|
result = append(result, '\n')
|
||||||
|
result = append(result, content[startByte:]...)
|
||||||
|
|
||||||
|
case EditInsertAfter:
|
||||||
|
result = append(result, content[:endByte]...)
|
||||||
|
result = append(result, '\n')
|
||||||
|
result = append(result, []byte(newContent)...)
|
||||||
|
result = append(result, content[endByte:]...)
|
||||||
|
|
||||||
|
case EditDelete:
|
||||||
|
result = append(result, content[:startByte]...)
|
||||||
|
result = append(result, content[endByte:]...)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectIndentation detects the indentation at a given byte position.
|
||||||
|
func detectIndentation(content []byte, bytePos uint32) string {
|
||||||
|
// Find the start of the line
|
||||||
|
lineStart := int(bytePos)
|
||||||
|
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||||
|
lineStart--
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract leading whitespace
|
||||||
|
var indent strings.Builder
|
||||||
|
for i := lineStart; i < int(bytePos) && i < len(content); i++ {
|
||||||
|
c := content[i]
|
||||||
|
if c == ' ' || c == '\t' {
|
||||||
|
indent.WriteByte(c)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indent.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// indentContent applies indentation to multi-line content.
|
||||||
|
func indentContent(content string, indent string) string {
|
||||||
|
if indent == "" {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(content, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
if i > 0 && line != "" {
|
||||||
|
lines[i] = indent + line
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDiff creates a unified diff between original and modified content.
|
||||||
|
// Uses Myers diff algorithm for accurate and readable diffs.
|
||||||
|
func generateDiff(original, modified, filename string) string {
|
||||||
|
dmp := diffmatchpatch.New()
|
||||||
|
diffs := dmp.DiffMain(original, modified, false)
|
||||||
|
|
||||||
|
// Cleanup for readability
|
||||||
|
diffs = dmp.DiffCleanupSemantic(diffs)
|
||||||
|
|
||||||
|
// Convert to unified diff format
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString(fmt.Sprintf("--- %s\n", filename))
|
||||||
|
buf.WriteString(fmt.Sprintf("+++ %s\n", filename))
|
||||||
|
|
||||||
|
// Group diffs into hunks
|
||||||
|
lineNum := 1
|
||||||
|
for _, diff := range diffs {
|
||||||
|
lines := strings.Split(diff.Text, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
// Skip empty last line from split
|
||||||
|
if i == len(lines)-1 && line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch diff.Type {
|
||||||
|
case diffmatchpatch.DiffDelete:
|
||||||
|
buf.WriteString(fmt.Sprintf("-%s\n", line))
|
||||||
|
case diffmatchpatch.DiffInsert:
|
||||||
|
buf.WriteString(fmt.Sprintf("+%s\n", line))
|
||||||
|
case diffmatchpatch.DiffEqual:
|
||||||
|
buf.WriteString(fmt.Sprintf(" %s\n", line))
|
||||||
|
lineNum++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveTextSelector finds the byte range for a text-based selection.
|
||||||
|
func (e *Engine) resolveTextSelector(sel ASTSelector, content []byte) (start, end int, err error) {
|
||||||
|
switch {
|
||||||
|
case sel.Text != "":
|
||||||
|
return e.findExactText(content, sel.Text, sel.Index)
|
||||||
|
case sel.TextPattern != "":
|
||||||
|
return e.findRegexPattern(content, sel.TextPattern, sel.Index)
|
||||||
|
case sel.AtLine > 0:
|
||||||
|
return e.findLineRange(content, sel.AtLine, sel.LineEnd)
|
||||||
|
default:
|
||||||
|
return 0, 0, errors.NewInvalidEditError("text selector requires text, text_pattern, or at_line")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExactText finds an exact text match in content.
|
||||||
|
func (e *Engine) findExactText(content []byte, text string, index int) (start, end int, err error) {
|
||||||
|
if text == "" {
|
||||||
|
return 0, 0, errors.NewInvalidEditError("text selector cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
textBytes := []byte(text)
|
||||||
|
type match struct{ start, end int }
|
||||||
|
var matches []match
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
for {
|
||||||
|
idx := bytes.Index(content[offset:], textBytes)
|
||||||
|
if idx == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
matches = append(matches, match{
|
||||||
|
start: offset + idx,
|
||||||
|
end: offset + idx + len(textBytes),
|
||||||
|
})
|
||||||
|
offset += idx + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text not found: %q", truncateString(text, 50)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If multiple matches and no index specified, require explicit selection
|
||||||
|
if len(matches) > 1 && index == 0 {
|
||||||
|
// Check if index was explicitly set to 0 or just defaulted
|
||||||
|
// Since we can't distinguish, we'll allow index 0 but warn about multiple matches
|
||||||
|
// Actually, let's be strict and require explicit index for multiple matches
|
||||||
|
locations := make([]string, 0, min(len(matches), 5))
|
||||||
|
for i, m := range matches {
|
||||||
|
if i >= 5 {
|
||||||
|
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
line := countLines(content[:m.start]) + 1
|
||||||
|
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||||
|
}
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||||
|
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
if index >= len(matches) {
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches[index].start, matches[index].end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findRegexPattern finds a regex pattern match in content.
|
||||||
|
func (e *Engine) findRegexPattern(content []byte, pattern string, index int) (start, end int, err error) {
|
||||||
|
re, err := compileRegex(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, errors.Wrap(errors.ErrInvalidEdit, "invalid regex pattern", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := re.FindAllIndex(content, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern not found: %q", truncateString(pattern, 50)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If multiple matches and index is 0 (default), show error with locations
|
||||||
|
if len(matches) > 1 && index == 0 {
|
||||||
|
locations := make([]string, 0, min(len(matches), 5))
|
||||||
|
for i, m := range matches {
|
||||||
|
if i >= 5 {
|
||||||
|
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
line := countLines(content[:m[0]]) + 1
|
||||||
|
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||||
|
}
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||||
|
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
if index >= len(matches) {
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches[index][0], matches[index][1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findLineRange finds the byte range for a line range selection.
|
||||||
|
func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, end int, err error) {
|
||||||
|
if lineEnd == 0 {
|
||||||
|
lineEnd = lineStart
|
||||||
|
}
|
||||||
|
|
||||||
|
if lineStart < 1 {
|
||||||
|
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line number must be >= 1, got %d", lineStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
if lineEnd < lineStart {
|
||||||
|
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line_end (%d) must be >= line (%d)", lineEnd, lineStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := bytes.Split(content, []byte("\n"))
|
||||||
|
totalLines := len(lines)
|
||||||
|
|
||||||
|
// Convert to 0-indexed
|
||||||
|
startIdx := lineStart - 1
|
||||||
|
endIdx := lineEnd - 1
|
||||||
|
|
||||||
|
if startIdx >= totalLines {
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line %d out of range (file has %d lines)", lineStart, totalLines))
|
||||||
|
}
|
||||||
|
if endIdx >= totalLines {
|
||||||
|
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line_end %d out of range (file has %d lines)", lineEnd, totalLines))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate byte positions
|
||||||
|
start = 0
|
||||||
|
for i := 0; i < startIdx; i++ {
|
||||||
|
start += len(lines[i]) + 1 // +1 for newline
|
||||||
|
}
|
||||||
|
|
||||||
|
end = start
|
||||||
|
for i := startIdx; i <= endIdx; i++ {
|
||||||
|
end += len(lines[i])
|
||||||
|
if i < totalLines-1 {
|
||||||
|
end += 1 // newline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return start, end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyTextEditOperation applies a text edit operation.
|
||||||
|
func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start, end int, newContent string) ([]byte, error) {
|
||||||
|
// Detect indentation at the selection point
|
||||||
|
indentation := detectIndentationAtByte(content, start)
|
||||||
|
indentedContent := indentContent(newContent, indentation)
|
||||||
|
|
||||||
|
var result []byte
|
||||||
|
|
||||||
|
switch op {
|
||||||
|
case EditReplace:
|
||||||
|
result = append(result, content[:start]...)
|
||||||
|
result = append(result, []byte(indentedContent)...)
|
||||||
|
result = append(result, content[end:]...)
|
||||||
|
|
||||||
|
case EditInsertBefore:
|
||||||
|
result = append(result, content[:start]...)
|
||||||
|
result = append(result, []byte(indentedContent)...)
|
||||||
|
result = append(result, '\n')
|
||||||
|
result = append(result, content[start:]...)
|
||||||
|
|
||||||
|
case EditInsertAfter:
|
||||||
|
result = append(result, content[:end]...)
|
||||||
|
result = append(result, '\n')
|
||||||
|
result = append(result, []byte(indentedContent)...)
|
||||||
|
result = append(result, content[end:]...)
|
||||||
|
|
||||||
|
case EditDelete:
|
||||||
|
result = append(result, content[:start]...)
|
||||||
|
result = append(result, content[end:]...)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", op))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectIndentationAtByte detects indentation at a byte position.
|
||||||
|
func detectIndentationAtByte(content []byte, bytePos int) string {
|
||||||
|
// Find the start of the line
|
||||||
|
lineStart := bytePos
|
||||||
|
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||||
|
lineStart--
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract leading whitespace
|
||||||
|
var indent strings.Builder
|
||||||
|
for i := lineStart; i < bytePos && i < len(content); i++ {
|
||||||
|
c := content[i]
|
||||||
|
if c == ' ' || c == '\t' {
|
||||||
|
indent.WriteByte(c)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indent.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateString truncates a string to maxLen with ellipsis.
|
||||||
|
func truncateString(s string, maxLen int) string {
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen-3] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
// countLines counts the number of newlines in content.
|
||||||
|
func countLines(content []byte) int {
|
||||||
|
return bytes.Count(content, []byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateLanguage checks if AST editing is supported for a file.
|
||||||
|
// Returns nil for supported languages, error for unsupported.
|
||||||
|
// Note: Text-based editing is always available regardless of this check.
|
||||||
|
func ValidateLanguage(filename string) error {
|
||||||
|
lang := protocol.DetectLanguage(filename)
|
||||||
|
if lang == protocol.LangUnknown {
|
||||||
|
return fmt.Errorf("unsupported file type for AST editing: %s (text-based editing is available)", filename)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,836 @@
|
|||||||
|
package edit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateEdit(t *testing.T) {
|
||||||
|
e := NewEngine(parser.NewRegistry())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
edit *ASTEdit
|
||||||
|
name string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid replace",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.go",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Kind: "function_declaration"},
|
||||||
|
NewContent: "func NewFunc() {}",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid delete",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.go",
|
||||||
|
Operation: EditDelete,
|
||||||
|
Selector: ASTSelector{Name: "oldFunc"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing file",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Kind: "function_declaration"},
|
||||||
|
NewContent: "func NewFunc() {}",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing operation",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.go",
|
||||||
|
Selector: ASTSelector{Kind: "function_declaration"},
|
||||||
|
NewContent: "func NewFunc() {}",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "replace without content",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.go",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Kind: "function_declaration"},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty selector",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.go",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{},
|
||||||
|
NewContent: "func NewFunc() {}",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown operation",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.go",
|
||||||
|
Operation: "unknown",
|
||||||
|
Selector: ASTSelector{Kind: "function_declaration"},
|
||||||
|
NewContent: "func NewFunc() {}",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := e.validateASTEdit(tt.edit)
|
||||||
|
if tt.wantErr && err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
if !tt.wantErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSelector(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
content := []byte(`package main
|
||||||
|
|
||||||
|
func Hello() {
|
||||||
|
println("hello")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Goodbye() {
|
||||||
|
println("goodbye")
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := registry.Parse(ctx, "test.go", content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sel ASTSelector
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "by kind",
|
||||||
|
sel: ASTSelector{Kind: "function_declaration"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "by name",
|
||||||
|
sel: ASTSelector{Name: "Hello"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "by kind and name",
|
||||||
|
sel: ASTSelector{Kind: "function_declaration", Name: "Goodbye"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "by line",
|
||||||
|
sel: ASTSelector{AtLine: 3},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match",
|
||||||
|
sel: ASTSelector{Name: "NonExistent"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "index out of range",
|
||||||
|
sel: ASTSelector{Kind: "function_declaration", Index: 10},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
node, err := e.resolveSelector(tt.sel, result.Tree, content)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if node == nil {
|
||||||
|
t.Error("expected node")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyEdit(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
content := []byte(`package main
|
||||||
|
|
||||||
|
func Hello() {
|
||||||
|
println("hello")
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := registry.Parse(ctx, "test.go", content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
operation EditOperation
|
||||||
|
newCode string
|
||||||
|
wantIn string // substring that should be in result
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "replace",
|
||||||
|
operation: EditReplace,
|
||||||
|
newCode: "func NewHello() {}",
|
||||||
|
wantIn: "NewHello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert after",
|
||||||
|
operation: EditInsertAfter,
|
||||||
|
newCode: "func After() {}",
|
||||||
|
wantIn: "After",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "insert before",
|
||||||
|
operation: EditInsertBefore,
|
||||||
|
newCode: "func Before() {}",
|
||||||
|
wantIn: "Before",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete",
|
||||||
|
operation: EditDelete,
|
||||||
|
newCode: "",
|
||||||
|
wantIn: "package main", // Should still have package declaration
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Find the function node
|
||||||
|
node, err := e.resolveSelector(ASTSelector{Kind: "function_declaration"}, result.Tree, content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolve failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: "test.go",
|
||||||
|
Operation: tt.operation,
|
||||||
|
NewContent: tt.newCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
newContent, err := e.applyEdit(edit, node, content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(newContent), tt.wantIn) {
|
||||||
|
t.Errorf("result does not contain %q:\n%s", tt.wantIn, string(newContent))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreview(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func Hello() {
|
||||||
|
println("hello")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Kind: "function_declaration"},
|
||||||
|
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Preview(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("preview failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Fatalf("preview was not successful: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Applied {
|
||||||
|
t.Error("preview should not apply changes")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Diff == "" {
|
||||||
|
t.Error("expected diff in result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify original file is unchanged
|
||||||
|
fileContent, _ := os.ReadFile(tmpFile)
|
||||||
|
if string(fileContent) != content {
|
||||||
|
t.Error("original file was modified during preview")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyToFile(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func Hello() {
|
||||||
|
println("hello")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Kind: "function_declaration"},
|
||||||
|
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Fatalf("apply was not successful: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Applied {
|
||||||
|
t.Error("apply should set Applied=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file was modified
|
||||||
|
fileContent, _ := os.ReadFile(tmpFile)
|
||||||
|
if !strings.Contains(string(fileContent), "NewHello") {
|
||||||
|
t.Error("file was not modified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectIndentation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
want string
|
||||||
|
pos uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no indent",
|
||||||
|
content: "func main() {}",
|
||||||
|
pos: 0,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tab indent",
|
||||||
|
content: "func main() {\n\tprintln(\"hello\")\n}",
|
||||||
|
pos: 15,
|
||||||
|
want: "\t",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "space indent",
|
||||||
|
content: "func main() {\n println(\"hello\")\n}",
|
||||||
|
pos: 18,
|
||||||
|
want: " ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := detectIndentation([]byte(tt.content), tt.pos)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("detectIndentation() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateDiff(t *testing.T) {
|
||||||
|
original := "line1\nline2\nline3"
|
||||||
|
modified := "line1\nmodified\nline3"
|
||||||
|
filename := "test.txt"
|
||||||
|
|
||||||
|
diff := generateDiff(original, modified, filename)
|
||||||
|
|
||||||
|
if !strings.Contains(diff, "---") {
|
||||||
|
t.Error("diff should contain --- header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(diff, "+++") {
|
||||||
|
t.Error("diff should contain +++ header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(diff, "-line2") {
|
||||||
|
t.Error("diff should show removed line")
|
||||||
|
}
|
||||||
|
if !strings.Contains(diff, "+modified") {
|
||||||
|
t.Error("diff should show added line")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Text-based editing tests ====================
|
||||||
|
|
||||||
|
func TestTextEditWithExactText(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp markdown file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "README.md")
|
||||||
|
|
||||||
|
content := `# My Project
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Run the following command:
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
See the docs.
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Text: "## Installation"},
|
||||||
|
NewContent: "## Getting Started",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Fatalf("apply was not successful: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file was modified
|
||||||
|
fileContent, _ := os.ReadFile(tmpFile)
|
||||||
|
if !strings.Contains(string(fileContent), "## Getting Started") {
|
||||||
|
t.Error("file was not modified correctly")
|
||||||
|
}
|
||||||
|
if strings.Contains(string(fileContent), "## Installation") {
|
||||||
|
t.Error("old text should be replaced")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextEditWithLineRange(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp config file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
|
||||||
|
content := `name: myapp
|
||||||
|
version: 1.0.0
|
||||||
|
database:
|
||||||
|
host: localhost
|
||||||
|
port: 5432
|
||||||
|
logging:
|
||||||
|
level: debug
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{
|
||||||
|
AtLine: 3,
|
||||||
|
LineEnd: 5,
|
||||||
|
},
|
||||||
|
NewContent: "database:\n host: production.db.example.com\n port: 5433",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Fatalf("apply was not successful: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file was modified
|
||||||
|
fileContent, _ := os.ReadFile(tmpFile)
|
||||||
|
if !strings.Contains(string(fileContent), "production.db.example.com") {
|
||||||
|
t.Error("file was not modified correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextEditWithRegexPattern(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp JSON file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "package.json")
|
||||||
|
|
||||||
|
content := `{
|
||||||
|
"name": "my-package",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "A test package"
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{TextPattern: `"version":\s*"[^"]+"`},
|
||||||
|
NewContent: `"version": "2.0.0"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Fatalf("apply was not successful: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file was modified
|
||||||
|
fileContent, _ := os.ReadFile(tmpFile)
|
||||||
|
if !strings.Contains(string(fileContent), `"version": "2.0.0"`) {
|
||||||
|
t.Error("file was not modified correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextEditInsertAfter(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp env file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, ".env")
|
||||||
|
|
||||||
|
content := `DATABASE_URL=postgres://localhost/mydb
|
||||||
|
SECRET_KEY=abc123
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditInsertAfter,
|
||||||
|
Selector: ASTSelector{Text: "DATABASE_URL=postgres://localhost/mydb"},
|
||||||
|
NewContent: "REDIS_URL=redis://localhost:6379",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Fatalf("apply was not successful: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file was modified
|
||||||
|
fileContent, _ := os.ReadFile(tmpFile)
|
||||||
|
if !strings.Contains(string(fileContent), "REDIS_URL=redis://localhost:6379") {
|
||||||
|
t.Error("file was not modified correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextEditMultipleMatchesError(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp file with repeated text
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||||
|
|
||||||
|
content := `TODO: fix this
|
||||||
|
some code here
|
||||||
|
TODO: also fix this
|
||||||
|
more code
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Text: "TODO"},
|
||||||
|
NewContent: "DONE",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should fail because of multiple matches
|
||||||
|
if result.Success {
|
||||||
|
t.Error("expected error for multiple matches without index")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.Error, "matches") {
|
||||||
|
t.Errorf("error should mention multiple matches: %s", result.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextEditWithIndex(t *testing.T) {
|
||||||
|
registry := parser.NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
e := NewEngine(registry)
|
||||||
|
|
||||||
|
// Create a temp file with repeated text
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||||
|
|
||||||
|
content := `TODO: fix this
|
||||||
|
some code here
|
||||||
|
TODO: also fix this
|
||||||
|
more code
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
edit := &ASTEdit{
|
||||||
|
File: tmpFile,
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{
|
||||||
|
Text: "TODO",
|
||||||
|
Index: 1, // Select second match
|
||||||
|
},
|
||||||
|
NewContent: "DONE",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := e.Apply(ctx, edit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("apply failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
t.Fatalf("apply was not successful: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify only second TODO was replaced
|
||||||
|
fileContent, _ := os.ReadFile(tmpFile)
|
||||||
|
contentStr := string(fileContent)
|
||||||
|
if !strings.Contains(contentStr, "TODO: fix this") {
|
||||||
|
t.Error("first TODO should not be replaced")
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentStr, "DONE: also fix this") {
|
||||||
|
t.Error("second TODO should be replaced")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTextEdit(t *testing.T) {
|
||||||
|
e := NewEngine(parser.NewRegistry())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
edit *ASTEdit
|
||||||
|
name string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid text selector",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.md",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{Text: "some text"},
|
||||||
|
NewContent: "new text",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid pattern selector",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.md",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{TextPattern: "\\d+"},
|
||||||
|
NewContent: "replaced",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid line selector",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.md",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{AtLine: 5},
|
||||||
|
NewContent: "new line",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty selector",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.md",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{},
|
||||||
|
NewContent: "new text",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid regex pattern",
|
||||||
|
edit: &ASTEdit{
|
||||||
|
File: "test.md",
|
||||||
|
Operation: EditReplace,
|
||||||
|
Selector: ASTSelector{TextPattern: "[invalid"},
|
||||||
|
NewContent: "new text",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := e.validateTextEdit(tt.edit)
|
||||||
|
if tt.wantErr && err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
if !tt.wantErr && err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindLineRange(t *testing.T) {
|
||||||
|
e := NewEngine(parser.NewRegistry())
|
||||||
|
|
||||||
|
content := []byte("line1\nline2\nline3\nline4\nline5")
|
||||||
|
|
||||||
|
// Content: "line1\nline2\nline3\nline4\nline5" (no trailing newline)
|
||||||
|
// Positions: line1=0-5, \n=5, line2=6-10, \n=11, line3=12-16, \n=17, line4=18-22, \n=23, line5=24-28
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
lineStart int
|
||||||
|
lineEnd int
|
||||||
|
wantStart int
|
||||||
|
wantEnd int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single line",
|
||||||
|
lineStart: 2,
|
||||||
|
lineEnd: 0, // defaults to lineStart
|
||||||
|
wantStart: 6,
|
||||||
|
wantEnd: 12, // includes trailing newline
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "range of lines",
|
||||||
|
lineStart: 2,
|
||||||
|
lineEnd: 4,
|
||||||
|
wantStart: 6,
|
||||||
|
wantEnd: 24, // through end of line4 including newline
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first line",
|
||||||
|
lineStart: 1,
|
||||||
|
lineEnd: 1,
|
||||||
|
wantStart: 0,
|
||||||
|
wantEnd: 6, // includes trailing newline
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "line out of range",
|
||||||
|
lineStart: 10,
|
||||||
|
lineEnd: 10,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid line number",
|
||||||
|
lineStart: 0,
|
||||||
|
lineEnd: 1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "end before start",
|
||||||
|
lineStart: 3,
|
||||||
|
lineEnd: 2,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
start, end, err := e.findLineRange(content, tt.lineStart, tt.lineEnd)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if start != tt.wantStart {
|
||||||
|
t.Errorf("start = %d, want %d", start, tt.wantStart)
|
||||||
|
}
|
||||||
|
if end != tt.wantEnd {
|
||||||
|
t.Errorf("end = %d, want %d", end, tt.wantEnd)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,310 @@
|
|||||||
|
// Package lsp provides a generic LSP client implementation.
|
||||||
|
package lsp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
json "github.com/goccy/go-json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client represents an LSP client connection.
|
||||||
|
type Client struct {
|
||||||
|
stdin io.WriteCloser
|
||||||
|
stdout io.ReadCloser
|
||||||
|
stderr io.ReadCloser
|
||||||
|
cmd *exec.Cmd
|
||||||
|
pending map[int64]chan *Response
|
||||||
|
done chan struct{}
|
||||||
|
notifications chan *Notification
|
||||||
|
requestID atomic.Int64
|
||||||
|
runningMu sync.RWMutex
|
||||||
|
stopOnce sync.Once
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request represents a JSON-RPC request.
|
||||||
|
type Request struct {
|
||||||
|
Params interface{} `json:"params,omitempty"`
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response represents a JSON-RPC response.
|
||||||
|
type Response struct {
|
||||||
|
Error *ResponseError `json:"error,omitempty"`
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
Result json.RawMessage `json:"result,omitempty"`
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseError represents a JSON-RPC error.
|
||||||
|
type ResponseError struct {
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ResponseError) Error() string {
|
||||||
|
return fmt.Sprintf("LSP error %d: %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notification represents a JSON-RPC notification.
|
||||||
|
type Notification struct {
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params json.RawMessage `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new LSP client from a command.
|
||||||
|
func NewClient(cmd *exec.Cmd) (*Client, error) {
|
||||||
|
stdin, err := cmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get stdin pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
_ = stdin.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stderr, err := cmd.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
_ = stdin.Close()
|
||||||
|
_ = stdout.Close()
|
||||||
|
return nil, fmt.Errorf("failed to get stderr pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
_ = stdin.Close()
|
||||||
|
_ = stdout.Close()
|
||||||
|
_ = stderr.Close()
|
||||||
|
return nil, fmt.Errorf("failed to start LSP server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Client{
|
||||||
|
cmd: cmd,
|
||||||
|
stdin: stdin,
|
||||||
|
stdout: stdout,
|
||||||
|
stderr: stderr,
|
||||||
|
pending: make(map[int64]chan *Response),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
running: true,
|
||||||
|
notifications: make(chan *Notification, 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start reader goroutine
|
||||||
|
go c.readLoop()
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call sends a request and waits for a response.
|
||||||
|
func (c *Client) Call(ctx context.Context, method string, params interface{}) (*Response, error) {
|
||||||
|
c.runningMu.RLock()
|
||||||
|
if !c.running {
|
||||||
|
c.runningMu.RUnlock()
|
||||||
|
return nil, fmt.Errorf("client is not running")
|
||||||
|
}
|
||||||
|
c.runningMu.RUnlock()
|
||||||
|
|
||||||
|
id := c.requestID.Add(1)
|
||||||
|
req := &Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: id,
|
||||||
|
Method: method,
|
||||||
|
Params: params,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response channel
|
||||||
|
respChan := make(chan *Response, 1)
|
||||||
|
c.mu.Lock()
|
||||||
|
c.pending[id] = respChan
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.pending, id)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
if err := c.send(req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for response
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-c.done:
|
||||||
|
return nil, fmt.Errorf("client closed")
|
||||||
|
case resp := <-respChan:
|
||||||
|
if resp.Error != nil {
|
||||||
|
return nil, resp.Error
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify sends a notification (no response expected).
|
||||||
|
func (c *Client) Notify(method string, params interface{}) error {
|
||||||
|
c.runningMu.RLock()
|
||||||
|
if !c.running {
|
||||||
|
c.runningMu.RUnlock()
|
||||||
|
return fmt.Errorf("client is not running")
|
||||||
|
}
|
||||||
|
c.runningMu.RUnlock()
|
||||||
|
|
||||||
|
notif := struct {
|
||||||
|
Params interface{} `json:"params,omitempty"`
|
||||||
|
JSONRPC string `json:"jsonrpc"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
}{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Method: method,
|
||||||
|
Params: params,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.send(notif)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notifications returns a channel for receiving server notifications.
|
||||||
|
func (c *Client) Notifications() <-chan *Notification {
|
||||||
|
return c.notifications
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the client and the LSP server.
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
var err error
|
||||||
|
c.stopOnce.Do(func() {
|
||||||
|
c.runningMu.Lock()
|
||||||
|
c.running = false
|
||||||
|
c.runningMu.Unlock()
|
||||||
|
|
||||||
|
close(c.done)
|
||||||
|
|
||||||
|
// Close stdin to signal the server
|
||||||
|
_ = c.stdin.Close()
|
||||||
|
|
||||||
|
// Wait for process to exit with timeout
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = c.cmd.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Clean exit
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
// Force kill
|
||||||
|
_ = c.cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
|
||||||
|
close(c.notifications)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// send writes a JSON-RPC message to the server.
|
||||||
|
func (c *Client) send(msg interface{}) error {
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format with Content-Length header
|
||||||
|
header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
|
||||||
|
_, err = c.stdin.Write([]byte(header))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.stdin.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLoop reads and dispatches messages from the server.
|
||||||
|
func (c *Client) readLoop() {
|
||||||
|
reader := bufio.NewReader(c.stdout)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read headers
|
||||||
|
contentLength := -1
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(line, "Content-Length:") {
|
||||||
|
lengthStr := strings.TrimSpace(strings.TrimPrefix(line, "Content-Length:"))
|
||||||
|
contentLength, _ = strconv.Atoi(lengthStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentLength <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read body
|
||||||
|
body := make([]byte, contentLength)
|
||||||
|
_, err := io.ReadFull(reader, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as response first
|
||||||
|
var resp Response
|
||||||
|
if err := json.Unmarshal(body, &resp); err == nil && resp.ID != 0 {
|
||||||
|
c.mu.Lock()
|
||||||
|
if ch, ok := c.pending[resp.ID]; ok {
|
||||||
|
ch <- &resp
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as notification
|
||||||
|
var notif Notification
|
||||||
|
if err := json.Unmarshal(body, ¬if); err == nil && notif.Method != "" {
|
||||||
|
select {
|
||||||
|
case c.notifications <- ¬if:
|
||||||
|
default:
|
||||||
|
// Drop notification if channel is full
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning returns whether the client is running.
|
||||||
|
func (c *Client) IsRunning() bool {
|
||||||
|
c.runningMu.RLock()
|
||||||
|
defer c.runningMu.RUnlock()
|
||||||
|
return c.running
|
||||||
|
}
|
||||||
@@ -0,0 +1,535 @@
|
|||||||
|
package lsp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
json "github.com/goccy/go-json"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager manages LSP servers for different languages.
|
||||||
|
type Manager struct {
|
||||||
|
servers map[protocol.Language]*ManagedServer
|
||||||
|
logger *slog.Logger
|
||||||
|
stopReaper chan struct{}
|
||||||
|
workspaceRoot string
|
||||||
|
timeout time.Duration
|
||||||
|
idleTimeout time.Duration
|
||||||
|
mu sync.RWMutex
|
||||||
|
stopped bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagedServer represents a managed LSP server instance.
|
||||||
|
type ManagedServer struct {
|
||||||
|
lastUsed time.Time
|
||||||
|
initErr error
|
||||||
|
client *Client
|
||||||
|
openDocs map[string]int
|
||||||
|
language protocol.Language
|
||||||
|
capabilities ServerCapabilities
|
||||||
|
mu sync.Mutex
|
||||||
|
ready bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig contains the configuration for an LSP server.
|
||||||
|
type ServerConfig struct {
|
||||||
|
Command []string
|
||||||
|
Args []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultServerConfigs contains default configurations for LSP servers.
|
||||||
|
var DefaultServerConfigs = map[protocol.Language]ServerConfig{
|
||||||
|
protocol.LangGo: {
|
||||||
|
Command: []string{"gopls"},
|
||||||
|
Args: []string{"serve"},
|
||||||
|
},
|
||||||
|
protocol.LangTypeScript: {
|
||||||
|
Command: []string{"typescript-language-server"},
|
||||||
|
Args: []string{"--stdio"},
|
||||||
|
},
|
||||||
|
protocol.LangJavaScript: {
|
||||||
|
Command: []string{"typescript-language-server"},
|
||||||
|
Args: []string{"--stdio"},
|
||||||
|
},
|
||||||
|
protocol.LangPython: {
|
||||||
|
Command: []string{"pylsp"},
|
||||||
|
},
|
||||||
|
protocol.LangC: {
|
||||||
|
Command: []string{"clangd"},
|
||||||
|
},
|
||||||
|
protocol.LangCpp: {
|
||||||
|
Command: []string{"clangd"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new LSP manager.
|
||||||
|
func NewManager(workspaceRoot string, logger *slog.Logger) *Manager {
|
||||||
|
m := &Manager{
|
||||||
|
servers: make(map[protocol.Language]*ManagedServer),
|
||||||
|
timeout: 10 * time.Second,
|
||||||
|
idleTimeout: 5 * time.Minute,
|
||||||
|
workspaceRoot: workspaceRoot,
|
||||||
|
logger: logger,
|
||||||
|
stopReaper: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start idle reaper
|
||||||
|
go m.reapIdleServers()
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServer returns or creates an LSP server for the given language.
|
||||||
|
func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*ManagedServer, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
srv, exists := m.servers[lang]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if exists && srv.ready {
|
||||||
|
// Update lastUsed with server's own lock to avoid race condition
|
||||||
|
srv.mu.Lock()
|
||||||
|
srv.lastUsed = time.Now()
|
||||||
|
srv.mu.Unlock()
|
||||||
|
return srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new server
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if srv, ok := m.servers[lang]; ok && srv.ready {
|
||||||
|
srv.mu.Lock()
|
||||||
|
srv.lastUsed = time.Now()
|
||||||
|
srv.mu.Unlock()
|
||||||
|
return srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if server config exists
|
||||||
|
config, ok := DefaultServerConfigs[lang]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New(errors.ErrLSPServerNotFound, fmt.Sprintf("no LSP server configured for language: %s", lang)).
|
||||||
|
WithContext("language", string(lang)).
|
||||||
|
WithRemediation("Configure an LSP server for this language or use a supported language")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if command is available
|
||||||
|
cmdPath, err := exec.LookPath(config.Command[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.NewLSPServerNotFound(string(lang), config.Command[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create command
|
||||||
|
args := append(config.Command[1:], config.Args...)
|
||||||
|
cmd := exec.CommandContext(ctx, cmdPath, args...)
|
||||||
|
cmd.Env = os.Environ()
|
||||||
|
cmd.Dir = m.workspaceRoot
|
||||||
|
|
||||||
|
// Create client
|
||||||
|
client, err := NewClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
// Ensure process is killed if client creation fails
|
||||||
|
if cmd.Process != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(errors.ErrLSPCommunication, "failed to create LSP client", err).
|
||||||
|
WithContext("language", string(lang)).
|
||||||
|
WithContext("command", config.Command[0]).
|
||||||
|
WithRemediation("Ensure the LSP server binary is executable and compatible with your system")
|
||||||
|
}
|
||||||
|
|
||||||
|
newSrv := &ManagedServer{
|
||||||
|
client: client,
|
||||||
|
language: lang,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
openDocs: make(map[string]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize server
|
||||||
|
if err := m.initializeServer(ctx, newSrv); err != nil {
|
||||||
|
_ = client.Close()
|
||||||
|
// Ensure process is killed on initialization failure
|
||||||
|
if cmd.Process != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
newSrv.initErr = err
|
||||||
|
return nil, errors.Wrap(errors.ErrLSPInitFailed, "LSP server initialization failed", err).
|
||||||
|
WithContext("language", string(lang)).
|
||||||
|
WithContext("command", config.Command[0]).
|
||||||
|
WithRemediation("Check LSP server logs for initialization errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
newSrv.ready = true
|
||||||
|
m.servers[lang] = newSrv
|
||||||
|
m.logger.Info("started LSP server", "language", lang, "command", config.Command[0])
|
||||||
|
|
||||||
|
return newSrv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeServer performs the LSP initialization handshake.
|
||||||
|
func (m *Manager) initializeServer(ctx context.Context, srv *ManagedServer) error {
|
||||||
|
// Create context with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Build root URI
|
||||||
|
rootURI := "file://" + m.workspaceRoot
|
||||||
|
|
||||||
|
// Send initialize request
|
||||||
|
params := InitializeParams{
|
||||||
|
ProcessID: os.Getpid(),
|
||||||
|
RootURI: rootURI,
|
||||||
|
Capabilities: Capabilities{
|
||||||
|
TextDocument: TextDocumentClientCapabilities{
|
||||||
|
Hover: HoverCapability{
|
||||||
|
ContentFormat: []string{"markdown", "plaintext"},
|
||||||
|
},
|
||||||
|
Definition: DefinitionCapability{
|
||||||
|
LinkSupport: true,
|
||||||
|
},
|
||||||
|
References: ReferencesCapability{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := srv.client.Call(ctx, "initialize", params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("initialize failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse capabilities
|
||||||
|
var result InitializeResult
|
||||||
|
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse initialize result: %w", err)
|
||||||
|
}
|
||||||
|
srv.capabilities = result.Capabilities
|
||||||
|
|
||||||
|
// Send initialized notification
|
||||||
|
if err := srv.client.Notify("initialized", struct{}{}); err != nil {
|
||||||
|
return fmt.Errorf("initialized notification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hover performs a hover request at the given position.
|
||||||
|
func (m *Manager) Hover(ctx context.Context, file string, line, col int) (*HoverResult, error) {
|
||||||
|
lang := protocol.DetectLanguage(file)
|
||||||
|
srv, err := m.GetServer(ctx, lang)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure document is open
|
||||||
|
err = m.ensureDocumentOpen(ctx, srv, file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := HoverParams{
|
||||||
|
TextDocumentPositionParams: TextDocumentPositionParams{
|
||||||
|
TextDocument: TextDocumentIdentifier{
|
||||||
|
URI: fileToURI(file),
|
||||||
|
},
|
||||||
|
Position: Position{
|
||||||
|
Line: line - 1, // Convert to 0-indexed
|
||||||
|
Character: col - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := srv.client.Call(ctx, "textDocument/hover", params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hover request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Result == nil || string(resp.Result) == "null" {
|
||||||
|
return nil, nil // No hover info
|
||||||
|
}
|
||||||
|
|
||||||
|
var result HoverResult
|
||||||
|
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse hover result: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Definition finds the definition of the symbol at the given position.
|
||||||
|
func (m *Manager) Definition(ctx context.Context, file string, line, col int) ([]Location, error) {
|
||||||
|
lang := protocol.DetectLanguage(file)
|
||||||
|
srv, err := m.GetServer(ctx, lang)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure document is open
|
||||||
|
err = m.ensureDocumentOpen(ctx, srv, file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := DefinitionParams{
|
||||||
|
TextDocumentPositionParams: TextDocumentPositionParams{
|
||||||
|
TextDocument: TextDocumentIdentifier{
|
||||||
|
URI: fileToURI(file),
|
||||||
|
},
|
||||||
|
Position: Position{
|
||||||
|
Line: line - 1,
|
||||||
|
Character: col - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := srv.client.Call(ctx, "textDocument/definition", params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("definition request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Result == nil || string(resp.Result) == "null" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result can be Location, []Location, or []LocationLink
|
||||||
|
var locations []Location
|
||||||
|
if err := json.Unmarshal(resp.Result, &locations); err != nil {
|
||||||
|
// Try single location
|
||||||
|
var single Location
|
||||||
|
if err := json.Unmarshal(resp.Result, &single); err == nil {
|
||||||
|
locations = []Location{single}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return locations, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// References finds all references to the symbol at the given position.
|
||||||
|
func (m *Manager) References(ctx context.Context, file string, line, col int, includeDeclaration bool) ([]Location, error) {
|
||||||
|
lang := protocol.DetectLanguage(file)
|
||||||
|
srv, err := m.GetServer(ctx, lang)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure document is open
|
||||||
|
err = m.ensureDocumentOpen(ctx, srv, file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := ReferenceParams{
|
||||||
|
TextDocumentPositionParams: TextDocumentPositionParams{
|
||||||
|
TextDocument: TextDocumentIdentifier{
|
||||||
|
URI: fileToURI(file),
|
||||||
|
},
|
||||||
|
Position: Position{
|
||||||
|
Line: line - 1,
|
||||||
|
Character: col - 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Context: ReferenceContext{
|
||||||
|
IncludeDeclaration: includeDeclaration,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := srv.client.Call(ctx, "textDocument/references", params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("references request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Result == nil || string(resp.Result) == "null" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var locations []Location
|
||||||
|
if err := json.Unmarshal(resp.Result, &locations); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse references result: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return locations, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureDocumentOpen opens a document if not already open.
|
||||||
|
func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, file string) error {
|
||||||
|
uri := fileToURI(file)
|
||||||
|
|
||||||
|
srv.mu.Lock()
|
||||||
|
if _, ok := srv.openDocs[uri]; ok {
|
||||||
|
srv.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
// Read file content
|
||||||
|
content, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get language ID
|
||||||
|
langID := languageToLSPID(srv.language)
|
||||||
|
|
||||||
|
params := DidOpenTextDocumentParams{
|
||||||
|
TextDocument: TextDocumentItem{
|
||||||
|
URI: uri,
|
||||||
|
LanguageID: langID,
|
||||||
|
Version: 1,
|
||||||
|
Text: string(content),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := srv.client.Notify("textDocument/didOpen", params); err != nil {
|
||||||
|
return fmt.Errorf("didOpen failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.mu.Lock()
|
||||||
|
srv.openDocs[uri] = 1
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseDocument closes a document in the server.
|
||||||
|
func (m *Manager) CloseDocument(_ context.Context, lang protocol.Language, file string) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
srv, ok := m.servers[lang]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok || !srv.ready {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fileToURI(file)
|
||||||
|
|
||||||
|
srv.mu.Lock()
|
||||||
|
if _, ok := srv.openDocs[uri]; !ok {
|
||||||
|
srv.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
delete(srv.openDocs, uri)
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
params := DidCloseTextDocumentParams{
|
||||||
|
TextDocument: TextDocumentIdentifier{
|
||||||
|
URI: uri,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return srv.client.Notify("textDocument/didClose", params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// reapIdleServers periodically closes idle servers.
|
||||||
|
func (m *Manager) reapIdleServers() {
|
||||||
|
ticker := time.NewTicker(60 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.stopReaper:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
m.mu.Lock()
|
||||||
|
for lang, srv := range m.servers {
|
||||||
|
// Check lastUsed with server's lock to avoid race condition
|
||||||
|
srv.mu.Lock()
|
||||||
|
idle := time.Since(srv.lastUsed) > m.idleTimeout
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
if idle {
|
||||||
|
m.logger.Info("closing idle LSP server", "language", lang)
|
||||||
|
_ = srv.client.Close()
|
||||||
|
delete(m.servers, lang)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down all LSP servers.
|
||||||
|
func (m *Manager) Close() error {
|
||||||
|
close(m.stopReaper)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.stopped = true
|
||||||
|
|
||||||
|
for lang, srv := range m.servers {
|
||||||
|
m.logger.Info("shutting down LSP server", "language", lang)
|
||||||
|
// Try graceful shutdown
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
_, _ = srv.client.Call(ctx, "shutdown", nil)
|
||||||
|
cancel()
|
||||||
|
_ = srv.client.Notify("exit", nil)
|
||||||
|
_ = srv.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.servers = make(map[protocol.Language]*ManagedServer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAvailable checks if an LSP server is available for the given language.
|
||||||
|
func (m *Manager) IsAvailable(lang protocol.Language) bool {
|
||||||
|
config, ok := DefaultServerConfigs[lang]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := exec.LookPath(config.Command[0])
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileToURI converts a file path to a file URI.
|
||||||
|
func fileToURI(file string) string {
|
||||||
|
absPath, err := filepath.Abs(file)
|
||||||
|
if err != nil {
|
||||||
|
return "file://" + file
|
||||||
|
}
|
||||||
|
return "file://" + absPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// URIToFile converts a file URI to a file path.
|
||||||
|
func URIToFile(uri string) string {
|
||||||
|
if len(uri) > 7 && uri[:7] == "file://" {
|
||||||
|
return uri[7:]
|
||||||
|
}
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
// languageToLSPID converts a language to LSP language ID.
|
||||||
|
func languageToLSPID(lang protocol.Language) string {
|
||||||
|
switch lang {
|
||||||
|
case protocol.LangGo:
|
||||||
|
return "go"
|
||||||
|
case protocol.LangTypeScript:
|
||||||
|
return "typescript"
|
||||||
|
case protocol.LangJavaScript:
|
||||||
|
return "javascript"
|
||||||
|
case protocol.LangPython:
|
||||||
|
return "python"
|
||||||
|
case protocol.LangC:
|
||||||
|
return "c"
|
||||||
|
case protocol.LangCpp:
|
||||||
|
return "cpp"
|
||||||
|
default:
|
||||||
|
return string(lang)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
package lsp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileToURI(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
file string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "absolute path",
|
||||||
|
file: "/Users/test/file.go",
|
||||||
|
want: "file:///Users/test/file.go",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := fileToURI(tt.file)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("fileToURI() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestURIToFile(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uri string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "file uri",
|
||||||
|
uri: "file:///Users/test/file.go",
|
||||||
|
want: "/Users/test/file.go",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not a file uri",
|
||||||
|
uri: "/Users/test/file.go",
|
||||||
|
want: "/Users/test/file.go",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := URIToFile(tt.uri)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("URIToFile() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLanguageToLSPID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
lang protocol.Language
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{protocol.LangGo, "go"},
|
||||||
|
{protocol.LangTypeScript, "typescript"},
|
||||||
|
{protocol.LangJavaScript, "javascript"},
|
||||||
|
{protocol.LangPython, "python"},
|
||||||
|
{protocol.LangC, "c"},
|
||||||
|
{protocol.LangCpp, "cpp"},
|
||||||
|
{protocol.LangUnknown, "unknown"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.lang), func(t *testing.T) {
|
||||||
|
got := languageToLSPID(tt.lang)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("languageToLSPID() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAvailable(t *testing.T) {
|
||||||
|
// This tests the structure of the manager without actually spawning servers
|
||||||
|
// which requires the actual LSP servers to be installed
|
||||||
|
|
||||||
|
// Just verify the DefaultServerConfigs structure
|
||||||
|
expectedLanguages := []protocol.Language{
|
||||||
|
protocol.LangGo,
|
||||||
|
protocol.LangTypeScript,
|
||||||
|
protocol.LangJavaScript,
|
||||||
|
protocol.LangPython,
|
||||||
|
protocol.LangC,
|
||||||
|
protocol.LangCpp,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, lang := range expectedLanguages {
|
||||||
|
if _, ok := DefaultServerConfigs[lang]; !ok {
|
||||||
|
t.Errorf("missing server config for language: %s", lang)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultServerConfigs(t *testing.T) {
|
||||||
|
// Verify the command structure
|
||||||
|
for lang, config := range DefaultServerConfigs {
|
||||||
|
if len(config.Command) == 0 {
|
||||||
|
t.Errorf("language %s has empty command", lang)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
package lsp
|
||||||
|
|
||||||
|
// InitializeParams are the parameters for the initialize request.
|
||||||
|
type InitializeParams struct {
|
||||||
|
RootURI string `json:"rootUri"`
|
||||||
|
Capabilities Capabilities `json:"capabilities"`
|
||||||
|
ProcessID int `json:"processId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capabilities represents client capabilities.
|
||||||
|
type Capabilities struct {
|
||||||
|
TextDocument TextDocumentClientCapabilities `json:"textDocument"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextDocumentClientCapabilities represents text document capabilities.
|
||||||
|
type TextDocumentClientCapabilities struct {
|
||||||
|
Hover HoverCapability `json:"hover,omitempty"`
|
||||||
|
Definition DefinitionCapability `json:"definition,omitempty"`
|
||||||
|
References ReferencesCapability `json:"references,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HoverCapability represents hover capabilities.
|
||||||
|
type HoverCapability struct {
|
||||||
|
ContentFormat []string `json:"contentFormat,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefinitionCapability represents definition capabilities.
|
||||||
|
type DefinitionCapability struct {
|
||||||
|
LinkSupport bool `json:"linkSupport,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReferencesCapability represents references capabilities.
|
||||||
|
type ReferencesCapability struct{}
|
||||||
|
|
||||||
|
// InitializeResult is the result of the initialize request.
|
||||||
|
type InitializeResult struct {
|
||||||
|
Capabilities ServerCapabilities `json:"capabilities"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerCapabilities represents server capabilities.
|
||||||
|
type ServerCapabilities struct {
|
||||||
|
HoverProvider bool `json:"hoverProvider,omitempty"`
|
||||||
|
DefinitionProvider bool `json:"definitionProvider,omitempty"`
|
||||||
|
ReferencesProvider bool `json:"referencesProvider,omitempty"`
|
||||||
|
DocumentSymbolProvider bool `json:"documentSymbolProvider,omitempty"`
|
||||||
|
TextDocumentSync int `json:"textDocumentSync,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Position represents a position in a document.
|
||||||
|
type Position struct {
|
||||||
|
Line int `json:"line"` // 0-indexed
|
||||||
|
Character int `json:"character"` // 0-indexed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Range represents a range in a document.
|
||||||
|
type Range struct {
|
||||||
|
Start Position `json:"start"`
|
||||||
|
End Position `json:"end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Location represents a location in a document.
|
||||||
|
type Location struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
Range Range `json:"range"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextDocumentIdentifier identifies a text document.
|
||||||
|
type TextDocumentIdentifier struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextDocumentPositionParams represents position parameters.
|
||||||
|
type TextDocumentPositionParams struct {
|
||||||
|
TextDocument TextDocumentIdentifier `json:"textDocument"`
|
||||||
|
Position Position `json:"position"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HoverParams are the parameters for the hover request.
|
||||||
|
type HoverParams struct {
|
||||||
|
TextDocumentPositionParams
|
||||||
|
}
|
||||||
|
|
||||||
|
// HoverResult is the result of the hover request.
|
||||||
|
type HoverResult struct {
|
||||||
|
Range *Range `json:"range,omitempty"`
|
||||||
|
Contents MarkupContent `json:"contents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkupContent represents markup content.
|
||||||
|
type MarkupContent struct {
|
||||||
|
Kind string `json:"kind"` // "plaintext" or "markdown"
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefinitionParams are the parameters for the definition request.
|
||||||
|
type DefinitionParams struct {
|
||||||
|
TextDocumentPositionParams
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReferenceParams are the parameters for the references request.
|
||||||
|
type ReferenceParams struct {
|
||||||
|
TextDocumentPositionParams
|
||||||
|
Context ReferenceContext `json:"context"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReferenceContext represents reference context.
|
||||||
|
type ReferenceContext struct {
|
||||||
|
IncludeDeclaration bool `json:"includeDeclaration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextDocumentItem represents a text document.
|
||||||
|
type TextDocumentItem struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
LanguageID string `json:"languageId"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Version int `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DidOpenTextDocumentParams are the parameters for didOpen.
|
||||||
|
type DidOpenTextDocumentParams struct {
|
||||||
|
TextDocument TextDocumentItem `json:"textDocument"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DidCloseTextDocumentParams are the parameters for didClose.
|
||||||
|
type DidCloseTextDocumentParams struct {
|
||||||
|
TextDocument TextDocumentIdentifier `json:"textDocument"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocumentSymbol represents a symbol in a document.
|
||||||
|
type DocumentSymbol struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Detail string `json:"detail,omitempty"`
|
||||||
|
Children []DocumentSymbol `json:"children,omitempty"`
|
||||||
|
Range Range `json:"range"`
|
||||||
|
SelectionRange Range `json:"selectionRange"`
|
||||||
|
Kind int `json:"kind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SymbolInformation represents symbol information.
|
||||||
|
type SymbolInformation struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
ContainerName string `json:"containerName,omitempty"`
|
||||||
|
Location Location `json:"location"`
|
||||||
|
Kind int `json:"kind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocumentSymbolParams are the parameters for documentSymbol.
|
||||||
|
type DocumentSymbolParams struct {
|
||||||
|
TextDocument TextDocumentIdentifier `json:"textDocument"`
|
||||||
|
}
|
||||||
@@ -0,0 +1,190 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FindNodeAtPosition finds the node at the given line and column.
|
||||||
|
func FindNodeAtPosition(tree *sitter.Tree, line, col int) *sitter.Node {
|
||||||
|
if tree == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
root := tree.RootNode()
|
||||||
|
if root == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to 0-indexed
|
||||||
|
point := sitter.Point{
|
||||||
|
Row: uint32(line - 1), // #nosec G115 - line numbers are bounded by file size
|
||||||
|
Column: uint32(col - 1), // #nosec G115 - column numbers are bounded by line length
|
||||||
|
}
|
||||||
|
|
||||||
|
return findNodeAtPoint(root, point)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findNodeAtPoint recursively finds the smallest node containing the point.
|
||||||
|
func findNodeAtPoint(node *sitter.Node, point sitter.Point) *sitter.Node {
|
||||||
|
if node == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
startPoint := node.StartPoint()
|
||||||
|
endPoint := node.EndPoint()
|
||||||
|
|
||||||
|
// Check if point is within this node
|
||||||
|
if !pointInRange(point, startPoint, endPoint) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to find a more specific child node
|
||||||
|
for i := 0; i < int(node.ChildCount()); i++ {
|
||||||
|
child := node.Child(i)
|
||||||
|
if child == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if result := findNodeAtPoint(child, point); result != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No child contains the point, return this node
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// pointInRange checks if a point is within a range.
|
||||||
|
func pointInRange(point, start, end sitter.Point) bool {
|
||||||
|
// Before start?
|
||||||
|
if point.Row < start.Row || (point.Row == start.Row && point.Column < start.Column) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// After end?
|
||||||
|
if point.Row > end.Row || (point.Row == end.Row && point.Column >= end.Column) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindParentOfKind finds the nearest ancestor of the given node type.
|
||||||
|
func FindParentOfKind(node *sitter.Node, kind string) *sitter.Node {
|
||||||
|
if node == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
current := node.Parent()
|
||||||
|
for current != nil {
|
||||||
|
if current.Type() == kind {
|
||||||
|
return current
|
||||||
|
}
|
||||||
|
current = current.Parent()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodeText returns the text content of a node.
|
||||||
|
func GetNodeText(node *sitter.Node, content []byte) string {
|
||||||
|
if node == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
start := node.StartByte()
|
||||||
|
end := node.EndByte()
|
||||||
|
|
||||||
|
if int(start) >= len(content) || int(end) > len(content) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(content[start:end])
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalkTree walks the tree calling fn for each node.
|
||||||
|
// If fn returns false, the walk stops.
|
||||||
|
func WalkTree(node *sitter.Node, fn func(*sitter.Node) bool) {
|
||||||
|
if node == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fn(node) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < int(node.ChildCount()); i++ {
|
||||||
|
WalkTree(node.Child(i), fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindNodesByKind finds all nodes of a given kind.
|
||||||
|
func FindNodesByKind(root *sitter.Node, kind string) []*sitter.Node {
|
||||||
|
var nodes []*sitter.Node
|
||||||
|
|
||||||
|
WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
if n.Type() == kind {
|
||||||
|
nodes = append(nodes, n)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindNamedChildren returns all named (non-anonymous) children of a node.
|
||||||
|
func FindNamedChildren(node *sitter.Node) []*sitter.Node {
|
||||||
|
if node == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var children []*sitter.Node
|
||||||
|
for i := 0; i < int(node.NamedChildCount()); i++ {
|
||||||
|
if child := node.NamedChild(i); child != nil {
|
||||||
|
children = append(children, child)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return children
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChildByFieldName returns the child node with the given field name.
|
||||||
|
func GetChildByFieldName(node *sitter.Node, fieldName string) *sitter.Node {
|
||||||
|
if node == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return node.ChildByFieldName(fieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeLocation returns the location of a node.
|
||||||
|
func NodeLocation(node *sitter.Node, filename string) protocol.Location {
|
||||||
|
if node == nil {
|
||||||
|
return protocol.Location{}
|
||||||
|
}
|
||||||
|
|
||||||
|
startPoint := node.StartPoint()
|
||||||
|
return protocol.Location{
|
||||||
|
File: filename,
|
||||||
|
Line: int(startPoint.Row) + 1,
|
||||||
|
Column: int(startPoint.Column) + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeRange returns the range of a node.
|
||||||
|
func NodeRange(node *sitter.Node, filename string) protocol.Range {
|
||||||
|
if node == nil {
|
||||||
|
return protocol.Range{}
|
||||||
|
}
|
||||||
|
|
||||||
|
startPoint := node.StartPoint()
|
||||||
|
endPoint := node.EndPoint()
|
||||||
|
|
||||||
|
return protocol.Range{
|
||||||
|
Start: protocol.Location{
|
||||||
|
File: filename,
|
||||||
|
Line: int(startPoint.Row) + 1,
|
||||||
|
Column: int(startPoint.Column) + 1,
|
||||||
|
},
|
||||||
|
End: protocol.Location{
|
||||||
|
File: filename,
|
||||||
|
Line: int(endPoint.Row) + 1,
|
||||||
|
Column: int(endPoint.Column) + 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestLRUCacheEviction tests that the LRU cache properly evicts old entries.
|
||||||
|
func TestLRUCacheEviction(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create 101 unique Go files (cache size is 100)
|
||||||
|
for i := 0; i < 101; i++ {
|
||||||
|
content := []byte(fmt.Sprintf("package main\n\nfunc test%d() {}\n", i))
|
||||||
|
filename := "test.go"
|
||||||
|
|
||||||
|
_, err := registry.Parse(ctx, filename, content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse failed for iteration %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The LRU cache should have evicted the oldest entry
|
||||||
|
// Verify cache size is capped at 100
|
||||||
|
cacheLen := registry.cache.Len()
|
||||||
|
if cacheLen > 100 {
|
||||||
|
t.Errorf("Cache size %d exceeds max size 100", cacheLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCacheHit tests that repeated parsing of the same content uses cache.
|
||||||
|
func TestCacheHit(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
content := []byte("package main\n\nfunc test() {}\n")
|
||||||
|
filename := "test.go"
|
||||||
|
|
||||||
|
// First parse
|
||||||
|
result1, err := registry.Parse(ctx, filename, content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("First parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second parse should use cache
|
||||||
|
result2, err := registry.Parse(ctx, filename, content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The tree should be the same object (cached)
|
||||||
|
if result1.Tree != result2.Tree {
|
||||||
|
t.Error("Expected cached tree to be reused, but got different tree objects")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContentHashCollisionResistance tests that different content produces different hashes.
|
||||||
|
func TestContentHashCollisionResistance(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
content1 []byte
|
||||||
|
content2 []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "different content",
|
||||||
|
content1: []byte("package main"),
|
||||||
|
content2: []byte("package test"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same prefix different suffix",
|
||||||
|
content1: []byte("package main\nfunc a() {}"),
|
||||||
|
content2: []byte("package main\nfunc b() {}"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different length",
|
||||||
|
content1: []byte("short"),
|
||||||
|
content2: []byte("much longer content here"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
hash1 := contentHash(tc.content1)
|
||||||
|
hash2 := contentHash(tc.content2)
|
||||||
|
|
||||||
|
if hash1 == hash2 {
|
||||||
|
t.Errorf("Hash collision: %s == %s for different content", hash1, hash2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContentHashConsistency tests that the same content always produces the same hash.
|
||||||
|
func TestContentHashConsistency(t *testing.T) {
|
||||||
|
content := []byte("package main\n\nfunc test() {}\n")
|
||||||
|
|
||||||
|
hash1 := contentHash(content)
|
||||||
|
hash2 := contentHash(content)
|
||||||
|
hash3 := contentHash(content)
|
||||||
|
|
||||||
|
if hash1 != hash2 || hash2 != hash3 {
|
||||||
|
t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkContentHash_xxHash benchmarks the xxHash implementation.
|
||||||
|
func BenchmarkContentHash_xxHash(b *testing.B) {
|
||||||
|
// Typical file content size (10KB)
|
||||||
|
content := make([]byte, 10*1024)
|
||||||
|
for i := range content {
|
||||||
|
content[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = contentHash(content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkCacheHitRate benchmarks cache performance with realistic workload.
|
||||||
|
func BenchmarkCacheHitRate(b *testing.B) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a set of common files that get parsed repeatedly
|
||||||
|
files := [][]byte{
|
||||||
|
[]byte("package main\n\nfunc main() {}\n"),
|
||||||
|
[]byte("package test\n\nimport \"testing\"\n"),
|
||||||
|
[]byte("package util\n\nfunc helper() string { return \"\" }\n"),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Simulate realistic access pattern with cache hits
|
||||||
|
content := files[i%len(files)]
|
||||||
|
_, _ = registry.Parse(ctx, "test.go", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,550 @@
|
|||||||
|
// Package parser provides documentation extraction for multiple languages.
|
||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DocComment represents an extracted documentation comment.
|
||||||
|
type DocComment struct {
|
||||||
|
Tags map[string]string
|
||||||
|
Text string
|
||||||
|
Raw string
|
||||||
|
Style CommentStyle
|
||||||
|
StartLine int
|
||||||
|
EndLine int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommentStyle indicates the type of comment.
|
||||||
|
type CommentStyle string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CommentStyleLine CommentStyle = "line" // // comment
|
||||||
|
CommentStyleBlock CommentStyle = "block" // /* comment */
|
||||||
|
CommentStyleJSDoc CommentStyle = "jsdoc" // /** comment */
|
||||||
|
CommentStyleDoxygen CommentStyle = "doxygen" // /** comment */ or /// comment
|
||||||
|
CommentStyleDocstring CommentStyle = "docstring" // """comment""" or '''comment'''
|
||||||
|
CommentStyleHash CommentStyle = "hash" // # comment (Python)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtractDocComment extracts the documentation comment for a node.
|
||||||
|
func ExtractDocComment(n *sitter.Node, content []byte, lang protocol.Language) *DocComment {
|
||||||
|
if n == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch lang {
|
||||||
|
case protocol.LangGo:
|
||||||
|
return extractGoDocComment(n, content)
|
||||||
|
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||||
|
return extractJSDocComment(n, content)
|
||||||
|
case protocol.LangPython:
|
||||||
|
return extractPythonDocComment(n, content)
|
||||||
|
case protocol.LangC, protocol.LangCpp:
|
||||||
|
return extractCDocComment(n, content)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractGoDocComment extracts Go documentation comments.
|
||||||
|
// Go uses // or /* */ comments immediately preceding a declaration.
|
||||||
|
func extractGoDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||||
|
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||||
|
if len(comments) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []string
|
||||||
|
var raw []string
|
||||||
|
startLine := -1
|
||||||
|
endLine := -1
|
||||||
|
|
||||||
|
for _, c := range comments {
|
||||||
|
text := GetNodeText(c, content)
|
||||||
|
raw = append(raw, text)
|
||||||
|
|
||||||
|
if startLine == -1 {
|
||||||
|
startLine = int(c.StartPoint().Row) + 1
|
||||||
|
}
|
||||||
|
endLine = int(c.EndPoint().Row) + 1
|
||||||
|
|
||||||
|
cleaned := cleanGoComment(text)
|
||||||
|
if cleaned != "" {
|
||||||
|
parts = append(parts, cleaned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DocComment{
|
||||||
|
Text: strings.Join(parts, "\n"),
|
||||||
|
Raw: strings.Join(raw, "\n"),
|
||||||
|
Style: detectCommentStyle(raw[0]),
|
||||||
|
Tags: nil, // Go doesn't use JSDoc-style tags
|
||||||
|
StartLine: startLine,
|
||||||
|
EndLine: endLine,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractJSDocComment extracts JSDoc-style documentation comments.
|
||||||
|
func extractJSDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||||
|
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||||
|
if len(comments) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSDoc prefers the last comment block if it's a JSDoc comment
|
||||||
|
var jsDocComment *sitter.Node
|
||||||
|
for i := len(comments) - 1; i >= 0; i-- {
|
||||||
|
text := GetNodeText(comments[i], content)
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(text), "/**") {
|
||||||
|
jsDocComment = comments[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if jsDocComment != nil {
|
||||||
|
text := GetNodeText(jsDocComment, content)
|
||||||
|
cleaned, tags := parseJSDoc(text)
|
||||||
|
return &DocComment{
|
||||||
|
Text: cleaned,
|
||||||
|
Raw: text,
|
||||||
|
Style: CommentStyleJSDoc,
|
||||||
|
Tags: tags,
|
||||||
|
StartLine: int(jsDocComment.StartPoint().Row) + 1,
|
||||||
|
EndLine: int(jsDocComment.EndPoint().Row) + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to regular comments
|
||||||
|
var parts []string
|
||||||
|
var raw []string
|
||||||
|
startLine := -1
|
||||||
|
endLine := -1
|
||||||
|
|
||||||
|
for _, c := range comments {
|
||||||
|
text := GetNodeText(c, content)
|
||||||
|
raw = append(raw, text)
|
||||||
|
|
||||||
|
if startLine == -1 {
|
||||||
|
startLine = int(c.StartPoint().Row) + 1
|
||||||
|
}
|
||||||
|
endLine = int(c.EndPoint().Row) + 1
|
||||||
|
|
||||||
|
cleaned := cleanJSComment(text)
|
||||||
|
if cleaned != "" {
|
||||||
|
parts = append(parts, cleaned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DocComment{
|
||||||
|
Text: strings.Join(parts, "\n"),
|
||||||
|
Raw: strings.Join(raw, "\n"),
|
||||||
|
Style: CommentStyleLine,
|
||||||
|
Tags: nil,
|
||||||
|
StartLine: startLine,
|
||||||
|
EndLine: endLine,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPythonDocComment extracts Python docstrings.
|
||||||
|
// Python docstrings are triple-quoted strings inside the function/class body.
|
||||||
|
func extractPythonDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||||
|
// Python docstrings are inside the body, not before
|
||||||
|
body := n.ChildByFieldName("body")
|
||||||
|
if body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// First statement should be the docstring if present
|
||||||
|
if body.NamedChildCount() > 0 {
|
||||||
|
first := body.NamedChild(0)
|
||||||
|
if first != nil && first.Type() == "expression_statement" {
|
||||||
|
if first.NamedChildCount() > 0 {
|
||||||
|
expr := first.NamedChild(0)
|
||||||
|
if expr != nil && expr.Type() == "string" {
|
||||||
|
text := GetNodeText(expr, content)
|
||||||
|
cleaned := cleanPythonDocstring(text)
|
||||||
|
return &DocComment{
|
||||||
|
Text: cleaned,
|
||||||
|
Raw: text,
|
||||||
|
Style: CommentStyleDocstring,
|
||||||
|
Tags: nil,
|
||||||
|
StartLine: int(expr.StartPoint().Row) + 1,
|
||||||
|
EndLine: int(expr.EndPoint().Row) + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also check for # comments before the definition
|
||||||
|
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||||
|
if len(comments) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []string
|
||||||
|
var raw []string
|
||||||
|
startLine := -1
|
||||||
|
endLine := -1
|
||||||
|
|
||||||
|
for _, c := range comments {
|
||||||
|
text := GetNodeText(c, content)
|
||||||
|
raw = append(raw, text)
|
||||||
|
|
||||||
|
if startLine == -1 {
|
||||||
|
startLine = int(c.StartPoint().Row) + 1
|
||||||
|
}
|
||||||
|
endLine = int(c.EndPoint().Row) + 1
|
||||||
|
|
||||||
|
// Clean # comment
|
||||||
|
cleaned := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(text), "#"))
|
||||||
|
if cleaned != "" {
|
||||||
|
parts = append(parts, cleaned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DocComment{
|
||||||
|
Text: strings.Join(parts, "\n"),
|
||||||
|
Raw: strings.Join(raw, "\n"),
|
||||||
|
Style: CommentStyleHash,
|
||||||
|
Tags: nil,
|
||||||
|
StartLine: startLine,
|
||||||
|
EndLine: endLine,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCDocComment extracts C/C++ documentation comments (Doxygen style).
|
||||||
|
func extractCDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||||
|
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||||
|
if len(comments) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for Doxygen-style comment
|
||||||
|
var doxyComment *sitter.Node
|
||||||
|
for i := len(comments) - 1; i >= 0; i-- {
|
||||||
|
text := GetNodeText(comments[i], content)
|
||||||
|
trimmed := strings.TrimSpace(text)
|
||||||
|
if strings.HasPrefix(trimmed, "/**") || strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
|
||||||
|
doxyComment = comments[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if doxyComment != nil {
|
||||||
|
text := GetNodeText(doxyComment, content)
|
||||||
|
cleaned, tags := parseDoxygen(text)
|
||||||
|
return &DocComment{
|
||||||
|
Text: cleaned,
|
||||||
|
Raw: text,
|
||||||
|
Style: CommentStyleDoxygen,
|
||||||
|
Tags: tags,
|
||||||
|
StartLine: int(doxyComment.StartPoint().Row) + 1,
|
||||||
|
EndLine: int(doxyComment.EndPoint().Row) + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to regular comments
|
||||||
|
var parts []string
|
||||||
|
var raw []string
|
||||||
|
startLine := -1
|
||||||
|
endLine := -1
|
||||||
|
|
||||||
|
for _, c := range comments {
|
||||||
|
text := GetNodeText(c, content)
|
||||||
|
raw = append(raw, text)
|
||||||
|
|
||||||
|
if startLine == -1 {
|
||||||
|
startLine = int(c.StartPoint().Row) + 1
|
||||||
|
}
|
||||||
|
endLine = int(c.EndPoint().Row) + 1
|
||||||
|
|
||||||
|
cleaned := cleanCComment(text)
|
||||||
|
if cleaned != "" {
|
||||||
|
parts = append(parts, cleaned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DocComment{
|
||||||
|
Text: strings.Join(parts, "\n"),
|
||||||
|
Raw: strings.Join(raw, "\n"),
|
||||||
|
Style: detectCommentStyle(raw[0]),
|
||||||
|
Tags: nil,
|
||||||
|
StartLine: startLine,
|
||||||
|
EndLine: endLine,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectPrecedingComments collects all comment nodes immediately before a node.
|
||||||
|
func collectPrecedingComments(n *sitter.Node, _ []byte, commentTypes []string) []*sitter.Node {
|
||||||
|
var comments []*sitter.Node
|
||||||
|
|
||||||
|
// Walk backwards through siblings
|
||||||
|
prev := n.PrevSibling()
|
||||||
|
lastCommentLine := int(n.StartPoint().Row)
|
||||||
|
|
||||||
|
for prev != nil {
|
||||||
|
isComment := false
|
||||||
|
nodeType := prev.Type()
|
||||||
|
for _, ct := range commentTypes {
|
||||||
|
if nodeType == ct {
|
||||||
|
isComment = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isComment {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
commentEndLine := int(prev.EndPoint().Row)
|
||||||
|
|
||||||
|
// Check if there's a blank line gap
|
||||||
|
if lastCommentLine-commentEndLine > 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
comments = append([]*sitter.Node{prev}, comments...)
|
||||||
|
lastCommentLine = int(prev.StartPoint().Row)
|
||||||
|
prev = prev.PrevSibling()
|
||||||
|
}
|
||||||
|
|
||||||
|
return comments
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectCommentStyle determines the style of a comment.
|
||||||
|
func detectCommentStyle(comment string) CommentStyle {
|
||||||
|
trimmed := strings.TrimSpace(comment)
|
||||||
|
if strings.HasPrefix(trimmed, "/**") {
|
||||||
|
return CommentStyleJSDoc
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
|
||||||
|
return CommentStyleDoxygen
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, "/*") {
|
||||||
|
return CommentStyleBlock
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, "//") {
|
||||||
|
return CommentStyleLine
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
return CommentStyleHash
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, `"""`) || strings.HasPrefix(trimmed, `'''`) {
|
||||||
|
return CommentStyleDocstring
|
||||||
|
}
|
||||||
|
return CommentStyleLine
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanGoComment cleans a Go comment.
|
||||||
|
func cleanGoComment(comment string) string {
|
||||||
|
comment = strings.TrimSpace(comment)
|
||||||
|
|
||||||
|
// Handle // comments
|
||||||
|
if after, found := strings.CutPrefix(comment, "//"); found {
|
||||||
|
return strings.TrimSpace(after)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle /* */ comments
|
||||||
|
if strings.HasPrefix(comment, "/*") && strings.HasSuffix(comment, "*/") {
|
||||||
|
comment = strings.TrimPrefix(comment, "/*")
|
||||||
|
comment = strings.TrimSuffix(comment, "*/")
|
||||||
|
return cleanBlockComment(comment)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(comment)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanJSComment cleans a JavaScript/TypeScript comment.
|
||||||
|
func cleanJSComment(comment string) string {
|
||||||
|
return cleanGoComment(comment) // Same rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanCComment cleans a C/C++ comment.
|
||||||
|
func cleanCComment(comment string) string {
|
||||||
|
return cleanGoComment(comment) // Same rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanBlockComment cleans the content of a block comment.
|
||||||
|
func cleanBlockComment(comment string) string {
|
||||||
|
lines := strings.Split(comment, "\n")
|
||||||
|
var cleaned []string
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
// Remove leading * from each line (common in block comments)
|
||||||
|
line = strings.TrimPrefix(line, "*")
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
cleaned = append(cleaned, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove empty leading/trailing lines
|
||||||
|
for len(cleaned) > 0 && cleaned[0] == "" {
|
||||||
|
cleaned = cleaned[1:]
|
||||||
|
}
|
||||||
|
for len(cleaned) > 0 && cleaned[len(cleaned)-1] == "" {
|
||||||
|
cleaned = cleaned[:len(cleaned)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(cleaned, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseJSDoc parses a JSDoc comment and extracts tags.
|
||||||
|
func parseJSDoc(comment string) (string, map[string]string) {
|
||||||
|
comment = strings.TrimSpace(comment)
|
||||||
|
|
||||||
|
// Remove /** and */
|
||||||
|
comment = strings.TrimPrefix(comment, "/**")
|
||||||
|
comment = strings.TrimSuffix(comment, "*/")
|
||||||
|
|
||||||
|
lines := strings.Split(comment, "\n")
|
||||||
|
var descLines []string
|
||||||
|
tags := make(map[string]string)
|
||||||
|
|
||||||
|
// Regex for JSDoc tags
|
||||||
|
tagPattern := regexp.MustCompile(`^\s*\*?\s*@(\w+)\s*(.*)$`)
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
line = strings.TrimPrefix(line, "*")
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
|
||||||
|
tagName := matches[1]
|
||||||
|
tagValue := strings.TrimSpace(matches[2])
|
||||||
|
if existing, ok := tags[tagName]; ok {
|
||||||
|
tags[tagName] = existing + "\n" + tagValue
|
||||||
|
} else {
|
||||||
|
tags[tagName] = tagValue
|
||||||
|
}
|
||||||
|
} else if line != "" {
|
||||||
|
descLines = append(descLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(descLines, "\n"), tags
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDoxygen parses a Doxygen comment and extracts tags.
|
||||||
|
func parseDoxygen(comment string) (string, map[string]string) {
|
||||||
|
comment = strings.TrimSpace(comment)
|
||||||
|
|
||||||
|
// Handle /// and //! style comments
|
||||||
|
comment = strings.TrimPrefix(comment, "///")
|
||||||
|
comment = strings.TrimPrefix(comment, "//!")
|
||||||
|
|
||||||
|
// Handle /** */ style comments
|
||||||
|
comment = strings.TrimPrefix(comment, "/**")
|
||||||
|
comment = strings.TrimSuffix(comment, "*/")
|
||||||
|
|
||||||
|
lines := strings.Split(comment, "\n")
|
||||||
|
var descLines []string
|
||||||
|
tags := make(map[string]string)
|
||||||
|
|
||||||
|
// Regex for Doxygen tags (@param, @return, \param, \return, etc.)
|
||||||
|
tagPattern := regexp.MustCompile(`^\s*\*?\s*[@\\](\w+)\s*(.*)$`)
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
line = strings.TrimPrefix(line, "*")
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
|
||||||
|
tagName := matches[1]
|
||||||
|
tagValue := strings.TrimSpace(matches[2])
|
||||||
|
if existing, ok := tags[tagName]; ok {
|
||||||
|
tags[tagName] = existing + "\n" + tagValue
|
||||||
|
} else {
|
||||||
|
tags[tagName] = tagValue
|
||||||
|
}
|
||||||
|
} else if line != "" {
|
||||||
|
descLines = append(descLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(descLines, "\n"), tags
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatDocComment formats a DocComment for display.
|
||||||
|
func FormatDocComment(doc *DocComment) string {
|
||||||
|
if doc == nil || doc.Text == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString(doc.Text)
|
||||||
|
|
||||||
|
if len(doc.Tags) > 0 {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
// Order: description, params, returns, other
|
||||||
|
paramOrder := []string{"param", "parameter", "arg", "argument"}
|
||||||
|
returnOrder := []string{"return", "returns", "retval"}
|
||||||
|
|
||||||
|
// Write params first
|
||||||
|
for _, tagName := range paramOrder {
|
||||||
|
if val, ok := doc.Tags[tagName]; ok {
|
||||||
|
for _, line := range strings.Split(val, "\n") {
|
||||||
|
sb.WriteString("@" + tagName + " " + line + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write returns
|
||||||
|
for _, tagName := range returnOrder {
|
||||||
|
if val, ok := doc.Tags[tagName]; ok {
|
||||||
|
sb.WriteString("@" + tagName + " " + val + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write remaining tags
|
||||||
|
written := make(map[string]bool)
|
||||||
|
for _, t := range paramOrder {
|
||||||
|
written[t] = true
|
||||||
|
}
|
||||||
|
for _, t := range returnOrder {
|
||||||
|
written[t] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for tagName, val := range doc.Tags {
|
||||||
|
if !written[tagName] {
|
||||||
|
sb.WriteString("@" + tagName + " " + val + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(sb.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanPythonDocstring cleans a Python docstring.
|
||||||
|
func cleanPythonDocstring(doc string) string {
|
||||||
|
doc = strings.TrimSpace(doc)
|
||||||
|
|
||||||
|
// Remove triple quotes
|
||||||
|
doc = strings.TrimPrefix(doc, `"""`)
|
||||||
|
doc = strings.TrimSuffix(doc, `"""`)
|
||||||
|
doc = strings.TrimPrefix(doc, `'''`)
|
||||||
|
doc = strings.TrimSuffix(doc, `'''`)
|
||||||
|
|
||||||
|
return strings.TrimSpace(doc)
|
||||||
|
}
|
||||||
@@ -0,0 +1,630 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractGoDocComment(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
nodeKind string
|
||||||
|
wantText string
|
||||||
|
wantStyle CommentStyle
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single line comment",
|
||||||
|
code: `package main
|
||||||
|
|
||||||
|
// Hello says hello
|
||||||
|
func Hello() {}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "Hello says hello",
|
||||||
|
wantStyle: CommentStyleLine,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-line comments",
|
||||||
|
code: `package main
|
||||||
|
|
||||||
|
// This is a function
|
||||||
|
// that does something
|
||||||
|
// important
|
||||||
|
func DoSomething() {}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "This is a function\nthat does something\nimportant",
|
||||||
|
wantStyle: CommentStyleLine,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "block comment",
|
||||||
|
code: `package main
|
||||||
|
|
||||||
|
/* This is a block comment
|
||||||
|
describing the function */
|
||||||
|
func BlockCommented() {}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "This is a block comment\ndescribing the function",
|
||||||
|
wantStyle: CommentStyleBlock,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "doc comment with asterisks",
|
||||||
|
code: `package main
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This is a properly formatted
|
||||||
|
* block comment with asterisks
|
||||||
|
*/
|
||||||
|
func FormattedBlock() {}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "This is a properly formatted\nblock comment with asterisks",
|
||||||
|
wantStyle: CommentStyleBlock,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no comment",
|
||||||
|
code: `package main
|
||||||
|
|
||||||
|
func NoComment() {}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := registry.Parse(context.Background(), "test.go", []byte(tt.code))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the target node
|
||||||
|
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||||
|
if targetNode == nil {
|
||||||
|
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangGo)
|
||||||
|
|
||||||
|
if tt.wantText == "" {
|
||||||
|
if doc != nil && doc.Text != "" {
|
||||||
|
t.Errorf("expected no doc, got %q", doc.Text)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc == nil {
|
||||||
|
t.Fatal("expected doc, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Text != tt.wantText {
|
||||||
|
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Style != tt.wantStyle {
|
||||||
|
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractJSDocComment(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
wantTags map[string]string
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
nodeKind string
|
||||||
|
wantText string
|
||||||
|
wantStyle CommentStyle
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "JSDoc comment",
|
||||||
|
code: `/**
|
||||||
|
* Adds two numbers together.
|
||||||
|
* @param a The first number
|
||||||
|
* @param b The second number
|
||||||
|
* @returns The sum of a and b
|
||||||
|
*/
|
||||||
|
function add(a, b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "Adds two numbers together.",
|
||||||
|
wantStyle: CommentStyleJSDoc,
|
||||||
|
wantTags: map[string]string{
|
||||||
|
"param": "a The first number\nb The second number",
|
||||||
|
"returns": "The sum of a and b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple line comment",
|
||||||
|
code: `// This is a simple function
|
||||||
|
function simple() {}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "This is a simple function",
|
||||||
|
wantStyle: CommentStyleLine,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JSDoc with types",
|
||||||
|
code: `/**
|
||||||
|
* @param {string} name - The name
|
||||||
|
* @returns {boolean} True if valid
|
||||||
|
*/
|
||||||
|
function validate(name) {}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_declaration",
|
||||||
|
wantText: "",
|
||||||
|
wantStyle: CommentStyleJSDoc,
|
||||||
|
wantTags: map[string]string{
|
||||||
|
"param": "{string} name - The name",
|
||||||
|
"returns": "{boolean} True if valid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := registry.Parse(context.Background(), "test.js", []byte(tt.code))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||||
|
if targetNode == nil {
|
||||||
|
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangJavaScript)
|
||||||
|
|
||||||
|
if doc == nil {
|
||||||
|
t.Fatal("expected doc, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Text != tt.wantText {
|
||||||
|
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Style != tt.wantStyle {
|
||||||
|
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantTags != nil {
|
||||||
|
for k, want := range tt.wantTags {
|
||||||
|
if got := doc.Tags[k]; got != want {
|
||||||
|
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractPythonDocComment(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
nodeKind string
|
||||||
|
wantText string
|
||||||
|
wantStyle CommentStyle
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "docstring",
|
||||||
|
code: `def greet(name):
|
||||||
|
"""Greet a person by name."""
|
||||||
|
print(f"Hello, {name}!")
|
||||||
|
`,
|
||||||
|
nodeKind: "function_definition",
|
||||||
|
wantText: "Greet a person by name.",
|
||||||
|
wantStyle: CommentStyleDocstring,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-line docstring",
|
||||||
|
code: `def calculate(x, y):
|
||||||
|
"""
|
||||||
|
Calculate the sum of two numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: First number
|
||||||
|
y: Second number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of x and y
|
||||||
|
"""
|
||||||
|
return x + y
|
||||||
|
`,
|
||||||
|
nodeKind: "function_definition",
|
||||||
|
wantText: "Calculate the sum of two numbers.\n\n Args:\n x: First number\n y: Second number\n\n Returns:\n The sum of x and y",
|
||||||
|
wantStyle: CommentStyleDocstring,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "class docstring",
|
||||||
|
code: `class MyClass:
|
||||||
|
"""This is a class description."""
|
||||||
|
pass
|
||||||
|
`,
|
||||||
|
nodeKind: "class_definition",
|
||||||
|
wantText: "This is a class description.",
|
||||||
|
wantStyle: CommentStyleDocstring,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single quote docstring",
|
||||||
|
code: `def func():
|
||||||
|
'''Single quote docstring'''
|
||||||
|
pass
|
||||||
|
`,
|
||||||
|
nodeKind: "function_definition",
|
||||||
|
wantText: "Single quote docstring",
|
||||||
|
wantStyle: CommentStyleDocstring,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := registry.Parse(context.Background(), "test.py", []byte(tt.code))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||||
|
if targetNode == nil {
|
||||||
|
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangPython)
|
||||||
|
|
||||||
|
if doc == nil {
|
||||||
|
t.Fatal("expected doc, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Text != tt.wantText {
|
||||||
|
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Style != tt.wantStyle {
|
||||||
|
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCDocComment(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
wantTags map[string]string
|
||||||
|
name string
|
||||||
|
code string
|
||||||
|
nodeKind string
|
||||||
|
wantText string
|
||||||
|
wantStyle CommentStyle
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Doxygen block comment",
|
||||||
|
code: `/**
|
||||||
|
* Adds two numbers.
|
||||||
|
* @param a First number
|
||||||
|
* @param b Second number
|
||||||
|
* @return Sum of a and b
|
||||||
|
*/
|
||||||
|
int add(int a, int b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
nodeKind: "function_definition",
|
||||||
|
wantText: "Adds two numbers.",
|
||||||
|
wantStyle: CommentStyleDoxygen,
|
||||||
|
wantTags: map[string]string{
|
||||||
|
"param": "a First number\nb Second number",
|
||||||
|
"return": "Sum of a and b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regular block comment",
|
||||||
|
code: `/* This is a regular comment */
|
||||||
|
int regular() { return 0; }
|
||||||
|
`,
|
||||||
|
nodeKind: "function_definition",
|
||||||
|
wantText: "This is a regular comment",
|
||||||
|
wantStyle: CommentStyleBlock,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "line comment",
|
||||||
|
code: `// Simple function
|
||||||
|
int simple() { return 1; }
|
||||||
|
`,
|
||||||
|
nodeKind: "function_definition",
|
||||||
|
wantText: "Simple function",
|
||||||
|
wantStyle: CommentStyleLine,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := registry.Parse(context.Background(), "test.c", []byte(tt.code))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||||
|
if targetNode == nil {
|
||||||
|
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangC)
|
||||||
|
|
||||||
|
if doc == nil {
|
||||||
|
t.Fatal("expected doc, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Text != tt.wantText {
|
||||||
|
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc.Style != tt.wantStyle {
|
||||||
|
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantTags != nil {
|
||||||
|
for k, want := range tt.wantTags {
|
||||||
|
if got := doc.Tags[k]; got != want {
|
||||||
|
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseJSDoc(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
wantTags map[string]string
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "complete jsdoc",
|
||||||
|
input: `/**
|
||||||
|
* This is a description.
|
||||||
|
* Multiple lines.
|
||||||
|
* @param {string} name The name
|
||||||
|
* @returns {boolean} Result
|
||||||
|
*/`,
|
||||||
|
wantText: "This is a description.\nMultiple lines.",
|
||||||
|
wantTags: map[string]string{
|
||||||
|
"param": "{string} name The name",
|
||||||
|
"returns": "{boolean} Result",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty jsdoc",
|
||||||
|
input: `/** */`,
|
||||||
|
wantText: "",
|
||||||
|
wantTags: map[string]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only description",
|
||||||
|
input: `/** Simple description */`,
|
||||||
|
wantText: "Simple description",
|
||||||
|
wantTags: map[string]string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
text, tags := parseJSDoc(tt.input)
|
||||||
|
|
||||||
|
if text != tt.wantText {
|
||||||
|
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tags) != len(tt.wantTags) {
|
||||||
|
t.Errorf("tag count mismatch: got %d, want %d", len(tags), len(tt.wantTags))
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, want := range tt.wantTags {
|
||||||
|
if got := tags[k]; got != want {
|
||||||
|
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDoxygen(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
wantTags map[string]string
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "doxygen with @ tags",
|
||||||
|
input: `/**
|
||||||
|
* Brief description.
|
||||||
|
* @param x Value
|
||||||
|
* @return Result
|
||||||
|
*/`,
|
||||||
|
wantText: "Brief description.",
|
||||||
|
wantTags: map[string]string{
|
||||||
|
"param": "x Value",
|
||||||
|
"return": "Result",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "doxygen with backslash tags",
|
||||||
|
input: `/**
|
||||||
|
* Description.
|
||||||
|
* \param y Input
|
||||||
|
* \retval Output value
|
||||||
|
*/`,
|
||||||
|
wantText: "Description.",
|
||||||
|
wantTags: map[string]string{
|
||||||
|
"param": "y Input",
|
||||||
|
"retval": "Output value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "triple slash",
|
||||||
|
input: `/// Simple description`,
|
||||||
|
wantText: "Simple description",
|
||||||
|
wantTags: map[string]string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
text, tags := parseDoxygen(tt.input)
|
||||||
|
|
||||||
|
if text != tt.wantText {
|
||||||
|
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, want := range tt.wantTags {
|
||||||
|
if got := tags[k]; got != want {
|
||||||
|
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatDocComment(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
doc *DocComment
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with tags",
|
||||||
|
doc: &DocComment{
|
||||||
|
Text: "This is a function.",
|
||||||
|
Tags: map[string]string{
|
||||||
|
"param": "x The value",
|
||||||
|
"returns": "The result",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "This is a function.\n\n@param x The value\n@returns The result",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no tags",
|
||||||
|
doc: &DocComment{
|
||||||
|
Text: "Simple description.",
|
||||||
|
Tags: nil,
|
||||||
|
},
|
||||||
|
want: "Simple description.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil doc",
|
||||||
|
doc: nil,
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty text",
|
||||||
|
doc: &DocComment{
|
||||||
|
Text: "",
|
||||||
|
Tags: nil,
|
||||||
|
},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := FormatDocComment(tt.doc)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("mismatch:\ngot: %q\nwant: %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetectCommentStyle(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want CommentStyle
|
||||||
|
}{
|
||||||
|
{"/** JSDoc */", CommentStyleJSDoc},
|
||||||
|
{"/// Doxygen", CommentStyleDoxygen},
|
||||||
|
{"//! Doxygen", CommentStyleDoxygen},
|
||||||
|
{"/* block */", CommentStyleBlock},
|
||||||
|
{"// line", CommentStyleLine},
|
||||||
|
{"# hash", CommentStyleHash},
|
||||||
|
{`"""docstring"""`, CommentStyleDocstring},
|
||||||
|
{`'''docstring'''`, CommentStyleDocstring},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := detectCommentStyle(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("got %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findNodeByKind finds the first node of the given kind.
|
||||||
|
func findNodeByKind(root *sitter.Node, nodeType string) *sitter.Node {
|
||||||
|
if root == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *sitter.Node
|
||||||
|
WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
if n.Type() == nodeType {
|
||||||
|
result = n
|
||||||
|
return false // stop walking
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanBlockComment(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: "\n * Line 1\n * Line 2\n ",
|
||||||
|
want: "Line 1\nLine 2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "Simple",
|
||||||
|
want: "Simple",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "\n\nWith blank lines\n\n",
|
||||||
|
want: "With blank lines",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input[:min(10, len(tt.input))], func(t *testing.T) {
|
||||||
|
got := cleanBlockComment(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("got %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,271 @@
|
|||||||
|
// Package parser provides Tree-sitter based parsing for multiple languages.
|
||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cespare/xxhash/v2"
|
||||||
|
lru "github.com/hashicorp/golang-lru/v2"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
"github.com/smacker/go-tree-sitter/c"
|
||||||
|
"github.com/smacker/go-tree-sitter/cpp"
|
||||||
|
"github.com/smacker/go-tree-sitter/golang"
|
||||||
|
"github.com/smacker/go-tree-sitter/html"
|
||||||
|
"github.com/smacker/go-tree-sitter/javascript"
|
||||||
|
"github.com/smacker/go-tree-sitter/python"
|
||||||
|
"github.com/smacker/go-tree-sitter/typescript/typescript"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxFileSize is the maximum file size we'll parse (10MB).
|
||||||
|
const MaxFileSize = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
// Registry manages Tree-sitter parsers for different languages.
|
||||||
|
type Registry struct {
|
||||||
|
parsers map[protocol.Language]*sitter.Parser
|
||||||
|
cache *lru.Cache[string, *CachedTree]
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedTree stores a parsed tree with its metadata.
|
||||||
|
// Content is not stored to reduce memory usage.
|
||||||
|
type CachedTree struct {
|
||||||
|
Tree *sitter.Tree
|
||||||
|
Language protocol.Language
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseResult contains the result of parsing a file.
|
||||||
|
type ParseResult struct {
|
||||||
|
Tree *sitter.Tree
|
||||||
|
Language protocol.Language
|
||||||
|
Errors []SyntaxError
|
||||||
|
Content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyntaxError represents a syntax error found during parsing.
|
||||||
|
type SyntaxError struct {
|
||||||
|
Message string
|
||||||
|
NodeType string
|
||||||
|
Location protocol.Location
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry creates a new parser registry.
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
// Create LRU cache with capacity of 100 trees
|
||||||
|
cache, err := lru.New[string, *CachedTree](100)
|
||||||
|
if err != nil {
|
||||||
|
// LRU.New only errors if size <= 0, which won't happen here
|
||||||
|
panic(fmt.Sprintf("failed to create LRU cache: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Registry{
|
||||||
|
parsers: make(map[protocol.Language]*sitter.Parser),
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLanguage returns the Tree-sitter language for a given language.
|
||||||
|
func getLanguage(lang protocol.Language) (*sitter.Language, error) {
|
||||||
|
switch lang {
|
||||||
|
case protocol.LangGo:
|
||||||
|
return golang.GetLanguage(), nil
|
||||||
|
case protocol.LangTypeScript:
|
||||||
|
return typescript.GetLanguage(), nil
|
||||||
|
case protocol.LangJavaScript:
|
||||||
|
return javascript.GetLanguage(), nil
|
||||||
|
case protocol.LangPython:
|
||||||
|
return python.GetLanguage(), nil
|
||||||
|
case protocol.LangC:
|
||||||
|
return c.GetLanguage(), nil
|
||||||
|
case protocol.LangCpp:
|
||||||
|
return cpp.GetLanguage(), nil
|
||||||
|
case protocol.LangHTML:
|
||||||
|
return html.GetLanguage(), nil
|
||||||
|
case protocol.LangVue:
|
||||||
|
// Vue SFC files use HTML-like template syntax, so we use the HTML parser
|
||||||
|
return html.GetLanguage(), nil
|
||||||
|
default:
|
||||||
|
return nil, errors.New(errors.ErrInvalidLanguage, fmt.Sprintf("language %s is not supported", lang)).
|
||||||
|
WithContext("language", string(lang)).
|
||||||
|
WithRemediation("Supported languages: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParser returns a parser for the given language.
|
||||||
|
func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
if p, ok := r.parsers[lang]; ok {
|
||||||
|
r.mu.RUnlock()
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
// Create new parser
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if p, ok := r.parsers[lang]; ok {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sitterLang, err := getLanguage(lang)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := sitter.NewParser()
|
||||||
|
parser.SetLanguage(sitterLang)
|
||||||
|
r.parsers[lang] = parser
|
||||||
|
|
||||||
|
return parser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses the given content for the specified language.
|
||||||
|
func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||||
|
// Check file size
|
||||||
|
if len(content) > MaxFileSize {
|
||||||
|
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect binary files
|
||||||
|
if isBinary(content) {
|
||||||
|
return nil, errors.New(errors.ErrParseFailed, "binary file detected").
|
||||||
|
WithContext("file", filename).
|
||||||
|
WithRemediation("This appears to be a binary file and cannot be parsed as source code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect language
|
||||||
|
lang := protocol.DetectLanguage(filename)
|
||||||
|
if lang == protocol.LangUnknown {
|
||||||
|
return nil, errors.New(errors.ErrInvalidLanguage, "could not detect language from filename").
|
||||||
|
WithContext("file", filename).
|
||||||
|
WithRemediation("Ensure file has a recognized extension (e.g., .go, .ts, .py, .c, .cpp, .html, .vue, .json, .yaml)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle YAML and JSON separately (they don't use tree-sitter)
|
||||||
|
switch lang {
|
||||||
|
case protocol.LangYAML:
|
||||||
|
return r.ParseYAML(ctx, filename, content)
|
||||||
|
case protocol.LangJSON:
|
||||||
|
return r.ParseJSON(ctx, filename, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check cache (LRU cache is thread-safe)
|
||||||
|
hash := contentHash(content)
|
||||||
|
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
|
||||||
|
errors := extractErrors(cached.Tree.RootNode(), content)
|
||||||
|
return &ParseResult{
|
||||||
|
Tree: cached.Tree,
|
||||||
|
Language: lang,
|
||||||
|
Errors: errors,
|
||||||
|
Content: content,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get parser
|
||||||
|
parser, err := r.GetParser(lang)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse content - tree-sitter parsers are not thread-safe,
|
||||||
|
// so we need to hold the lock during parsing
|
||||||
|
r.mu.Lock()
|
||||||
|
tree, err := parser.ParseCtx(ctx, nil, content)
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.NewParseError(string(lang), filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract syntax errors
|
||||||
|
errors := extractErrors(tree.RootNode(), content)
|
||||||
|
|
||||||
|
// Cache result (LRU cache handles eviction automatically)
|
||||||
|
r.cache.Add(hash, &CachedTree{
|
||||||
|
Tree: tree,
|
||||||
|
Language: lang,
|
||||||
|
})
|
||||||
|
|
||||||
|
return &ParseResult{
|
||||||
|
Tree: tree,
|
||||||
|
Language: lang,
|
||||||
|
Errors: errors,
|
||||||
|
Content: content,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractErrors finds all error nodes in the tree.
|
||||||
|
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
|
||||||
|
var errors []SyntaxError
|
||||||
|
|
||||||
|
var walk func(n *sitter.Node)
|
||||||
|
walk = func(n *sitter.Node) {
|
||||||
|
if n == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.IsError() || n.IsMissing() {
|
||||||
|
startPoint := n.StartPoint()
|
||||||
|
nodeType := "ERROR"
|
||||||
|
if n.IsMissing() {
|
||||||
|
nodeType = "MISSING"
|
||||||
|
}
|
||||||
|
|
||||||
|
errors = append(errors, SyntaxError{
|
||||||
|
Location: protocol.Location{
|
||||||
|
Line: int(startPoint.Row) + 1,
|
||||||
|
Column: int(startPoint.Column) + 1,
|
||||||
|
},
|
||||||
|
Message: fmt.Sprintf("syntax error: unexpected %s", n.Type()),
|
||||||
|
NodeType: nodeType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < int(n.ChildCount()); i++ {
|
||||||
|
walk(n.Child(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
walk(node)
|
||||||
|
return errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// contentHash returns a fast hash of the content for caching.
|
||||||
|
// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes.
|
||||||
|
func contentHash(content []byte) string {
|
||||||
|
h := xxhash.Sum64(content)
|
||||||
|
return fmt.Sprintf("%016x", h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBinary checks if content appears to be binary.
|
||||||
|
func isBinary(content []byte) bool {
|
||||||
|
// Check first 8000 bytes for null bytes
|
||||||
|
checkLen := min(8000, len(content))
|
||||||
|
|
||||||
|
for i := range checkLen {
|
||||||
|
if content[i] == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes all parsers and clears the cache.
|
||||||
|
func (r *Registry) Close() {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
for _, p := range r.parsers {
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
r.parsers = make(map[protocol.Language]*sitter.Parser)
|
||||||
|
|
||||||
|
// Purge LRU cache
|
||||||
|
r.cache.Purge()
|
||||||
|
}
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewRegistry(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("expected non-nil registry")
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetParser(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
lang protocol.Language
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{protocol.LangGo, false},
|
||||||
|
{protocol.LangTypeScript, false},
|
||||||
|
{protocol.LangJavaScript, false},
|
||||||
|
{protocol.LangPython, false},
|
||||||
|
{protocol.LangC, false},
|
||||||
|
{protocol.LangCpp, false},
|
||||||
|
{protocol.LangHTML, false},
|
||||||
|
{protocol.LangVue, false},
|
||||||
|
{protocol.LangUnknown, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(string(tt.lang), func(t *testing.T) {
|
||||||
|
parser, err := r.GetParser(tt.lang)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if parser == nil {
|
||||||
|
t.Error("expected non-nil parser")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filename string
|
||||||
|
content string
|
||||||
|
wantLang protocol.Language
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "go file",
|
||||||
|
filename: "test.go",
|
||||||
|
content: "package main\n\nfunc main() {}\n",
|
||||||
|
wantLang: protocol.LangGo,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "typescript file",
|
||||||
|
filename: "test.ts",
|
||||||
|
content: "function hello(): void {}\n",
|
||||||
|
wantLang: protocol.LangTypeScript,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "react tsx file",
|
||||||
|
filename: "Component.tsx",
|
||||||
|
content: `import React from 'react';\n\nexport const Button: React.FC = () => <button className="btn">Click</button>;`,
|
||||||
|
wantLang: protocol.LangTypeScript,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "react jsx file",
|
||||||
|
filename: "Component.jsx",
|
||||||
|
content: `import React from 'react';\n\nexport const Button = () => <button className="btn">Click</button>;`,
|
||||||
|
wantLang: protocol.LangJavaScript,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "python file",
|
||||||
|
filename: "test.py",
|
||||||
|
content: "def hello():\n pass\n",
|
||||||
|
wantLang: protocol.LangPython,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "html file",
|
||||||
|
filename: "test.html",
|
||||||
|
content: `<!DOCTYPE html><html><head><title>Test</title></head><body><h1 class="text-xl">Hello</h1></body></html>`,
|
||||||
|
wantLang: protocol.LangHTML,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vue file",
|
||||||
|
filename: "Component.vue",
|
||||||
|
content: `<template><div class="container"><h1>{{ title }}</h1></div></template>`,
|
||||||
|
wantLang: protocol.LangVue,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown file",
|
||||||
|
filename: "test.txt",
|
||||||
|
content: "hello world",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := r.Parse(ctx, tt.filename, []byte(tt.content))
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Language != tt.wantLang {
|
||||||
|
t.Errorf("expected language %s, got %s", tt.wantLang, result.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Tree == nil {
|
||||||
|
t.Error("expected non-nil tree")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithSyntaxErrors(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
// Invalid Go code
|
||||||
|
content := "package main\n\nfunc main( {}\n" // Missing closing paren
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result, err := r.Parse(ctx, "test.go", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have parsed (tree-sitter is error-tolerant)
|
||||||
|
if result.Tree == nil {
|
||||||
|
t.Error("expected non-nil tree")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have detected errors
|
||||||
|
if len(result.Errors) == 0 {
|
||||||
|
t.Error("expected syntax errors to be detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsBinary(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content []byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "text file",
|
||||||
|
content: []byte("hello world"),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "binary with null byte",
|
||||||
|
content: []byte{0x68, 0x65, 0x6c, 0x00, 0x6f},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty file",
|
||||||
|
content: []byte{},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := isBinary(tt.content); got != tt.want {
|
||||||
|
t.Errorf("isBinary() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCaching(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
content := []byte("package main\n\nfunc main() {}\n")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Parse once
|
||||||
|
result1, err := r.Parse(ctx, "test.go", content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse again with same content
|
||||||
|
result2, err := r.Parse(ctx, "test.go", content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return cached tree (same pointer)
|
||||||
|
if result1.Tree != result2.Tree {
|
||||||
|
t.Error("expected cached tree to be returned")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,474 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtractSymbols extracts symbols from a parsed tree.
|
||||||
|
func ExtractSymbols(tree *sitter.Tree, content []byte, lang protocol.Language, filename string) []protocol.Symbol {
|
||||||
|
if tree == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
root := tree.RootNode()
|
||||||
|
if root == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch lang {
|
||||||
|
case protocol.LangGo:
|
||||||
|
return extractGoSymbols(root, content, filename)
|
||||||
|
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||||
|
return extractJSSymbols(root, content, filename)
|
||||||
|
case protocol.LangPython:
|
||||||
|
return extractPythonSymbols(root, content, filename)
|
||||||
|
case protocol.LangC, protocol.LangCpp:
|
||||||
|
return extractCSymbols(root, content, filename)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractGoSymbols extracts symbols from Go code.
|
||||||
|
func extractGoSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||||
|
var symbols []protocol.Symbol
|
||||||
|
|
||||||
|
WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
var symbol *protocol.Symbol
|
||||||
|
|
||||||
|
switch n.Type() {
|
||||||
|
case "function_declaration":
|
||||||
|
symbol = extractGoFunction(n, content, filename)
|
||||||
|
case "method_declaration":
|
||||||
|
symbol = extractGoMethod(n, content, filename)
|
||||||
|
case "type_declaration":
|
||||||
|
symbol = extractGoType(n, content, filename)
|
||||||
|
case "const_declaration", "var_declaration":
|
||||||
|
syms := extractGoVarConst(n, content, filename)
|
||||||
|
symbols = append(symbols, syms...)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if symbol != nil {
|
||||||
|
if doc := ExtractDocComment(n, content, protocol.LangGo); doc != nil {
|
||||||
|
symbol.Doc = FormatDocComment(doc)
|
||||||
|
}
|
||||||
|
symbols = append(symbols, *symbol)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractGoFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolFunction,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractGoMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get receiver type
|
||||||
|
receiver := n.ChildByFieldName("receiver")
|
||||||
|
receiverType := ""
|
||||||
|
if receiver != nil {
|
||||||
|
// Find the type in the receiver
|
||||||
|
WalkTree(receiver, func(node *sitter.Node) bool {
|
||||||
|
if node.Type() == "type_identifier" {
|
||||||
|
receiverType = GetNodeText(node, content)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
name := GetNodeText(nameNode, content)
|
||||||
|
if receiverType != "" {
|
||||||
|
name = "(" + receiverType + ")." + name
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: name,
|
||||||
|
Kind: protocol.SymbolMethod,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractGoType(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
// Find type_spec child
|
||||||
|
for i := 0; i < int(n.NamedChildCount()); i++ {
|
||||||
|
child := n.NamedChild(i)
|
||||||
|
if child != nil && child.Type() == "type_spec" {
|
||||||
|
nameNode := child.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
kind := protocol.SymbolType
|
||||||
|
typeNode := child.ChildByFieldName("type")
|
||||||
|
if typeNode != nil {
|
||||||
|
switch typeNode.Type() {
|
||||||
|
case "struct_type":
|
||||||
|
kind = protocol.SymbolStruct
|
||||||
|
case "interface_type":
|
||||||
|
kind = protocol.SymbolInterface
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: kind,
|
||||||
|
Location: NodeLocation(child, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractGoVarConst(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||||
|
var symbols []protocol.Symbol
|
||||||
|
kind := protocol.SymbolVariable
|
||||||
|
if n.Type() == "const_declaration" {
|
||||||
|
kind = protocol.SymbolConstant
|
||||||
|
}
|
||||||
|
|
||||||
|
WalkTree(n, func(node *sitter.Node) bool {
|
||||||
|
if node.Type() == "const_spec" || node.Type() == "var_spec" {
|
||||||
|
nameNode := node.ChildByFieldName("name")
|
||||||
|
if nameNode != nil {
|
||||||
|
symbols = append(symbols, protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: kind,
|
||||||
|
Location: NodeLocation(node, filename),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractJSSymbols extracts symbols from JavaScript/TypeScript code.
|
||||||
|
func extractJSSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||||
|
var symbols []protocol.Symbol
|
||||||
|
|
||||||
|
WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
var symbol *protocol.Symbol
|
||||||
|
|
||||||
|
switch n.Type() {
|
||||||
|
case "function_declaration":
|
||||||
|
symbol = extractJSFunction(n, content, filename)
|
||||||
|
case "class_declaration":
|
||||||
|
symbol = extractJSClass(n, content, filename)
|
||||||
|
case "method_definition":
|
||||||
|
symbol = extractJSMethod(n, content, filename)
|
||||||
|
case "lexical_declaration", "variable_declaration":
|
||||||
|
syms := extractJSVariable(n, content, filename)
|
||||||
|
symbols = append(symbols, syms...)
|
||||||
|
return true
|
||||||
|
case "interface_declaration":
|
||||||
|
symbol = extractTSInterface(n, content, filename)
|
||||||
|
case "type_alias_declaration":
|
||||||
|
symbol = extractTSTypeAlias(n, content, filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
if symbol != nil {
|
||||||
|
if doc := ExtractDocComment(n, content, protocol.LangJavaScript); doc != nil {
|
||||||
|
symbol.Doc = FormatDocComment(doc)
|
||||||
|
}
|
||||||
|
symbols = append(symbols, *symbol)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractJSFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolFunction,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractJSClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolClass,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractJSMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolMethod,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractJSVariable(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||||
|
var symbols []protocol.Symbol
|
||||||
|
|
||||||
|
WalkTree(n, func(node *sitter.Node) bool {
|
||||||
|
if node.Type() == "variable_declarator" {
|
||||||
|
nameNode := node.ChildByFieldName("name")
|
||||||
|
if nameNode != nil {
|
||||||
|
symbols = append(symbols, protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolVariable,
|
||||||
|
Location: NodeLocation(node, filename),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTSInterface(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolInterface,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTSTypeAlias(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolType,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPythonSymbols extracts symbols from Python code.
|
||||||
|
func extractPythonSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||||
|
var symbols []protocol.Symbol
|
||||||
|
|
||||||
|
WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
var symbol *protocol.Symbol
|
||||||
|
|
||||||
|
switch n.Type() {
|
||||||
|
case "function_definition":
|
||||||
|
symbol = extractPythonFunction(n, content, filename)
|
||||||
|
case "class_definition":
|
||||||
|
symbol = extractPythonClass(n, content, filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
if symbol != nil {
|
||||||
|
if doc := ExtractDocComment(n, content, protocol.LangPython); doc != nil {
|
||||||
|
symbol.Doc = FormatDocComment(doc)
|
||||||
|
}
|
||||||
|
symbols = append(symbols, *symbol)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractPythonFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a method (inside a class)
|
||||||
|
parent := n.Parent()
|
||||||
|
kind := protocol.SymbolFunction
|
||||||
|
if parent != nil && parent.Type() == "block" {
|
||||||
|
grandparent := parent.Parent()
|
||||||
|
if grandparent != nil && grandparent.Type() == "class_definition" {
|
||||||
|
kind = protocol.SymbolMethod
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: kind,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractPythonClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolClass,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCSymbols extracts symbols from C/C++ code.
|
||||||
|
func extractCSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||||
|
var symbols []protocol.Symbol
|
||||||
|
|
||||||
|
WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
var symbol *protocol.Symbol
|
||||||
|
|
||||||
|
switch n.Type() {
|
||||||
|
case "function_definition":
|
||||||
|
symbol = extractCFunction(n, content, filename)
|
||||||
|
case "struct_specifier":
|
||||||
|
symbol = extractCStruct(n, content, filename)
|
||||||
|
case "class_specifier":
|
||||||
|
symbol = extractCppClass(n, content, filename)
|
||||||
|
case "declaration":
|
||||||
|
// Could be function declaration or variable
|
||||||
|
if hasFunctionDeclarator(n) {
|
||||||
|
symbol = extractCFunctionDecl(n, content, filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if symbol != nil {
|
||||||
|
if doc := ExtractDocComment(n, content, protocol.LangC); doc != nil {
|
||||||
|
symbol.Doc = FormatDocComment(doc)
|
||||||
|
}
|
||||||
|
symbols = append(symbols, *symbol)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
declarator := n.ChildByFieldName("declarator")
|
||||||
|
if declarator == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the function name within the declarator
|
||||||
|
var name string
|
||||||
|
WalkTree(declarator, func(node *sitter.Node) bool {
|
||||||
|
if node.Type() == "identifier" {
|
||||||
|
name = GetNodeText(node, content)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: name,
|
||||||
|
Kind: protocol.SymbolFunction,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCStruct(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolStruct,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCppClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
nameNode := n.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: GetNodeText(nameNode, content),
|
||||||
|
Kind: protocol.SymbolClass,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCFunctionDecl(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||||
|
declarator := n.ChildByFieldName("declarator")
|
||||||
|
if declarator == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var name string
|
||||||
|
WalkTree(declarator, func(node *sitter.Node) bool {
|
||||||
|
if node.Type() == "identifier" {
|
||||||
|
name = GetNodeText(node, content)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &protocol.Symbol{
|
||||||
|
Name: name,
|
||||||
|
Kind: protocol.SymbolFunction,
|
||||||
|
Location: NodeLocation(n, filename),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasFunctionDeclarator(n *sitter.Node) bool {
|
||||||
|
found := false
|
||||||
|
WalkTree(n, func(node *sitter.Node) bool {
|
||||||
|
if node.Type() == "function_declarator" {
|
||||||
|
found = true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return found
|
||||||
|
}
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractGoSymbols(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
// Hello prints a greeting
|
||||||
|
func Hello() {
|
||||||
|
println("hello")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server handles requests
|
||||||
|
type Server struct {
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the server
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const MaxConnections = 100
|
||||||
|
var globalVar = "test"
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := r.Parse(ctx, "test.go", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangGo, "test.go")
|
||||||
|
|
||||||
|
expectedSymbols := map[string]protocol.SymbolKind{
|
||||||
|
"Hello": protocol.SymbolFunction,
|
||||||
|
"Server": protocol.SymbolStruct,
|
||||||
|
"(Server).Start": protocol.SymbolMethod,
|
||||||
|
"MaxConnections": protocol.SymbolConstant,
|
||||||
|
"globalVar": protocol.SymbolVariable,
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]bool)
|
||||||
|
for _, sym := range symbols {
|
||||||
|
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||||
|
found[sym.Name] = true
|
||||||
|
if sym.Kind != expectedKind {
|
||||||
|
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name := range expectedSymbols {
|
||||||
|
if !found[name] {
|
||||||
|
t.Errorf("expected to find symbol %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractJSSymbols(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
content := `
|
||||||
|
function greet(name) {
|
||||||
|
console.log("Hello, " + name);
|
||||||
|
}
|
||||||
|
|
||||||
|
class User {
|
||||||
|
constructor(name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
getName() {
|
||||||
|
return this.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const MAX_USERS = 100;
|
||||||
|
let currentUser = null;
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := r.Parse(ctx, "test.js", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangJavaScript, "test.js")
|
||||||
|
|
||||||
|
expectedSymbols := map[string]protocol.SymbolKind{
|
||||||
|
"greet": protocol.SymbolFunction,
|
||||||
|
"User": protocol.SymbolClass,
|
||||||
|
"MAX_USERS": protocol.SymbolVariable,
|
||||||
|
"currentUser": protocol.SymbolVariable,
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]bool)
|
||||||
|
for _, sym := range symbols {
|
||||||
|
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||||
|
found[sym.Name] = true
|
||||||
|
if sym.Kind != expectedKind {
|
||||||
|
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name := range expectedSymbols {
|
||||||
|
if !found[name] {
|
||||||
|
t.Errorf("expected to find symbol %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractPythonSymbols(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
content := `
|
||||||
|
def greet(name):
|
||||||
|
"""Greet a person by name."""
|
||||||
|
print(f"Hello, {name}")
|
||||||
|
|
||||||
|
class User:
|
||||||
|
"""Represents a user."""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
return self.name
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := r.Parse(ctx, "test.py", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangPython, "test.py")
|
||||||
|
|
||||||
|
expectedSymbols := map[string]protocol.SymbolKind{
|
||||||
|
"greet": protocol.SymbolFunction,
|
||||||
|
"User": protocol.SymbolClass,
|
||||||
|
"__init__": protocol.SymbolMethod,
|
||||||
|
"get_name": protocol.SymbolMethod,
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]bool)
|
||||||
|
for _, sym := range symbols {
|
||||||
|
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||||
|
found[sym.Name] = true
|
||||||
|
if sym.Kind != expectedKind {
|
||||||
|
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name := range expectedSymbols {
|
||||||
|
if !found[name] {
|
||||||
|
t.Errorf("expected to find symbol %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCSymbols(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
content := `
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
struct Point {
|
||||||
|
int x;
|
||||||
|
int y;
|
||||||
|
};
|
||||||
|
|
||||||
|
void print_point(struct Point p) {
|
||||||
|
printf("(%d, %d)\n", p.x, p.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
struct Point p = {1, 2};
|
||||||
|
print_point(p);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := r.Parse(ctx, "test.c", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangC, "test.c")
|
||||||
|
|
||||||
|
// Note: C symbol extraction is complex, checking for at least main and Point
|
||||||
|
expectedSymbols := map[string]protocol.SymbolKind{
|
||||||
|
"Point": protocol.SymbolStruct,
|
||||||
|
"main": protocol.SymbolFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]bool)
|
||||||
|
for _, sym := range symbols {
|
||||||
|
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||||
|
found[sym.Name] = true
|
||||||
|
if sym.Kind != expectedKind {
|
||||||
|
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name := range expectedSymbols {
|
||||||
|
if !found[name] {
|
||||||
|
t.Errorf("expected to find symbol %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,195 @@
|
|||||||
|
// Package parser provides YAML and JSON parsing with AST support.
|
||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// YAMLNode wraps yaml.Node to provide tree-sitter-like interface
|
||||||
|
type YAMLNode struct {
|
||||||
|
*yaml.Node
|
||||||
|
Content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSONNode represents a JSON AST node
|
||||||
|
type JSONNode struct {
|
||||||
|
Value any
|
||||||
|
Type string
|
||||||
|
Children []*JSONNode
|
||||||
|
Line int
|
||||||
|
Column int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseYAML parses YAML content and returns a tree-sitter-compatible result
|
||||||
|
func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||||
|
// Check file size
|
||||||
|
if len(content) > MaxFileSize {
|
||||||
|
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse YAML
|
||||||
|
var root yaml.Node
|
||||||
|
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||||
|
return nil, errors.NewParseError("yaml", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract syntax errors from YAML parse
|
||||||
|
syntaxErrors := extractYAMLErrors()
|
||||||
|
|
||||||
|
// Create a pseudo tree-sitter tree for compatibility
|
||||||
|
// We'll use nil for the tree since YAML doesn't use tree-sitter
|
||||||
|
return &ParseResult{
|
||||||
|
Tree: nil, // YAML uses yaml.Node instead
|
||||||
|
Language: protocol.LangYAML,
|
||||||
|
Errors: syntaxErrors,
|
||||||
|
Content: content,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseJSON parses JSON content and returns a tree-sitter-compatible result
|
||||||
|
func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||||
|
// Check file size
|
||||||
|
if len(content) > MaxFileSize {
|
||||||
|
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON to validate syntax
|
||||||
|
var jsonData any
|
||||||
|
if err := json.Unmarshal(content, &jsonData); err != nil {
|
||||||
|
return nil, errors.NewParseError("json", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON parsing succeeded, no syntax errors
|
||||||
|
return &ParseResult{
|
||||||
|
Tree: nil, // JSON uses native Go structures
|
||||||
|
Language: protocol.LangJSON,
|
||||||
|
Errors: []SyntaxError{},
|
||||||
|
Content: content,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractYAMLErrors extracts errors from YAML nodes
|
||||||
|
func extractYAMLErrors() []SyntaxError {
|
||||||
|
// YAML parser already validates during unmarshal
|
||||||
|
// If we got here, there are no syntax errors
|
||||||
|
// However, we could add semantic validation here in the future
|
||||||
|
return []SyntaxError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalkYAML walks a YAML AST and calls fn for each node
|
||||||
|
func WalkYAML(node *yaml.Node, fn func(*yaml.Node) bool) {
|
||||||
|
if node == nil || !fn(node) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, child := range node.Content {
|
||||||
|
WalkYAML(child, fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetYAMLNodeText returns the text representation of a YAML node
|
||||||
|
func GetYAMLNodeText(node *yaml.Node) string {
|
||||||
|
if node == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch node.Kind {
|
||||||
|
case yaml.DocumentNode:
|
||||||
|
if len(node.Content) > 0 {
|
||||||
|
return GetYAMLNodeText(node.Content[0])
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
case yaml.MappingNode:
|
||||||
|
return node.Value
|
||||||
|
case yaml.SequenceNode:
|
||||||
|
return node.Value
|
||||||
|
case yaml.ScalarNode:
|
||||||
|
return node.Value
|
||||||
|
case yaml.AliasNode:
|
||||||
|
return node.Value
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetYAMLNodeLocation returns the location of a YAML node
|
||||||
|
func GetYAMLNodeLocation(node *yaml.Node) protocol.Location {
|
||||||
|
if node == nil {
|
||||||
|
return protocol.Location{Line: 1, Column: 1}
|
||||||
|
}
|
||||||
|
|
||||||
|
return protocol.Location{
|
||||||
|
Line: node.Line,
|
||||||
|
Column: node.Column,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryYAML performs a simple query on YAML content
|
||||||
|
// Example: "$.metadata.name" to find the name field in metadata
|
||||||
|
func QueryYAML(content []byte, query string) ([]*yaml.Node, error) {
|
||||||
|
var root yaml.Node
|
||||||
|
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse YAML: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple path-based query implementation
|
||||||
|
// This is a basic implementation - can be extended with more sophisticated queries
|
||||||
|
var results []*yaml.Node
|
||||||
|
|
||||||
|
WalkYAML(&root, func(node *yaml.Node) bool {
|
||||||
|
if node.Value == query || node.Tag == query {
|
||||||
|
results = append(results, node)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryJSON performs a simple query on JSON content
|
||||||
|
func QueryJSON(content []byte, query string) ([]any, error) {
|
||||||
|
var data any
|
||||||
|
if err := json.Unmarshal(content, &data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic implementation - can be extended with JSONPath support
|
||||||
|
var results []any
|
||||||
|
|
||||||
|
// For now, just validate that it's valid JSON
|
||||||
|
results = append(results, data)
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateYAML validates YAML content without parsing to full AST
|
||||||
|
func ValidateYAML(content []byte) error {
|
||||||
|
var node yaml.Node
|
||||||
|
if err := yaml.Unmarshal(content, &node); err != nil {
|
||||||
|
return fmt.Errorf("YAML validation failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateJSON validates JSON content
|
||||||
|
func ValidateJSON(content []byte) error {
|
||||||
|
var data any
|
||||||
|
if err := json.Unmarshal(content, &data); err != nil {
|
||||||
|
return fmt.Errorf("JSON validation failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToSitterTree is a placeholder that returns nil for YAML/JSON
|
||||||
|
// These formats don't use tree-sitter, but we keep this for interface compatibility
|
||||||
|
func (yn *YAMLNode) ToSitterTree() *sitter.Tree {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,283 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseYAML(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid simple YAML",
|
||||||
|
content: `name: test
|
||||||
|
version: 1.0.0
|
||||||
|
enabled: true`,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid nested YAML",
|
||||||
|
content: `metadata:
|
||||||
|
name: test-app
|
||||||
|
namespace: default
|
||||||
|
spec:
|
||||||
|
replicas: 3
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: test`,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid list YAML",
|
||||||
|
content: `items:
|
||||||
|
- name: item1
|
||||||
|
value: 100
|
||||||
|
- name: item2
|
||||||
|
value: 200`,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid YAML - bad syntax",
|
||||||
|
content: `name: test\n bad: indent\n wrong: [unclosed`,
|
||||||
|
shouldError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := registry.ParseYAML(context.Background(), "test.yaml", []byte(tt.content))
|
||||||
|
|
||||||
|
if tt.shouldError {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Error("expected result but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Language != protocol.LangYAML {
|
||||||
|
t.Errorf("expected language YAML, got %s", result.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Errors) > 0 {
|
||||||
|
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseJSON(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid simple JSON",
|
||||||
|
content: `{"name": "test", "version": "1.0.0", "enabled": true}`,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid nested JSON",
|
||||||
|
content: `{
|
||||||
|
"metadata": {
|
||||||
|
"name": "test-app",
|
||||||
|
"namespace": "default"
|
||||||
|
},
|
||||||
|
"spec": {
|
||||||
|
"replicas": 3,
|
||||||
|
"selector": {
|
||||||
|
"matchLabels": {
|
||||||
|
"app": "test"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid array JSON",
|
||||||
|
content: `[{"name": "item1", "value": 100}, {"name": "item2", "value": 200}]`,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON - unclosed brace",
|
||||||
|
content: `{"name": "test", "value": 100`,
|
||||||
|
shouldError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON - trailing comma",
|
||||||
|
content: `{"name": "test", "value": 100,}`,
|
||||||
|
shouldError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := registry.ParseJSON(context.Background(), "test.json", []byte(tt.content))
|
||||||
|
|
||||||
|
if tt.shouldError {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error but got none")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Error("expected result but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Language != protocol.LangJSON {
|
||||||
|
t.Errorf("expected language JSON, got %s", result.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Errors) > 0 {
|
||||||
|
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryParse_YAML_JSON(t *testing.T) {
|
||||||
|
registry := NewRegistry()
|
||||||
|
defer registry.Close()
|
||||||
|
|
||||||
|
yamlContent := []byte(`name: test
|
||||||
|
version: 1.0.0`)
|
||||||
|
|
||||||
|
jsonContent := []byte(`{"name": "test", "version": "1.0.0"}`)
|
||||||
|
|
||||||
|
// Test YAML through main Parse method
|
||||||
|
yamlResult, err := registry.Parse(context.Background(), "config.yaml", yamlContent)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to parse YAML: %v", err)
|
||||||
|
}
|
||||||
|
if yamlResult.Language != protocol.LangYAML {
|
||||||
|
t.Errorf("expected YAML language, got %s", yamlResult.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test JSON through main Parse method
|
||||||
|
jsonResult, err := registry.Parse(context.Background(), "config.json", jsonContent)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to parse JSON: %v", err)
|
||||||
|
}
|
||||||
|
if jsonResult.Language != protocol.LangJSON {
|
||||||
|
t.Errorf("expected JSON language, got %s", jsonResult.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test .yml extension
|
||||||
|
ymlResult, err := registry.Parse(context.Background(), "config.yml", yamlContent)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to parse .yml: %v", err)
|
||||||
|
}
|
||||||
|
if ymlResult.Language != protocol.LangYAML {
|
||||||
|
t.Errorf("expected YAML language for .yml extension, got %s", ymlResult.Language)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkYAML(t *testing.T) {
|
||||||
|
content := []byte(`metadata:
|
||||||
|
name: test
|
||||||
|
labels:
|
||||||
|
app: myapp
|
||||||
|
env: prod`)
|
||||||
|
|
||||||
|
var root yaml.Node
|
||||||
|
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||||
|
t.Fatalf("failed to parse YAML: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeCount := 0
|
||||||
|
WalkYAML(&root, func(node *yaml.Node) bool {
|
||||||
|
nodeCount++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if nodeCount == 0 {
|
||||||
|
t.Error("expected to visit nodes, but count is 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateYAML(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content []byte
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid YAML",
|
||||||
|
content: []byte("name: test\nvalue: 100"),
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid YAML",
|
||||||
|
content: []byte("name: test\n bad:\n[unclosed"),
|
||||||
|
shouldError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := ValidateYAML(tt.content)
|
||||||
|
if (err != nil) != tt.shouldError {
|
||||||
|
t.Errorf("ValidateYAML() error = %v, shouldError = %v", err, tt.shouldError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content []byte
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid JSON",
|
||||||
|
content: []byte(`{"name": "test", "value": 100}`),
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON",
|
||||||
|
content: []byte(`{"name": "test", "value": 100`),
|
||||||
|
shouldError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := ValidateJSON(tt.content)
|
||||||
|
if (err != nil) != tt.shouldError {
|
||||||
|
t.Errorf("ValidateJSON() error = %v, shouldError = %v", err, tt.shouldError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,538 @@
|
|||||||
|
// Package query implements a hybrid AST query language with pattern matching.
|
||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
sitter "github.com/smacker/go-tree-sitter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Global regex cache for compiled patterns (thread-safe)
|
||||||
|
var regexCache sync.Map // string -> *regexp.Regexp
|
||||||
|
|
||||||
|
// compileRegex compiles a regex pattern with caching for performance.
|
||||||
|
// Cached patterns avoid repeated compilation overhead (10-50x speedup).
|
||||||
|
// Thread-safe: uses LoadOrStore to prevent race conditions.
|
||||||
|
func compileRegex(pattern string) (*regexp.Regexp, error) {
|
||||||
|
// Check cache first
|
||||||
|
if cached, ok := regexCache.Load(pattern); ok {
|
||||||
|
return cached.(*regexp.Regexp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile regex
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to store - if another goroutine already stored it, use theirs
|
||||||
|
// This prevents race conditions where multiple goroutines compile the same pattern
|
||||||
|
actual, _ := regexCache.LoadOrStore(pattern, re)
|
||||||
|
return actual.(*regexp.Regexp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASTQuery defines a query for matching AST patterns.
|
||||||
|
type ASTQuery struct {
|
||||||
|
Pattern string `json:"pattern"` // code pattern with $VAR placeholders
|
||||||
|
Language string `json:"language"` // required
|
||||||
|
Filters QueryFilters `json:"filters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryFilters provide additional filtering criteria.
|
||||||
|
type QueryFilters struct {
|
||||||
|
HasChild *ASTQuery `json:"has_child,omitempty"`
|
||||||
|
HasParent *ASTQuery `json:"has_parent,omitempty"`
|
||||||
|
NameMatches string `json:"name_matches,omitempty"`
|
||||||
|
NameExact string `json:"name_exact,omitempty"`
|
||||||
|
InFile string `json:"in_file,omitempty"`
|
||||||
|
NotInFile string `json:"not_in_file,omitempty"`
|
||||||
|
KindIn []string `json:"kind_in,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchResult represents a single match from a query.
|
||||||
|
type MatchResult struct {
|
||||||
|
Node *sitter.Node
|
||||||
|
Captures map[string]CapturedNode
|
||||||
|
File string
|
||||||
|
Text string
|
||||||
|
Location protocol.Location
|
||||||
|
}
|
||||||
|
|
||||||
|
// CapturedNode represents a captured node or nodes.
|
||||||
|
type CapturedNode struct {
|
||||||
|
Text string
|
||||||
|
Nodes []*sitter.Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptureType indicates the type of capture.
|
||||||
|
type CaptureType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
CaptureSingle CaptureType = iota // $NAME - single node
|
||||||
|
CaptureMultiple // $$$NAME - multiple nodes
|
||||||
|
CaptureWildcard // $_ - wildcard (don't capture)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Capture represents a placeholder in a pattern.
|
||||||
|
type Capture struct {
|
||||||
|
Name string
|
||||||
|
Type CaptureType
|
||||||
|
Position int // position in the pattern
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsedPattern represents a parsed code pattern.
|
||||||
|
type ParsedPattern struct {
|
||||||
|
Original string
|
||||||
|
Template string
|
||||||
|
Captures []Capture
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matcher performs AST pattern matching.
|
||||||
|
type Matcher struct {
|
||||||
|
registry *parser.Registry
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMatcher creates a new pattern matcher.
|
||||||
|
func NewMatcher(registry *parser.Registry) *Matcher {
|
||||||
|
return &Matcher{registry: registry}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePattern parses a pattern string and extracts captures.
|
||||||
|
func ParsePattern(pattern string) (*ParsedPattern, error) {
|
||||||
|
if pattern == "" {
|
||||||
|
return nil, fmt.Errorf("empty pattern")
|
||||||
|
}
|
||||||
|
|
||||||
|
var captures []Capture
|
||||||
|
template := pattern
|
||||||
|
captureID := 0
|
||||||
|
|
||||||
|
// Find all captures: $$$ (multi), $_ (wildcard), $NAME (single)
|
||||||
|
// Order matters: check $$$ first
|
||||||
|
multiRe := regexp.MustCompile(`\$\$\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||||
|
wildcardRe := regexp.MustCompile(`\$_`)
|
||||||
|
singleRe := regexp.MustCompile(`\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||||
|
|
||||||
|
// Extract multi-node captures ($$$NAME)
|
||||||
|
for _, match := range multiRe.FindAllStringSubmatchIndex(pattern, -1) {
|
||||||
|
name := pattern[match[2]:match[3]]
|
||||||
|
captures = append(captures, Capture{
|
||||||
|
Name: name,
|
||||||
|
Type: CaptureMultiple,
|
||||||
|
Position: match[0],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace multi-captures with placeholder identifiers
|
||||||
|
template = multiRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||||
|
captureID++
|
||||||
|
return fmt.Sprintf("__multi_%d__", captureID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Extract wildcards ($_)
|
||||||
|
for _, match := range wildcardRe.FindAllStringIndex(pattern, -1) {
|
||||||
|
captures = append(captures, Capture{
|
||||||
|
Name: "_",
|
||||||
|
Type: CaptureWildcard,
|
||||||
|
Position: match[0],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace wildcards with placeholder identifiers
|
||||||
|
template = wildcardRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||||
|
captureID++
|
||||||
|
return fmt.Sprintf("__wild_%d__", captureID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Extract single-node captures ($NAME) - exclude those that are part of $$$NAME
|
||||||
|
// Check which $NAME patterns are not preceded by $$
|
||||||
|
remaining := template
|
||||||
|
for _, match := range singleRe.FindAllStringSubmatchIndex(remaining, -1) {
|
||||||
|
name := remaining[match[2]:match[3]]
|
||||||
|
// Skip if this looks like our placeholder
|
||||||
|
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
captures = append(captures, Capture{
|
||||||
|
Name: name,
|
||||||
|
Type: CaptureSingle,
|
||||||
|
Position: match[0],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace single captures with placeholder identifiers
|
||||||
|
template = singleRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||||
|
name := strings.TrimPrefix(s, "$")
|
||||||
|
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
|
||||||
|
return s // keep our placeholders as is
|
||||||
|
}
|
||||||
|
captureID++
|
||||||
|
return fmt.Sprintf("__single_%d__", captureID)
|
||||||
|
})
|
||||||
|
|
||||||
|
return &ParsedPattern{
|
||||||
|
Original: pattern,
|
||||||
|
Captures: captures,
|
||||||
|
Template: template,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match executes a query against a parsed tree.
|
||||||
|
func (m *Matcher) Match(ctx context.Context, query *ASTQuery, tree *sitter.Tree, content []byte, filename string) ([]MatchResult, error) {
|
||||||
|
if query.Pattern == "" {
|
||||||
|
return nil, fmt.Errorf("query pattern is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
lang := protocol.Language(query.Language)
|
||||||
|
if lang == "" || lang == protocol.LangUnknown {
|
||||||
|
return nil, fmt.Errorf("valid language is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the pattern
|
||||||
|
parsed, err := ParsePattern(query.Pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid pattern: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []MatchResult
|
||||||
|
|
||||||
|
// Walk the tree and find matches
|
||||||
|
root := tree.RootNode()
|
||||||
|
if root == nil {
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.WalkTree(root, func(n *sitter.Node) bool {
|
||||||
|
// Check for context cancellation
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to match this node against the pattern
|
||||||
|
if matched, captures := matchNode(n, parsed, content); matched {
|
||||||
|
// Apply filters
|
||||||
|
if !passesFilters(n, query.Filters, content) {
|
||||||
|
return true // continue walking
|
||||||
|
}
|
||||||
|
|
||||||
|
startPoint := n.StartPoint()
|
||||||
|
results = append(results, MatchResult{
|
||||||
|
Node: n,
|
||||||
|
Captures: captures,
|
||||||
|
File: filename,
|
||||||
|
Location: protocol.Location{
|
||||||
|
Line: int(startPoint.Row) + 1,
|
||||||
|
Column: int(startPoint.Column) + 1,
|
||||||
|
},
|
||||||
|
Text: parser.GetNodeText(n, content),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchNode attempts to match a node against a parsed pattern.
|
||||||
|
// This is a simplified matcher that looks for structural similarity.
|
||||||
|
func matchNode(node *sitter.Node, pattern *ParsedPattern, content []byte) (bool, map[string]CapturedNode) {
|
||||||
|
if node == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
captures := make(map[string]CapturedNode)
|
||||||
|
|
||||||
|
// Use pattern keyword matching as a heuristic to find matching nodes
|
||||||
|
// A full implementation would parse both pattern and node and compare AST structure
|
||||||
|
matched := matchPatternHeuristic(node, pattern, content, captures)
|
||||||
|
|
||||||
|
return matched, captures
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchPatternHeuristic uses heuristics to match patterns.
|
||||||
|
// This is a simplified implementation that matches based on node type and structure.
|
||||||
|
func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content []byte, captures map[string]CapturedNode) bool {
|
||||||
|
patternLower := strings.ToLower(pattern.Original)
|
||||||
|
nodeType := node.Type()
|
||||||
|
|
||||||
|
// Match function patterns
|
||||||
|
if strings.Contains(patternLower, "func ") || strings.Contains(patternLower, "function ") {
|
||||||
|
if nodeType != "function_declaration" && nodeType != "method_declaration" && nodeType != "function_definition" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
extractFunctionCaptures(node, pattern.Captures, content, captures)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match class patterns
|
||||||
|
if strings.Contains(patternLower, "class ") {
|
||||||
|
if nodeType != "class_declaration" && nodeType != "class_definition" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
extractClassCaptures(node, pattern.Captures, content, captures)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match struct patterns (Go, C, C++)
|
||||||
|
if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") {
|
||||||
|
if nodeType != "type_declaration" && nodeType != "struct_specifier" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
extractStructCaptures(node, pattern.Captures, content, captures)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match interface patterns (Go, TypeScript)
|
||||||
|
if strings.Contains(patternLower, "interface ") {
|
||||||
|
if nodeType != "interface_declaration" && nodeType != "type_declaration" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
extractInterfaceCaptures(node, pattern.Captures, content, captures)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFunctionCaptures extracts captures from a function node.
|
||||||
|
func extractFunctionCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||||
|
for _, cap := range capturesDef {
|
||||||
|
switch cap.Name {
|
||||||
|
case "NAME", "name":
|
||||||
|
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{nameNode},
|
||||||
|
Text: parser.GetNodeText(nameNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "ARGS", "args", "PARAMS", "params":
|
||||||
|
if paramsNode := node.ChildByFieldName("parameters"); paramsNode != nil {
|
||||||
|
var paramNodes []*sitter.Node
|
||||||
|
for i := 0; i < int(paramsNode.NamedChildCount()); i++ {
|
||||||
|
paramNodes = append(paramNodes, paramsNode.NamedChild(i))
|
||||||
|
}
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: paramNodes,
|
||||||
|
Text: parser.GetNodeText(paramsNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "BODY", "body":
|
||||||
|
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{bodyNode},
|
||||||
|
Text: parser.GetNodeText(bodyNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "RETURN", "return", "RESULT", "result":
|
||||||
|
if resultNode := node.ChildByFieldName("result"); resultNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{resultNode},
|
||||||
|
Text: parser.GetNodeText(resultNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractClassCaptures extracts captures from a class node.
|
||||||
|
func extractClassCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||||
|
for _, cap := range capturesDef {
|
||||||
|
switch cap.Name {
|
||||||
|
case "NAME", "name":
|
||||||
|
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{nameNode},
|
||||||
|
Text: parser.GetNodeText(nameNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "BODY", "body":
|
||||||
|
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{bodyNode},
|
||||||
|
Text: parser.GetNodeText(bodyNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "EXTENDS", "extends", "SUPERCLASS", "superclass":
|
||||||
|
if extendsNode := node.ChildByFieldName("superclass"); extendsNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{extendsNode},
|
||||||
|
Text: parser.GetNodeText(extendsNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractStructCaptures extracts captures from a struct node.
|
||||||
|
func extractStructCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||||
|
for _, cap := range capturesDef {
|
||||||
|
switch cap.Name {
|
||||||
|
case "NAME", "name":
|
||||||
|
// For Go type_declaration, we need to look at the type_spec child
|
||||||
|
if node.Type() == "type_declaration" {
|
||||||
|
for i := 0; i < int(node.NamedChildCount()); i++ {
|
||||||
|
child := node.NamedChild(i)
|
||||||
|
if child != nil && child.Type() == "type_spec" {
|
||||||
|
if nameNode := child.ChildByFieldName("name"); nameNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{nameNode},
|
||||||
|
Text: parser.GetNodeText(nameNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{nameNode},
|
||||||
|
Text: parser.GetNodeText(nameNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "FIELDS", "fields", "BODY", "body":
|
||||||
|
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{bodyNode},
|
||||||
|
Text: parser.GetNodeText(bodyNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractInterfaceCaptures extracts captures from an interface node.
|
||||||
|
func extractInterfaceCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||||
|
for _, cap := range capturesDef {
|
||||||
|
switch cap.Name {
|
||||||
|
case "NAME", "name":
|
||||||
|
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{nameNode},
|
||||||
|
Text: parser.GetNodeText(nameNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "BODY", "body", "METHODS", "methods":
|
||||||
|
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||||
|
captures[cap.Name] = CapturedNode{
|
||||||
|
Nodes: []*sitter.Node{bodyNode},
|
||||||
|
Text: parser.GetNodeText(bodyNode, content),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// passesFilters checks if a node passes all the specified filters.
|
||||||
|
func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool {
|
||||||
|
// Name regex filter (uses cached compilation)
|
||||||
|
if filters.NameMatches != "" {
|
||||||
|
nameNode := node.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
name := parser.GetNodeText(nameNode, content)
|
||||||
|
re, err := compileRegex(filters.NameMatches)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !re.MatchString(name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact name filter
|
||||||
|
if filters.NameExact != "" {
|
||||||
|
nameNode := node.ChildByFieldName("name")
|
||||||
|
if nameNode == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
name := parser.GetNodeText(nameNode, content)
|
||||||
|
if name != filters.NameExact {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kind filter
|
||||||
|
if len(filters.KindIn) > 0 {
|
||||||
|
nodeType := node.Type()
|
||||||
|
found := false
|
||||||
|
for _, kind := range filters.KindIn {
|
||||||
|
if nodeType == kind {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatResults formats match results for display.
|
||||||
|
func FormatResults(results []MatchResult, maxResults int) string {
|
||||||
|
if len(results) == 0 {
|
||||||
|
return "No matches found."
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results)))
|
||||||
|
|
||||||
|
displayCount := len(results)
|
||||||
|
truncated := false
|
||||||
|
if maxResults > 0 && displayCount > maxResults {
|
||||||
|
displayCount = maxResults
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < displayCount; i++ {
|
||||||
|
r := results[i]
|
||||||
|
nodeType := "unknown"
|
||||||
|
if r.Node != nil {
|
||||||
|
nodeType = r.Node.Type()
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
|
||||||
|
|
||||||
|
// Truncate very long text
|
||||||
|
text := r.Text
|
||||||
|
if len(text) > 500 {
|
||||||
|
text = text[:500] + "..."
|
||||||
|
}
|
||||||
|
sb.WriteString("```\n")
|
||||||
|
sb.WriteString(text)
|
||||||
|
sb.WriteString("\n```\n")
|
||||||
|
|
||||||
|
// Show captures
|
||||||
|
if len(r.Captures) > 0 {
|
||||||
|
sb.WriteString("Captures: ")
|
||||||
|
first := true
|
||||||
|
for name, cap := range r.Captures {
|
||||||
|
if !first {
|
||||||
|
sb.WriteString(", ")
|
||||||
|
}
|
||||||
|
first = false
|
||||||
|
capText := cap.Text
|
||||||
|
if len(capText) > 50 {
|
||||||
|
capText = capText[:50] + "..."
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
|
||||||
|
}
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if truncated {
|
||||||
|
sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
@@ -0,0 +1,559 @@
|
|||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParsePattern(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pattern string
|
||||||
|
captureNames []string
|
||||||
|
captureTypes []CaptureType
|
||||||
|
wantCaptures int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty pattern",
|
||||||
|
pattern: "",
|
||||||
|
wantErr: true,
|
||||||
|
wantCaptures: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single capture",
|
||||||
|
pattern: "func $NAME() {}",
|
||||||
|
wantErr: false,
|
||||||
|
wantCaptures: 1,
|
||||||
|
captureNames: []string{"NAME"},
|
||||||
|
captureTypes: []CaptureType{CaptureSingle},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple single captures",
|
||||||
|
pattern: "func $NAME($ARGS) $RETURN",
|
||||||
|
wantErr: false,
|
||||||
|
wantCaptures: 3,
|
||||||
|
captureNames: []string{"NAME", "ARGS", "RETURN"},
|
||||||
|
captureTypes: []CaptureType{CaptureSingle, CaptureSingle, CaptureSingle},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-node capture",
|
||||||
|
pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
||||||
|
wantErr: false,
|
||||||
|
wantCaptures: 3,
|
||||||
|
captureNames: []string{"ARGS", "BODY", "NAME"},
|
||||||
|
captureTypes: []CaptureType{CaptureMultiple, CaptureMultiple, CaptureSingle},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard capture",
|
||||||
|
pattern: "func $NAME($_) {}",
|
||||||
|
wantErr: false,
|
||||||
|
wantCaptures: 2,
|
||||||
|
captureNames: []string{"NAME", "_"},
|
||||||
|
captureTypes: []CaptureType{CaptureSingle, CaptureWildcard},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no captures",
|
||||||
|
pattern: "func main() {}",
|
||||||
|
wantErr: false,
|
||||||
|
wantCaptures: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parsed, err := ParsePattern(tt.pattern)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parsed.Captures) != tt.wantCaptures {
|
||||||
|
t.Errorf("expected %d captures, got %d", tt.wantCaptures, len(parsed.Captures))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check capture names (order may vary)
|
||||||
|
if tt.captureNames != nil {
|
||||||
|
captureMap := make(map[string]CaptureType)
|
||||||
|
for _, cap := range parsed.Captures {
|
||||||
|
captureMap[cap.Name] = cap.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, name := range tt.captureNames {
|
||||||
|
if _, ok := captureMap[name]; !ok {
|
||||||
|
t.Errorf("expected capture %s not found", name)
|
||||||
|
}
|
||||||
|
if captureMap[name] != tt.captureTypes[i] {
|
||||||
|
t.Errorf("capture %s: expected type %v, got %v", name, tt.captureTypes[i], captureMap[name])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchGoFunctions(t *testing.T) {
|
||||||
|
reg := parser.NewRegistry()
|
||||||
|
defer reg.Close()
|
||||||
|
|
||||||
|
matcher := NewMatcher(reg)
|
||||||
|
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func Hello() {
|
||||||
|
println("hello")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Greet(name string) error {
|
||||||
|
println("hello", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query *ASTQuery
|
||||||
|
name string
|
||||||
|
wantMatches int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match all functions",
|
||||||
|
query: &ASTQuery{
|
||||||
|
Pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
||||||
|
Language: "go",
|
||||||
|
},
|
||||||
|
wantMatches: 3, // Hello, Greet, Start
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match functions starting with H",
|
||||||
|
query: &ASTQuery{
|
||||||
|
Pattern: "func $NAME() {}",
|
||||||
|
Language: "go",
|
||||||
|
Filters: QueryFilters{
|
||||||
|
NameMatches: "^H",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantMatches: 1, // Hello
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match specific function",
|
||||||
|
query: &ASTQuery{
|
||||||
|
Pattern: "func $NAME() {}",
|
||||||
|
Language: "go",
|
||||||
|
Filters: QueryFilters{
|
||||||
|
NameExact: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantMatches: 1, // Hello
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("match failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != tt.wantMatches {
|
||||||
|
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||||
|
for i, r := range results {
|
||||||
|
t.Logf("match %d: %s at line %d", i, r.Node.Type(), r.Location.Line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchGoStructs(t *testing.T) {
|
||||||
|
reg := parser.NewRegistry()
|
||||||
|
defer reg.Close()
|
||||||
|
|
||||||
|
matcher := NewMatcher(reg)
|
||||||
|
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
Port int
|
||||||
|
Host string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Timeout int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Log(msg string)
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query *ASTQuery
|
||||||
|
name string
|
||||||
|
wantMinimum int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match all structs",
|
||||||
|
query: &ASTQuery{
|
||||||
|
Pattern: "type $NAME struct { $$$FIELDS }",
|
||||||
|
Language: "go",
|
||||||
|
},
|
||||||
|
wantMinimum: 2, // Server, Config (may also match interface as type_declaration)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("match failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) < tt.wantMinimum {
|
||||||
|
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchJSFunctions(t *testing.T) {
|
||||||
|
reg := parser.NewRegistry()
|
||||||
|
defer reg.Close()
|
||||||
|
|
||||||
|
matcher := NewMatcher(reg)
|
||||||
|
|
||||||
|
content := `
|
||||||
|
function greet(name) {
|
||||||
|
console.log("Hello, " + name);
|
||||||
|
}
|
||||||
|
|
||||||
|
function sayHello() {
|
||||||
|
console.log("Hello!");
|
||||||
|
}
|
||||||
|
|
||||||
|
class User {
|
||||||
|
constructor(name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
getName() {
|
||||||
|
return this.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := reg.Parse(ctx, "test.js", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query *ASTQuery
|
||||||
|
name string
|
||||||
|
wantMatches int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match all functions",
|
||||||
|
query: &ASTQuery{
|
||||||
|
Pattern: "function $NAME($$$ARGS) { $$$BODY }",
|
||||||
|
Language: "javascript",
|
||||||
|
},
|
||||||
|
wantMatches: 2, // greet, sayHello
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match classes",
|
||||||
|
query: &ASTQuery{
|
||||||
|
Pattern: "class $NAME { $$$BODY }",
|
||||||
|
Language: "javascript",
|
||||||
|
},
|
||||||
|
wantMatches: 1, // User
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.js")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("match failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != tt.wantMatches {
|
||||||
|
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchPythonSymbols(t *testing.T) {
|
||||||
|
reg := parser.NewRegistry()
|
||||||
|
defer reg.Close()
|
||||||
|
|
||||||
|
matcher := NewMatcher(reg)
|
||||||
|
|
||||||
|
content := `
|
||||||
|
def greet(name):
|
||||||
|
print(f"Hello, {name}")
|
||||||
|
|
||||||
|
def calculate(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
class User:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
return self.name
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := reg.Parse(ctx, "test.py", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query *ASTQuery
|
||||||
|
name string
|
||||||
|
wantMinimum int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match classes",
|
||||||
|
query: &ASTQuery{
|
||||||
|
Pattern: "class $NAME: $$$BODY",
|
||||||
|
Language: "python",
|
||||||
|
},
|
||||||
|
wantMinimum: 1, // User
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.py")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("match failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) < tt.wantMinimum {
|
||||||
|
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryFilters(t *testing.T) {
|
||||||
|
reg := parser.NewRegistry()
|
||||||
|
defer reg.Close()
|
||||||
|
|
||||||
|
matcher := NewMatcher(reg)
|
||||||
|
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func HelloWorld() {}
|
||||||
|
func helloWorld() {}
|
||||||
|
func GoodbyeWorld() {}
|
||||||
|
func Main() {}
|
||||||
|
`
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filters QueryFilters
|
||||||
|
wantMatches int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "regex filter - starts with H",
|
||||||
|
filters: QueryFilters{
|
||||||
|
NameMatches: "^[Hh]ello",
|
||||||
|
},
|
||||||
|
wantMatches: 2, // HelloWorld, helloWorld
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact name filter",
|
||||||
|
filters: QueryFilters{
|
||||||
|
NameExact: "Main",
|
||||||
|
},
|
||||||
|
wantMatches: 1, // Main
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "kind filter",
|
||||||
|
filters: QueryFilters{
|
||||||
|
KindIn: []string{"function_declaration"},
|
||||||
|
},
|
||||||
|
wantMatches: 4, // all functions
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
query := &ASTQuery{
|
||||||
|
Pattern: "func $NAME() {}",
|
||||||
|
Language: "go",
|
||||||
|
Filters: tt.filters,
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("match failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != tt.wantMatches {
|
||||||
|
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||||
|
for _, r := range results {
|
||||||
|
if nameNode := r.Node.ChildByFieldName("name"); nameNode != nil {
|
||||||
|
t.Logf("matched: %s", parser.GetNodeText(nameNode, []byte(content)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatResults(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
results []MatchResult
|
||||||
|
maxResults int
|
||||||
|
wantEmpty bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty results",
|
||||||
|
results: []MatchResult{},
|
||||||
|
maxResults: 100,
|
||||||
|
wantEmpty: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single result",
|
||||||
|
results: []MatchResult{
|
||||||
|
{
|
||||||
|
File: "test.go",
|
||||||
|
Location: protocol.Location{Line: 10, Column: 1},
|
||||||
|
Text: "func Hello() {}",
|
||||||
|
Captures: map[string]CapturedNode{
|
||||||
|
"NAME": {Text: "Hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
maxResults: 100,
|
||||||
|
wantEmpty: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "truncated results",
|
||||||
|
results: []MatchResult{
|
||||||
|
{File: "a.go", Location: protocol.Location{Line: 1}, Text: "func A() {}"},
|
||||||
|
{File: "b.go", Location: protocol.Location{Line: 1}, Text: "func B() {}"},
|
||||||
|
{File: "c.go", Location: protocol.Location{Line: 1}, Text: "func C() {}"},
|
||||||
|
},
|
||||||
|
maxResults: 2,
|
||||||
|
wantEmpty: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
output := FormatResults(tt.results, tt.maxResults)
|
||||||
|
|
||||||
|
if tt.wantEmpty {
|
||||||
|
if output != "No matches found." {
|
||||||
|
t.Errorf("expected 'No matches found.', got: %s", output)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if output == "No matches found." {
|
||||||
|
t.Error("expected results, got 'No matches found.'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryValidation(t *testing.T) {
|
||||||
|
reg := parser.NewRegistry()
|
||||||
|
defer reg.Close()
|
||||||
|
|
||||||
|
matcher := NewMatcher(reg)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Parse some valid content
|
||||||
|
content := `package main
|
||||||
|
func main() {}
|
||||||
|
`
|
||||||
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query *ASTQuery
|
||||||
|
name string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty pattern",
|
||||||
|
query: &ASTQuery{Pattern: "", Language: "go"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing language",
|
||||||
|
query: &ASTQuery{Pattern: "func $NAME() {}", Language: ""},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown language",
|
||||||
|
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "unknown"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid query",
|
||||||
|
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "go"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestCompileRegexCaching tests that regex compilation is cached.
|
||||||
|
func TestCompileRegexCaching(t *testing.T) {
|
||||||
|
// Clear cache before test
|
||||||
|
regexCache = sync.Map{}
|
||||||
|
|
||||||
|
pattern := `^test_\w+$`
|
||||||
|
|
||||||
|
// First compilation
|
||||||
|
re1, err := compileRegex(pattern)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("First compile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second compilation should return cached version
|
||||||
|
re2, err := compileRegex(pattern)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Second compile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be the exact same object
|
||||||
|
if re1 != re2 {
|
||||||
|
t.Error("Expected cached regex to be reused, got different objects")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's in the cache
|
||||||
|
cached, ok := regexCache.Load(pattern)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Pattern not found in cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cached.(*regexp.Regexp) != re1 {
|
||||||
|
t.Error("Cached regex doesn't match returned regex")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompileRegexConcurrent tests concurrent regex compilation.
|
||||||
|
func TestCompileRegexConcurrent(t *testing.T) {
|
||||||
|
// Clear cache before test
|
||||||
|
regexCache = sync.Map{}
|
||||||
|
|
||||||
|
pattern := `[a-z]+_\d+`
|
||||||
|
const numGoroutines = 100
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numGoroutines)
|
||||||
|
|
||||||
|
results := make([]*regexp.Regexp, numGoroutines)
|
||||||
|
errors := make(chan error, numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
re, err := compileRegex(pattern)
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
results[i] = re
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
for err := range errors {
|
||||||
|
t.Errorf("Concurrent compile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All results should be the same object (cached)
|
||||||
|
for i := 1; i < numGoroutines; i++ {
|
||||||
|
if results[i] != results[0] {
|
||||||
|
t.Errorf("Result %d is different from result 0 (cache not working)", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompileRegexInvalidPattern tests error handling for invalid patterns.
|
||||||
|
func TestCompileRegexInvalidPattern(t *testing.T) {
|
||||||
|
// Clear cache before test
|
||||||
|
regexCache = sync.Map{}
|
||||||
|
|
||||||
|
invalidPattern := `[invalid(`
|
||||||
|
|
||||||
|
_, err := compileRegex(invalidPattern)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid pattern, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid patterns should not be cached
|
||||||
|
_, ok := regexCache.Load(invalidPattern)
|
||||||
|
if ok {
|
||||||
|
t.Error("Invalid pattern should not be cached")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompileRegexMultiplePatterns tests that different patterns are cached separately.
|
||||||
|
func TestCompileRegexMultiplePatterns(t *testing.T) {
|
||||||
|
// Clear cache before test
|
||||||
|
regexCache = sync.Map{}
|
||||||
|
|
||||||
|
patterns := []string{
|
||||||
|
`^test_\w+$`,
|
||||||
|
`^\d{4}-\d{2}-\d{2}$`,
|
||||||
|
`^[A-Z][a-z]+$`,
|
||||||
|
`\b\w+@\w+\.\w+\b`,
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled := make([]*regexp.Regexp, len(patterns))
|
||||||
|
|
||||||
|
// Compile all patterns
|
||||||
|
for i, pattern := range patterns {
|
||||||
|
re, err := compileRegex(pattern)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Compile failed for pattern %s: %v", pattern, err)
|
||||||
|
}
|
||||||
|
compiled[i] = re
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all are cached
|
||||||
|
for i, pattern := range patterns {
|
||||||
|
cached, ok := regexCache.Load(pattern)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Pattern %s not in cache", pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cached.(*regexp.Regexp) != compiled[i] {
|
||||||
|
t.Errorf("Cached regex for %s doesn't match compiled version", pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All should be different objects
|
||||||
|
for i := 0; i < len(compiled); i++ {
|
||||||
|
for j := i + 1; j < len(compiled); j++ {
|
||||||
|
if compiled[i] == compiled[j] {
|
||||||
|
t.Errorf("Pattern %d and %d have same regex object", i, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkCompileRegex_Uncached benchmarks regex compilation without caching.
|
||||||
|
func BenchmarkCompileRegex_Uncached(b *testing.B) {
|
||||||
|
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = regexp.Compile(pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkCompileRegex_Cached benchmarks regex compilation with caching.
|
||||||
|
func BenchmarkCompileRegex_Cached(b *testing.B) {
|
||||||
|
// Clear cache
|
||||||
|
regexCache = sync.Map{}
|
||||||
|
|
||||||
|
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
|
||||||
|
|
||||||
|
// Pre-populate cache
|
||||||
|
_, _ = compileRegex(pattern)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = compileRegex(pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkCompileRegex_MixedPatterns benchmarks realistic workload with multiple patterns.
|
||||||
|
func BenchmarkCompileRegex_MixedPatterns(b *testing.B) {
|
||||||
|
// Clear cache
|
||||||
|
regexCache = sync.Map{}
|
||||||
|
|
||||||
|
patterns := []string{
|
||||||
|
`^test_\w+$`,
|
||||||
|
`^\d{4}-\d{2}-\d{2}$`,
|
||||||
|
`^[A-Z][a-z]+$`,
|
||||||
|
`\b\w+@\w+\.\w+\b`,
|
||||||
|
`^func\s+\w+\(`,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Simulate realistic access pattern
|
||||||
|
pattern := patterns[i%len(patterns)]
|
||||||
|
_, _ = compileRegex(pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,401 @@
|
|||||||
|
// Package search provides text search functionality using ripgrep.
|
||||||
|
package search
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
json "github.com/goccy/go-json"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Searcher provides text search functionality using ripgrep.
|
||||||
|
type Searcher struct {
|
||||||
|
cfg *config.Config
|
||||||
|
logger *slog.Logger
|
||||||
|
rgPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request represents a search request.
|
||||||
|
type Request struct {
|
||||||
|
Pattern string
|
||||||
|
Paths []string
|
||||||
|
FileTypes []string
|
||||||
|
ContextLines int
|
||||||
|
MaxResults int
|
||||||
|
IgnoreCase bool
|
||||||
|
Regex bool
|
||||||
|
IncludeHidden bool
|
||||||
|
FollowSymlinks bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result represents a single search result.
|
||||||
|
type Result struct {
|
||||||
|
File string `json:"file"`
|
||||||
|
MatchText string `json:"match_text"`
|
||||||
|
Language protocol.Language `json:"language"`
|
||||||
|
Context ContextLines `json:"context"`
|
||||||
|
Line int `json:"line"`
|
||||||
|
Column int `json:"column"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextLines holds lines before and after a match.
|
||||||
|
type ContextLines struct {
|
||||||
|
Before []string `json:"before"`
|
||||||
|
After []string `json:"after"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchResults holds the complete search results.
|
||||||
|
type SearchResults struct {
|
||||||
|
Results []Result `json:"results"`
|
||||||
|
Truncated bool `json:"truncated"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ripgrep JSON output types
|
||||||
|
type rgMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type rgMatch struct {
|
||||||
|
Path struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"path"`
|
||||||
|
Lines struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"lines"`
|
||||||
|
Submatches []struct {
|
||||||
|
Match struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"match"`
|
||||||
|
Start int `json:"start"`
|
||||||
|
End int `json:"end"`
|
||||||
|
} `json:"submatches"`
|
||||||
|
LineNumber int `json:"line_number"`
|
||||||
|
AbsoluteOffset int `json:"absolute_offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type rgContext struct {
|
||||||
|
Path struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"path"`
|
||||||
|
Lines struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"lines"`
|
||||||
|
LineNumber int `json:"line_number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type rgSummary struct {
|
||||||
|
ElapsedTotal struct {
|
||||||
|
Secs int `json:"secs"`
|
||||||
|
Nanos int `json:"nanos"`
|
||||||
|
} `json:"elapsed_total"`
|
||||||
|
Stats struct {
|
||||||
|
Searches int `json:"searches"`
|
||||||
|
SearchesWithMatch int `json:"searches_with_match"`
|
||||||
|
BytesSearched int64 `json:"bytes_searched"`
|
||||||
|
BytesPrinted int64 `json:"bytes_printed"`
|
||||||
|
MatchedLines int `json:"matched_lines"`
|
||||||
|
Matches int `json:"matches"`
|
||||||
|
} `json:"stats"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Searcher instance.
|
||||||
|
func New(cfg *config.Config, logger *slog.Logger) (*Searcher, error) {
|
||||||
|
// Detect ripgrep binary
|
||||||
|
rgPath, err := exec.LookPath("rg")
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.NewRipgrepNotFound()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Searcher{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
rgPath: rgPath,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search executes a search and returns results.
|
||||||
|
func (s *Searcher) Search(ctx context.Context, req *Request) (*SearchResults, error) {
|
||||||
|
if req.Pattern == "" {
|
||||||
|
return nil, errors.New(errors.ErrInvalidPattern, "pattern cannot be empty").
|
||||||
|
WithRemediation("Provide a non-empty search pattern")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build ripgrep command
|
||||||
|
args := s.buildArgs(req)
|
||||||
|
|
||||||
|
s.logger.Debug("executing ripgrep", "args", args)
|
||||||
|
|
||||||
|
// Create command with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, s.cfg.SearchTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, s.rgPath, args...) // #nosec G204 - rgPath is validated at initialization
|
||||||
|
|
||||||
|
// Set working directory to workspace root
|
||||||
|
cmd.Dir = s.cfg.WorkspaceRoot
|
||||||
|
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
// Run command - ripgrep returns exit code 1 for no matches, which is not an error
|
||||||
|
err := cmd.Run()
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
return nil, errors.NewSearchTimeout(req.Pattern, s.cfg.SearchTimeout.String())
|
||||||
|
}
|
||||||
|
// Exit code 1 means no matches, which is fine
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
|
||||||
|
return &SearchResults{Results: []Result{}, Total: 0}, nil
|
||||||
|
}
|
||||||
|
// Exit code 2 means error
|
||||||
|
if stderr.Len() > 0 {
|
||||||
|
return nil, errors.Wrap(errors.ErrSearchFailed, "ripgrep search failed", err).
|
||||||
|
WithContext("pattern", req.Pattern).
|
||||||
|
WithContext("stderr", stderr.String()).
|
||||||
|
WithRemediation("Check search pattern syntax and ensure files are readable")
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(errors.ErrSearchFailed, "ripgrep search failed", err).
|
||||||
|
WithContext("pattern", req.Pattern).
|
||||||
|
WithRemediation("Check search pattern syntax and ensure ripgrep is functioning correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON output
|
||||||
|
return s.parseOutput(&stdout, req.MaxResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildArgs builds the ripgrep command arguments.
|
||||||
|
func (s *Searcher) buildArgs(req *Request) []string {
|
||||||
|
args := []string{"--json"}
|
||||||
|
|
||||||
|
// Add context lines
|
||||||
|
if req.ContextLines > 0 {
|
||||||
|
args = append(args, fmt.Sprintf("--context=%d", req.ContextLines))
|
||||||
|
}
|
||||||
|
|
||||||
|
// File type filtering
|
||||||
|
for _, ft := range req.FileTypes {
|
||||||
|
args = append(args, "--type", ft)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case sensitivity
|
||||||
|
if req.IgnoreCase {
|
||||||
|
args = append(args, "--ignore-case")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fixed strings (non-regex)
|
||||||
|
if !req.Regex {
|
||||||
|
args = append(args, "--fixed-strings")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Follow symlinks
|
||||||
|
if req.FollowSymlinks || s.cfg.FollowSymlinks {
|
||||||
|
args = append(args, "--follow")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include hidden files
|
||||||
|
if req.IncludeHidden {
|
||||||
|
args = append(args, "--hidden")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Respect .gitignore (default behavior for rg)
|
||||||
|
if !s.cfg.RespectGitignore {
|
||||||
|
args = append(args, "--no-ignore")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max count per file to limit results
|
||||||
|
if req.MaxResults > 0 {
|
||||||
|
args = append(args, fmt.Sprintf("--max-count=%d", req.MaxResults))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add pattern
|
||||||
|
args = append(args, "--", req.Pattern)
|
||||||
|
|
||||||
|
// Add paths (default to current directory which is workspace root)
|
||||||
|
if len(req.Paths) > 0 {
|
||||||
|
for _, p := range req.Paths {
|
||||||
|
// Validate path is within workspace
|
||||||
|
if s.cfg.IsPathAllowed(p) {
|
||||||
|
args = append(args, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
args = append(args, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOutput parses ripgrep JSON output.
|
||||||
|
func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchResults, error) {
|
||||||
|
results := &SearchResults{
|
||||||
|
Results: []Result{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track context by file and line
|
||||||
|
contextBefore := make(map[string][]string) // file -> lines before current match
|
||||||
|
currentFile := ""
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(output)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg rgMessage
|
||||||
|
if err := json.Unmarshal(line, &msg); err != nil {
|
||||||
|
s.logger.Debug("failed to parse ripgrep output line", "error", err, "line", string(line))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
|
case "match":
|
||||||
|
var match rgMatch
|
||||||
|
if err := json.Unmarshal(msg.Data, &match); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check max results
|
||||||
|
if maxResults > 0 && len(results.Results) >= maxResults {
|
||||||
|
results.Truncated = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Result{
|
||||||
|
File: match.Path.Text,
|
||||||
|
Line: match.LineNumber,
|
||||||
|
MatchText: strings.TrimRight(match.Lines.Text, "\n\r"),
|
||||||
|
Language: protocol.DetectLanguage(match.Path.Text),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add column from first submatch
|
||||||
|
if len(match.Submatches) > 0 {
|
||||||
|
result.Column = match.Submatches[0].Start + 1 // 1-indexed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add context before
|
||||||
|
if ctx, ok := contextBefore[match.Path.Text]; ok {
|
||||||
|
result.Context.Before = ctx
|
||||||
|
delete(contextBefore, match.Path.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
results.Results = append(results.Results, result)
|
||||||
|
currentFile = match.Path.Text
|
||||||
|
|
||||||
|
case "context":
|
||||||
|
var ctx rgContext
|
||||||
|
if err := json.Unmarshal(msg.Data, &ctx); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lineText := strings.TrimRight(ctx.Lines.Text, "\n\r")
|
||||||
|
|
||||||
|
// Determine if this is before or after context
|
||||||
|
if len(results.Results) > 0 {
|
||||||
|
lastResult := &results.Results[len(results.Results)-1]
|
||||||
|
if lastResult.File == ctx.Path.Text && ctx.LineNumber > lastResult.Line {
|
||||||
|
// This is after context
|
||||||
|
lastResult.Context.After = append(lastResult.Context.After, lineText)
|
||||||
|
} else if ctx.Path.Text == currentFile || currentFile == "" {
|
||||||
|
// This is before context for a potential upcoming match
|
||||||
|
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Before any match - store as potential before context
|
||||||
|
contextBefore[ctx.Path.Text] = append(contextBefore[ctx.Path.Text], lineText)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "summary":
|
||||||
|
var summary rgSummary
|
||||||
|
if err := json.Unmarshal(msg.Data, &summary); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
results.Total = summary.Stats.Matches
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error reading ripgrep output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatResults formats search results for display.
|
||||||
|
func (s *Searcher) FormatResults(results *SearchResults) string {
|
||||||
|
if len(results.Results) == 0 {
|
||||||
|
return "No matches found."
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// Group results by file
|
||||||
|
fileResults := make(map[string][]Result)
|
||||||
|
var fileOrder []string
|
||||||
|
for _, r := range results.Results {
|
||||||
|
if _, exists := fileResults[r.File]; !exists {
|
||||||
|
fileOrder = append(fileOrder, r.File)
|
||||||
|
}
|
||||||
|
fileResults[r.File] = append(fileResults[r.File], r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write summary
|
||||||
|
totalMatches := len(results.Results)
|
||||||
|
fileCount := len(fileResults)
|
||||||
|
sb.WriteString(fmt.Sprintf("Found %d matches in %d files", totalMatches, fileCount))
|
||||||
|
if results.Truncated {
|
||||||
|
sb.WriteString(fmt.Sprintf(" (truncated, total: %d)", results.Total))
|
||||||
|
}
|
||||||
|
sb.WriteString(":\n\n")
|
||||||
|
|
||||||
|
// Write results grouped by file
|
||||||
|
for _, file := range fileOrder {
|
||||||
|
// Make path relative to workspace root if possible
|
||||||
|
relPath := file
|
||||||
|
if absPath, err := filepath.Abs(file); err == nil {
|
||||||
|
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
|
||||||
|
relPath = rel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(fmt.Sprintf("**%s**\n", relPath))
|
||||||
|
|
||||||
|
for _, r := range fileResults[file] {
|
||||||
|
// Write context before
|
||||||
|
for _, ctx := range r.Context.Before {
|
||||||
|
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write match line
|
||||||
|
sb.WriteString(fmt.Sprintf("L%d│ %s\n", r.Line, truncateLine(r.MatchText, 200)))
|
||||||
|
|
||||||
|
// Write context after
|
||||||
|
for _, ctx := range r.Context.After {
|
||||||
|
sb.WriteString(fmt.Sprintf(" │ %s\n", truncateLine(ctx, 200)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateLine truncates a line if it exceeds maxLen.
|
||||||
|
func truncateLine(s string, maxLen int) string {
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen-3] + "..."
|
||||||
|
}
|
||||||
@@ -0,0 +1,326 @@
|
|||||||
|
package search
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
cfg := config.Default()
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
searcher, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
// ripgrep might not be installed
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
t.Skip("ripgrep not installed, skipping test")
|
||||||
|
}
|
||||||
|
t.Fatalf("New failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if searcher == nil {
|
||||||
|
t.Fatal("expected non-nil searcher")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildArgs(t *testing.T) {
|
||||||
|
cfg := config.Default()
|
||||||
|
cfg.WorkspaceRoot = "/test/workspace"
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
// Create searcher without checking for rg binary
|
||||||
|
s := &Searcher{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
rgPath: "/usr/bin/rg",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *Request
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic search",
|
||||||
|
req: &Request{
|
||||||
|
Pattern: "test",
|
||||||
|
ContextLines: 2,
|
||||||
|
Regex: true,
|
||||||
|
},
|
||||||
|
expected: []string{"--json", "--context=2", "--", "test", "."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ignore case",
|
||||||
|
req: &Request{
|
||||||
|
Pattern: "test",
|
||||||
|
IgnoreCase: true,
|
||||||
|
Regex: true,
|
||||||
|
},
|
||||||
|
expected: []string{"--json", "--ignore-case", "--", "test", "."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fixed strings",
|
||||||
|
req: &Request{
|
||||||
|
Pattern: "test",
|
||||||
|
Regex: false,
|
||||||
|
},
|
||||||
|
expected: []string{"--json", "--fixed-strings", "--", "test", "."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with file types",
|
||||||
|
req: &Request{
|
||||||
|
Pattern: "test",
|
||||||
|
FileTypes: []string{"go", "ts"},
|
||||||
|
Regex: true,
|
||||||
|
},
|
||||||
|
expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with max results",
|
||||||
|
req: &Request{
|
||||||
|
Pattern: "test",
|
||||||
|
MaxResults: 10,
|
||||||
|
Regex: true,
|
||||||
|
},
|
||||||
|
expected: []string{"--json", "--max-count=10", "--", "test", "."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
args := s.buildArgs(tt.req)
|
||||||
|
|
||||||
|
// Check that all expected args are present
|
||||||
|
for _, exp := range tt.expected {
|
||||||
|
found := false
|
||||||
|
for _, arg := range args {
|
||||||
|
if arg == exp {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("expected arg %q not found in %v", exp, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatResults(t *testing.T) {
|
||||||
|
cfg := config.Default()
|
||||||
|
cfg.WorkspaceRoot = "/test/workspace"
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
s := &Searcher{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
rgPath: "/usr/bin/rg",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
results *SearchResults
|
||||||
|
contains []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty results",
|
||||||
|
results: &SearchResults{
|
||||||
|
Results: []Result{},
|
||||||
|
},
|
||||||
|
contains: []string{"No matches found"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single result",
|
||||||
|
results: &SearchResults{
|
||||||
|
Results: []Result{
|
||||||
|
{
|
||||||
|
File: "test.go",
|
||||||
|
Line: 10,
|
||||||
|
Column: 5,
|
||||||
|
MatchText: "func TestSomething()",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Total: 1,
|
||||||
|
},
|
||||||
|
contains: []string{"test.go", "L10", "TestSomething"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "truncated results",
|
||||||
|
results: &SearchResults{
|
||||||
|
Results: []Result{
|
||||||
|
{
|
||||||
|
File: "test.go",
|
||||||
|
Line: 10,
|
||||||
|
MatchText: "match",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Truncated: true,
|
||||||
|
Total: 100,
|
||||||
|
},
|
||||||
|
contains: []string{"truncated", "100"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
output := s.FormatResults(tt.results)
|
||||||
|
|
||||||
|
for _, exp := range tt.contains {
|
||||||
|
if !strings.Contains(output, exp) {
|
||||||
|
t.Errorf("expected output to contain %q, got:\n%s", exp, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOutput(t *testing.T) {
|
||||||
|
cfg := config.Default()
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
s := &Searcher{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
rgPath: "/usr/bin/rg",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample ripgrep JSON output
|
||||||
|
jsonOutput := `{"type":"begin","data":{"path":{"text":"test.go"}}}
|
||||||
|
{"type":"match","data":{"path":{"text":"test.go"},"lines":{"text":"func TestSomething() {\n"},"line_number":10,"absolute_offset":100,"submatches":[{"match":{"text":"Test"},"start":5,"end":9}]}}
|
||||||
|
{"type":"end","data":{"path":{"text":"test.go"},"stats":{"bytes_searched":1000}}}
|
||||||
|
{"type":"summary","data":{"elapsed_total":{"secs":0,"nanos":1000000},"stats":{"searches":1,"searches_with_match":1,"bytes_searched":1000,"bytes_printed":100,"matched_lines":1,"matches":1}}}
|
||||||
|
`
|
||||||
|
buf := bytes.NewBufferString(jsonOutput)
|
||||||
|
|
||||||
|
results, err := s.parseOutput(buf, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseOutput failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results.Results) != 1 {
|
||||||
|
t.Errorf("expected 1 result, got %d", len(results.Results))
|
||||||
|
}
|
||||||
|
|
||||||
|
if results.Results[0].File != "test.go" {
|
||||||
|
t.Errorf("expected file 'test.go', got %q", results.Results[0].File)
|
||||||
|
}
|
||||||
|
|
||||||
|
if results.Results[0].Line != 10 {
|
||||||
|
t.Errorf("expected line 10, got %d", results.Results[0].Line)
|
||||||
|
}
|
||||||
|
|
||||||
|
if results.Results[0].Column != 6 { // 1-indexed
|
||||||
|
t.Errorf("expected column 6, got %d", results.Results[0].Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateLine(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
maxLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: "short",
|
||||||
|
maxLen: 10,
|
||||||
|
expected: "short",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "this is a very long line that should be truncated",
|
||||||
|
maxLen: 20,
|
||||||
|
expected: "this is a very lo...",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "exact",
|
||||||
|
maxLen: 5,
|
||||||
|
expected: "exact",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := truncateLine(tt.input, tt.maxLen)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("truncateLine(%q, %d) = %q, want %q", tt.input, tt.maxLen, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchIntegration(t *testing.T) {
|
||||||
|
// Create a temporary directory with test files
|
||||||
|
tmpDir, err := os.MkdirTemp("", "mcp-filepuff-search-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
// Create test files
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
println("Hello, World!")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
err = os.WriteFile(testFile, []byte(content), 0600)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := config.Default()
|
||||||
|
cfg.WorkspaceRoot = tmpDir
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
searcher, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("ripgrep not installed, skipping integration test")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := &Request{
|
||||||
|
Pattern: "Hello",
|
||||||
|
ContextLines: 1,
|
||||||
|
Regex: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := searcher.Search(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Search failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results.Results) != 1 {
|
||||||
|
t.Errorf("expected 1 result, got %d", len(results.Results))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results.Results) > 0 && !strings.Contains(results.Results[0].MatchText, "Hello") {
|
||||||
|
t.Errorf("expected match to contain 'Hello', got %q", results.Results[0].MatchText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchEmptyPattern(t *testing.T) {
|
||||||
|
cfg := config.Default()
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
||||||
|
|
||||||
|
s := &Searcher{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
rgPath: "/usr/bin/rg",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := &Request{
|
||||||
|
Pattern: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := s.Search(ctx, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty pattern")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,993 @@
|
|||||||
|
// Package server implements the MCP server for file operations.
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/edit"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/lsp"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/search"
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
"github.com/mark3labs/mcp-go/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server represents the MCP file operations server.
|
||||||
|
type Server struct {
|
||||||
|
cfg *config.Config
|
||||||
|
logger *slog.Logger
|
||||||
|
mcp *server.MCPServer
|
||||||
|
searcher *search.Searcher
|
||||||
|
parser *parser.Registry
|
||||||
|
matcher *query.Matcher
|
||||||
|
lspManager *lsp.Manager
|
||||||
|
editor *edit.Engine
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new MCP server instance.
|
||||||
|
func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
|
||||||
|
parserRegistry := parser.NewRegistry()
|
||||||
|
s := &Server{
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
parser: parserRegistry,
|
||||||
|
matcher: query.NewMatcher(parserRegistry),
|
||||||
|
editor: edit.NewEngine(parserRegistry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize searcher
|
||||||
|
searcher, err := search.New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("ripgrep not available, search functionality disabled", "error", err)
|
||||||
|
}
|
||||||
|
s.searcher = searcher
|
||||||
|
|
||||||
|
// Initialize LSP manager if enabled
|
||||||
|
if cfg.EnableLSP {
|
||||||
|
s.lspManager = lsp.NewManager(cfg.WorkspaceRoot, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
mcpServer := server.NewMCPServer(
|
||||||
|
"mcp-filepuff",
|
||||||
|
"1.0.0",
|
||||||
|
server.WithLogging(),
|
||||||
|
)
|
||||||
|
s.mcp = mcpServer
|
||||||
|
|
||||||
|
// Register tools
|
||||||
|
s.registerTools()
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerTools registers all available tools with the MCP server.
|
||||||
|
func (s *Server) registerTools() {
|
||||||
|
// Register ping tool for health checks
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("ping",
|
||||||
|
mcp.WithDescription("Health check - returns pong to verify the server is running"),
|
||||||
|
mcp.WithReadOnlyHintAnnotation(true),
|
||||||
|
),
|
||||||
|
s.handlePing,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Register file_search tool
|
||||||
|
if s.searcher != nil {
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("file_search",
|
||||||
|
mcp.WithDescription("Search for text patterns in files using ripgrep. Supports regex patterns, file type filtering, and context lines."),
|
||||||
|
mcp.WithReadOnlyHintAnnotation(true),
|
||||||
|
mcp.WithString("pattern",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("The search pattern (regex by default)"),
|
||||||
|
),
|
||||||
|
mcp.WithArray("paths",
|
||||||
|
mcp.Description("Paths to search in (defaults to workspace root)"),
|
||||||
|
mcp.WithStringItems(),
|
||||||
|
),
|
||||||
|
mcp.WithArray("file_types",
|
||||||
|
mcp.Description("File types to search (e.g., ['go', 'ts', 'py'])"),
|
||||||
|
mcp.WithStringItems(),
|
||||||
|
),
|
||||||
|
mcp.WithBoolean("ignore_case",
|
||||||
|
mcp.Description("Case insensitive search"),
|
||||||
|
),
|
||||||
|
mcp.WithBoolean("regex",
|
||||||
|
mcp.Description("Treat pattern as regex (default: true)"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("context_lines",
|
||||||
|
mcp.Description("Number of context lines around matches (default: 2)"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("max_results",
|
||||||
|
mcp.Description("Maximum number of results to return"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleFileSearch,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register file_read tool
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("file_read",
|
||||||
|
mcp.WithDescription("Read a file's contents with optional line range and AST symbol summary"),
|
||||||
|
mcp.WithReadOnlyHintAnnotation(true),
|
||||||
|
mcp.WithString("path",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Path to the file to read"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("line_start",
|
||||||
|
mcp.Description("Starting line number (1-indexed)"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("line_end",
|
||||||
|
mcp.Description("Ending line number (inclusive)"),
|
||||||
|
),
|
||||||
|
mcp.WithBoolean("include_ast",
|
||||||
|
mcp.Description("Include AST symbol summary (functions, classes, types, etc.)"),
|
||||||
|
),
|
||||||
|
mcp.WithBoolean("symbols_only",
|
||||||
|
mcp.Description("Return only symbol summary without file content (token-efficient mode). Requires include_ast=true."),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("max_lines",
|
||||||
|
mcp.Description("Maximum number of lines to return (for token efficiency). Applied after line_start/line_end."),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleFileRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Register ast_query tool
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("ast_query",
|
||||||
|
mcp.WithDescription("Search for AST patterns in code files. Use code patterns with $VAR placeholders to match and capture code structures like functions, classes, and types."),
|
||||||
|
mcp.WithReadOnlyHintAnnotation(true),
|
||||||
|
mcp.WithString("pattern",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Code pattern with placeholders: $NAME (single), $$$ARGS (multiple), $_ (wildcard). Examples: 'func $NAME($$$ARGS) error', 'class $NAME { $$$BODY }'"),
|
||||||
|
),
|
||||||
|
mcp.WithString("language",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Target language: go, typescript, javascript, python, c, cpp"),
|
||||||
|
),
|
||||||
|
mcp.WithArray("paths",
|
||||||
|
mcp.Description("Paths to search in (defaults to workspace root)"),
|
||||||
|
mcp.WithStringItems(),
|
||||||
|
),
|
||||||
|
mcp.WithString("name_matches",
|
||||||
|
mcp.Description("Regex pattern to filter by name"),
|
||||||
|
),
|
||||||
|
mcp.WithString("name_exact",
|
||||||
|
mcp.Description("Exact name to match"),
|
||||||
|
),
|
||||||
|
mcp.WithArray("kind_in",
|
||||||
|
mcp.Description("Node types to match (e.g., function_declaration, class_declaration)"),
|
||||||
|
mcp.WithStringItems(),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("max_results",
|
||||||
|
mcp.Description("Maximum number of results to return (default: 100)"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleASTQuery,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Register LSP-based tools if LSP is enabled
|
||||||
|
if s.lspManager != nil {
|
||||||
|
// Register symbol_at tool
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("symbol_at",
|
||||||
|
mcp.WithDescription("Get information about the symbol at a specific position in a file. Returns type, documentation, and definition location using LSP when available."),
|
||||||
|
mcp.WithReadOnlyHintAnnotation(true),
|
||||||
|
mcp.WithString("file",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Path to the file"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("line",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Line number (1-indexed)"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("column",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Column number (1-indexed)"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleSymbolAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Register find_definition tool
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("find_definition",
|
||||||
|
mcp.WithDescription("Find the definition of the symbol at a specific position. Uses LSP to locate where a function, variable, type, etc. is defined."),
|
||||||
|
mcp.WithReadOnlyHintAnnotation(true),
|
||||||
|
mcp.WithString("file",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Path to the file"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("line",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Line number (1-indexed)"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("column",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Column number (1-indexed)"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleFindDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Register find_references tool
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("find_references",
|
||||||
|
mcp.WithDescription("Find all references to the symbol at a specific position. Uses LSP to locate all usages of a function, variable, type, etc."),
|
||||||
|
mcp.WithReadOnlyHintAnnotation(true),
|
||||||
|
mcp.WithString("file",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Path to the file"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("line",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Line number (1-indexed)"),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("column",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Column number (1-indexed)"),
|
||||||
|
),
|
||||||
|
mcp.WithBoolean("include_declaration",
|
||||||
|
mcp.Description("Include the declaration in results (default: true)"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleFindReferences,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register edit tools
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("edit_preview",
|
||||||
|
mcp.WithDescription("Preview an edit without applying it. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++), and text-based editing for other files (Markdown, JSON, YAML, config files, etc.)."),
|
||||||
|
mcp.WithString("file",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Path to the file to edit"),
|
||||||
|
),
|
||||||
|
mcp.WithString("operation",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Edit operation: replace, insert_before, insert_after, delete"),
|
||||||
|
),
|
||||||
|
mcp.WithString("new_content",
|
||||||
|
mcp.Description("New content (required for replace/insert operations)"),
|
||||||
|
),
|
||||||
|
// AST-mode selectors (for code files)
|
||||||
|
mcp.WithString("selector_kind",
|
||||||
|
mcp.Description("AST node type to match (e.g., function_declaration, class_declaration). For code files only."),
|
||||||
|
),
|
||||||
|
mcp.WithString("selector_name",
|
||||||
|
mcp.Description("Name of the symbol to match. For code files only."),
|
||||||
|
),
|
||||||
|
// Shared selectors
|
||||||
|
mcp.WithNumber("selector_line",
|
||||||
|
mcp.Description("Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range."),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("selector_index",
|
||||||
|
mcp.Description("Index of the match to use if multiple matches found (default: 0)"),
|
||||||
|
),
|
||||||
|
// Text-mode selectors (for non-code files or explicit text matching)
|
||||||
|
mcp.WithNumber("selector_line_end",
|
||||||
|
mcp.Description("End line number for range selection (text mode). Used with selector_line."),
|
||||||
|
),
|
||||||
|
mcp.WithString("selector_text",
|
||||||
|
mcp.Description("Exact text to match (text mode). Must be unique or use selector_index."),
|
||||||
|
),
|
||||||
|
mcp.WithString("selector_pattern",
|
||||||
|
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleEditPreview,
|
||||||
|
)
|
||||||
|
|
||||||
|
s.mcp.AddTool(
|
||||||
|
mcp.NewTool("edit_apply",
|
||||||
|
mcp.WithDescription("Apply an edit to a file. Uses AST-aware editing for code files (Go, TypeScript, JavaScript, Python, C, C++) with syntax validation, and text-based editing for other files (Markdown, JSON, YAML, config files, etc.)."),
|
||||||
|
mcp.WithString("file",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Path to the file to edit"),
|
||||||
|
),
|
||||||
|
mcp.WithString("operation",
|
||||||
|
mcp.Required(),
|
||||||
|
mcp.Description("Edit operation: replace, insert_before, insert_after, delete"),
|
||||||
|
),
|
||||||
|
mcp.WithString("new_content",
|
||||||
|
mcp.Description("New content (required for replace/insert operations)"),
|
||||||
|
),
|
||||||
|
// AST-mode selectors (for code files)
|
||||||
|
mcp.WithString("selector_kind",
|
||||||
|
mcp.Description("AST node type to match (e.g., function_declaration, class_declaration). For code files only."),
|
||||||
|
),
|
||||||
|
mcp.WithString("selector_name",
|
||||||
|
mcp.Description("Name of the symbol to match. For code files only."),
|
||||||
|
),
|
||||||
|
// Shared selectors
|
||||||
|
mcp.WithNumber("selector_line",
|
||||||
|
mcp.Description("Line number (1-indexed). For AST mode: narrows search. For text mode: start of line range."),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("selector_index",
|
||||||
|
mcp.Description("Index of the match to use if multiple matches found (default: 0)"),
|
||||||
|
),
|
||||||
|
// Text-mode selectors (for non-code files or explicit text matching)
|
||||||
|
mcp.WithNumber("selector_line_end",
|
||||||
|
mcp.Description("End line number for range selection (text mode). Used with selector_line."),
|
||||||
|
),
|
||||||
|
mcp.WithString("selector_text",
|
||||||
|
mcp.Description("Exact text to match (text mode). Must be unique or use selector_index."),
|
||||||
|
),
|
||||||
|
mcp.WithString("selector_pattern",
|
||||||
|
mcp.Description("Regex pattern to match (text mode). Must be unique or use selector_index."),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
s.handleEditApply,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlePing handles the ping health check tool.
|
||||||
|
func (s *Server) handlePing(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
return mcp.NewToolResultText("pong"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFileSearch handles the file_search tool.
|
||||||
|
func (s *Server) handleFileSearch(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
s.logger.Debug("file_search completed",
|
||||||
|
"duration_ms", time.Since(start).Milliseconds(),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if s.searcher == nil {
|
||||||
|
return mcp.NewToolResultError("ripgrep (rg) is not available. Please install it: https://github.com/BurntSushi/ripgrep#installation"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse request arguments using SDK helpers
|
||||||
|
pattern, err := request.RequireString("pattern")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("pattern is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &search.Request{
|
||||||
|
Pattern: pattern,
|
||||||
|
Paths: request.GetStringSlice("paths", nil),
|
||||||
|
FileTypes: request.GetStringSlice("file_types", nil),
|
||||||
|
IgnoreCase: request.GetBool("ignore_case", false),
|
||||||
|
Regex: request.GetBool("regex", true),
|
||||||
|
ContextLines: request.GetInt("context_lines", 2),
|
||||||
|
MaxResults: request.GetInt("max_results", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute search
|
||||||
|
results, err := s.searcher.Search(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("search error: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("search completed",
|
||||||
|
"pattern", pattern,
|
||||||
|
"results_count", len(results.Results),
|
||||||
|
"truncated", results.Truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Format results
|
||||||
|
output := s.searcher.FormatResults(results)
|
||||||
|
return mcp.NewToolResultText(output), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFileRead handles the file_read tool.
|
||||||
|
func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
path, err := request.RequireString("path")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("path is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate path is within workspace
|
||||||
|
if !s.cfg.IsPathAllowed(path) {
|
||||||
|
return mcp.NewToolResultError("path is outside workspace root"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read file
|
||||||
|
content, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil
|
||||||
|
}
|
||||||
|
if os.IsPermission(err) {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil
|
||||||
|
}
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("error reading file: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check file size
|
||||||
|
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle line range
|
||||||
|
lines := splitLines(string(content))
|
||||||
|
lineStart := request.GetInt("line_start", 1)
|
||||||
|
lineEnd := request.GetInt("line_end", len(lines))
|
||||||
|
|
||||||
|
// Clamp to valid range
|
||||||
|
if lineStart < 1 {
|
||||||
|
lineStart = 1
|
||||||
|
}
|
||||||
|
if lineEnd > len(lines) {
|
||||||
|
lineEnd = len(lines)
|
||||||
|
}
|
||||||
|
if lineStart > lineEnd {
|
||||||
|
lineStart = lineEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
var output strings.Builder
|
||||||
|
|
||||||
|
// Include AST summary if requested
|
||||||
|
includeAST := request.GetBool("include_ast", false)
|
||||||
|
symbolsOnly := request.GetBool("symbols_only", false)
|
||||||
|
maxLines := request.GetInt("max_lines", 0)
|
||||||
|
|
||||||
|
// Validate symbols_only requires include_ast
|
||||||
|
if symbolsOnly && !includeAST {
|
||||||
|
return mcp.NewToolResultError("symbols_only requires include_ast=true"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if includeAST {
|
||||||
|
astSummary := s.generateASTSummary(ctx, path, content)
|
||||||
|
if astSummary != "" {
|
||||||
|
output.WriteString(astSummary)
|
||||||
|
if !symbolsOnly {
|
||||||
|
output.WriteString("\n---\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip file content if symbols_only mode
|
||||||
|
if !symbolsOnly {
|
||||||
|
// Apply max_lines limit if specified
|
||||||
|
effectiveEnd := lineEnd
|
||||||
|
if maxLines > 0 && (lineEnd-lineStart+1) > maxLines {
|
||||||
|
effectiveEnd = lineStart + maxLines - 1
|
||||||
|
if effectiveEnd < lineEnd {
|
||||||
|
// Add note that output was truncated
|
||||||
|
defer func() {
|
||||||
|
output.WriteString(fmt.Sprintf("\n[... %d more lines omitted for token efficiency. Use line_start/line_end or increase max_lines to see more]\n", lineEnd-effectiveEnd))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract requested lines
|
||||||
|
for i := lineStart - 1; i < effectiveEnd && i < len(lines); i++ {
|
||||||
|
output.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mcp.NewToolResultText(output.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateASTSummary generates a summary of symbols in the file.
|
||||||
|
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
|
||||||
|
// Parse the file
|
||||||
|
result, err := s.parser.Parse(ctx, path, content)
|
||||||
|
if err != nil {
|
||||||
|
return "" // Silently skip AST if parsing fails
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract symbols
|
||||||
|
lang := protocol.DetectLanguage(path)
|
||||||
|
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
|
||||||
|
if len(symbols) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// Get relative path
|
||||||
|
relPath := path
|
||||||
|
if absPath, err := filepath.Abs(path); err == nil {
|
||||||
|
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
|
||||||
|
relPath = rel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
|
||||||
|
sb.WriteString("Symbols:\n")
|
||||||
|
|
||||||
|
for _, sym := range symbols {
|
||||||
|
kindStr := symbolKindIcon(sym.Kind)
|
||||||
|
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// symbolKindIcon returns an icon/prefix for a symbol kind.
|
||||||
|
func symbolKindIcon(kind protocol.SymbolKind) string {
|
||||||
|
switch kind {
|
||||||
|
case protocol.SymbolFunction:
|
||||||
|
return "func"
|
||||||
|
case protocol.SymbolMethod:
|
||||||
|
return "meth"
|
||||||
|
case protocol.SymbolClass:
|
||||||
|
return "class"
|
||||||
|
case protocol.SymbolStruct:
|
||||||
|
return "struct"
|
||||||
|
case protocol.SymbolInterface:
|
||||||
|
return "iface"
|
||||||
|
case protocol.SymbolVariable:
|
||||||
|
return "var"
|
||||||
|
case protocol.SymbolConstant:
|
||||||
|
return "const"
|
||||||
|
case protocol.SymbolType:
|
||||||
|
return "type"
|
||||||
|
case protocol.SymbolField:
|
||||||
|
return "field"
|
||||||
|
case protocol.SymbolProperty:
|
||||||
|
return "prop"
|
||||||
|
case protocol.SymbolModule:
|
||||||
|
return "mod"
|
||||||
|
case protocol.SymbolPackage:
|
||||||
|
return "pkg"
|
||||||
|
default:
|
||||||
|
return "sym"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitLines(s string) []string {
|
||||||
|
// Use optimized stdlib implementation (2-3x faster than manual loop)
|
||||||
|
return strings.Split(s, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleASTQuery handles the ast_query tool.
|
||||||
|
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
pattern, err := request.RequireString("pattern")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("pattern is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
language, err := request.RequireString("language")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("language is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build query
|
||||||
|
astQuery := &query.ASTQuery{
|
||||||
|
Pattern: pattern,
|
||||||
|
Language: language,
|
||||||
|
Filters: query.QueryFilters{
|
||||||
|
NameMatches: request.GetString("name_matches", ""),
|
||||||
|
NameExact: request.GetString("name_exact", ""),
|
||||||
|
KindIn: request.GetStringSlice("kind_in", nil),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
maxResults := request.GetInt("max_results", 100)
|
||||||
|
paths := request.GetStringSlice("paths", nil)
|
||||||
|
|
||||||
|
// Default to workspace root if no paths specified
|
||||||
|
if len(paths) == 0 {
|
||||||
|
paths = []string{s.cfg.WorkspaceRoot}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find files to search based on language
|
||||||
|
ext := languageToExtension(language)
|
||||||
|
if ext == "" {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var allResults []query.MatchResult
|
||||||
|
|
||||||
|
// Walk through paths and find matching files
|
||||||
|
for _, searchPath := range paths {
|
||||||
|
// Validate path is within workspace
|
||||||
|
if !s.cfg.IsPathAllowed(searchPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return nil // Skip files with errors
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
// Skip hidden directories
|
||||||
|
if strings.HasPrefix(info.Name(), ".") {
|
||||||
|
return filepath.SkipDir
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check file extension matches language
|
||||||
|
if !strings.HasSuffix(path, ext) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and parse file
|
||||||
|
content, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil // Skip unreadable files
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check file size
|
||||||
|
if int64(len(content)) > s.cfg.MaxFileSize {
|
||||||
|
return nil // Skip large files
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse file
|
||||||
|
result, err := s.parser.Parse(ctx, path, content)
|
||||||
|
if err != nil {
|
||||||
|
return nil // Skip unparseable files
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run query
|
||||||
|
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil // Skip on error
|
||||||
|
}
|
||||||
|
|
||||||
|
allResults = append(allResults, matches...)
|
||||||
|
|
||||||
|
// Stop if we have enough results
|
||||||
|
if maxResults > 0 && len(allResults) >= maxResults {
|
||||||
|
return filepath.SkipAll
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("error walking path", "path", searchPath, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format and return results
|
||||||
|
output := query.FormatResults(allResults, maxResults)
|
||||||
|
return mcp.NewToolResultText(output), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// languageToExtension maps language names to file extensions.
|
||||||
|
func languageToExtension(language string) string {
|
||||||
|
switch strings.ToLower(language) {
|
||||||
|
case "go":
|
||||||
|
return ".go"
|
||||||
|
case "typescript":
|
||||||
|
return ".ts"
|
||||||
|
case "javascript":
|
||||||
|
return ".js"
|
||||||
|
case "python":
|
||||||
|
return ".py"
|
||||||
|
case "c":
|
||||||
|
return ".c"
|
||||||
|
case "cpp", "c++":
|
||||||
|
return ".cpp"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSymbolAt handles the symbol_at tool.
|
||||||
|
func (s *Server) handleSymbolAt(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
file, err := request.RequireString("file")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("file is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line := request.GetInt("line", 0)
|
||||||
|
if line <= 0 {
|
||||||
|
return mcp.NewToolResultError("line must be positive"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
col := request.GetInt("column", 0)
|
||||||
|
if col <= 0 {
|
||||||
|
return mcp.NewToolResultError("column must be positive"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate path
|
||||||
|
if !s.cfg.IsPathAllowed(file) {
|
||||||
|
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try LSP hover
|
||||||
|
hover, err := s.lspManager.Hover(ctx, file, line, col)
|
||||||
|
if err != nil {
|
||||||
|
// Fall back to AST-based info
|
||||||
|
return s.handleSymbolAtFallback(ctx, file, line, col)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hover == nil {
|
||||||
|
return mcp.NewToolResultText("No symbol information available at this position."), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var output strings.Builder
|
||||||
|
output.WriteString("**Symbol Information**\n\n")
|
||||||
|
output.WriteString(hover.Contents.Value)
|
||||||
|
|
||||||
|
return mcp.NewToolResultText(output.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSymbolAtFallback provides AST-based symbol info when LSP is unavailable.
|
||||||
|
func (s *Server) handleSymbolAtFallback(ctx context.Context, file string, line, col int) (*mcp.CallToolResult, error) {
|
||||||
|
content, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to read file: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := s.parser.Parse(ctx, file, content)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to parse file: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
node := parser.FindNodeAtPosition(result.Tree, line, col)
|
||||||
|
if node == nil {
|
||||||
|
return mcp.NewToolResultText("No symbol at this position."), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var output strings.Builder
|
||||||
|
output.WriteString("**Symbol Information** (AST fallback)\n\n")
|
||||||
|
output.WriteString(fmt.Sprintf("Node type: `%s`\n", node.Type()))
|
||||||
|
output.WriteString(fmt.Sprintf("Text: `%s`\n", parser.GetNodeText(node, content)))
|
||||||
|
|
||||||
|
return mcp.NewToolResultText(output.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFindDefinition handles the find_definition tool.
|
||||||
|
func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
file, err := request.RequireString("file")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("file is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line := request.GetInt("line", 0)
|
||||||
|
if line <= 0 {
|
||||||
|
return mcp.NewToolResultError("line must be positive"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
col := request.GetInt("column", 0)
|
||||||
|
if col <= 0 {
|
||||||
|
return mcp.NewToolResultError("column must be positive"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate path
|
||||||
|
if !s.cfg.IsPathAllowed(file) {
|
||||||
|
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
locations, err := s.lspManager.Definition(ctx, file, line, col)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("definition lookup failed: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(locations) == 0 {
|
||||||
|
return mcp.NewToolResultText("No definition found."), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var output strings.Builder
|
||||||
|
output.WriteString(fmt.Sprintf("Found %d definition(s):\n\n", len(locations)))
|
||||||
|
|
||||||
|
for _, loc := range locations {
|
||||||
|
filePath := lsp.URIToFile(loc.URI)
|
||||||
|
output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||||
|
|
||||||
|
// Try to read a preview snippet
|
||||||
|
preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3)
|
||||||
|
if preview != "" {
|
||||||
|
output.WriteString("```\n")
|
||||||
|
output.WriteString(preview)
|
||||||
|
output.WriteString("```\n")
|
||||||
|
}
|
||||||
|
output.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return mcp.NewToolResultText(output.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFindReferences handles the find_references tool.
|
||||||
|
func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
file, err := request.RequireString("file")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("file is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line := request.GetInt("line", 0)
|
||||||
|
if line <= 0 {
|
||||||
|
return mcp.NewToolResultError("line must be positive"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
col := request.GetInt("column", 0)
|
||||||
|
if col <= 0 {
|
||||||
|
return mcp.NewToolResultError("column must be positive"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
includeDecl := request.GetBool("include_declaration", true)
|
||||||
|
|
||||||
|
// Validate path
|
||||||
|
if !s.cfg.IsPathAllowed(file) {
|
||||||
|
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
locations, err := s.lspManager.References(ctx, file, line, col, includeDecl)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("references lookup failed: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(locations) == 0 {
|
||||||
|
return mcp.NewToolResultText("No references found."), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var output strings.Builder
|
||||||
|
output.WriteString(fmt.Sprintf("Found %d reference(s):\n\n", len(locations)))
|
||||||
|
|
||||||
|
// Group by file
|
||||||
|
fileGroups := make(map[string][]lsp.Location)
|
||||||
|
for _, loc := range locations {
|
||||||
|
filePath := lsp.URIToFile(loc.URI)
|
||||||
|
fileGroups[filePath] = append(fileGroups[filePath], loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
for filePath, locs := range fileGroups {
|
||||||
|
output.WriteString(fmt.Sprintf("**%s** (%d)\n", filePath, len(locs)))
|
||||||
|
for _, loc := range locs {
|
||||||
|
output.WriteString(fmt.Sprintf(" L%d:%d\n", loc.Range.Start.Line+1, loc.Range.Start.Character+1))
|
||||||
|
}
|
||||||
|
output.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return mcp.NewToolResultText(output.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readFilePreview reads a few lines from a file around the given line.
|
||||||
|
func readFilePreview(file string, line, contextLines int) string {
|
||||||
|
content, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := splitLines(string(content))
|
||||||
|
startLine := max(1, line-contextLines)
|
||||||
|
endLine := min(line+contextLines, len(lines))
|
||||||
|
|
||||||
|
var preview strings.Builder
|
||||||
|
for i := startLine - 1; i < endLine && i < len(lines); i++ {
|
||||||
|
lineText := lines[i]
|
||||||
|
if len(lineText) > 100 {
|
||||||
|
lineText = lineText[:100] + "..."
|
||||||
|
}
|
||||||
|
prefix := " "
|
||||||
|
if i+1 == line {
|
||||||
|
prefix = "> "
|
||||||
|
}
|
||||||
|
preview.WriteString(fmt.Sprintf("%s%4d: %s\n", prefix, i+1, lineText))
|
||||||
|
}
|
||||||
|
|
||||||
|
return preview.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEditPreview handles the edit_preview tool.
|
||||||
|
func (s *Server) handleEditPreview(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
return s.handleEdit(ctx, request, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEditApply handles the edit_apply tool.
|
||||||
|
func (s *Server) handleEditApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
return s.handleEdit(ctx, request, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEdit is the shared implementation for edit_preview and edit_apply.
|
||||||
|
func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, apply bool) (*mcp.CallToolResult, error) {
|
||||||
|
file, err := request.RequireString("file")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("file is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
operation, err := request.RequireString("operation")
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError("operation is required"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate path
|
||||||
|
if !s.cfg.IsPathAllowed(file) {
|
||||||
|
return mcp.NewToolResultError("file is outside workspace root"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: We no longer validate language support here.
|
||||||
|
// The edit engine automatically detects whether to use AST or text mode.
|
||||||
|
|
||||||
|
// Build edit request with both AST and text-mode selectors
|
||||||
|
astEdit := &edit.ASTEdit{
|
||||||
|
File: file,
|
||||||
|
Operation: edit.EditOperation(operation),
|
||||||
|
NewContent: request.GetString("new_content", ""),
|
||||||
|
Selector: edit.ASTSelector{
|
||||||
|
// AST-mode selectors
|
||||||
|
Kind: request.GetString("selector_kind", ""),
|
||||||
|
Name: request.GetString("selector_name", ""),
|
||||||
|
AtLine: request.GetInt("selector_line", 0),
|
||||||
|
Index: request.GetInt("selector_index", 0),
|
||||||
|
// Text-mode selectors
|
||||||
|
LineEnd: request.GetInt("selector_line_end", 0),
|
||||||
|
Text: request.GetString("selector_text", ""),
|
||||||
|
TextPattern: request.GetString("selector_pattern", ""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform edit
|
||||||
|
var result *edit.EditResult
|
||||||
|
if apply {
|
||||||
|
result, err = s.editor.Apply(ctx, astEdit)
|
||||||
|
} else {
|
||||||
|
result, err = s.editor.Preview(ctx, astEdit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("edit failed: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Success {
|
||||||
|
return mcp.NewToolResultError(result.Error), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format output
|
||||||
|
var output strings.Builder
|
||||||
|
if apply {
|
||||||
|
output.WriteString("**Edit Applied Successfully**\n\n")
|
||||||
|
} else {
|
||||||
|
output.WriteString("**Edit Preview**\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
output.WriteString("Diff:\n```diff\n")
|
||||||
|
output.WriteString(result.Diff)
|
||||||
|
output.WriteString("```\n")
|
||||||
|
|
||||||
|
return mcp.NewToolResultText(output.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the MCP server and blocks until shutdown.
|
||||||
|
func (s *Server) Run(ctx context.Context) error {
|
||||||
|
// Set up signal handling for graceful shutdown
|
||||||
|
_, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
sig := <-sigChan
|
||||||
|
s.logger.Info("received shutdown signal", "signal", sig)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.logger.Info("starting MCP server",
|
||||||
|
"workspace", s.cfg.WorkspaceRoot,
|
||||||
|
"lsp_enabled", s.cfg.EnableLSP,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start the MCP server with stdio transport
|
||||||
|
return server.ServeStdio(s.mcp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the server.
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
|
s.logger.Info("shutting down MCP server")
|
||||||
|
|
||||||
|
// Close LSP manager
|
||||||
|
if s.lspManager != nil {
|
||||||
|
_ = s.lspManager.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close parser registry
|
||||||
|
if s.parser != nil {
|
||||||
|
s.parser.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,377 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lukaszraczylo/mcp-filepuff/internal/config"
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
// Create temp directory for testing
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
WorkspaceRoot: tmpDir,
|
||||||
|
EnableLSP: false, // Disable LSP for simpler testing
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv == nil {
|
||||||
|
t.Fatal("New() returned nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.cfg != cfg {
|
||||||
|
t.Error("server config mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.parser == nil {
|
||||||
|
t.Error("parser should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.matcher == nil {
|
||||||
|
t.Error("matcher should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.editor == nil {
|
||||||
|
t.Error("editor should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlePing(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := mcp.CallToolRequest{}
|
||||||
|
|
||||||
|
result, err := srv.handlePing(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("handlePing() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("handlePing() returned nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the result contains "pong"
|
||||||
|
contents := result.Content
|
||||||
|
if len(contents) == 0 {
|
||||||
|
t.Fatal("handlePing() returned empty content")
|
||||||
|
}
|
||||||
|
|
||||||
|
textContent, ok := contents[0].(mcp.TextContent)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("handlePing() did not return text content")
|
||||||
|
}
|
||||||
|
|
||||||
|
if textContent.Text != "pong" {
|
||||||
|
t.Errorf("handlePing() = %v, want 'pong'", textContent.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleFileRead(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a test file
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
// Hello says hello
|
||||||
|
func Hello() {
|
||||||
|
println("Hello, World!")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := mcp.CallToolRequest{}
|
||||||
|
req.Params.Arguments = map[string]interface{}{
|
||||||
|
"path": testFile,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := srv.handleFileRead(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("handleFileRead() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("handleFileRead() returned nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := result.Content
|
||||||
|
if len(contents) == 0 {
|
||||||
|
t.Fatal("handleFileRead() returned empty content")
|
||||||
|
}
|
||||||
|
|
||||||
|
textContent, ok := contents[0].(mcp.TextContent)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("handleFileRead() did not return text content")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should contain the file content
|
||||||
|
if textContent.Text == "" {
|
||||||
|
t.Error("handleFileRead() returned empty text")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleFileReadWithAST(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a test file
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
// Hello says hello
|
||||||
|
func Hello() {
|
||||||
|
println("Hello, World!")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := mcp.CallToolRequest{}
|
||||||
|
req.Params.Arguments = map[string]interface{}{
|
||||||
|
"path": testFile,
|
||||||
|
"include_ast": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := srv.handleFileRead(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("handleFileRead() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("handleFileRead() returned nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := result.Content
|
||||||
|
if len(contents) == 0 {
|
||||||
|
t.Fatal("handleFileRead() returned empty content")
|
||||||
|
}
|
||||||
|
|
||||||
|
textContent, ok := contents[0].(mcp.TextContent)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("handleFileRead() did not return text content")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should contain "Symbols:" section when include_ast is true
|
||||||
|
if textContent.Text == "" {
|
||||||
|
t.Error("handleFileRead() returned empty text")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleFileReadNotFound(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := mcp.CallToolRequest{}
|
||||||
|
req.Params.Arguments = map[string]interface{}{
|
||||||
|
"path": filepath.Join(tmpDir, "nonexistent.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := srv.handleFileRead(ctx, req)
|
||||||
|
// Should return error for non-existent file
|
||||||
|
if err == nil && result != nil {
|
||||||
|
// Check if result indicates an error
|
||||||
|
contents := result.Content
|
||||||
|
if len(contents) > 0 {
|
||||||
|
textContent, ok := contents[0].(mcp.TextContent)
|
||||||
|
if ok && textContent.Text == "" {
|
||||||
|
t.Log("handleFileRead() returned empty text for non-existent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleASTQuery(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a test file
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func Hello() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Goodbye() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := mcp.CallToolRequest{}
|
||||||
|
req.Params.Arguments = map[string]interface{}{
|
||||||
|
"pattern": "func $NAME() error",
|
||||||
|
"language": "go",
|
||||||
|
"paths": []interface{}{tmpDir},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := srv.handleASTQuery(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("handleASTQuery() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("handleASTQuery() returned nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := result.Content
|
||||||
|
if len(contents) == 0 {
|
||||||
|
t.Fatal("handleASTQuery() returned empty content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEditPreview(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a test file
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func Hello() {
|
||||||
|
println("Hello")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := mcp.CallToolRequest{}
|
||||||
|
req.Params.Arguments = map[string]interface{}{
|
||||||
|
"file": testFile,
|
||||||
|
"operation": "replace",
|
||||||
|
"selector_kind": "function_declaration",
|
||||||
|
"selector_name": "Hello",
|
||||||
|
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := srv.handleEdit(ctx, req, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("handleEdit(preview) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("handleEdit(preview) returned nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file was NOT modified (it's just a preview)
|
||||||
|
fileContent, _ := os.ReadFile(testFile)
|
||||||
|
if string(fileContent) != content {
|
||||||
|
t.Error("handleEdit(preview) should not modify the file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleEditApply(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a test file
|
||||||
|
testFile := filepath.Join(tmpDir, "test.go")
|
||||||
|
content := `package main
|
||||||
|
|
||||||
|
func Hello() {
|
||||||
|
println("Hello")
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(testFile, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{WorkspaceRoot: tmpDir, EnableLSP: false}
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||||
|
|
||||||
|
srv, err := New(cfg, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := mcp.CallToolRequest{}
|
||||||
|
req.Params.Arguments = map[string]interface{}{
|
||||||
|
"file": testFile,
|
||||||
|
"operation": "replace",
|
||||||
|
"selector_kind": "function_declaration",
|
||||||
|
"selector_name": "Hello",
|
||||||
|
"new_content": "func Hello() {\n\tprintln(\"Goodbye\")\n}",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := srv.handleEdit(ctx, req, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("handleEdit(apply) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("handleEdit(apply) returned nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file WAS modified
|
||||||
|
fileContent, _ := os.ReadFile(testFile)
|
||||||
|
if string(fileContent) == content {
|
||||||
|
t.Error("handleEdit(apply) should modify the file")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,289 @@
|
|||||||
|
// Package errors provides structured error handling with error codes and context.
|
||||||
|
package errors
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorCode represents a specific error condition.
|
||||||
|
type ErrorCode int
|
||||||
|
|
||||||
|
// Error codes organized by category
|
||||||
|
const (
|
||||||
|
// Search errors (1000-1099)
|
||||||
|
ErrRipgrepNotFound ErrorCode = 1001
|
||||||
|
ErrRipgrepTimeout ErrorCode = 1002
|
||||||
|
ErrInvalidPattern ErrorCode = 1003
|
||||||
|
ErrSearchFailed ErrorCode = 1004
|
||||||
|
ErrNoResults ErrorCode = 1005
|
||||||
|
|
||||||
|
// Parser errors (1100-1199)
|
||||||
|
ErrParserNotFound ErrorCode = 1101
|
||||||
|
ErrParseFailed ErrorCode = 1102
|
||||||
|
ErrInvalidLanguage ErrorCode = 1103
|
||||||
|
ErrFileTooBig ErrorCode = 1104
|
||||||
|
ErrInvalidSyntax ErrorCode = 1105
|
||||||
|
|
||||||
|
// LSP errors (1200-1299)
|
||||||
|
ErrLSPServerNotFound ErrorCode = 1201
|
||||||
|
ErrLSPInitFailed ErrorCode = 1202
|
||||||
|
ErrLSPTimeout ErrorCode = 1203
|
||||||
|
ErrLSPCommunication ErrorCode = 1204
|
||||||
|
ErrNoHoverInfo ErrorCode = 1205
|
||||||
|
ErrNoDefinition ErrorCode = 1206
|
||||||
|
ErrNoReferences ErrorCode = 1207
|
||||||
|
|
||||||
|
// Edit errors (1300-1399)
|
||||||
|
ErrEditFailed ErrorCode = 1301
|
||||||
|
ErrInvalidEdit ErrorCode = 1302
|
||||||
|
ErrFileNotFound ErrorCode = 1303
|
||||||
|
ErrFileNotReadable ErrorCode = 1304
|
||||||
|
ErrFileNotWritable ErrorCode = 1305
|
||||||
|
ErrNodeNotFound ErrorCode = 1306
|
||||||
|
ErrValidationFailed ErrorCode = 1307
|
||||||
|
ErrInvalidSelection ErrorCode = 1308
|
||||||
|
|
||||||
|
// Query errors (1400-1499)
|
||||||
|
ErrInvalidQuery ErrorCode = 1401
|
||||||
|
ErrQueryTimeout ErrorCode = 1402
|
||||||
|
ErrNoMatches ErrorCode = 1403
|
||||||
|
ErrQueryCompile ErrorCode = 1404
|
||||||
|
|
||||||
|
// Config errors (1500-1599)
|
||||||
|
ErrInvalidConfig ErrorCode = 1501
|
||||||
|
ErrPathNotAllowed ErrorCode = 1502
|
||||||
|
ErrWorkspaceNotSet ErrorCode = 1503
|
||||||
|
|
||||||
|
// Internal errors (1900-1999)
|
||||||
|
ErrInternal ErrorCode = 1900
|
||||||
|
ErrCacheFailed ErrorCode = 1901
|
||||||
|
ErrConcurrency ErrorCode = 1902
|
||||||
|
)
|
||||||
|
|
||||||
|
// StructuredError represents an error with rich context and remediation info.
|
||||||
|
type StructuredError struct {
|
||||||
|
Cause error
|
||||||
|
Context map[string]any
|
||||||
|
Message string
|
||||||
|
Remediation string
|
||||||
|
Stack string
|
||||||
|
Code ErrorCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface.
|
||||||
|
func (e *StructuredError) Error() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// Error code and message
|
||||||
|
sb.WriteString(fmt.Sprintf("[%d] %s", e.Code, e.Message))
|
||||||
|
|
||||||
|
// Context if available
|
||||||
|
if len(e.Context) > 0 {
|
||||||
|
sb.WriteString("\nContext:")
|
||||||
|
for k, v := range e.Context {
|
||||||
|
sb.WriteString(fmt.Sprintf("\n %s: %v", k, v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remediation if available
|
||||||
|
if e.Remediation != "" {
|
||||||
|
sb.WriteString(fmt.Sprintf("\nHow to fix: %s", e.Remediation))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Underlying cause if available
|
||||||
|
if e.Cause != nil {
|
||||||
|
sb.WriteString(fmt.Sprintf("\nCaused by: %v", e.Cause))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying cause for error chain support.
|
||||||
|
func (e *StructuredError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext adds context to the error.
|
||||||
|
func (e *StructuredError) WithContext(key string, value any) *StructuredError {
|
||||||
|
if e.Context == nil {
|
||||||
|
e.Context = make(map[string]any)
|
||||||
|
}
|
||||||
|
e.Context[key] = value
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRemediation sets the remediation message.
|
||||||
|
func (e *StructuredError) WithRemediation(msg string) *StructuredError {
|
||||||
|
e.Remediation = msg
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new structured error with stack trace.
|
||||||
|
func New(code ErrorCode, message string) *StructuredError {
|
||||||
|
return &StructuredError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
Context: make(map[string]interface{}),
|
||||||
|
Stack: captureStack(2),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap wraps an existing error with structured error information.
|
||||||
|
func Wrap(code ErrorCode, message string, cause error) *StructuredError {
|
||||||
|
return &StructuredError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
Context: make(map[string]interface{}),
|
||||||
|
Cause: cause,
|
||||||
|
Stack: captureStack(2),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is checks if an error matches the given error code.
|
||||||
|
func Is(err error, code ErrorCode) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if se, ok := err.(*StructuredError); ok {
|
||||||
|
return se.Code == code
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCode extracts the error code from an error, or returns 0 if not a structured error.
|
||||||
|
func GetCode(err error) ErrorCode {
|
||||||
|
if err == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if se, ok := err.(*StructuredError); ok {
|
||||||
|
return se.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureStack captures the stack trace.
|
||||||
|
func captureStack(skip int) string {
|
||||||
|
const depth = 16
|
||||||
|
var pcs [depth]uintptr
|
||||||
|
n := runtime.Callers(skip+1, pcs[:])
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
frames := runtime.CallersFrames(pcs[:n])
|
||||||
|
|
||||||
|
for {
|
||||||
|
frame, more := frames.Next()
|
||||||
|
if !strings.Contains(frame.File, "runtime/") {
|
||||||
|
sb.WriteString(fmt.Sprintf("\n %s:%d %s", frame.File, frame.Line, frame.Function))
|
||||||
|
}
|
||||||
|
if !more {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common error constructors for convenience
|
||||||
|
|
||||||
|
// NewRipgrepNotFound creates an error for missing ripgrep binary.
|
||||||
|
func NewRipgrepNotFound() *StructuredError {
|
||||||
|
os := runtime.GOOS
|
||||||
|
install := "brew install ripgrep"
|
||||||
|
|
||||||
|
switch os {
|
||||||
|
case "linux":
|
||||||
|
install = "apt-get install ripgrep (Debian/Ubuntu) or yum install ripgrep (RHEL/CentOS)"
|
||||||
|
case "windows":
|
||||||
|
install = "choco install ripgrep or scoop install ripgrep"
|
||||||
|
}
|
||||||
|
|
||||||
|
return New(ErrRipgrepNotFound, "ripgrep (rg) binary not found in system PATH").
|
||||||
|
WithContext("os", os).
|
||||||
|
WithRemediation(fmt.Sprintf("Install ripgrep: %s", install))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLSPServerNotFound creates an error for missing LSP server.
|
||||||
|
func NewLSPServerNotFound(language, serverName string) *StructuredError {
|
||||||
|
return New(ErrLSPServerNotFound, fmt.Sprintf("LSP server '%s' not found for language %s", serverName, language)).
|
||||||
|
WithContext("language", language).
|
||||||
|
WithContext("server", serverName).
|
||||||
|
WithRemediation(fmt.Sprintf("Install the %s LSP server to enable IDE features for %s", serverName, language))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileTooLarge creates an error for files exceeding size limit.
|
||||||
|
func NewFileTooLarge(path string, size, limit int64) *StructuredError {
|
||||||
|
return New(ErrFileTooBig, "file exceeds maximum size limit").
|
||||||
|
WithContext("file", path).
|
||||||
|
WithContext("size_bytes", size).
|
||||||
|
WithContext("limit_bytes", limit).
|
||||||
|
WithRemediation(fmt.Sprintf("File size (%d bytes) exceeds limit (%d bytes). Consider processing smaller files or increasing the limit.", size, limit))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewParseError creates an error for parsing failures.
|
||||||
|
func NewParseError(language, file string, cause error) *StructuredError {
|
||||||
|
return Wrap(ErrParseFailed, fmt.Sprintf("failed to parse %s file", language), cause).
|
||||||
|
WithContext("language", language).
|
||||||
|
WithContext("file", file).
|
||||||
|
WithRemediation("Check file syntax and ensure it's valid source code for the specified language")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSearchTimeout creates an error for search timeouts.
|
||||||
|
func NewSearchTimeout(pattern string, duration string) *StructuredError {
|
||||||
|
return New(ErrRipgrepTimeout, "search operation timed out").
|
||||||
|
WithContext("pattern", pattern).
|
||||||
|
WithContext("duration", duration).
|
||||||
|
WithRemediation("Try narrowing the search scope, using more specific patterns, or increasing the timeout limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEditValidationError creates an error for edit validation failures.
|
||||||
|
func NewEditValidationError(file string, cause error) *StructuredError {
|
||||||
|
return Wrap(ErrValidationFailed, "edit validation failed - syntax errors detected", cause).
|
||||||
|
WithContext("file", file).
|
||||||
|
WithRemediation("Review the edit operation and ensure it produces valid syntax. The file was not modified.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileNotFoundError creates an error for missing files.
|
||||||
|
func NewFileNotFoundError(file string) *StructuredError {
|
||||||
|
return New(ErrFileNotFound, fmt.Sprintf("file not found: %s", file)).
|
||||||
|
WithContext("file", file).
|
||||||
|
WithRemediation("Verify the file path is correct and the file exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileNotReadableError creates an error for unreadable files.
|
||||||
|
func NewFileNotReadableError(file string, cause error) *StructuredError {
|
||||||
|
return Wrap(ErrFileNotReadable, fmt.Sprintf("cannot read file: %s", file), cause).
|
||||||
|
WithContext("file", file).
|
||||||
|
WithRemediation("Check file permissions and ensure the file is not locked by another process")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileNotWritableError creates an error for write failures.
|
||||||
|
func NewFileNotWritableError(file string, cause error) *StructuredError {
|
||||||
|
return Wrap(ErrFileNotWritable, fmt.Sprintf("cannot write to file: %s", file), cause).
|
||||||
|
WithContext("file", file).
|
||||||
|
WithRemediation("Check file permissions, disk space, and ensure the file is not locked by another process")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNodeNotFoundError creates an error when AST node selector finds no matches.
|
||||||
|
func NewNodeNotFoundError(selector string) *StructuredError {
|
||||||
|
return New(ErrNodeNotFound, "no AST nodes match the selector criteria").
|
||||||
|
WithContext("selector", selector).
|
||||||
|
WithRemediation("Verify the selector criteria (kind, name, pattern, line) match an existing code structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInvalidSelectionError creates an error for ambiguous or invalid selectors.
|
||||||
|
func NewInvalidSelectionError(message string) *StructuredError {
|
||||||
|
return New(ErrInvalidSelection, message).
|
||||||
|
WithRemediation("Refine the selector to be more specific or provide a selector_index to choose between multiple matches")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInvalidEditError creates an error for invalid edit operations.
|
||||||
|
func NewInvalidEditError(message string) *StructuredError {
|
||||||
|
return New(ErrInvalidEdit, message).
|
||||||
|
WithRemediation("Review the edit request and ensure all required fields are provided with valid values")
|
||||||
|
}
|
||||||
@@ -0,0 +1,375 @@
|
|||||||
|
// Package fuzzy provides fuzzy string matching using Levenshtein distance.
|
||||||
|
package fuzzy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Match represents a fuzzy match result.
|
||||||
|
type Match struct {
|
||||||
|
Text string
|
||||||
|
Distance int
|
||||||
|
Similarity float64
|
||||||
|
Score float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matcher provides fuzzy matching capabilities.
|
||||||
|
type Matcher struct {
|
||||||
|
threshold int
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new fuzzy matcher with the given threshold.
|
||||||
|
// Threshold is the maximum edit distance to consider a match (typically 1-3).
|
||||||
|
func New(threshold int) *Matcher {
|
||||||
|
return &Matcher{
|
||||||
|
threshold: threshold,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match performs fuzzy matching of query against candidates.
|
||||||
|
func (m *Matcher) Match(query string, candidates []string) []Match {
|
||||||
|
if query == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := make([]Match, 0, len(candidates)/10)
|
||||||
|
queryLower := strings.ToLower(query)
|
||||||
|
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
candidateLower := strings.ToLower(candidate)
|
||||||
|
|
||||||
|
// Calculate Levenshtein distance
|
||||||
|
dist := levenshteinDistance(queryLower, candidateLower)
|
||||||
|
|
||||||
|
// Skip if distance exceeds threshold
|
||||||
|
if dist > m.threshold {
|
||||||
|
// Check if it's a substring match (important for identifiers)
|
||||||
|
if !strings.Contains(candidateLower, queryLower) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Allow substring matches even if edit distance is high
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate similarity (0.0 to 1.0)
|
||||||
|
maxLen := max(len(query), len(candidate))
|
||||||
|
similarity := 1.0 - float64(dist)/float64(maxLen)
|
||||||
|
|
||||||
|
// Calculate composite score
|
||||||
|
score := m.calculateScore(queryLower, candidateLower, dist, similarity)
|
||||||
|
|
||||||
|
matches = append(matches, Match{
|
||||||
|
Text: candidate,
|
||||||
|
Distance: dist,
|
||||||
|
Similarity: similarity,
|
||||||
|
Score: score,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by score descending
|
||||||
|
sort.Slice(matches, func(i, j int) bool {
|
||||||
|
return matches[i].Score > matches[j].Score
|
||||||
|
})
|
||||||
|
|
||||||
|
return matches
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateScore computes a composite score considering multiple factors.
|
||||||
|
func (m *Matcher) calculateScore(query, candidate string, dist int, similarity float64) float64 {
|
||||||
|
score := similarity
|
||||||
|
|
||||||
|
// Bonus for exact match
|
||||||
|
if query == candidate {
|
||||||
|
score += 2.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bonus for prefix match (important for identifier search)
|
||||||
|
if strings.HasPrefix(candidate, query) {
|
||||||
|
score += 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bonus for word boundary matches (e.g., "getName" matches "get")
|
||||||
|
if containsWordBoundary(candidate, query) {
|
||||||
|
score += 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
// Penalty for length difference (prefer similar-length matches)
|
||||||
|
lenDiff := abs(len(candidate) - len(query))
|
||||||
|
score -= float64(lenDiff) * 0.01
|
||||||
|
|
||||||
|
// Penalty for edit distance
|
||||||
|
score -= float64(dist) * 0.1
|
||||||
|
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
|
||||||
|
// levenshteinDistance computes the Levenshtein distance between two strings.
|
||||||
|
// Uses the Wagner-Fischer algorithm with space optimization O(min(m,n)).
|
||||||
|
func levenshteinDistance(s1, s2 string) int {
|
||||||
|
if s1 == s2 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if len(s1) == 0 {
|
||||||
|
return len(s2)
|
||||||
|
}
|
||||||
|
if len(s2) == 0 {
|
||||||
|
return len(s1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure s1 is the shorter string for space optimization
|
||||||
|
if len(s1) > len(s2) {
|
||||||
|
s1, s2 = s2, s1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use rune slices to handle Unicode properly
|
||||||
|
r1 := []rune(s1)
|
||||||
|
r2 := []rune(s2)
|
||||||
|
len1 := len(r1)
|
||||||
|
len2 := len(r2)
|
||||||
|
|
||||||
|
// Only need two rows of the matrix
|
||||||
|
previous := make([]int, len2+1)
|
||||||
|
current := make([]int, len2+1)
|
||||||
|
|
||||||
|
// Initialize first row
|
||||||
|
for j := 0; j <= len2; j++ {
|
||||||
|
previous[j] = j
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate edit distance
|
||||||
|
for i := 1; i <= len1; i++ {
|
||||||
|
current[0] = i
|
||||||
|
|
||||||
|
for j := 1; j <= len2; j++ {
|
||||||
|
cost := 1
|
||||||
|
if r1[i-1] == r2[j-1] {
|
||||||
|
cost = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
current[j] = min(
|
||||||
|
previous[j]+1, // deletion
|
||||||
|
current[j-1]+1, // insertion
|
||||||
|
previous[j-1]+cost, // substitution
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap rows
|
||||||
|
previous, current = current, previous
|
||||||
|
}
|
||||||
|
|
||||||
|
return previous[len2]
|
||||||
|
}
|
||||||
|
|
||||||
|
// DamerauLevenshteinDistance computes Damerau-Levenshtein distance (includes transpositions).
|
||||||
|
// This is more accurate for typos where adjacent characters are swapped.
|
||||||
|
func DamerauLevenshteinDistance(s1, s2 string) int {
|
||||||
|
if s1 == s2 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if len(s1) == 0 {
|
||||||
|
return len(s2)
|
||||||
|
}
|
||||||
|
if len(s2) == 0 {
|
||||||
|
return len(s1)
|
||||||
|
}
|
||||||
|
|
||||||
|
r1 := []rune(s1)
|
||||||
|
r2 := []rune(s2)
|
||||||
|
len1 := len(r1)
|
||||||
|
len2 := len(r2)
|
||||||
|
|
||||||
|
// Create distance matrix
|
||||||
|
d := make([][]int, len1+1)
|
||||||
|
for i := range d {
|
||||||
|
d[i] = make([]int, len2+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize first row and column
|
||||||
|
for i := 0; i <= len1; i++ {
|
||||||
|
d[i][0] = i
|
||||||
|
}
|
||||||
|
for j := 0; j <= len2; j++ {
|
||||||
|
d[0][j] = j
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate distances
|
||||||
|
for i := 1; i <= len1; i++ {
|
||||||
|
for j := 1; j <= len2; j++ {
|
||||||
|
cost := 1
|
||||||
|
if r1[i-1] == r2[j-1] {
|
||||||
|
cost = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
d[i][j] = min(
|
||||||
|
d[i-1][j]+1, // deletion
|
||||||
|
d[i][j-1]+1, // insertion
|
||||||
|
d[i-1][j-1]+cost, // substitution
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check for transposition
|
||||||
|
if i > 1 && j > 1 && r1[i-1] == r2[j-2] && r1[i-2] == r2[j-1] {
|
||||||
|
d[i][j] = min(d[i][j], d[i-2][j-2]+cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return d[len1][len2]
|
||||||
|
}
|
||||||
|
|
||||||
|
// JaroWinklerSimilarity computes Jaro-Winkler similarity (0.0 to 1.0).
|
||||||
|
// Better for short strings and names.
|
||||||
|
func JaroWinklerSimilarity(s1, s2 string) float64 {
|
||||||
|
if s1 == s2 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
r1 := []rune(s1)
|
||||||
|
r2 := []rune(s2)
|
||||||
|
|
||||||
|
if len(r1) == 0 || len(r2) == 0 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate Jaro similarity first
|
||||||
|
jaro := jaroSimilarity(r1, r2)
|
||||||
|
|
||||||
|
// Calculate common prefix length (up to 4 characters)
|
||||||
|
prefixLen := 0
|
||||||
|
for i := 0; i < min(min(len(r1), len(r2)), 4); i++ {
|
||||||
|
if r1[i] == r2[i] {
|
||||||
|
prefixLen++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Jaro-Winkler adds bonus for common prefix
|
||||||
|
const p = 0.1
|
||||||
|
return jaro + float64(prefixLen)*p*(1.0-jaro)
|
||||||
|
}
|
||||||
|
|
||||||
|
// jaroSimilarity computes Jaro similarity.
|
||||||
|
func jaroSimilarity(r1, r2 []rune) float64 {
|
||||||
|
len1 := len(r1)
|
||||||
|
len2 := len(r2)
|
||||||
|
|
||||||
|
// Maximum allowed distance
|
||||||
|
matchDist := max(len1, len2)/2 - 1
|
||||||
|
if matchDist < 0 {
|
||||||
|
matchDist = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
matched1 := make([]bool, len1)
|
||||||
|
matched2 := make([]bool, len2)
|
||||||
|
|
||||||
|
matches := 0
|
||||||
|
transpositions := 0
|
||||||
|
|
||||||
|
// Find matches
|
||||||
|
for i := range len1 {
|
||||||
|
start := max(0, i-matchDist)
|
||||||
|
end := min(i+matchDist+1, len2)
|
||||||
|
|
||||||
|
for j := start; j < end; j++ {
|
||||||
|
if matched2[j] || r1[i] != r2[j] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
matched1[i] = true
|
||||||
|
matched2[j] = true
|
||||||
|
matches++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matches == 0 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count transpositions
|
||||||
|
k := 0
|
||||||
|
for i := range len1 {
|
||||||
|
if !matched1[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for !matched2[k] {
|
||||||
|
k++
|
||||||
|
}
|
||||||
|
if r1[i] != r2[k] {
|
||||||
|
transpositions++
|
||||||
|
}
|
||||||
|
k++
|
||||||
|
}
|
||||||
|
|
||||||
|
return (float64(matches)/float64(len1) +
|
||||||
|
float64(matches)/float64(len2) +
|
||||||
|
float64(matches-transpositions/2)/float64(matches)) / 3.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsWordBoundary checks if query appears at word boundaries in text.
|
||||||
|
func containsWordBoundary(text, query string) bool {
|
||||||
|
textLower := strings.ToLower(text)
|
||||||
|
queryLower := strings.ToLower(query)
|
||||||
|
|
||||||
|
idx := strings.Index(textLower, queryLower)
|
||||||
|
if idx == -1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if match is at start
|
||||||
|
if idx == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for underscore or non-alphanumeric boundary
|
||||||
|
prevRune := rune(text[idx-1])
|
||||||
|
if !unicode.IsLetter(prevRune) && !unicode.IsDigit(prevRune) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for camelCase boundary (lowercase before uppercase)
|
||||||
|
if idx > 0 && len(text) > idx {
|
||||||
|
curr := rune(text[idx])
|
||||||
|
prev := rune(text[idx-1])
|
||||||
|
if unicode.IsLower(prev) && unicode.IsUpper(curr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func min(values ...int) int {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
m := values[0]
|
||||||
|
for _, v := range values[1:] {
|
||||||
|
if v < m {
|
||||||
|
m = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func max(values ...int) int {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
m := values[0]
|
||||||
|
for _, v := range values[1:] {
|
||||||
|
if v > m {
|
||||||
|
m = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func abs(x int) int {
|
||||||
|
if x < 0 {
|
||||||
|
return -x
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
@@ -0,0 +1,275 @@
|
|||||||
|
package fuzzy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLevenshteinDistance(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
s1 string
|
||||||
|
s2 string
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"", "", 0},
|
||||||
|
{"", "abc", 3},
|
||||||
|
{"abc", "", 3},
|
||||||
|
{"abc", "abc", 0},
|
||||||
|
{"abc", "abd", 1},
|
||||||
|
{"kitten", "sitting", 3},
|
||||||
|
{"saturday", "sunday", 3},
|
||||||
|
{"book", "back", 2},
|
||||||
|
{"café", "cafe", 1}, // Unicode handling
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := levenshteinDistance(tt.s1, tt.s2)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("levenshteinDistance(%q, %q) = %d, want %d", tt.s1, tt.s2, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDamerauLevenshteinDistance(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
s1 string
|
||||||
|
s2 string
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{"abc", "abc", 0},
|
||||||
|
{"abc", "acb", 1}, // Transposition
|
||||||
|
{"ca", "abc", 3}, // Delete a, delete b, insert c = 3 operations
|
||||||
|
{"", "abc", 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := DamerauLevenshteinDistance(tt.s1, tt.s2)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("DamerauLevenshteinDistance(%q, %q) = %d, want %d", tt.s1, tt.s2, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJaroWinklerSimilarity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
s1 string
|
||||||
|
s2 string
|
||||||
|
minScore float64 // Minimum expected similarity
|
||||||
|
}{
|
||||||
|
{"", "", 1.0},
|
||||||
|
{"abc", "abc", 1.0},
|
||||||
|
{"martha", "marhta", 0.96}, // High similarity for transposition
|
||||||
|
{"dixon", "dicksonx", 0.76}, // Moderate similarity
|
||||||
|
{"", "abc", 0.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := JaroWinklerSimilarity(tt.s1, tt.s2)
|
||||||
|
if got < tt.minScore {
|
||||||
|
t.Errorf("JaroWinklerSimilarity(%q, %q) = %.2f, want >= %.2f", tt.s1, tt.s2, got, tt.minScore)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatcher_Match(t *testing.T) {
|
||||||
|
m := New(2) // Allow edit distance up to 2
|
||||||
|
|
||||||
|
candidates := []string{
|
||||||
|
"getUserName",
|
||||||
|
"getUsername",
|
||||||
|
"get_user_name",
|
||||||
|
"getUserId",
|
||||||
|
"setUserName",
|
||||||
|
"findUser",
|
||||||
|
"userName",
|
||||||
|
"usernameField",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
query string
|
||||||
|
topMatch string
|
||||||
|
expectMin int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
query: "getUserName",
|
||||||
|
expectMin: 3, // Exact + similar variants
|
||||||
|
topMatch: "getUserName",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
query: "getuser",
|
||||||
|
expectMin: 2, // Should match getUserName, getUsername at minimum
|
||||||
|
topMatch: "getUserName",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
query: "username",
|
||||||
|
expectMin: 2, // Case-insensitive matches
|
||||||
|
topMatch: "userName",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
matches := m.Match(tt.query, candidates)
|
||||||
|
|
||||||
|
if len(matches) < tt.expectMin {
|
||||||
|
t.Errorf("Match(%q) returned %d matches, want at least %d", tt.query, len(matches), tt.expectMin)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matches) > 0 {
|
||||||
|
// Top match should have highest score
|
||||||
|
if matches[0].Score < matches[len(matches)-1].Score {
|
||||||
|
t.Errorf("Match(%q) results not sorted by score", tt.query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatcher_EmptyQuery(t *testing.T) {
|
||||||
|
m := New(2)
|
||||||
|
candidates := []string{"test", "example"}
|
||||||
|
|
||||||
|
matches := m.Match("", candidates)
|
||||||
|
if matches != nil {
|
||||||
|
t.Errorf("Match with empty query should return nil, got %v", matches)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatcher_PrefixBonus(t *testing.T) {
|
||||||
|
m := New(2)
|
||||||
|
candidates := []string{
|
||||||
|
"getUserName", // prefix match
|
||||||
|
"findUserName", // contains but not prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := m.Match("get", candidates)
|
||||||
|
|
||||||
|
if len(matches) < 1 {
|
||||||
|
t.Fatal("Expected at least one match")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefix match should score higher
|
||||||
|
if matches[0].Text != "getUserName" {
|
||||||
|
t.Errorf("Expected prefix match to rank first, got %q", matches[0].Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatcher_ExactMatchBonus(t *testing.T) {
|
||||||
|
m := New(2)
|
||||||
|
candidates := []string{
|
||||||
|
"test",
|
||||||
|
"testing",
|
||||||
|
"tester",
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := m.Match("test", candidates)
|
||||||
|
|
||||||
|
if len(matches) < 1 {
|
||||||
|
t.Fatal("Expected at least one match")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact match should rank first
|
||||||
|
if matches[0].Text != "test" {
|
||||||
|
t.Errorf("Expected exact match to rank first, got %q", matches[0].Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact match should have highest score
|
||||||
|
if matches[0].Score < 2.0 { // Should have exact match bonus
|
||||||
|
t.Errorf("Exact match score too low: %.2f", matches[0].Score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContainsWordBoundary(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
text string
|
||||||
|
query string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"getUserName", "get", true}, // At start
|
||||||
|
{"getUserName", "user", true}, // After lowercase->uppercase boundary
|
||||||
|
{"get_user_name", "user", true}, // After underscore
|
||||||
|
{"getUserName", "Name", true}, // After lowercase->uppercase
|
||||||
|
{"getUserName", "ser", false}, // Middle of word
|
||||||
|
{"", "test", false}, // Empty text
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := containsWordBoundary(tt.text, tt.query)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("containsWordBoundary(%q, %q) = %v, want %v", tt.text, tt.query, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatcher_UnicodeHandling(t *testing.T) {
|
||||||
|
m := New(2)
|
||||||
|
candidates := []string{
|
||||||
|
"café",
|
||||||
|
"resume",
|
||||||
|
"naïve",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with Unicode characters
|
||||||
|
matches := m.Match("cafe", candidates)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
t.Error("Expected matches for Unicode strings")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should find café with small edit distance
|
||||||
|
found := false
|
||||||
|
for _, match := range matches {
|
||||||
|
if match.Text == "café" && match.Distance <= 2 {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Error("Failed to fuzzy match Unicode string 'café'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLevenshteinDistance(b *testing.B) {
|
||||||
|
s1 := "the quick brown fox jumps over the lazy dog"
|
||||||
|
s2 := "the quikc brown fox jumps ovver the lazy dog"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := range b.N {
|
||||||
|
_ = levenshteinDistance(s1, s2)
|
||||||
|
_ = i // use i to avoid unused warning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDamerauLevenshteinDistance(b *testing.B) {
|
||||||
|
s1 := "the quick brown fox jumps over the lazy dog"
|
||||||
|
s2 := "the quikc brown fox jumps ovver the lazy dog"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := range b.N {
|
||||||
|
_ = DamerauLevenshteinDistance(s1, s2)
|
||||||
|
_ = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkJaroWinklerSimilarity(b *testing.B) {
|
||||||
|
s1 := "martha"
|
||||||
|
s2 := "marhta"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := range b.N {
|
||||||
|
_ = JaroWinklerSimilarity(s1, s2)
|
||||||
|
_ = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMatcher_Match(b *testing.B) {
|
||||||
|
m := New(2)
|
||||||
|
candidates := []string{
|
||||||
|
"getUserName", "getUsername", "get_user_name", "getUserId",
|
||||||
|
"setUserName", "findUser", "userName", "usernameField",
|
||||||
|
"userAccount", "accountUser", "userProfile", "profileUser",
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := range b.N {
|
||||||
|
_ = m.Match("getuser", candidates)
|
||||||
|
_ = i
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
// Package protocol defines shared types used across the MCP file operations server.
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
// Location represents a position in a file.
|
||||||
|
type Location struct {
|
||||||
|
File string `json:"file"`
|
||||||
|
Line int `json:"line"`
|
||||||
|
Column int `json:"column"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Range represents a range in a file.
|
||||||
|
type Range struct {
|
||||||
|
Start Location `json:"start"`
|
||||||
|
End Location `json:"end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SymbolKind represents the kind of a symbol.
|
||||||
|
type SymbolKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SymbolFunction SymbolKind = "function"
|
||||||
|
SymbolMethod SymbolKind = "method"
|
||||||
|
SymbolClass SymbolKind = "class"
|
||||||
|
SymbolStruct SymbolKind = "struct"
|
||||||
|
SymbolInterface SymbolKind = "interface"
|
||||||
|
SymbolVariable SymbolKind = "variable"
|
||||||
|
SymbolConstant SymbolKind = "constant"
|
||||||
|
SymbolType SymbolKind = "type"
|
||||||
|
SymbolField SymbolKind = "field"
|
||||||
|
SymbolProperty SymbolKind = "property"
|
||||||
|
SymbolModule SymbolKind = "module"
|
||||||
|
SymbolPackage SymbolKind = "package"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Symbol represents a code symbol (function, class, variable, etc.).
|
||||||
|
type Symbol struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Kind SymbolKind `json:"kind"`
|
||||||
|
Doc string `json:"doc,omitempty"`
|
||||||
|
Location Location `json:"location"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyntaxError represents a syntax error in a file.
|
||||||
|
type SyntaxError struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Location Location `json:"location"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Language represents a programming language.
|
||||||
|
type Language string
|
||||||
|
|
||||||
|
const (
|
||||||
|
LangGo Language = "go"
|
||||||
|
LangTypeScript Language = "typescript"
|
||||||
|
LangJavaScript Language = "javascript"
|
||||||
|
LangPython Language = "python"
|
||||||
|
LangC Language = "c"
|
||||||
|
LangCpp Language = "cpp"
|
||||||
|
LangHTML Language = "html"
|
||||||
|
LangVue Language = "vue"
|
||||||
|
LangJSON Language = "json"
|
||||||
|
LangYAML Language = "yaml"
|
||||||
|
LangUnknown Language = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DetectLanguage detects the language from a filename.
|
||||||
|
func DetectLanguage(filename string) Language {
|
||||||
|
ext := getExtension(filename)
|
||||||
|
switch ext {
|
||||||
|
case ".go":
|
||||||
|
return LangGo
|
||||||
|
case ".ts", ".tsx":
|
||||||
|
return LangTypeScript
|
||||||
|
case ".js", ".jsx", ".mjs", ".cjs":
|
||||||
|
return LangJavaScript
|
||||||
|
case ".py", ".pyw":
|
||||||
|
return LangPython
|
||||||
|
case ".c", ".h":
|
||||||
|
return LangC
|
||||||
|
case ".cpp", ".cc", ".cxx", ".hpp", ".hxx":
|
||||||
|
return LangCpp
|
||||||
|
case ".html", ".htm":
|
||||||
|
return LangHTML
|
||||||
|
case ".vue":
|
||||||
|
return LangVue
|
||||||
|
case ".json":
|
||||||
|
return LangJSON
|
||||||
|
case ".yaml", ".yml":
|
||||||
|
return LangYAML
|
||||||
|
default:
|
||||||
|
return LangUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExtension(filename string) string {
|
||||||
|
for i := len(filename) - 1; i >= 0; i-- {
|
||||||
|
if filename[i] == '.' {
|
||||||
|
return filename[i:]
|
||||||
|
}
|
||||||
|
if filename[i] == '/' || filename[i] == '\\' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package protocol
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDetectLanguage(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filename string
|
||||||
|
expected Language
|
||||||
|
}{
|
||||||
|
{"main.go", LangGo},
|
||||||
|
{"server.go", LangGo},
|
||||||
|
{"index.ts", LangTypeScript},
|
||||||
|
{"component.tsx", LangTypeScript},
|
||||||
|
{"Button.tsx", LangTypeScript},
|
||||||
|
{"app.js", LangJavaScript},
|
||||||
|
{"component.jsx", LangJavaScript},
|
||||||
|
{"Component.jsx", LangJavaScript},
|
||||||
|
{"module.mjs", LangJavaScript},
|
||||||
|
{"common.cjs", LangJavaScript},
|
||||||
|
{"script.py", LangPython},
|
||||||
|
{"app.pyw", LangPython},
|
||||||
|
{"main.c", LangC},
|
||||||
|
{"header.h", LangC},
|
||||||
|
{"main.cpp", LangCpp},
|
||||||
|
{"main.cc", LangCpp},
|
||||||
|
{"main.cxx", LangCpp},
|
||||||
|
{"header.hpp", LangCpp},
|
||||||
|
{"header.hxx", LangCpp},
|
||||||
|
{"index.html", LangHTML},
|
||||||
|
{"page.htm", LangHTML},
|
||||||
|
{"Component.vue", LangVue},
|
||||||
|
{"unknown.txt", LangUnknown},
|
||||||
|
{"README", LangUnknown},
|
||||||
|
{"path/to/file.go", LangGo},
|
||||||
|
{"path/to/file.ts", LangTypeScript},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.filename, func(t *testing.T) {
|
||||||
|
result := DetectLanguage(tt.filename)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("DetectLanguage(%q) = %q, want %q", tt.filename, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetExtension(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filename string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"file.go", ".go"},
|
||||||
|
{"file.test.go", ".go"},
|
||||||
|
{"path/to/file.ts", ".ts"},
|
||||||
|
{"noextension", ""},
|
||||||
|
{".hidden", ".hidden"},
|
||||||
|
{"file.", "."},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.filename, func(t *testing.T) {
|
||||||
|
result := getExtension(tt.filename)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("getExtension(%q) = %q, want %q", tt.filename, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Vendored
+55
@@ -0,0 +1,55 @@
|
|||||||
|
/**
|
||||||
|
* @file header.h
|
||||||
|
* @brief Sample header file for testing.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef HEADER_H
|
||||||
|
#define HEADER_H
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Maximum buffer size.
|
||||||
|
*/
|
||||||
|
#define MAX_BUFFER_SIZE 1024
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Status codes for operations.
|
||||||
|
*/
|
||||||
|
typedef enum {
|
||||||
|
STATUS_OK = 0,
|
||||||
|
STATUS_ERROR = 1,
|
||||||
|
STATUS_NOT_FOUND = 2
|
||||||
|
} Status;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Buffer structure for data storage.
|
||||||
|
*/
|
||||||
|
typedef struct {
|
||||||
|
char data[MAX_BUFFER_SIZE];
|
||||||
|
int length;
|
||||||
|
} Buffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Initialize a buffer.
|
||||||
|
* @param buf Pointer to the buffer
|
||||||
|
*/
|
||||||
|
void buffer_init(Buffer* buf);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Write data to the buffer.
|
||||||
|
* @param buf Pointer to the buffer
|
||||||
|
* @param data Data to write
|
||||||
|
* @param len Length of data
|
||||||
|
* @return Status code
|
||||||
|
*/
|
||||||
|
Status buffer_write(Buffer* buf, const char* data, int len);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Read data from the buffer.
|
||||||
|
* @param buf Pointer to the buffer
|
||||||
|
* @param out Output buffer
|
||||||
|
* @param max_len Maximum length to read
|
||||||
|
* @return Number of bytes read
|
||||||
|
*/
|
||||||
|
int buffer_read(Buffer* buf, char* out, int max_len);
|
||||||
|
|
||||||
|
#endif /* HEADER_H */
|
||||||
Vendored
+57
@@ -0,0 +1,57 @@
|
|||||||
|
/**
|
||||||
|
* @file valid.c
|
||||||
|
* @brief Sample C file for testing.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A simple point structure.
|
||||||
|
*/
|
||||||
|
struct Point {
|
||||||
|
int x;
|
||||||
|
int y;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Creates a new point.
|
||||||
|
* @param x The x coordinate
|
||||||
|
* @param y The y coordinate
|
||||||
|
* @return A new Point structure
|
||||||
|
*/
|
||||||
|
struct Point create_point(int x, int y) {
|
||||||
|
struct Point p;
|
||||||
|
p.x = x;
|
||||||
|
p.y = y;
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates the distance from origin.
|
||||||
|
* @param p The point
|
||||||
|
* @return The squared distance from origin
|
||||||
|
*/
|
||||||
|
int distance_squared(struct Point p) {
|
||||||
|
return p.x * p.x + p.y * p.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Prints a point to stdout.
|
||||||
|
* @param p The point to print
|
||||||
|
*/
|
||||||
|
void print_point(struct Point p) {
|
||||||
|
printf("Point(%d, %d)\n", p.x, p.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple helper function
|
||||||
|
int add(int a, int b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(void) {
|
||||||
|
struct Point p = create_point(3, 4);
|
||||||
|
print_point(p);
|
||||||
|
printf("Distance squared: %d\n", distance_squared(p));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Vendored
+11
@@ -0,0 +1,11 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// This file contains intentional syntax errors for testing.
|
||||||
|
|
||||||
|
func broken( {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Incomplete struct {
|
||||||
|
Name string
|
||||||
|
// Missing closing brace
|
||||||
Vendored
+44
@@ -0,0 +1,44 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// Server represents the main application server.
|
||||||
|
type Server struct {
|
||||||
|
Name string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new Server instance.
|
||||||
|
func NewServer(name string, port int) *Server {
|
||||||
|
return &Server{
|
||||||
|
Name: name,
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the server.
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
fmt.Printf("Starting server %s on port %d\n", s.Name, s.Port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds application configuration.
|
||||||
|
type Config struct {
|
||||||
|
Debug bool
|
||||||
|
Timeout int
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultPort is the default server port.
|
||||||
|
DefaultPort = 8080
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Version is the application version.
|
||||||
|
Version = "1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
srv := NewServer("main", DefaultPort)
|
||||||
|
srv.Start()
|
||||||
|
}
|
||||||
Vendored
+29
@@ -0,0 +1,29 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Test HTML File</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container mx-auto px-4">
|
||||||
|
<h1 class="text-3xl font-bold text-blue-600">Hello World</h1>
|
||||||
|
<p class="text-gray-700 mt-4">This is a test HTML file with Tailwind CSS classes.</p>
|
||||||
|
|
||||||
|
<div class="flex gap-4 mt-8">
|
||||||
|
<button class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded">
|
||||||
|
Primary Button
|
||||||
|
</button>
|
||||||
|
<button class="bg-gray-500 hover:bg-gray-700 text-white font-bold py-2 px-4 rounded">
|
||||||
|
Secondary Button
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<ul class="list-disc list-inside mt-4">
|
||||||
|
<li>First item</li>
|
||||||
|
<li>Second item</li>
|
||||||
|
<li>Third item</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
Vendored
+62
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Sample Python module for testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessor:
|
||||||
|
"""Processes data records."""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
"""
|
||||||
|
Initialize the processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The processor name
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self._records: List[dict] = []
|
||||||
|
|
||||||
|
def add_record(self, record: dict) -> None:
|
||||||
|
"""
|
||||||
|
Add a record to the processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: The record to add
|
||||||
|
"""
|
||||||
|
self._records.append(record)
|
||||||
|
|
||||||
|
def process(self) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Process all records.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed records
|
||||||
|
"""
|
||||||
|
return [self._transform(r) for r in self._records]
|
||||||
|
|
||||||
|
def _transform(self, record: dict) -> dict:
|
||||||
|
"""Transform a single record."""
|
||||||
|
return {k.upper(): v for k, v in record.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sum(numbers: List[int]) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the sum of numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
numbers: List of integers to sum
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of all numbers
|
||||||
|
"""
|
||||||
|
return sum(numbers)
|
||||||
|
|
||||||
|
|
||||||
|
def find_maximum(values: List[int]) -> Optional[int]:
|
||||||
|
"""Find the maximum value in a list."""
|
||||||
|
return max(values) if values else None
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BATCH_SIZE = 100
|
||||||
Vendored
+129
@@ -0,0 +1,129 @@
|
|||||||
|
import React, { useState, useEffect } from 'react';
|
||||||
|
|
||||||
|
interface ButtonProps {
|
||||||
|
variant?: 'primary' | 'secondary';
|
||||||
|
disabled?: boolean;
|
||||||
|
onClick?: () => void;
|
||||||
|
children: React.ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A reusable button component with Tailwind CSS styling
|
||||||
|
*/
|
||||||
|
export const Button: React.FC<ButtonProps> = ({
|
||||||
|
variant = 'primary',
|
||||||
|
disabled = false,
|
||||||
|
onClick,
|
||||||
|
children
|
||||||
|
}) => {
|
||||||
|
const baseClasses = 'font-bold py-2 px-4 rounded transition-colors duration-200';
|
||||||
|
const variantClasses = {
|
||||||
|
primary: 'bg-blue-500 hover:bg-blue-700 text-white',
|
||||||
|
secondary: 'bg-gray-500 hover:bg-gray-700 text-white'
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
className={`${baseClasses} ${variantClasses[variant]} ${disabled ? 'opacity-50 cursor-not-allowed' : ''}`}
|
||||||
|
disabled={disabled}
|
||||||
|
onClick={onClick}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
interface TodoItem {
|
||||||
|
id: number;
|
||||||
|
text: string;
|
||||||
|
completed: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Todo list component demonstrating React hooks and Tailwind
|
||||||
|
*/
|
||||||
|
export const TodoList: React.FC = () => {
|
||||||
|
const [todos, setTodos] = useState<TodoItem[]>([
|
||||||
|
{ id: 1, text: 'Learn React', completed: true },
|
||||||
|
{ id: 2, text: 'Learn TypeScript', completed: true },
|
||||||
|
{ id: 3, text: 'Build amazing apps', completed: false }
|
||||||
|
]);
|
||||||
|
const [inputValue, setInputValue] = useState('');
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
console.log('Todos updated:', todos);
|
||||||
|
}, [todos]);
|
||||||
|
|
||||||
|
const addTodo = () => {
|
||||||
|
if (inputValue.trim()) {
|
||||||
|
const newTodo: TodoItem = {
|
||||||
|
id: Date.now(),
|
||||||
|
text: inputValue,
|
||||||
|
completed: false
|
||||||
|
};
|
||||||
|
setTodos([...todos, newTodo]);
|
||||||
|
setInputValue('');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const toggleTodo = (id: number) => {
|
||||||
|
setTodos(todos.map(todo =>
|
||||||
|
todo.id === id ? { ...todo, completed: !todo.completed } : todo
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
const deleteTodo = (id: number) => {
|
||||||
|
setTodos(todos.filter(todo => todo.id !== id));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="container mx-auto px-4 py-8 max-w-2xl">
|
||||||
|
<h1 className="text-3xl font-bold text-gray-800 mb-6">
|
||||||
|
My Todo List
|
||||||
|
</h1>
|
||||||
|
|
||||||
|
<div className="flex gap-2 mb-6">
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={inputValue}
|
||||||
|
onChange={(e) => setInputValue(e.target.value)}
|
||||||
|
onKeyPress={(e) => e.key === 'Enter' && addTodo()}
|
||||||
|
placeholder="Add a new todo..."
|
||||||
|
className="flex-1 px-4 py-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
<Button onClick={addTodo}>Add</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<ul className="space-y-2">
|
||||||
|
{todos.map(todo => (
|
||||||
|
<li
|
||||||
|
key={todo.id}
|
||||||
|
className="flex items-center gap-3 p-4 bg-white rounded-lg shadow-sm hover:shadow-md transition-shadow"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={todo.completed}
|
||||||
|
onChange={() => toggleTodo(todo.id)}
|
||||||
|
className="w-5 h-5 text-blue-600 rounded focus:ring-2 focus:ring-blue-500"
|
||||||
|
/>
|
||||||
|
<span className={`flex-1 ${todo.completed ? 'line-through text-gray-400' : 'text-gray-700'}`}>
|
||||||
|
{todo.text}
|
||||||
|
</span>
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
onClick={() => deleteTodo(todo.id)}
|
||||||
|
>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
{todos.length === 0 && (
|
||||||
|
<div className="text-center py-12 text-gray-400">
|
||||||
|
No todos yet. Add one above!
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
Vendored
+53
@@ -0,0 +1,53 @@
|
|||||||
|
/**
|
||||||
|
* Represents a user in the system.
|
||||||
|
*/
|
||||||
|
interface User {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
email: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration options for the application.
|
||||||
|
*/
|
||||||
|
type Config = {
|
||||||
|
debug: boolean;
|
||||||
|
timeout: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Greeting service for handling user greetings.
|
||||||
|
*/
|
||||||
|
class GreetingService {
|
||||||
|
private prefix: string;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new GreetingService.
|
||||||
|
* @param prefix The greeting prefix
|
||||||
|
*/
|
||||||
|
constructor(prefix: string) {
|
||||||
|
this.prefix = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Greets a user.
|
||||||
|
* @param user The user to greet
|
||||||
|
* @returns The greeting message
|
||||||
|
*/
|
||||||
|
greet(user: User): string {
|
||||||
|
return `${this.prefix}, ${user.name}!`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Formats a user for display.
|
||||||
|
* @param user The user to format
|
||||||
|
* @returns Formatted string
|
||||||
|
*/
|
||||||
|
function formatUser(user: User): string {
|
||||||
|
return `${user.name} <${user.email}>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_TIMEOUT = 5000;
|
||||||
|
|
||||||
|
export { User, Config, GreetingService, formatUser, DEFAULT_TIMEOUT };
|
||||||
Vendored
+76
@@ -0,0 +1,76 @@
|
|||||||
|
<template>
|
||||||
|
<div class="container mx-auto px-4 py-8">
|
||||||
|
<h1 class="text-3xl font-bold text-blue-600 mb-4">
|
||||||
|
{{ title }}
|
||||||
|
</h1>
|
||||||
|
|
||||||
|
<div v-if="showContent" class="bg-white shadow-md rounded-lg p-6">
|
||||||
|
<p class="text-gray-700 mb-4">{{ description }}</p>
|
||||||
|
|
||||||
|
<div class="flex gap-4 mt-4">
|
||||||
|
<button
|
||||||
|
@click="handlePrimary"
|
||||||
|
:class="['bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded', { 'opacity-50': isLoading }]"
|
||||||
|
:disabled="isLoading"
|
||||||
|
>
|
||||||
|
{{ primaryButtonText }}
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
|
@click="handleSecondary"
|
||||||
|
class="bg-gray-500 hover:bg-gray-700 text-white font-bold py-2 px-4 rounded"
|
||||||
|
>
|
||||||
|
{{ secondaryButtonText }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<ul v-for="item in items" :key="item.id" class="list-disc list-inside mt-4">
|
||||||
|
<li class="text-gray-600">{{ item.name }}</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else class="text-center text-gray-500">
|
||||||
|
<p>No content to display</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed } from 'vue';
|
||||||
|
|
||||||
|
interface Item {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const title = ref('Vue Component with Tailwind');
|
||||||
|
const description = ref('This is a sample Vue 3 component using Composition API and Tailwind CSS');
|
||||||
|
const showContent = ref(true);
|
||||||
|
const isLoading = ref(false);
|
||||||
|
|
||||||
|
const items = ref<Item[]>([
|
||||||
|
{ id: 1, name: 'First item' },
|
||||||
|
{ id: 2, name: 'Second item' },
|
||||||
|
{ id: 3, name: 'Third item' },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const primaryButtonText = computed(() => isLoading.value ? 'Loading...' : 'Primary Action');
|
||||||
|
const secondaryButtonText = ref('Secondary Action');
|
||||||
|
|
||||||
|
const handlePrimary = () => {
|
||||||
|
isLoading.value = true;
|
||||||
|
setTimeout(() => {
|
||||||
|
isLoading.value = false;
|
||||||
|
}, 2000);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSecondary = () => {
|
||||||
|
console.log('Secondary button clicked');
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.container {
|
||||||
|
max-width: 1200px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
Reference in New Issue
Block a user