mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-13 02:51:20 +00:00
Ho hum.
This commit is contained in:
@@ -0,0 +1,203 @@
|
||||
package edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
)
|
||||
|
||||
// TestConcurrentEditLocking tests that concurrent edits to the same file are properly serialized.
|
||||
func TestConcurrentEditLocking(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
// Create initial file
|
||||
initialContent := `package main
|
||||
|
||||
func main() {
|
||||
println("hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
registry := parser.NewRegistry()
|
||||
engine := NewEngine(registry)
|
||||
|
||||
// Run 10 concurrent edits
|
||||
const numEdits = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numEdits)
|
||||
|
||||
errors := make(chan error, numEdits)
|
||||
|
||||
for i := 0; i < numEdits; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: testFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Kind: "function_declaration",
|
||||
Name: "main",
|
||||
},
|
||||
NewContent: `func main() {
|
||||
println("edit ` + string(rune(i)) + `")
|
||||
}`,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent edit failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify file wasn't corrupted
|
||||
finalContent, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Parse to ensure it's still valid Go
|
||||
_, err = registry.Parse(context.Background(), testFile, finalContent)
|
||||
if err != nil {
|
||||
t.Errorf("File corrupted after concurrent edits: %v\nContent:\n%s", err, string(finalContent))
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentEditDifferentFiles tests that concurrent edits to different files don't block each other.
|
||||
func TestConcurrentEditDifferentFiles(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
registry := parser.NewRegistry()
|
||||
engine := NewEngine(registry)
|
||||
|
||||
const numFiles = 5
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numFiles)
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < numFiles; i++ {
|
||||
i := i
|
||||
testFile := filepath.Join(tmpDir, fmt.Sprintf("test%d.go", i))
|
||||
|
||||
// Create initial file
|
||||
initialContent := `package main
|
||||
|
||||
func test() {
|
||||
println("initial")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Wait for all goroutines to be ready
|
||||
<-startBarrier
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: testFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Kind: "function_declaration",
|
||||
Name: "test",
|
||||
},
|
||||
NewContent: `func test() {
|
||||
println("modified")
|
||||
}`,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Errorf("Edit failed for %s: %v", testFile, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Errorf("Edit unsuccessful for %s: %s", testFile, result.Error)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Release all goroutines simultaneously
|
||||
close(startBarrier)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestFileLockRelease tests that file locks are properly released after edits.
|
||||
func TestFileLockRelease(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
initialContent := `package main
|
||||
|
||||
func test() {}
|
||||
`
|
||||
if err := os.WriteFile(testFile, []byte(initialContent), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
registry := parser.NewRegistry()
|
||||
engine := NewEngine(registry)
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: testFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Kind: "function_declaration",
|
||||
Name: "test",
|
||||
},
|
||||
NewContent: `func test() { println("updated") }`,
|
||||
}
|
||||
|
||||
// First edit
|
||||
ctx := context.Background()
|
||||
result1, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result1.Success {
|
||||
t.Fatalf("First edit failed: %s", result1.Error)
|
||||
}
|
||||
|
||||
// Second edit should succeed (lock was released)
|
||||
result2, err := engine.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result2.Success {
|
||||
t.Fatalf("Second edit failed (lock not released?): %s", result2.Error)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,757 @@
|
||||
// Package edit provides AST-aware file editing capabilities.
|
||||
package edit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// Global regex cache for compiled patterns (thread-safe)
|
||||
var regexCache sync.Map // string -> *regexp.Regexp
|
||||
|
||||
// compileRegex compiles a regex pattern with caching for performance.
|
||||
func compileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
// Check cache first
|
||||
if cached, ok := regexCache.Load(pattern); ok {
|
||||
return cached.(*regexp.Regexp), nil
|
||||
}
|
||||
|
||||
// Compile and cache
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
regexCache.Store(pattern, re)
|
||||
return re, nil
|
||||
}
|
||||
|
||||
// EditOperation defines the type of edit operation.
|
||||
type EditOperation string
|
||||
|
||||
const (
|
||||
EditReplace EditOperation = "replace"
|
||||
EditInsertBefore EditOperation = "insert_before"
|
||||
EditInsertAfter EditOperation = "insert_after"
|
||||
EditDelete EditOperation = "delete"
|
||||
)
|
||||
|
||||
// ASTEdit represents an AST-aware edit request.
|
||||
type ASTEdit struct {
|
||||
File string `json:"file"`
|
||||
Operation EditOperation `json:"operation"`
|
||||
NewContent string `json:"new_content,omitempty"`
|
||||
Selector ASTSelector `json:"selector"`
|
||||
}
|
||||
|
||||
// ASTSelector specifies how to find the target node.
|
||||
type ASTSelector struct {
|
||||
Kind string `json:"kind,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
TextPattern string `json:"text_pattern,omitempty"`
|
||||
AtLine int `json:"at_line,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
LineEnd int `json:"line_end,omitempty"`
|
||||
}
|
||||
|
||||
// EditResult contains the result of an edit operation.
|
||||
type EditResult struct {
|
||||
Diff string `json:"diff,omitempty"`
|
||||
OriginalContent string `json:"original_content,omitempty"`
|
||||
NewContent string `json:"new_content,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
Applied bool `json:"applied"`
|
||||
}
|
||||
|
||||
// Engine performs AST-aware edits.
|
||||
type Engine struct {
|
||||
registry *parser.Registry
|
||||
fileLocks sync.Map // map[string]*sync.Mutex for per-file locking
|
||||
}
|
||||
|
||||
// NewEngine creates a new edit engine.
|
||||
func NewEngine(registry *parser.Registry) *Engine {
|
||||
return &Engine{
|
||||
registry: registry,
|
||||
fileLocks: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// lockFile acquires a lock for the specified file and returns an unlock function.
|
||||
// This prevents concurrent edits to the same file which could cause corruption.
|
||||
func (e *Engine) lockFile(filePath string) func() {
|
||||
// Get or create mutex for this file
|
||||
actual, _ := e.fileLocks.LoadOrStore(filePath, &sync.Mutex{})
|
||||
mu := actual.(*sync.Mutex)
|
||||
mu.Lock()
|
||||
return mu.Unlock
|
||||
}
|
||||
|
||||
// Preview generates a preview of an edit without applying it.
|
||||
func (e *Engine) Preview(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
|
||||
return e.performEdit(ctx, edit, false)
|
||||
}
|
||||
|
||||
// Apply performs an edit and writes the result to disk.
|
||||
// Uses file locking to prevent concurrent edits to the same file.
|
||||
func (e *Engine) Apply(ctx context.Context, edit *ASTEdit) (*EditResult, error) {
|
||||
unlock := e.lockFile(edit.File)
|
||||
defer unlock()
|
||||
return e.performEdit(ctx, edit, true)
|
||||
}
|
||||
|
||||
// performEdit executes an edit operation.
|
||||
func (e *Engine) performEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||
// Determine if we should use text mode
|
||||
useTextMode := e.shouldUseTextMode(edit)
|
||||
|
||||
if useTextMode {
|
||||
return e.performTextEdit(ctx, edit, apply)
|
||||
}
|
||||
return e.performASTEdit(ctx, edit, apply)
|
||||
}
|
||||
|
||||
// shouldUseTextMode determines if text-based editing should be used.
|
||||
func (e *Engine) shouldUseTextMode(edit *ASTEdit) bool {
|
||||
// Use text mode if text-specific selectors are provided
|
||||
if edit.Selector.Text != "" || edit.Selector.TextPattern != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Use text mode if line range is specified without AST selectors
|
||||
if edit.Selector.AtLine > 0 && edit.Selector.LineEnd > 0 &&
|
||||
edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Use text mode if language is not supported for AST
|
||||
lang := protocol.DetectLanguage(edit.File)
|
||||
return lang == protocol.LangUnknown
|
||||
}
|
||||
|
||||
// performASTEdit executes an AST-aware edit operation.
|
||||
func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||
// Validate operation
|
||||
if err := e.validateASTEdit(edit); err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(edit.File)
|
||||
if err != nil {
|
||||
structuredErr := errors.NewFileNotReadableError(edit.File, err)
|
||||
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
|
||||
}
|
||||
|
||||
// Parse file
|
||||
parseResult, err := e.registry.Parse(ctx, edit.File, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Find target node
|
||||
node, err := e.resolveSelector(edit.Selector, parseResult.Tree, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Apply edit
|
||||
newContent, err := e.applyEdit(edit, node, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Validate new content (re-parse)
|
||||
_, err = e.registry.Parse(ctx, edit.File, newContent)
|
||||
if err != nil {
|
||||
structuredErr := errors.NewEditValidationError(edit.File, err)
|
||||
return &EditResult{
|
||||
Success: false,
|
||||
Error: structuredErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate diff
|
||||
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||
|
||||
result := &EditResult{
|
||||
Success: true,
|
||||
Diff: diff,
|
||||
OriginalContent: string(content),
|
||||
NewContent: string(newContent),
|
||||
Applied: false,
|
||||
}
|
||||
|
||||
// Apply changes if requested
|
||||
if apply {
|
||||
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
|
||||
structuredErr := errors.NewFileNotWritableError(edit.File, err)
|
||||
return &EditResult{
|
||||
Success: false,
|
||||
Error: structuredErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
result.Applied = true
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// performTextEdit executes a text-based edit operation for non-AST files.
|
||||
func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) (*EditResult, error) {
|
||||
// Validate operation
|
||||
if err := e.validateTextEdit(edit); err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Read file
|
||||
content, err := os.ReadFile(edit.File)
|
||||
if err != nil {
|
||||
structuredErr := errors.NewFileNotReadableError(edit.File, err)
|
||||
return &EditResult{Success: false, Error: structuredErr.Error()}, nil
|
||||
}
|
||||
|
||||
// Find the text selection (byte range)
|
||||
start, end, err := e.resolveTextSelector(edit.Selector, content)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Apply edit
|
||||
newContent, err := e.applyTextEditOperation(edit.Operation, content, start, end, edit.NewContent)
|
||||
if err != nil {
|
||||
return &EditResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
// Generate diff
|
||||
diff := generateDiff(string(content), string(newContent), edit.File)
|
||||
|
||||
result := &EditResult{
|
||||
Success: true,
|
||||
Diff: diff,
|
||||
OriginalContent: string(content),
|
||||
NewContent: string(newContent),
|
||||
Applied: false,
|
||||
}
|
||||
|
||||
// Apply changes if requested
|
||||
if apply {
|
||||
if err := os.WriteFile(edit.File, newContent, 0600); err != nil {
|
||||
structuredErr := errors.NewFileNotWritableError(edit.File, err)
|
||||
return &EditResult{
|
||||
Success: false,
|
||||
Error: structuredErr.Error(),
|
||||
}, nil
|
||||
}
|
||||
result.Applied = true
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// validateBaseEdit checks common edit request fields.
|
||||
func (e *Engine) validateBaseEdit(edit *ASTEdit) error {
|
||||
if edit.File == "" {
|
||||
return errors.NewInvalidEditError("file is required")
|
||||
}
|
||||
|
||||
if edit.Operation == "" {
|
||||
return errors.NewInvalidEditError("operation is required")
|
||||
}
|
||||
|
||||
// Validate operation type
|
||||
switch edit.Operation {
|
||||
case EditReplace, EditInsertBefore, EditInsertAfter:
|
||||
if edit.NewContent == "" {
|
||||
return errors.NewInvalidEditError(fmt.Sprintf("new_content is required for %s operation", edit.Operation))
|
||||
}
|
||||
case EditDelete:
|
||||
// new_content not required
|
||||
default:
|
||||
return errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateASTEdit checks if an AST edit request is valid.
|
||||
func (e *Engine) validateASTEdit(edit *ASTEdit) error {
|
||||
if err := e.validateBaseEdit(edit); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate AST selector
|
||||
if edit.Selector.Kind == "" && edit.Selector.Name == "" && edit.Selector.Pattern == "" && edit.Selector.AtLine == 0 {
|
||||
return errors.NewInvalidEditError("AST selector must specify at least one of: kind, name, pattern, or at_line")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateTextEdit checks if a text edit request is valid.
|
||||
func (e *Engine) validateTextEdit(edit *ASTEdit) error {
|
||||
if err := e.validateBaseEdit(edit); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate text selector - need at least one text selection method
|
||||
hasTextSelector := edit.Selector.Text != "" ||
|
||||
edit.Selector.TextPattern != "" ||
|
||||
edit.Selector.AtLine > 0
|
||||
|
||||
if !hasTextSelector {
|
||||
return errors.NewInvalidEditError("text selector must specify at least one of: text, text_pattern, or at_line")
|
||||
}
|
||||
|
||||
// Validate regex pattern if provided (uses cached compilation)
|
||||
if edit.Selector.TextPattern != "" {
|
||||
if _, err := compileRegex(edit.Selector.TextPattern); err != nil {
|
||||
return errors.Wrap(errors.ErrInvalidEdit, "invalid text_pattern regex", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSelector finds the target node based on the selector.
|
||||
func (e *Engine) resolveSelector(sel ASTSelector, tree *sitter.Tree, content []byte) (*sitter.Node, error) {
|
||||
if tree == nil {
|
||||
return nil, errors.NewNodeNotFoundError("no AST tree available")
|
||||
}
|
||||
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return nil, errors.NewNodeNotFoundError("empty AST tree")
|
||||
}
|
||||
|
||||
var matches []*sitter.Node
|
||||
|
||||
parser.WalkTree(root, func(n *sitter.Node) bool {
|
||||
if e.matchesSelector(sel, n, content) {
|
||||
matches = append(matches, n)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(matches) == 0 {
|
||||
selectorDesc := fmt.Sprintf("kind=%s name=%s pattern=%s line=%d", sel.Kind, sel.Name, sel.Pattern, sel.AtLine)
|
||||
return nil, errors.NewNodeNotFoundError(selectorDesc)
|
||||
}
|
||||
|
||||
// Use index to select specific match
|
||||
index := sel.Index
|
||||
if index < 0 || index >= len(matches) {
|
||||
return nil, errors.NewInvalidSelectionError(fmt.Sprintf("selector matched %d nodes, but index %d is out of range", len(matches), index))
|
||||
}
|
||||
|
||||
return matches[index], nil
|
||||
}
|
||||
|
||||
// matchesSelector checks if a node matches the selector criteria.
|
||||
func (e *Engine) matchesSelector(sel ASTSelector, n *sitter.Node, content []byte) bool {
|
||||
// Check kind
|
||||
if sel.Kind != "" && n.Type() != sel.Kind {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check name (look for identifier in the node)
|
||||
if sel.Name != "" {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
// Also try to find an identifier child
|
||||
found := false
|
||||
for i := 0; i < int(n.NamedChildCount()); i++ {
|
||||
child := n.NamedChild(i)
|
||||
if child != nil && child.Type() == "identifier" {
|
||||
if parser.GetNodeText(child, content) == sel.Name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
} else if parser.GetNodeText(nameNode, content) != sel.Name {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check line
|
||||
if sel.AtLine > 0 {
|
||||
startLine := int(n.StartPoint().Row) + 1
|
||||
endLine := int(n.EndPoint().Row) + 1
|
||||
if sel.AtLine < startLine || sel.AtLine > endLine {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern matching is handled separately (simplified here)
|
||||
if sel.Pattern != "" {
|
||||
nodeText := parser.GetNodeText(n, content)
|
||||
if !strings.Contains(nodeText, sel.Pattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// applyEdit applies the edit operation to the content.
|
||||
func (e *Engine) applyEdit(edit *ASTEdit, node *sitter.Node, content []byte) ([]byte, error) {
|
||||
startByte := node.StartByte()
|
||||
endByte := node.EndByte()
|
||||
|
||||
// Detect and preserve indentation
|
||||
indentation := detectIndentation(content, startByte)
|
||||
newContent := indentContent(edit.NewContent, indentation)
|
||||
|
||||
var result []byte
|
||||
|
||||
switch edit.Operation {
|
||||
case EditReplace:
|
||||
result = append(result, content[:startByte]...)
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
case EditInsertBefore:
|
||||
result = append(result, content[:startByte]...)
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, content[startByte:]...)
|
||||
|
||||
case EditInsertAfter:
|
||||
result = append(result, content[:endByte]...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(newContent)...)
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
case EditDelete:
|
||||
result = append(result, content[:startByte]...)
|
||||
result = append(result, content[endByte:]...)
|
||||
|
||||
default:
|
||||
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", edit.Operation))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// detectIndentation detects the indentation at a given byte position.
|
||||
func detectIndentation(content []byte, bytePos uint32) string {
|
||||
// Find the start of the line
|
||||
lineStart := int(bytePos)
|
||||
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||
lineStart--
|
||||
}
|
||||
|
||||
// Extract leading whitespace
|
||||
var indent strings.Builder
|
||||
for i := lineStart; i < int(bytePos) && i < len(content); i++ {
|
||||
c := content[i]
|
||||
if c == ' ' || c == '\t' {
|
||||
indent.WriteByte(c)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return indent.String()
|
||||
}
|
||||
|
||||
// indentContent applies indentation to multi-line content.
|
||||
func indentContent(content string, indent string) string {
|
||||
if indent == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
for i, line := range lines {
|
||||
if i > 0 && line != "" {
|
||||
lines[i] = indent + line
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// generateDiff creates a unified diff between original and modified content.
|
||||
// Uses Myers diff algorithm for accurate and readable diffs.
|
||||
func generateDiff(original, modified, filename string) string {
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(original, modified, false)
|
||||
|
||||
// Cleanup for readability
|
||||
diffs = dmp.DiffCleanupSemantic(diffs)
|
||||
|
||||
// Convert to unified diff format
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(fmt.Sprintf("--- %s\n", filename))
|
||||
buf.WriteString(fmt.Sprintf("+++ %s\n", filename))
|
||||
|
||||
// Group diffs into hunks
|
||||
lineNum := 1
|
||||
for _, diff := range diffs {
|
||||
lines := strings.Split(diff.Text, "\n")
|
||||
for i, line := range lines {
|
||||
// Skip empty last line from split
|
||||
if i == len(lines)-1 && line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
switch diff.Type {
|
||||
case diffmatchpatch.DiffDelete:
|
||||
buf.WriteString(fmt.Sprintf("-%s\n", line))
|
||||
case diffmatchpatch.DiffInsert:
|
||||
buf.WriteString(fmt.Sprintf("+%s\n", line))
|
||||
case diffmatchpatch.DiffEqual:
|
||||
buf.WriteString(fmt.Sprintf(" %s\n", line))
|
||||
lineNum++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// resolveTextSelector finds the byte range for a text-based selection.
|
||||
func (e *Engine) resolveTextSelector(sel ASTSelector, content []byte) (start, end int, err error) {
|
||||
switch {
|
||||
case sel.Text != "":
|
||||
return e.findExactText(content, sel.Text, sel.Index)
|
||||
case sel.TextPattern != "":
|
||||
return e.findRegexPattern(content, sel.TextPattern, sel.Index)
|
||||
case sel.AtLine > 0:
|
||||
return e.findLineRange(content, sel.AtLine, sel.LineEnd)
|
||||
default:
|
||||
return 0, 0, errors.NewInvalidEditError("text selector requires text, text_pattern, or at_line")
|
||||
}
|
||||
}
|
||||
|
||||
// findExactText finds an exact text match in content.
|
||||
func (e *Engine) findExactText(content []byte, text string, index int) (start, end int, err error) {
|
||||
if text == "" {
|
||||
return 0, 0, errors.NewInvalidEditError("text selector cannot be empty")
|
||||
}
|
||||
|
||||
textBytes := []byte(text)
|
||||
type match struct{ start, end int }
|
||||
var matches []match
|
||||
|
||||
offset := 0
|
||||
for {
|
||||
idx := bytes.Index(content[offset:], textBytes)
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
matches = append(matches, match{
|
||||
start: offset + idx,
|
||||
end: offset + idx + len(textBytes),
|
||||
})
|
||||
offset += idx + 1
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text not found: %q", truncateString(text, 50)))
|
||||
}
|
||||
|
||||
// If multiple matches and no index specified, require explicit selection
|
||||
if len(matches) > 1 && index == 0 {
|
||||
// Check if index was explicitly set to 0 or just defaulted
|
||||
// Since we can't distinguish, we'll allow index 0 but warn about multiple matches
|
||||
// Actually, let's be strict and require explicit index for multiple matches
|
||||
locations := make([]string, 0, min(len(matches), 5))
|
||||
for i, m := range matches {
|
||||
if i >= 5 {
|
||||
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||
break
|
||||
}
|
||||
line := countLines(content[:m.start]) + 1
|
||||
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||
}
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("text matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||
}
|
||||
|
||||
if index >= len(matches) {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||
}
|
||||
|
||||
return matches[index].start, matches[index].end, nil
|
||||
}
|
||||
|
||||
// findRegexPattern finds a regex pattern match in content.
|
||||
func (e *Engine) findRegexPattern(content []byte, pattern string, index int) (start, end int, err error) {
|
||||
re, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(errors.ErrInvalidEdit, "invalid regex pattern", err)
|
||||
}
|
||||
|
||||
matches := re.FindAllIndex(content, -1)
|
||||
if len(matches) == 0 {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern not found: %q", truncateString(pattern, 50)))
|
||||
}
|
||||
|
||||
// If multiple matches and index is 0 (default), show error with locations
|
||||
if len(matches) > 1 && index == 0 {
|
||||
locations := make([]string, 0, min(len(matches), 5))
|
||||
for i, m := range matches {
|
||||
if i >= 5 {
|
||||
locations = append(locations, fmt.Sprintf("... and %d more", len(matches)-5))
|
||||
break
|
||||
}
|
||||
line := countLines(content[:m[0]]) + 1
|
||||
locations = append(locations, fmt.Sprintf("line %d", line))
|
||||
}
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("pattern matches %d locations (%s); use selector_index to specify which one (0-%d)",
|
||||
len(matches), strings.Join(locations, ", "), len(matches)-1))
|
||||
}
|
||||
|
||||
if index >= len(matches) {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("selector_index %d out of range (found %d matches)", index, len(matches)))
|
||||
}
|
||||
|
||||
return matches[index][0], matches[index][1], nil
|
||||
}
|
||||
|
||||
// findLineRange finds the byte range for a line range selection.
|
||||
func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, end int, err error) {
|
||||
if lineEnd == 0 {
|
||||
lineEnd = lineStart
|
||||
}
|
||||
|
||||
if lineStart < 1 {
|
||||
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line number must be >= 1, got %d", lineStart))
|
||||
}
|
||||
|
||||
if lineEnd < lineStart {
|
||||
return 0, 0, errors.NewInvalidEditError(fmt.Sprintf("line_end (%d) must be >= line (%d)", lineEnd, lineStart))
|
||||
}
|
||||
|
||||
lines := bytes.Split(content, []byte("\n"))
|
||||
totalLines := len(lines)
|
||||
|
||||
// Convert to 0-indexed
|
||||
startIdx := lineStart - 1
|
||||
endIdx := lineEnd - 1
|
||||
|
||||
if startIdx >= totalLines {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line %d out of range (file has %d lines)", lineStart, totalLines))
|
||||
}
|
||||
if endIdx >= totalLines {
|
||||
return 0, 0, errors.NewInvalidSelectionError(fmt.Sprintf("line_end %d out of range (file has %d lines)", lineEnd, totalLines))
|
||||
}
|
||||
|
||||
// Calculate byte positions
|
||||
start = 0
|
||||
for i := 0; i < startIdx; i++ {
|
||||
start += len(lines[i]) + 1 // +1 for newline
|
||||
}
|
||||
|
||||
end = start
|
||||
for i := startIdx; i <= endIdx; i++ {
|
||||
end += len(lines[i])
|
||||
if i < totalLines-1 {
|
||||
end += 1 // newline
|
||||
}
|
||||
}
|
||||
|
||||
return start, end, nil
|
||||
}
|
||||
|
||||
// applyTextEditOperation applies a text edit operation.
|
||||
func (e *Engine) applyTextEditOperation(op EditOperation, content []byte, start, end int, newContent string) ([]byte, error) {
|
||||
// Detect indentation at the selection point
|
||||
indentation := detectIndentationAtByte(content, start)
|
||||
indentedContent := indentContent(newContent, indentation)
|
||||
|
||||
var result []byte
|
||||
|
||||
switch op {
|
||||
case EditReplace:
|
||||
result = append(result, content[:start]...)
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
case EditInsertBefore:
|
||||
result = append(result, content[:start]...)
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, content[start:]...)
|
||||
|
||||
case EditInsertAfter:
|
||||
result = append(result, content[:end]...)
|
||||
result = append(result, '\n')
|
||||
result = append(result, []byte(indentedContent)...)
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
case EditDelete:
|
||||
result = append(result, content[:start]...)
|
||||
result = append(result, content[end:]...)
|
||||
|
||||
default:
|
||||
return nil, errors.NewInvalidEditError(fmt.Sprintf("unknown operation: %s", op))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// detectIndentationAtByte detects indentation at a byte position.
|
||||
func detectIndentationAtByte(content []byte, bytePos int) string {
|
||||
// Find the start of the line
|
||||
lineStart := bytePos
|
||||
for lineStart > 0 && content[lineStart-1] != '\n' {
|
||||
lineStart--
|
||||
}
|
||||
|
||||
// Extract leading whitespace
|
||||
var indent strings.Builder
|
||||
for i := lineStart; i < bytePos && i < len(content); i++ {
|
||||
c := content[i]
|
||||
if c == ' ' || c == '\t' {
|
||||
indent.WriteByte(c)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return indent.String()
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen with ellipsis.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// countLines counts the number of newlines in content.
|
||||
func countLines(content []byte) int {
|
||||
return bytes.Count(content, []byte("\n"))
|
||||
}
|
||||
|
||||
// ValidateLanguage checks if AST editing is supported for a file.
|
||||
// Returns nil for supported languages, error for unsupported.
|
||||
// Note: Text-based editing is always available regardless of this check.
|
||||
func ValidateLanguage(filename string) error {
|
||||
lang := protocol.DetectLanguage(filename)
|
||||
if lang == protocol.LangUnknown {
|
||||
return fmt.Errorf("unsupported file type for AST editing: %s (text-based editing is available)", filename)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,836 @@
|
||||
package edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
)
|
||||
|
||||
func TestValidateEdit(t *testing.T) {
|
||||
e := NewEngine(parser.NewRegistry())
|
||||
|
||||
tests := []struct {
|
||||
edit *ASTEdit
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid replace",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid delete",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditDelete,
|
||||
Selector: ASTSelector{Name: "oldFunc"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing file",
|
||||
edit: &ASTEdit{
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing operation",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "replace without content",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown operation",
|
||||
edit: &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: "unknown",
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewFunc() {}",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := e.validateASTEdit(tt.edit)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSelector(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
content := []byte(`package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
|
||||
func Goodbye() {
|
||||
println("goodbye")
|
||||
}
|
||||
`)
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := registry.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sel ASTSelector
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "by kind",
|
||||
sel: ASTSelector{Kind: "function_declaration"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "by name",
|
||||
sel: ASTSelector{Name: "Hello"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "by kind and name",
|
||||
sel: ASTSelector{Kind: "function_declaration", Name: "Goodbye"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "by line",
|
||||
sel: ASTSelector{AtLine: 3},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
sel: ASTSelector{Name: "NonExistent"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "index out of range",
|
||||
sel: ASTSelector{Kind: "function_declaration", Index: 10},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
node, err := e.resolveSelector(tt.sel, result.Tree, content)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if node == nil {
|
||||
t.Error("expected node")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEdit(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
content := []byte(`package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
`)
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := registry.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
operation EditOperation
|
||||
newCode string
|
||||
wantIn string // substring that should be in result
|
||||
}{
|
||||
{
|
||||
name: "replace",
|
||||
operation: EditReplace,
|
||||
newCode: "func NewHello() {}",
|
||||
wantIn: "NewHello",
|
||||
},
|
||||
{
|
||||
name: "insert after",
|
||||
operation: EditInsertAfter,
|
||||
newCode: "func After() {}",
|
||||
wantIn: "After",
|
||||
},
|
||||
{
|
||||
name: "insert before",
|
||||
operation: EditInsertBefore,
|
||||
newCode: "func Before() {}",
|
||||
wantIn: "Before",
|
||||
},
|
||||
{
|
||||
name: "delete",
|
||||
operation: EditDelete,
|
||||
newCode: "",
|
||||
wantIn: "package main", // Should still have package declaration
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Find the function node
|
||||
node, err := e.resolveSelector(ASTSelector{Kind: "function_declaration"}, result.Tree, content)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve failed: %v", err)
|
||||
}
|
||||
|
||||
edit := &ASTEdit{
|
||||
File: "test.go",
|
||||
Operation: tt.operation,
|
||||
NewContent: tt.newCode,
|
||||
}
|
||||
|
||||
newContent, err := e.applyEdit(edit, node, content)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(newContent), tt.wantIn) {
|
||||
t.Errorf("result does not contain %q:\n%s", tt.wantIn, string(newContent))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreview(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
|
||||
}
|
||||
|
||||
result, err := e.Preview(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("preview failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("preview was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
if result.Applied {
|
||||
t.Error("preview should not apply changes")
|
||||
}
|
||||
|
||||
if result.Diff == "" {
|
||||
t.Error("expected diff in result")
|
||||
}
|
||||
|
||||
// Verify original file is unchanged
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if string(fileContent) != content {
|
||||
t.Error("original file was modified during preview")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyToFile(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.go")
|
||||
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Kind: "function_declaration"},
|
||||
NewContent: "func NewHello() {\n\tprintln(\"new hello\")\n}",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
if !result.Applied {
|
||||
t.Error("apply should set Applied=true")
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "NewHello") {
|
||||
t.Error("file was not modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectIndentation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
pos uint32
|
||||
}{
|
||||
{
|
||||
name: "no indent",
|
||||
content: "func main() {}",
|
||||
pos: 0,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "tab indent",
|
||||
content: "func main() {\n\tprintln(\"hello\")\n}",
|
||||
pos: 15,
|
||||
want: "\t",
|
||||
},
|
||||
{
|
||||
name: "space indent",
|
||||
content: "func main() {\n println(\"hello\")\n}",
|
||||
pos: 18,
|
||||
want: " ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := detectIndentation([]byte(tt.content), tt.pos)
|
||||
if got != tt.want {
|
||||
t.Errorf("detectIndentation() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDiff(t *testing.T) {
|
||||
original := "line1\nline2\nline3"
|
||||
modified := "line1\nmodified\nline3"
|
||||
filename := "test.txt"
|
||||
|
||||
diff := generateDiff(original, modified, filename)
|
||||
|
||||
if !strings.Contains(diff, "---") {
|
||||
t.Error("diff should contain --- header")
|
||||
}
|
||||
if !strings.Contains(diff, "+++") {
|
||||
t.Error("diff should contain +++ header")
|
||||
}
|
||||
if !strings.Contains(diff, "-line2") {
|
||||
t.Error("diff should show removed line")
|
||||
}
|
||||
if !strings.Contains(diff, "+modified") {
|
||||
t.Error("diff should show added line")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Text-based editing tests ====================
|
||||
|
||||
func TestTextEditWithExactText(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp markdown file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "README.md")
|
||||
|
||||
content := `# My Project
|
||||
|
||||
## Installation
|
||||
|
||||
Run the following command:
|
||||
|
||||
## Usage
|
||||
|
||||
See the docs.
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Text: "## Installation"},
|
||||
NewContent: "## Getting Started",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "## Getting Started") {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
if strings.Contains(string(fileContent), "## Installation") {
|
||||
t.Error("old text should be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditWithLineRange(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp config file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `name: myapp
|
||||
version: 1.0.0
|
||||
database:
|
||||
host: localhost
|
||||
port: 5432
|
||||
logging:
|
||||
level: debug
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
AtLine: 3,
|
||||
LineEnd: 5,
|
||||
},
|
||||
NewContent: "database:\n host: production.db.example.com\n port: 5433",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "production.db.example.com") {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditWithRegexPattern(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp JSON file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "package.json")
|
||||
|
||||
content := `{
|
||||
"name": "my-package",
|
||||
"version": "1.0.0",
|
||||
"description": "A test package"
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{TextPattern: `"version":\s*"[^"]+"`},
|
||||
NewContent: `"version": "2.0.0"`,
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), `"version": "2.0.0"`) {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditInsertAfter(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp env file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
content := `DATABASE_URL=postgres://localhost/mydb
|
||||
SECRET_KEY=abc123
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditInsertAfter,
|
||||
Selector: ASTSelector{Text: "DATABASE_URL=postgres://localhost/mydb"},
|
||||
NewContent: "REDIS_URL=redis://localhost:6379",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify file was modified
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
if !strings.Contains(string(fileContent), "REDIS_URL=redis://localhost:6379") {
|
||||
t.Error("file was not modified correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditMultipleMatchesError(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file with repeated text
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := `TODO: fix this
|
||||
some code here
|
||||
TODO: also fix this
|
||||
more code
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Text: "TODO"},
|
||||
NewContent: "DONE",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
// Should fail because of multiple matches
|
||||
if result.Success {
|
||||
t.Error("expected error for multiple matches without index")
|
||||
}
|
||||
if !strings.Contains(result.Error, "matches") {
|
||||
t.Errorf("error should mention multiple matches: %s", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextEditWithIndex(t *testing.T) {
|
||||
registry := parser.NewRegistry()
|
||||
defer registry.Close()
|
||||
e := NewEngine(registry)
|
||||
|
||||
// Create a temp file with repeated text
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := `TODO: fix this
|
||||
some code here
|
||||
TODO: also fix this
|
||||
more code
|
||||
`
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
edit := &ASTEdit{
|
||||
File: tmpFile,
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{
|
||||
Text: "TODO",
|
||||
Index: 1, // Select second match
|
||||
},
|
||||
NewContent: "DONE",
|
||||
}
|
||||
|
||||
result, err := e.Apply(ctx, edit)
|
||||
if err != nil {
|
||||
t.Fatalf("apply failed: %v", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
t.Fatalf("apply was not successful: %s", result.Error)
|
||||
}
|
||||
|
||||
// Verify only second TODO was replaced
|
||||
fileContent, _ := os.ReadFile(tmpFile)
|
||||
contentStr := string(fileContent)
|
||||
if !strings.Contains(contentStr, "TODO: fix this") {
|
||||
t.Error("first TODO should not be replaced")
|
||||
}
|
||||
if !strings.Contains(contentStr, "DONE: also fix this") {
|
||||
t.Error("second TODO should be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTextEdit(t *testing.T) {
|
||||
e := NewEngine(parser.NewRegistry())
|
||||
|
||||
tests := []struct {
|
||||
edit *ASTEdit
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid text selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{Text: "some text"},
|
||||
NewContent: "new text",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid pattern selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{TextPattern: "\\d+"},
|
||||
NewContent: "replaced",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid line selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{AtLine: 5},
|
||||
NewContent: "new line",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty selector",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{},
|
||||
NewContent: "new text",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid regex pattern",
|
||||
edit: &ASTEdit{
|
||||
File: "test.md",
|
||||
Operation: EditReplace,
|
||||
Selector: ASTSelector{TextPattern: "[invalid"},
|
||||
NewContent: "new text",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := e.validateTextEdit(tt.edit)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindLineRange(t *testing.T) {
|
||||
e := NewEngine(parser.NewRegistry())
|
||||
|
||||
content := []byte("line1\nline2\nline3\nline4\nline5")
|
||||
|
||||
// Content: "line1\nline2\nline3\nline4\nline5" (no trailing newline)
|
||||
// Positions: line1=0-5, \n=5, line2=6-10, \n=11, line3=12-16, \n=17, line4=18-22, \n=23, line5=24-28
|
||||
tests := []struct {
|
||||
name string
|
||||
lineStart int
|
||||
lineEnd int
|
||||
wantStart int
|
||||
wantEnd int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single line",
|
||||
lineStart: 2,
|
||||
lineEnd: 0, // defaults to lineStart
|
||||
wantStart: 6,
|
||||
wantEnd: 12, // includes trailing newline
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "range of lines",
|
||||
lineStart: 2,
|
||||
lineEnd: 4,
|
||||
wantStart: 6,
|
||||
wantEnd: 24, // through end of line4 including newline
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "first line",
|
||||
lineStart: 1,
|
||||
lineEnd: 1,
|
||||
wantStart: 0,
|
||||
wantEnd: 6, // includes trailing newline
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "line out of range",
|
||||
lineStart: 10,
|
||||
lineEnd: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid line number",
|
||||
lineStart: 0,
|
||||
lineEnd: 1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "end before start",
|
||||
lineStart: 3,
|
||||
lineEnd: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
start, end, err := e.findLineRange(content, tt.lineStart, tt.lineEnd)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if start != tt.wantStart {
|
||||
t.Errorf("start = %d, want %d", start, tt.wantStart)
|
||||
}
|
||||
if end != tt.wantEnd {
|
||||
t.Errorf("end = %d, want %d", end, tt.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user