This commit is contained in:
2026-01-18 18:40:26 +00:00
commit 185e73da47
51 changed files with 14073 additions and 0 deletions
+538
View File
@@ -0,0 +1,538 @@
// Package query implements a hybrid AST query language with pattern matching.
package query
import (
"context"
"fmt"
"regexp"
"strings"
"sync"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// Global regex cache for compiled patterns (thread-safe)
var regexCache sync.Map // string -> *regexp.Regexp
// compileRegex compiles a regex pattern with caching for performance.
// Cached patterns avoid repeated compilation overhead (10-50x speedup).
// Thread-safe: uses LoadOrStore to prevent race conditions.
func compileRegex(pattern string) (*regexp.Regexp, error) {
// Check cache first
if cached, ok := regexCache.Load(pattern); ok {
return cached.(*regexp.Regexp), nil
}
// Compile regex
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
// Try to store - if another goroutine already stored it, use theirs
// This prevents race conditions where multiple goroutines compile the same pattern
actual, _ := regexCache.LoadOrStore(pattern, re)
return actual.(*regexp.Regexp), nil
}
// ASTQuery defines a query for matching AST patterns.
type ASTQuery struct {
Pattern string `json:"pattern"` // code pattern with $VAR placeholders
Language string `json:"language"` // required
Filters QueryFilters `json:"filters,omitempty"`
}
// QueryFilters provide additional filtering criteria.
type QueryFilters struct {
HasChild *ASTQuery `json:"has_child,omitempty"`
HasParent *ASTQuery `json:"has_parent,omitempty"`
NameMatches string `json:"name_matches,omitempty"`
NameExact string `json:"name_exact,omitempty"`
InFile string `json:"in_file,omitempty"`
NotInFile string `json:"not_in_file,omitempty"`
KindIn []string `json:"kind_in,omitempty"`
}
// MatchResult represents a single match from a query.
type MatchResult struct {
Node *sitter.Node
Captures map[string]CapturedNode
File string
Text string
Location protocol.Location
}
// CapturedNode represents a captured node or nodes.
type CapturedNode struct {
Text string
Nodes []*sitter.Node
}
// CaptureType indicates the type of capture.
type CaptureType int
const (
CaptureSingle CaptureType = iota // $NAME - single node
CaptureMultiple // $$$NAME - multiple nodes
CaptureWildcard // $_ - wildcard (don't capture)
)
// Capture represents a placeholder in a pattern.
type Capture struct {
Name string
Type CaptureType
Position int // position in the pattern
}
// ParsedPattern represents a parsed code pattern.
type ParsedPattern struct {
Original string
Template string
Captures []Capture
}
// Matcher performs AST pattern matching.
type Matcher struct {
registry *parser.Registry
}
// NewMatcher creates a new pattern matcher.
func NewMatcher(registry *parser.Registry) *Matcher {
return &Matcher{registry: registry}
}
// ParsePattern parses a pattern string and extracts captures.
func ParsePattern(pattern string) (*ParsedPattern, error) {
if pattern == "" {
return nil, fmt.Errorf("empty pattern")
}
var captures []Capture
template := pattern
captureID := 0
// Find all captures: $$$ (multi), $_ (wildcard), $NAME (single)
// Order matters: check $$$ first
multiRe := regexp.MustCompile(`\$\$\$([A-Za-z_][A-Za-z0-9_]*)`)
wildcardRe := regexp.MustCompile(`\$_`)
singleRe := regexp.MustCompile(`\$([A-Za-z_][A-Za-z0-9_]*)`)
// Extract multi-node captures ($$$NAME)
for _, match := range multiRe.FindAllStringSubmatchIndex(pattern, -1) {
name := pattern[match[2]:match[3]]
captures = append(captures, Capture{
Name: name,
Type: CaptureMultiple,
Position: match[0],
})
}
// Replace multi-captures with placeholder identifiers
template = multiRe.ReplaceAllStringFunc(template, func(s string) string {
captureID++
return fmt.Sprintf("__multi_%d__", captureID)
})
// Extract wildcards ($_)
for _, match := range wildcardRe.FindAllStringIndex(pattern, -1) {
captures = append(captures, Capture{
Name: "_",
Type: CaptureWildcard,
Position: match[0],
})
}
// Replace wildcards with placeholder identifiers
template = wildcardRe.ReplaceAllStringFunc(template, func(s string) string {
captureID++
return fmt.Sprintf("__wild_%d__", captureID)
})
// Extract single-node captures ($NAME) - exclude those that are part of $$$NAME
// Check which $NAME patterns are not preceded by $$
remaining := template
for _, match := range singleRe.FindAllStringSubmatchIndex(remaining, -1) {
name := remaining[match[2]:match[3]]
// Skip if this looks like our placeholder
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
continue
}
captures = append(captures, Capture{
Name: name,
Type: CaptureSingle,
Position: match[0],
})
}
// Replace single captures with placeholder identifiers
template = singleRe.ReplaceAllStringFunc(template, func(s string) string {
name := strings.TrimPrefix(s, "$")
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
return s // keep our placeholders as is
}
captureID++
return fmt.Sprintf("__single_%d__", captureID)
})
return &ParsedPattern{
Original: pattern,
Captures: captures,
Template: template,
}, nil
}
// Match executes a query against a parsed tree.
func (m *Matcher) Match(ctx context.Context, query *ASTQuery, tree *sitter.Tree, content []byte, filename string) ([]MatchResult, error) {
if query.Pattern == "" {
return nil, fmt.Errorf("query pattern is required")
}
lang := protocol.Language(query.Language)
if lang == "" || lang == protocol.LangUnknown {
return nil, fmt.Errorf("valid language is required")
}
// Parse the pattern
parsed, err := ParsePattern(query.Pattern)
if err != nil {
return nil, fmt.Errorf("invalid pattern: %w", err)
}
var results []MatchResult
// Walk the tree and find matches
root := tree.RootNode()
if root == nil {
return results, nil
}
parser.WalkTree(root, func(n *sitter.Node) bool {
// Check for context cancellation
select {
case <-ctx.Done():
return false
default:
}
// Try to match this node against the pattern
if matched, captures := matchNode(n, parsed, content); matched {
// Apply filters
if !passesFilters(n, query.Filters, content) {
return true // continue walking
}
startPoint := n.StartPoint()
results = append(results, MatchResult{
Node: n,
Captures: captures,
File: filename,
Location: protocol.Location{
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
},
Text: parser.GetNodeText(n, content),
})
}
return true
})
return results, nil
}
// matchNode attempts to match a node against a parsed pattern.
// This is a simplified matcher that looks for structural similarity.
func matchNode(node *sitter.Node, pattern *ParsedPattern, content []byte) (bool, map[string]CapturedNode) {
if node == nil {
return false, nil
}
captures := make(map[string]CapturedNode)
// Use pattern keyword matching as a heuristic to find matching nodes
// A full implementation would parse both pattern and node and compare AST structure
matched := matchPatternHeuristic(node, pattern, content, captures)
return matched, captures
}
// matchPatternHeuristic uses heuristics to match patterns.
// This is a simplified implementation that matches based on node type and structure.
func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content []byte, captures map[string]CapturedNode) bool {
patternLower := strings.ToLower(pattern.Original)
nodeType := node.Type()
// Match function patterns
if strings.Contains(patternLower, "func ") || strings.Contains(patternLower, "function ") {
if nodeType != "function_declaration" && nodeType != "method_declaration" && nodeType != "function_definition" {
return false
}
extractFunctionCaptures(node, pattern.Captures, content, captures)
return true
}
// Match class patterns
if strings.Contains(patternLower, "class ") {
if nodeType != "class_declaration" && nodeType != "class_definition" {
return false
}
extractClassCaptures(node, pattern.Captures, content, captures)
return true
}
// Match struct patterns (Go, C, C++)
if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") {
if nodeType != "type_declaration" && nodeType != "struct_specifier" {
return false
}
extractStructCaptures(node, pattern.Captures, content, captures)
return true
}
// Match interface patterns (Go, TypeScript)
if strings.Contains(patternLower, "interface ") {
if nodeType != "interface_declaration" && nodeType != "type_declaration" {
return false
}
extractInterfaceCaptures(node, pattern.Captures, content, captures)
return true
}
return false
}
// extractFunctionCaptures extracts captures from a function node.
func extractFunctionCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "ARGS", "args", "PARAMS", "params":
if paramsNode := node.ChildByFieldName("parameters"); paramsNode != nil {
var paramNodes []*sitter.Node
for i := 0; i < int(paramsNode.NamedChildCount()); i++ {
paramNodes = append(paramNodes, paramsNode.NamedChild(i))
}
captures[cap.Name] = CapturedNode{
Nodes: paramNodes,
Text: parser.GetNodeText(paramsNode, content),
}
}
case "BODY", "body":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
case "RETURN", "return", "RESULT", "result":
if resultNode := node.ChildByFieldName("result"); resultNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{resultNode},
Text: parser.GetNodeText(resultNode, content),
}
}
}
}
}
// extractClassCaptures extracts captures from a class node.
func extractClassCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "BODY", "body":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
case "EXTENDS", "extends", "SUPERCLASS", "superclass":
if extendsNode := node.ChildByFieldName("superclass"); extendsNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{extendsNode},
Text: parser.GetNodeText(extendsNode, content),
}
}
}
}
}
// extractStructCaptures extracts captures from a struct node.
func extractStructCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
// For Go type_declaration, we need to look at the type_spec child
if node.Type() == "type_declaration" {
for i := 0; i < int(node.NamedChildCount()); i++ {
child := node.NamedChild(i)
if child != nil && child.Type() == "type_spec" {
if nameNode := child.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
}
}
} else if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "FIELDS", "fields", "BODY", "body":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
}
}
}
// extractInterfaceCaptures extracts captures from an interface node.
func extractInterfaceCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
for _, cap := range capturesDef {
switch cap.Name {
case "NAME", "name":
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{nameNode},
Text: parser.GetNodeText(nameNode, content),
}
}
case "BODY", "body", "METHODS", "methods":
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
captures[cap.Name] = CapturedNode{
Nodes: []*sitter.Node{bodyNode},
Text: parser.GetNodeText(bodyNode, content),
}
}
}
}
}
// passesFilters checks if a node passes all the specified filters.
func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool {
// Name regex filter (uses cached compilation)
if filters.NameMatches != "" {
nameNode := node.ChildByFieldName("name")
if nameNode == nil {
return false
}
name := parser.GetNodeText(nameNode, content)
re, err := compileRegex(filters.NameMatches)
if err != nil {
return false
}
if !re.MatchString(name) {
return false
}
}
// Exact name filter
if filters.NameExact != "" {
nameNode := node.ChildByFieldName("name")
if nameNode == nil {
return false
}
name := parser.GetNodeText(nameNode, content)
if name != filters.NameExact {
return false
}
}
// Kind filter
if len(filters.KindIn) > 0 {
nodeType := node.Type()
found := false
for _, kind := range filters.KindIn {
if nodeType == kind {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// FormatResults formats match results for display.
func FormatResults(results []MatchResult, maxResults int) string {
if len(results) == 0 {
return "No matches found."
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results)))
displayCount := len(results)
truncated := false
if maxResults > 0 && displayCount > maxResults {
displayCount = maxResults
truncated = true
}
for i := 0; i < displayCount; i++ {
r := results[i]
nodeType := "unknown"
if r.Node != nil {
nodeType = r.Node.Type()
}
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
// Truncate very long text
text := r.Text
if len(text) > 500 {
text = text[:500] + "..."
}
sb.WriteString("```\n")
sb.WriteString(text)
sb.WriteString("\n```\n")
// Show captures
if len(r.Captures) > 0 {
sb.WriteString("Captures: ")
first := true
for name, cap := range r.Captures {
if !first {
sb.WriteString(", ")
}
first = false
capText := cap.Text
if len(capText) > 50 {
capText = capText[:50] + "..."
}
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
}
sb.WriteString("\n")
}
sb.WriteString("\n")
}
if truncated {
sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults))
}
return sb.String()
}
+559
View File
@@ -0,0 +1,559 @@
package query
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestParsePattern(t *testing.T) {
tests := []struct {
name string
pattern string
captureNames []string
captureTypes []CaptureType
wantCaptures int
wantErr bool
}{
{
name: "empty pattern",
pattern: "",
wantErr: true,
wantCaptures: 0,
},
{
name: "single capture",
pattern: "func $NAME() {}",
wantErr: false,
wantCaptures: 1,
captureNames: []string{"NAME"},
captureTypes: []CaptureType{CaptureSingle},
},
{
name: "multiple single captures",
pattern: "func $NAME($ARGS) $RETURN",
wantErr: false,
wantCaptures: 3,
captureNames: []string{"NAME", "ARGS", "RETURN"},
captureTypes: []CaptureType{CaptureSingle, CaptureSingle, CaptureSingle},
},
{
name: "multi-node capture",
pattern: "func $NAME($$$ARGS) { $$$BODY }",
wantErr: false,
wantCaptures: 3,
captureNames: []string{"ARGS", "BODY", "NAME"},
captureTypes: []CaptureType{CaptureMultiple, CaptureMultiple, CaptureSingle},
},
{
name: "wildcard capture",
pattern: "func $NAME($_) {}",
wantErr: false,
wantCaptures: 2,
captureNames: []string{"NAME", "_"},
captureTypes: []CaptureType{CaptureSingle, CaptureWildcard},
},
{
name: "no captures",
pattern: "func main() {}",
wantErr: false,
wantCaptures: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParsePattern(tt.pattern)
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(parsed.Captures) != tt.wantCaptures {
t.Errorf("expected %d captures, got %d", tt.wantCaptures, len(parsed.Captures))
}
// Check capture names (order may vary)
if tt.captureNames != nil {
captureMap := make(map[string]CaptureType)
for _, cap := range parsed.Captures {
captureMap[cap.Name] = cap.Type
}
for i, name := range tt.captureNames {
if _, ok := captureMap[name]; !ok {
t.Errorf("expected capture %s not found", name)
}
if captureMap[name] != tt.captureTypes[i] {
t.Errorf("capture %s: expected type %v, got %v", name, tt.captureTypes[i], captureMap[name])
}
}
}
})
}
}
func TestMatchGoFunctions(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
func Hello() {
println("hello")
}
func Greet(name string) error {
println("hello", name)
return nil
}
type Server struct {
Port int
}
func (s *Server) Start() error {
return nil
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMatches int
}{
{
name: "match all functions",
query: &ASTQuery{
Pattern: "func $NAME($$$ARGS) { $$$BODY }",
Language: "go",
},
wantMatches: 3, // Hello, Greet, Start
},
{
name: "match functions starting with H",
query: &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: QueryFilters{
NameMatches: "^H",
},
},
wantMatches: 1, // Hello
},
{
name: "match specific function",
query: &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: QueryFilters{
NameExact: "Hello",
},
},
wantMatches: 1, // Hello
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
for i, r := range results {
t.Logf("match %d: %s at line %d", i, r.Node.Type(), r.Location.Line)
}
}
})
}
}
func TestMatchGoStructs(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
type Server struct {
Port int
Host string
}
type Config struct {
Timeout int
}
type Logger interface {
Log(msg string)
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMinimum int
}{
{
name: "match all structs",
query: &ASTQuery{
Pattern: "type $NAME struct { $$$FIELDS }",
Language: "go",
},
wantMinimum: 2, // Server, Config (may also match interface as type_declaration)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) < tt.wantMinimum {
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
}
})
}
}
func TestMatchJSFunctions(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `
function greet(name) {
console.log("Hello, " + name);
}
function sayHello() {
console.log("Hello!");
}
class User {
constructor(name) {
this.name = name;
}
getName() {
return this.name;
}
}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.js", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMatches int
}{
{
name: "match all functions",
query: &ASTQuery{
Pattern: "function $NAME($$$ARGS) { $$$BODY }",
Language: "javascript",
},
wantMatches: 2, // greet, sayHello
},
{
name: "match classes",
query: &ASTQuery{
Pattern: "class $NAME { $$$BODY }",
Language: "javascript",
},
wantMatches: 1, // User
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.js")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
}
})
}
}
func TestMatchPythonSymbols(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `
def greet(name):
print(f"Hello, {name}")
def calculate(a, b):
return a + b
class User:
def __init__(self, name):
self.name = name
def get_name(self):
return self.name
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.py", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantMinimum int
}{
{
name: "match classes",
query: &ASTQuery{
Pattern: "class $NAME: $$$BODY",
Language: "python",
},
wantMinimum: 1, // User
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.py")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) < tt.wantMinimum {
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
}
})
}
}
func TestQueryFilters(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
content := `package main
func HelloWorld() {}
func helloWorld() {}
func GoodbyeWorld() {}
func Main() {}
`
ctx := context.Background()
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
name string
filters QueryFilters
wantMatches int
}{
{
name: "regex filter - starts with H",
filters: QueryFilters{
NameMatches: "^[Hh]ello",
},
wantMatches: 2, // HelloWorld, helloWorld
},
{
name: "exact name filter",
filters: QueryFilters{
NameExact: "Main",
},
wantMatches: 1, // Main
},
{
name: "kind filter",
filters: QueryFilters{
KindIn: []string{"function_declaration"},
},
wantMatches: 4, // all functions
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := &ASTQuery{
Pattern: "func $NAME() {}",
Language: "go",
Filters: tt.filters,
}
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
if err != nil {
t.Fatalf("match failed: %v", err)
}
if len(results) != tt.wantMatches {
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
for _, r := range results {
if nameNode := r.Node.ChildByFieldName("name"); nameNode != nil {
t.Logf("matched: %s", parser.GetNodeText(nameNode, []byte(content)))
}
}
}
})
}
}
func TestFormatResults(t *testing.T) {
tests := []struct {
name string
results []MatchResult
maxResults int
wantEmpty bool
}{
{
name: "empty results",
results: []MatchResult{},
maxResults: 100,
wantEmpty: true,
},
{
name: "single result",
results: []MatchResult{
{
File: "test.go",
Location: protocol.Location{Line: 10, Column: 1},
Text: "func Hello() {}",
Captures: map[string]CapturedNode{
"NAME": {Text: "Hello"},
},
},
},
maxResults: 100,
wantEmpty: false,
},
{
name: "truncated results",
results: []MatchResult{
{File: "a.go", Location: protocol.Location{Line: 1}, Text: "func A() {}"},
{File: "b.go", Location: protocol.Location{Line: 1}, Text: "func B() {}"},
{File: "c.go", Location: protocol.Location{Line: 1}, Text: "func C() {}"},
},
maxResults: 2,
wantEmpty: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output := FormatResults(tt.results, tt.maxResults)
if tt.wantEmpty {
if output != "No matches found." {
t.Errorf("expected 'No matches found.', got: %s", output)
}
} else {
if output == "No matches found." {
t.Error("expected results, got 'No matches found.'")
}
}
})
}
}
func TestQueryValidation(t *testing.T) {
reg := parser.NewRegistry()
defer reg.Close()
matcher := NewMatcher(reg)
ctx := context.Background()
// Parse some valid content
content := `package main
func main() {}
`
result, err := reg.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
tests := []struct {
query *ASTQuery
name string
wantErr bool
}{
{
name: "empty pattern",
query: &ASTQuery{Pattern: "", Language: "go"},
wantErr: true,
},
{
name: "missing language",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: ""},
wantErr: true,
},
{
name: "unknown language",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "unknown"},
wantErr: true,
},
{
name: "valid query",
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "go"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
+198
View File
@@ -0,0 +1,198 @@
package query
import (
"regexp"
"sync"
"testing"
)
// TestCompileRegexCaching tests that regex compilation is cached.
func TestCompileRegexCaching(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
pattern := `^test_\w+$`
// First compilation
re1, err := compileRegex(pattern)
if err != nil {
t.Fatalf("First compile failed: %v", err)
}
// Second compilation should return cached version
re2, err := compileRegex(pattern)
if err != nil {
t.Fatalf("Second compile failed: %v", err)
}
// Should be the exact same object
if re1 != re2 {
t.Error("Expected cached regex to be reused, got different objects")
}
// Verify it's in the cache
cached, ok := regexCache.Load(pattern)
if !ok {
t.Error("Pattern not found in cache")
}
if cached.(*regexp.Regexp) != re1 {
t.Error("Cached regex doesn't match returned regex")
}
}
// TestCompileRegexConcurrent tests concurrent regex compilation.
func TestCompileRegexConcurrent(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
pattern := `[a-z]+_\d+`
const numGoroutines = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
results := make([]*regexp.Regexp, numGoroutines)
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
i := i
go func() {
defer wg.Done()
re, err := compileRegex(pattern)
if err != nil {
errors <- err
return
}
results[i] = re
}()
}
wg.Wait()
close(errors)
// Check for errors
for err := range errors {
t.Errorf("Concurrent compile failed: %v", err)
}
// All results should be the same object (cached)
for i := 1; i < numGoroutines; i++ {
if results[i] != results[0] {
t.Errorf("Result %d is different from result 0 (cache not working)", i)
}
}
}
// TestCompileRegexInvalidPattern tests error handling for invalid patterns.
func TestCompileRegexInvalidPattern(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
invalidPattern := `[invalid(`
_, err := compileRegex(invalidPattern)
if err == nil {
t.Error("Expected error for invalid pattern, got nil")
}
// Invalid patterns should not be cached
_, ok := regexCache.Load(invalidPattern)
if ok {
t.Error("Invalid pattern should not be cached")
}
}
// TestCompileRegexMultiplePatterns tests that different patterns are cached separately.
func TestCompileRegexMultiplePatterns(t *testing.T) {
// Clear cache before test
regexCache = sync.Map{}
patterns := []string{
`^test_\w+$`,
`^\d{4}-\d{2}-\d{2}$`,
`^[A-Z][a-z]+$`,
`\b\w+@\w+\.\w+\b`,
}
compiled := make([]*regexp.Regexp, len(patterns))
// Compile all patterns
for i, pattern := range patterns {
re, err := compileRegex(pattern)
if err != nil {
t.Fatalf("Compile failed for pattern %s: %v", pattern, err)
}
compiled[i] = re
}
// Verify all are cached
for i, pattern := range patterns {
cached, ok := regexCache.Load(pattern)
if !ok {
t.Errorf("Pattern %s not in cache", pattern)
}
if cached.(*regexp.Regexp) != compiled[i] {
t.Errorf("Cached regex for %s doesn't match compiled version", pattern)
}
}
// All should be different objects
for i := 0; i < len(compiled); i++ {
for j := i + 1; j < len(compiled); j++ {
if compiled[i] == compiled[j] {
t.Errorf("Pattern %d and %d have same regex object", i, j)
}
}
}
}
// BenchmarkCompileRegex_Uncached benchmarks regex compilation without caching.
func BenchmarkCompileRegex_Uncached(b *testing.B) {
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = regexp.Compile(pattern)
}
}
// BenchmarkCompileRegex_Cached benchmarks regex compilation with caching.
func BenchmarkCompileRegex_Cached(b *testing.B) {
// Clear cache
regexCache = sync.Map{}
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
// Pre-populate cache
_, _ = compileRegex(pattern)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = compileRegex(pattern)
}
}
// BenchmarkCompileRegex_MixedPatterns benchmarks realistic workload with multiple patterns.
func BenchmarkCompileRegex_MixedPatterns(b *testing.B) {
// Clear cache
regexCache = sync.Map{}
patterns := []string{
`^test_\w+$`,
`^\d{4}-\d{2}-\d{2}$`,
`^[A-Z][a-z]+$`,
`\b\w+@\w+\.\w+\b`,
`^func\s+\w+\(`,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Simulate realistic access pattern
pattern := patterns[i%len(patterns)]
_, _ = compileRegex(pattern)
}
}