mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-05 22:23:50 +00:00
560 lines
11 KiB
Go
560 lines
11 KiB
Go
package query
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
|
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
|
)
|
|
|
|
func TestParsePattern(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
pattern string
|
|
captureNames []string
|
|
captureTypes []CaptureType
|
|
wantCaptures int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "empty pattern",
|
|
pattern: "",
|
|
wantErr: true,
|
|
wantCaptures: 0,
|
|
},
|
|
{
|
|
name: "single capture",
|
|
pattern: "func $NAME() {}",
|
|
wantErr: false,
|
|
wantCaptures: 1,
|
|
captureNames: []string{"NAME"},
|
|
captureTypes: []CaptureType{CaptureSingle},
|
|
},
|
|
{
|
|
name: "multiple single captures",
|
|
pattern: "func $NAME($ARGS) $RETURN",
|
|
wantErr: false,
|
|
wantCaptures: 3,
|
|
captureNames: []string{"NAME", "ARGS", "RETURN"},
|
|
captureTypes: []CaptureType{CaptureSingle, CaptureSingle, CaptureSingle},
|
|
},
|
|
{
|
|
name: "multi-node capture",
|
|
pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
|
wantErr: false,
|
|
wantCaptures: 3,
|
|
captureNames: []string{"ARGS", "BODY", "NAME"},
|
|
captureTypes: []CaptureType{CaptureMultiple, CaptureMultiple, CaptureSingle},
|
|
},
|
|
{
|
|
name: "wildcard capture",
|
|
pattern: "func $NAME($_) {}",
|
|
wantErr: false,
|
|
wantCaptures: 2,
|
|
captureNames: []string{"NAME", "_"},
|
|
captureTypes: []CaptureType{CaptureSingle, CaptureWildcard},
|
|
},
|
|
{
|
|
name: "no captures",
|
|
pattern: "func main() {}",
|
|
wantErr: false,
|
|
wantCaptures: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parsed, err := ParsePattern(tt.pattern)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Error("expected error")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if len(parsed.Captures) != tt.wantCaptures {
|
|
t.Errorf("expected %d captures, got %d", tt.wantCaptures, len(parsed.Captures))
|
|
}
|
|
|
|
// Check capture names (order may vary)
|
|
if tt.captureNames != nil {
|
|
captureMap := make(map[string]CaptureType)
|
|
for _, cap := range parsed.Captures {
|
|
captureMap[cap.Name] = cap.Type
|
|
}
|
|
|
|
for i, name := range tt.captureNames {
|
|
if _, ok := captureMap[name]; !ok {
|
|
t.Errorf("expected capture %s not found", name)
|
|
}
|
|
if captureMap[name] != tt.captureTypes[i] {
|
|
t.Errorf("capture %s: expected type %v, got %v", name, tt.captureTypes[i], captureMap[name])
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchGoFunctions(t *testing.T) {
|
|
reg := parser.NewRegistry()
|
|
defer reg.Close()
|
|
|
|
matcher := NewMatcher(reg)
|
|
|
|
content := `package main
|
|
|
|
func Hello() {
|
|
println("hello")
|
|
}
|
|
|
|
func Greet(name string) error {
|
|
println("hello", name)
|
|
return nil
|
|
}
|
|
|
|
type Server struct {
|
|
Port int
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
return nil
|
|
}
|
|
`
|
|
|
|
ctx := context.Background()
|
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
query *ASTQuery
|
|
name string
|
|
wantMatches int
|
|
}{
|
|
{
|
|
name: "match all functions",
|
|
query: &ASTQuery{
|
|
Pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
|
Language: "go",
|
|
},
|
|
wantMatches: 3, // Hello, Greet, Start
|
|
},
|
|
{
|
|
name: "match functions starting with H",
|
|
query: &ASTQuery{
|
|
Pattern: "func $NAME() {}",
|
|
Language: "go",
|
|
Filters: QueryFilters{
|
|
NameMatches: "^H",
|
|
},
|
|
},
|
|
wantMatches: 1, // Hello
|
|
},
|
|
{
|
|
name: "match specific function",
|
|
query: &ASTQuery{
|
|
Pattern: "func $NAME() {}",
|
|
Language: "go",
|
|
Filters: QueryFilters{
|
|
NameExact: "Hello",
|
|
},
|
|
},
|
|
wantMatches: 1, // Hello
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
|
if err != nil {
|
|
t.Fatalf("match failed: %v", err)
|
|
}
|
|
|
|
if len(results) != tt.wantMatches {
|
|
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
|
for i, r := range results {
|
|
t.Logf("match %d: %s at line %d", i, r.Node.Type(), r.Location.Line)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchGoStructs(t *testing.T) {
|
|
reg := parser.NewRegistry()
|
|
defer reg.Close()
|
|
|
|
matcher := NewMatcher(reg)
|
|
|
|
content := `package main
|
|
|
|
type Server struct {
|
|
Port int
|
|
Host string
|
|
}
|
|
|
|
type Config struct {
|
|
Timeout int
|
|
}
|
|
|
|
type Logger interface {
|
|
Log(msg string)
|
|
}
|
|
`
|
|
|
|
ctx := context.Background()
|
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
query *ASTQuery
|
|
name string
|
|
wantMinimum int
|
|
}{
|
|
{
|
|
name: "match all structs",
|
|
query: &ASTQuery{
|
|
Pattern: "type $NAME struct { $$$FIELDS }",
|
|
Language: "go",
|
|
},
|
|
wantMinimum: 2, // Server, Config (may also match interface as type_declaration)
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
|
if err != nil {
|
|
t.Fatalf("match failed: %v", err)
|
|
}
|
|
|
|
if len(results) < tt.wantMinimum {
|
|
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchJSFunctions(t *testing.T) {
|
|
reg := parser.NewRegistry()
|
|
defer reg.Close()
|
|
|
|
matcher := NewMatcher(reg)
|
|
|
|
content := `
|
|
function greet(name) {
|
|
console.log("Hello, " + name);
|
|
}
|
|
|
|
function sayHello() {
|
|
console.log("Hello!");
|
|
}
|
|
|
|
class User {
|
|
constructor(name) {
|
|
this.name = name;
|
|
}
|
|
|
|
getName() {
|
|
return this.name;
|
|
}
|
|
}
|
|
`
|
|
|
|
ctx := context.Background()
|
|
result, err := reg.Parse(ctx, "test.js", []byte(content))
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
query *ASTQuery
|
|
name string
|
|
wantMatches int
|
|
}{
|
|
{
|
|
name: "match all functions",
|
|
query: &ASTQuery{
|
|
Pattern: "function $NAME($$$ARGS) { $$$BODY }",
|
|
Language: "javascript",
|
|
},
|
|
wantMatches: 2, // greet, sayHello
|
|
},
|
|
{
|
|
name: "match classes",
|
|
query: &ASTQuery{
|
|
Pattern: "class $NAME { $$$BODY }",
|
|
Language: "javascript",
|
|
},
|
|
wantMatches: 1, // User
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.js")
|
|
if err != nil {
|
|
t.Fatalf("match failed: %v", err)
|
|
}
|
|
|
|
if len(results) != tt.wantMatches {
|
|
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMatchPythonSymbols(t *testing.T) {
|
|
reg := parser.NewRegistry()
|
|
defer reg.Close()
|
|
|
|
matcher := NewMatcher(reg)
|
|
|
|
content := `
|
|
def greet(name):
|
|
print(f"Hello, {name}")
|
|
|
|
def calculate(a, b):
|
|
return a + b
|
|
|
|
class User:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def get_name(self):
|
|
return self.name
|
|
`
|
|
|
|
ctx := context.Background()
|
|
result, err := reg.Parse(ctx, "test.py", []byte(content))
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
query *ASTQuery
|
|
name string
|
|
wantMinimum int
|
|
}{
|
|
{
|
|
name: "match classes",
|
|
query: &ASTQuery{
|
|
Pattern: "class $NAME: $$$BODY",
|
|
Language: "python",
|
|
},
|
|
wantMinimum: 1, // User
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.py")
|
|
if err != nil {
|
|
t.Fatalf("match failed: %v", err)
|
|
}
|
|
|
|
if len(results) < tt.wantMinimum {
|
|
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestQueryFilters(t *testing.T) {
|
|
reg := parser.NewRegistry()
|
|
defer reg.Close()
|
|
|
|
matcher := NewMatcher(reg)
|
|
|
|
content := `package main
|
|
|
|
func HelloWorld() {}
|
|
func helloWorld() {}
|
|
func GoodbyeWorld() {}
|
|
func Main() {}
|
|
`
|
|
|
|
ctx := context.Background()
|
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
filters QueryFilters
|
|
wantMatches int
|
|
}{
|
|
{
|
|
name: "regex filter - starts with H",
|
|
filters: QueryFilters{
|
|
NameMatches: "^[Hh]ello",
|
|
},
|
|
wantMatches: 2, // HelloWorld, helloWorld
|
|
},
|
|
{
|
|
name: "exact name filter",
|
|
filters: QueryFilters{
|
|
NameExact: "Main",
|
|
},
|
|
wantMatches: 1, // Main
|
|
},
|
|
{
|
|
name: "kind filter",
|
|
filters: QueryFilters{
|
|
KindIn: []string{"function_declaration"},
|
|
},
|
|
wantMatches: 4, // all functions
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
query := &ASTQuery{
|
|
Pattern: "func $NAME() {}",
|
|
Language: "go",
|
|
Filters: tt.filters,
|
|
}
|
|
|
|
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
|
|
if err != nil {
|
|
t.Fatalf("match failed: %v", err)
|
|
}
|
|
|
|
if len(results) != tt.wantMatches {
|
|
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
|
for _, r := range results {
|
|
if nameNode := r.Node.ChildByFieldName("name"); nameNode != nil {
|
|
t.Logf("matched: %s", parser.GetNodeText(nameNode, []byte(content)))
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFormatResults(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
results []MatchResult
|
|
maxResults int
|
|
wantEmpty bool
|
|
}{
|
|
{
|
|
name: "empty results",
|
|
results: []MatchResult{},
|
|
maxResults: 100,
|
|
wantEmpty: true,
|
|
},
|
|
{
|
|
name: "single result",
|
|
results: []MatchResult{
|
|
{
|
|
File: "test.go",
|
|
Location: protocol.Location{Line: 10, Column: 1},
|
|
Text: "func Hello() {}",
|
|
Captures: map[string]CapturedNode{
|
|
"NAME": {Text: "Hello"},
|
|
},
|
|
},
|
|
},
|
|
maxResults: 100,
|
|
wantEmpty: false,
|
|
},
|
|
{
|
|
name: "truncated results",
|
|
results: []MatchResult{
|
|
{File: "a.go", Location: protocol.Location{Line: 1}, Text: "func A() {}"},
|
|
{File: "b.go", Location: protocol.Location{Line: 1}, Text: "func B() {}"},
|
|
{File: "c.go", Location: protocol.Location{Line: 1}, Text: "func C() {}"},
|
|
},
|
|
maxResults: 2,
|
|
wantEmpty: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
output := FormatResults(tt.results, tt.maxResults)
|
|
|
|
if tt.wantEmpty {
|
|
if output != "No matches found." {
|
|
t.Errorf("expected 'No matches found.', got: %s", output)
|
|
}
|
|
} else {
|
|
if output == "No matches found." {
|
|
t.Error("expected results, got 'No matches found.'")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestQueryValidation(t *testing.T) {
|
|
reg := parser.NewRegistry()
|
|
defer reg.Close()
|
|
|
|
matcher := NewMatcher(reg)
|
|
ctx := context.Background()
|
|
|
|
// Parse some valid content
|
|
content := `package main
|
|
func main() {}
|
|
`
|
|
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
|
if err != nil {
|
|
t.Fatalf("parse failed: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
query *ASTQuery
|
|
name string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "empty pattern",
|
|
query: &ASTQuery{Pattern: "", Language: "go"},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing language",
|
|
query: &ASTQuery{Pattern: "func $NAME() {}", Language: ""},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "unknown language",
|
|
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "unknown"},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid query",
|
|
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "go"},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Error("expected error")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|