Files
filepuff-mcp/internal/query/query_test.go
T
2026-01-18 18:40:26 +00:00

560 lines
11 KiB
Go

package query
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestParsePattern(t *testing.T) {
tests := []struct {
name string
pattern string
captureNames []string
captureTypes []CaptureType
wantCaptures int
wantErr bool
}{
{
name: "empty pattern",
pattern: "",
wantErr: true,
wantCaptures: 0,
},
{
name: "single capture",
pattern: "func $NAME() {}",
wantErr: false,
wantCaptures: 1,
captureNames: []string{"NAME"},
captureTypes: []CaptureType{CaptureSingle},
},
{
name: "multiple single captures",
pattern: "func $NAME($ARGS) $RETURN",
wantErr: false,
wantCaptures: 3,
captureNames: []string{"NAME", "ARGS", "RETURN"},
captureTypes: []CaptureType{CaptureSingle, CaptureSingle, CaptureSingle},
},
{
name: "multi-node capture",
pattern: "func $NAME($$$ARGS) { $$$BODY }",
wantErr: false,
wantCaptures: 3,
captureNames: []string{"ARGS", "BODY", "NAME"},
captureTypes: []CaptureType{CaptureMultiple, CaptureMultiple, CaptureSingle},
},
{
name: "wildcard capture",
pattern: "func $NAME($_) {}",
wantErr: false,
wantCaptures: 2,
captureNames: []string{"NAME", "_"},
captureTypes: []CaptureType{CaptureSingle, CaptureWildcard},
},
{
name: "no captures",
pattern: "func main() {}",
wantErr: false,
wantCaptures: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParsePattern(tt.pattern)
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(parsed.Captures) != tt.wantCaptures {
t.Errorf("expected %d captures, got %d", tt.wantCaptures, len(parsed.Captures))
}
// Check capture names (order may vary)
if tt.captureNames != nil {
captureMap := make(map[string]CaptureType)
for _, cap := range parsed.Captures {
captureMap[cap.Name] = cap.Type
}
for i, name := range tt.captureNames {
if _, ok := captureMap[name]; !ok {
t.Errorf("expected capture %s not found", name)
}
if captureMap[name] != tt.captureTypes[i] {
t.Errorf("capture %s: expected type %v, got %v", name, tt.captureTypes[i], captureMap[name])
}
}
}
})
}
}
func TestMatchGoFunctions(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
func Hello() {
println("hello")
}
func Greet(name string) error {
println("hello", name)
return nil
}
type Server struct {
Port int
}
func (s *Server) Start() error {
return nil
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMatches int
}{
{
name: "match all functions",
query: &ASTQuery{
Pattern: "func $NAME($$$ARGS) { $$$BODY }",
Language: "go",
},
wantMatches: 3, // Hello, Greet, Start
},
{
name: "match functions starting with H",
query: &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: QueryFilters{
NameMatches: "^H",
},
},
wantMatches: 1, // Hello
},
{
name: "match specific function",
query: &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: QueryFilters{
NameExact: "Hello",
},
},
wantMatches: 1, // Hello
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
for i, r := range results {
t.Logf("match %d: %s at line %d", i, r.Node.Type(), r.Location.Line)
}
}
})
}
}
func TestMatchGoStructs(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
type Server struct {
Port int
Host string
}
type Config struct {
Timeout int
}
type Logger interface {
Log(msg string)
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMinimum int
}{
{
name: "match all structs",
query: &ASTQuery{
Pattern: "type $NAME struct { $$$FIELDS }",
Language: "go",
},
wantMinimum: 2, // Server, Config (may also match interface as type_declaration)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) < tt.wantMinimum {
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
}
})
}
}
func TestMatchJSFunctions(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `
function greet(name) {
console.log("Hello, " + name);
}
function sayHello() {
console.log("Hello!");
}
class User {
constructor(name) {
this.name = name;
}
getName() {
return this.name;
}
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.js", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMatches int
}{
{
name: "match all functions",
query: &ASTQuery{
Pattern: "function $NAME($$$ARGS) { $$$BODY }",
Language: "javascript",
},
wantMatches: 2, // greet, sayHello
},
{
name: "match classes",
query: &ASTQuery{
Pattern: "class $NAME { $$$BODY }",
Language: "javascript",
},
wantMatches: 1, // User
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.js")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
}
})
}
}
func TestMatchPythonSymbols(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `
def greet(name):
print(f"Hello, {name}")
def calculate(a, b):
return a + b
class User:
def __init__(self, name):
self.name = name
def get_name(self):
return self.name
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.py", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMinimum int
}{
{
name: "match classes",
query: &ASTQuery{
Pattern: "class $NAME: $$$BODY",
Language: "python",
},
wantMinimum: 1, // User
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.py")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) < tt.wantMinimum {
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
}
})
}
}
func TestQueryFilters(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
func HelloWorld() {}
func helloWorld() {}
func GoodbyeWorld() {}
func Main() {}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
name string
filters QueryFilters
wantMatches int
}{
{
name: "regex filter - starts with H",
filters: QueryFilters{
NameMatches: "^[Hh]ello",
},
wantMatches: 2, // HelloWorld, helloWorld
},
{
name: "exact name filter",
filters: QueryFilters{
NameExact: "Main",
},
wantMatches: 1, // Main
},
{
name: "kind filter",
filters: QueryFilters{
KindIn: []string{"function_declaration"},
},
wantMatches: 4, // all functions
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: tt.filters,
}
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
for _, r := range results {
if nameNode := r.Node.ChildByFieldName("name"); nameNode != nil {
t.Logf("matched: %s", parser.GetNodeText(nameNode, []byte(content)))
}
}
}
})
}
}
func TestFormatResults(t *testing.T) {
tests := []struct {
name string
results []MatchResult
maxResults int
wantEmpty bool
}{
{
name: "empty results",
results: []MatchResult{},
maxResults: 100,
wantEmpty: true,
},
{
name: "single result",
results: []MatchResult{
{
File: "test.go",
Location: protocol.Location{Line: 10, Column: 1},
Text: "func Hello() {}",
Captures: map[string]CapturedNode{
"NAME": {Text: "Hello"},
},
},
},
maxResults: 100,
wantEmpty: false,
},
{
name: "truncated results",
results: []MatchResult{
{File: "a.go", Location: protocol.Location{Line: 1}, Text: "func A() {}"},
{File: "b.go", Location: protocol.Location{Line: 1}, Text: "func B() {}"},
{File: "c.go", Location: protocol.Location{Line: 1}, Text: "func C() {}"},
},
maxResults: 2,
wantEmpty: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output := FormatResults(tt.results, tt.maxResults)
if tt.wantEmpty {
if output != "No matches found." {
t.Errorf("expected 'No matches found.', got: %s", output)
}
} else {
if output == "No matches found." {
t.Error("expected results, got 'No matches found.'")
}
}
})
}
}
func TestQueryValidation(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
ctx := context.Background()
// Parse some valid content
content := `package main
func main() {}
`
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantErr bool
}{
{
name: "empty pattern",
query: &ASTQuery{Pattern: "", Language: "go"},
wantErr: true,
},
{
name: "missing language",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: ""},
wantErr: true,
},
{
name: "unknown language",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "unknown"},
wantErr: true,
},
{
name: "valid query",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "go"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}