mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-14 02:51:27 +00:00
Ho hum.
This commit is contained in:
@@ -0,0 +1,538 @@
|
||||
// Package query implements a hybrid AST query language with pattern matching.
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// Global regex cache for compiled patterns (thread-safe)
|
||||
var regexCache sync.Map // string -> *regexp.Regexp
|
||||
|
||||
// compileRegex compiles a regex pattern with caching for performance.
|
||||
// Cached patterns avoid repeated compilation overhead (10-50x speedup).
|
||||
// Thread-safe: uses LoadOrStore to prevent race conditions.
|
||||
func compileRegex(pattern string) (*regexp.Regexp, error) {
|
||||
// Check cache first
|
||||
if cached, ok := regexCache.Load(pattern); ok {
|
||||
return cached.(*regexp.Regexp), nil
|
||||
}
|
||||
|
||||
// Compile regex
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to store - if another goroutine already stored it, use theirs
|
||||
// This prevents race conditions where multiple goroutines compile the same pattern
|
||||
actual, _ := regexCache.LoadOrStore(pattern, re)
|
||||
return actual.(*regexp.Regexp), nil
|
||||
}
|
||||
|
||||
// ASTQuery defines a query for matching AST patterns.
|
||||
type ASTQuery struct {
|
||||
Pattern string `json:"pattern"` // code pattern with $VAR placeholders
|
||||
Language string `json:"language"` // required
|
||||
Filters QueryFilters `json:"filters,omitempty"`
|
||||
}
|
||||
|
||||
// QueryFilters provide additional filtering criteria.
|
||||
type QueryFilters struct {
|
||||
HasChild *ASTQuery `json:"has_child,omitempty"`
|
||||
HasParent *ASTQuery `json:"has_parent,omitempty"`
|
||||
NameMatches string `json:"name_matches,omitempty"`
|
||||
NameExact string `json:"name_exact,omitempty"`
|
||||
InFile string `json:"in_file,omitempty"`
|
||||
NotInFile string `json:"not_in_file,omitempty"`
|
||||
KindIn []string `json:"kind_in,omitempty"`
|
||||
}
|
||||
|
||||
// MatchResult represents a single match from a query.
|
||||
type MatchResult struct {
|
||||
Node *sitter.Node
|
||||
Captures map[string]CapturedNode
|
||||
File string
|
||||
Text string
|
||||
Location protocol.Location
|
||||
}
|
||||
|
||||
// CapturedNode represents a captured node or nodes.
|
||||
type CapturedNode struct {
|
||||
Text string
|
||||
Nodes []*sitter.Node
|
||||
}
|
||||
|
||||
// CaptureType indicates the type of capture.
|
||||
type CaptureType int
|
||||
|
||||
const (
|
||||
CaptureSingle CaptureType = iota // $NAME - single node
|
||||
CaptureMultiple // $$$NAME - multiple nodes
|
||||
CaptureWildcard // $_ - wildcard (don't capture)
|
||||
)
|
||||
|
||||
// Capture represents a placeholder in a pattern.
|
||||
type Capture struct {
|
||||
Name string
|
||||
Type CaptureType
|
||||
Position int // position in the pattern
|
||||
}
|
||||
|
||||
// ParsedPattern represents a parsed code pattern.
|
||||
type ParsedPattern struct {
|
||||
Original string
|
||||
Template string
|
||||
Captures []Capture
|
||||
}
|
||||
|
||||
// Matcher performs AST pattern matching.
|
||||
type Matcher struct {
|
||||
registry *parser.Registry
|
||||
}
|
||||
|
||||
// NewMatcher creates a new pattern matcher.
|
||||
func NewMatcher(registry *parser.Registry) *Matcher {
|
||||
return &Matcher{registry: registry}
|
||||
}
|
||||
|
||||
// ParsePattern parses a pattern string and extracts captures.
|
||||
func ParsePattern(pattern string) (*ParsedPattern, error) {
|
||||
if pattern == "" {
|
||||
return nil, fmt.Errorf("empty pattern")
|
||||
}
|
||||
|
||||
var captures []Capture
|
||||
template := pattern
|
||||
captureID := 0
|
||||
|
||||
// Find all captures: $$$ (multi), $_ (wildcard), $NAME (single)
|
||||
// Order matters: check $$$ first
|
||||
multiRe := regexp.MustCompile(`\$\$\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||
wildcardRe := regexp.MustCompile(`\$_`)
|
||||
singleRe := regexp.MustCompile(`\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||
|
||||
// Extract multi-node captures ($$$NAME)
|
||||
for _, match := range multiRe.FindAllStringSubmatchIndex(pattern, -1) {
|
||||
name := pattern[match[2]:match[3]]
|
||||
captures = append(captures, Capture{
|
||||
Name: name,
|
||||
Type: CaptureMultiple,
|
||||
Position: match[0],
|
||||
})
|
||||
}
|
||||
|
||||
// Replace multi-captures with placeholder identifiers
|
||||
template = multiRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||
captureID++
|
||||
return fmt.Sprintf("__multi_%d__", captureID)
|
||||
})
|
||||
|
||||
// Extract wildcards ($_)
|
||||
for _, match := range wildcardRe.FindAllStringIndex(pattern, -1) {
|
||||
captures = append(captures, Capture{
|
||||
Name: "_",
|
||||
Type: CaptureWildcard,
|
||||
Position: match[0],
|
||||
})
|
||||
}
|
||||
|
||||
// Replace wildcards with placeholder identifiers
|
||||
template = wildcardRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||
captureID++
|
||||
return fmt.Sprintf("__wild_%d__", captureID)
|
||||
})
|
||||
|
||||
// Extract single-node captures ($NAME) - exclude those that are part of $$$NAME
|
||||
// Check which $NAME patterns are not preceded by $$
|
||||
remaining := template
|
||||
for _, match := range singleRe.FindAllStringSubmatchIndex(remaining, -1) {
|
||||
name := remaining[match[2]:match[3]]
|
||||
// Skip if this looks like our placeholder
|
||||
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
|
||||
continue
|
||||
}
|
||||
captures = append(captures, Capture{
|
||||
Name: name,
|
||||
Type: CaptureSingle,
|
||||
Position: match[0],
|
||||
})
|
||||
}
|
||||
|
||||
// Replace single captures with placeholder identifiers
|
||||
template = singleRe.ReplaceAllStringFunc(template, func(s string) string {
|
||||
name := strings.TrimPrefix(s, "$")
|
||||
if strings.HasPrefix(name, "_multi_") || strings.HasPrefix(name, "_wild_") {
|
||||
return s // keep our placeholders as is
|
||||
}
|
||||
captureID++
|
||||
return fmt.Sprintf("__single_%d__", captureID)
|
||||
})
|
||||
|
||||
return &ParsedPattern{
|
||||
Original: pattern,
|
||||
Captures: captures,
|
||||
Template: template,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Match executes a query against a parsed tree.
|
||||
func (m *Matcher) Match(ctx context.Context, query *ASTQuery, tree *sitter.Tree, content []byte, filename string) ([]MatchResult, error) {
|
||||
if query.Pattern == "" {
|
||||
return nil, fmt.Errorf("query pattern is required")
|
||||
}
|
||||
|
||||
lang := protocol.Language(query.Language)
|
||||
if lang == "" || lang == protocol.LangUnknown {
|
||||
return nil, fmt.Errorf("valid language is required")
|
||||
}
|
||||
|
||||
// Parse the pattern
|
||||
parsed, err := ParsePattern(query.Pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pattern: %w", err)
|
||||
}
|
||||
|
||||
var results []MatchResult
|
||||
|
||||
// Walk the tree and find matches
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
parser.WalkTree(root, func(n *sitter.Node) bool {
|
||||
// Check for context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
// Try to match this node against the pattern
|
||||
if matched, captures := matchNode(n, parsed, content); matched {
|
||||
// Apply filters
|
||||
if !passesFilters(n, query.Filters, content) {
|
||||
return true // continue walking
|
||||
}
|
||||
|
||||
startPoint := n.StartPoint()
|
||||
results = append(results, MatchResult{
|
||||
Node: n,
|
||||
Captures: captures,
|
||||
File: filename,
|
||||
Location: protocol.Location{
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
},
|
||||
Text: parser.GetNodeText(n, content),
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// matchNode attempts to match a node against a parsed pattern.
|
||||
// This is a simplified matcher that looks for structural similarity.
|
||||
func matchNode(node *sitter.Node, pattern *ParsedPattern, content []byte) (bool, map[string]CapturedNode) {
|
||||
if node == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
captures := make(map[string]CapturedNode)
|
||||
|
||||
// Use pattern keyword matching as a heuristic to find matching nodes
|
||||
// A full implementation would parse both pattern and node and compare AST structure
|
||||
matched := matchPatternHeuristic(node, pattern, content, captures)
|
||||
|
||||
return matched, captures
|
||||
}
|
||||
|
||||
// matchPatternHeuristic uses heuristics to match patterns.
|
||||
// This is a simplified implementation that matches based on node type and structure.
|
||||
func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content []byte, captures map[string]CapturedNode) bool {
|
||||
patternLower := strings.ToLower(pattern.Original)
|
||||
nodeType := node.Type()
|
||||
|
||||
// Match function patterns
|
||||
if strings.Contains(patternLower, "func ") || strings.Contains(patternLower, "function ") {
|
||||
if nodeType != "function_declaration" && nodeType != "method_declaration" && nodeType != "function_definition" {
|
||||
return false
|
||||
}
|
||||
extractFunctionCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
// Match class patterns
|
||||
if strings.Contains(patternLower, "class ") {
|
||||
if nodeType != "class_declaration" && nodeType != "class_definition" {
|
||||
return false
|
||||
}
|
||||
extractClassCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
// Match struct patterns (Go, C, C++)
|
||||
if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") {
|
||||
if nodeType != "type_declaration" && nodeType != "struct_specifier" {
|
||||
return false
|
||||
}
|
||||
extractStructCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
// Match interface patterns (Go, TypeScript)
|
||||
if strings.Contains(patternLower, "interface ") {
|
||||
if nodeType != "interface_declaration" && nodeType != "type_declaration" {
|
||||
return false
|
||||
}
|
||||
extractInterfaceCaptures(node, pattern.Captures, content, captures)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// extractFunctionCaptures extracts captures from a function node.
|
||||
func extractFunctionCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "ARGS", "args", "PARAMS", "params":
|
||||
if paramsNode := node.ChildByFieldName("parameters"); paramsNode != nil {
|
||||
var paramNodes []*sitter.Node
|
||||
for i := 0; i < int(paramsNode.NamedChildCount()); i++ {
|
||||
paramNodes = append(paramNodes, paramsNode.NamedChild(i))
|
||||
}
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: paramNodes,
|
||||
Text: parser.GetNodeText(paramsNode, content),
|
||||
}
|
||||
}
|
||||
case "BODY", "body":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
case "RETURN", "return", "RESULT", "result":
|
||||
if resultNode := node.ChildByFieldName("result"); resultNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{resultNode},
|
||||
Text: parser.GetNodeText(resultNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractClassCaptures extracts captures from a class node.
|
||||
func extractClassCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "BODY", "body":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
case "EXTENDS", "extends", "SUPERCLASS", "superclass":
|
||||
if extendsNode := node.ChildByFieldName("superclass"); extendsNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{extendsNode},
|
||||
Text: parser.GetNodeText(extendsNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractStructCaptures extracts captures from a struct node.
|
||||
func extractStructCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
// For Go type_declaration, we need to look at the type_spec child
|
||||
if node.Type() == "type_declaration" {
|
||||
for i := 0; i < int(node.NamedChildCount()); i++ {
|
||||
child := node.NamedChild(i)
|
||||
if child != nil && child.Type() == "type_spec" {
|
||||
if nameNode := child.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "FIELDS", "fields", "BODY", "body":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractInterfaceCaptures extracts captures from an interface node.
|
||||
func extractInterfaceCaptures(node *sitter.Node, capturesDef []Capture, content []byte, captures map[string]CapturedNode) {
|
||||
for _, cap := range capturesDef {
|
||||
switch cap.Name {
|
||||
case "NAME", "name":
|
||||
if nameNode := node.ChildByFieldName("name"); nameNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{nameNode},
|
||||
Text: parser.GetNodeText(nameNode, content),
|
||||
}
|
||||
}
|
||||
case "BODY", "body", "METHODS", "methods":
|
||||
if bodyNode := node.ChildByFieldName("body"); bodyNode != nil {
|
||||
captures[cap.Name] = CapturedNode{
|
||||
Nodes: []*sitter.Node{bodyNode},
|
||||
Text: parser.GetNodeText(bodyNode, content),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// passesFilters checks if a node passes all the specified filters.
|
||||
func passesFilters(node *sitter.Node, filters QueryFilters, content []byte) bool {
|
||||
// Name regex filter (uses cached compilation)
|
||||
if filters.NameMatches != "" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return false
|
||||
}
|
||||
name := parser.GetNodeText(nameNode, content)
|
||||
re, err := compileRegex(filters.NameMatches)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !re.MatchString(name) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Exact name filter
|
||||
if filters.NameExact != "" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return false
|
||||
}
|
||||
name := parser.GetNodeText(nameNode, content)
|
||||
if name != filters.NameExact {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Kind filter
|
||||
if len(filters.KindIn) > 0 {
|
||||
nodeType := node.Type()
|
||||
found := false
|
||||
for _, kind := range filters.KindIn {
|
||||
if nodeType == kind {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// FormatResults formats match results for display.
|
||||
func FormatResults(results []MatchResult, maxResults int) string {
|
||||
if len(results) == 0 {
|
||||
return "No matches found."
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Found %d match(es):\n\n", len(results)))
|
||||
|
||||
displayCount := len(results)
|
||||
truncated := false
|
||||
if maxResults > 0 && displayCount > maxResults {
|
||||
displayCount = maxResults
|
||||
truncated = true
|
||||
}
|
||||
|
||||
for i := 0; i < displayCount; i++ {
|
||||
r := results[i]
|
||||
nodeType := "unknown"
|
||||
if r.Node != nil {
|
||||
nodeType = r.Node.Type()
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("**%s:%d** (%s)\n", r.File, r.Location.Line, nodeType))
|
||||
|
||||
// Truncate very long text
|
||||
text := r.Text
|
||||
if len(text) > 500 {
|
||||
text = text[:500] + "..."
|
||||
}
|
||||
sb.WriteString("```\n")
|
||||
sb.WriteString(text)
|
||||
sb.WriteString("\n```\n")
|
||||
|
||||
// Show captures
|
||||
if len(r.Captures) > 0 {
|
||||
sb.WriteString("Captures: ")
|
||||
first := true
|
||||
for name, cap := range r.Captures {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
first = false
|
||||
capText := cap.Text
|
||||
if len(capText) > 50 {
|
||||
capText = capText[:50] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("$%s=%s", name, capText))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
if truncated {
|
||||
sb.WriteString(fmt.Sprintf("... and %d more matches (truncated)\n", len(results)-maxResults))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,559 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestParsePattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
captureNames []string
|
||||
captureTypes []CaptureType
|
||||
wantCaptures int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty pattern",
|
||||
pattern: "",
|
||||
wantErr: true,
|
||||
wantCaptures: 0,
|
||||
},
|
||||
{
|
||||
name: "single capture",
|
||||
pattern: "func $NAME() {}",
|
||||
wantErr: false,
|
||||
wantCaptures: 1,
|
||||
captureNames: []string{"NAME"},
|
||||
captureTypes: []CaptureType{CaptureSingle},
|
||||
},
|
||||
{
|
||||
name: "multiple single captures",
|
||||
pattern: "func $NAME($ARGS) $RETURN",
|
||||
wantErr: false,
|
||||
wantCaptures: 3,
|
||||
captureNames: []string{"NAME", "ARGS", "RETURN"},
|
||||
captureTypes: []CaptureType{CaptureSingle, CaptureSingle, CaptureSingle},
|
||||
},
|
||||
{
|
||||
name: "multi-node capture",
|
||||
pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
||||
wantErr: false,
|
||||
wantCaptures: 3,
|
||||
captureNames: []string{"ARGS", "BODY", "NAME"},
|
||||
captureTypes: []CaptureType{CaptureMultiple, CaptureMultiple, CaptureSingle},
|
||||
},
|
||||
{
|
||||
name: "wildcard capture",
|
||||
pattern: "func $NAME($_) {}",
|
||||
wantErr: false,
|
||||
wantCaptures: 2,
|
||||
captureNames: []string{"NAME", "_"},
|
||||
captureTypes: []CaptureType{CaptureSingle, CaptureWildcard},
|
||||
},
|
||||
{
|
||||
name: "no captures",
|
||||
pattern: "func main() {}",
|
||||
wantErr: false,
|
||||
wantCaptures: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsed, err := ParsePattern(tt.pattern)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(parsed.Captures) != tt.wantCaptures {
|
||||
t.Errorf("expected %d captures, got %d", tt.wantCaptures, len(parsed.Captures))
|
||||
}
|
||||
|
||||
// Check capture names (order may vary)
|
||||
if tt.captureNames != nil {
|
||||
captureMap := make(map[string]CaptureType)
|
||||
for _, cap := range parsed.Captures {
|
||||
captureMap[cap.Name] = cap.Type
|
||||
}
|
||||
|
||||
for i, name := range tt.captureNames {
|
||||
if _, ok := captureMap[name]; !ok {
|
||||
t.Errorf("expected capture %s not found", name)
|
||||
}
|
||||
if captureMap[name] != tt.captureTypes[i] {
|
||||
t.Errorf("capture %s: expected type %v, got %v", name, tt.captureTypes[i], captureMap[name])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchGoFunctions(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `package main
|
||||
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
|
||||
func Greet(name string) error {
|
||||
println("hello", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "match all functions",
|
||||
query: &ASTQuery{
|
||||
Pattern: "func $NAME($$$ARGS) { $$$BODY }",
|
||||
Language: "go",
|
||||
},
|
||||
wantMatches: 3, // Hello, Greet, Start
|
||||
},
|
||||
{
|
||||
name: "match functions starting with H",
|
||||
query: &ASTQuery{
|
||||
Pattern: "func $NAME() {}",
|
||||
Language: "go",
|
||||
Filters: QueryFilters{
|
||||
NameMatches: "^H",
|
||||
},
|
||||
},
|
||||
wantMatches: 1, // Hello
|
||||
},
|
||||
{
|
||||
name: "match specific function",
|
||||
query: &ASTQuery{
|
||||
Pattern: "func $NAME() {}",
|
||||
Language: "go",
|
||||
Filters: QueryFilters{
|
||||
NameExact: "Hello",
|
||||
},
|
||||
},
|
||||
wantMatches: 1, // Hello
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
for i, r := range results {
|
||||
t.Logf("match %d: %s at line %d", i, r.Node.Type(), r.Location.Line)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchGoStructs(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `package main
|
||||
|
||||
type Server struct {
|
||||
Port int
|
||||
Host string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Timeout int
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Log(msg string)
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMinimum int
|
||||
}{
|
||||
{
|
||||
name: "match all structs",
|
||||
query: &ASTQuery{
|
||||
Pattern: "type $NAME struct { $$$FIELDS }",
|
||||
Language: "go",
|
||||
},
|
||||
wantMinimum: 2, // Server, Config (may also match interface as type_declaration)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) < tt.wantMinimum {
|
||||
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchJSFunctions(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `
|
||||
function greet(name) {
|
||||
console.log("Hello, " + name);
|
||||
}
|
||||
|
||||
function sayHello() {
|
||||
console.log("Hello!");
|
||||
}
|
||||
|
||||
class User {
|
||||
constructor(name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
getName() {
|
||||
return this.name;
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.js", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "match all functions",
|
||||
query: &ASTQuery{
|
||||
Pattern: "function $NAME($$$ARGS) { $$$BODY }",
|
||||
Language: "javascript",
|
||||
},
|
||||
wantMatches: 2, // greet, sayHello
|
||||
},
|
||||
{
|
||||
name: "match classes",
|
||||
query: &ASTQuery{
|
||||
Pattern: "class $NAME { $$$BODY }",
|
||||
Language: "javascript",
|
||||
},
|
||||
wantMatches: 1, // User
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.js")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchPythonSymbols(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `
|
||||
def greet(name):
|
||||
print(f"Hello, {name}")
|
||||
|
||||
def calculate(a, b):
|
||||
return a + b
|
||||
|
||||
class User:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.py", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantMinimum int
|
||||
}{
|
||||
{
|
||||
name: "match classes",
|
||||
query: &ASTQuery{
|
||||
Pattern: "class $NAME: $$$BODY",
|
||||
Language: "python",
|
||||
},
|
||||
wantMinimum: 1, // User
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.py")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) < tt.wantMinimum {
|
||||
t.Errorf("expected at least %d matches, got %d", tt.wantMinimum, len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryFilters(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
|
||||
content := `package main
|
||||
|
||||
func HelloWorld() {}
|
||||
func helloWorld() {}
|
||||
func GoodbyeWorld() {}
|
||||
func Main() {}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filters QueryFilters
|
||||
wantMatches int
|
||||
}{
|
||||
{
|
||||
name: "regex filter - starts with H",
|
||||
filters: QueryFilters{
|
||||
NameMatches: "^[Hh]ello",
|
||||
},
|
||||
wantMatches: 2, // HelloWorld, helloWorld
|
||||
},
|
||||
{
|
||||
name: "exact name filter",
|
||||
filters: QueryFilters{
|
||||
NameExact: "Main",
|
||||
},
|
||||
wantMatches: 1, // Main
|
||||
},
|
||||
{
|
||||
name: "kind filter",
|
||||
filters: QueryFilters{
|
||||
KindIn: []string{"function_declaration"},
|
||||
},
|
||||
wantMatches: 4, // all functions
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query := &ASTQuery{
|
||||
Pattern: "func $NAME() {}",
|
||||
Language: "go",
|
||||
Filters: tt.filters,
|
||||
}
|
||||
|
||||
results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go")
|
||||
if err != nil {
|
||||
t.Fatalf("match failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != tt.wantMatches {
|
||||
t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results))
|
||||
for _, r := range results {
|
||||
if nameNode := r.Node.ChildByFieldName("name"); nameNode != nil {
|
||||
t.Logf("matched: %s", parser.GetNodeText(nameNode, []byte(content)))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatResults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
results []MatchResult
|
||||
maxResults int
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "empty results",
|
||||
results: []MatchResult{},
|
||||
maxResults: 100,
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "single result",
|
||||
results: []MatchResult{
|
||||
{
|
||||
File: "test.go",
|
||||
Location: protocol.Location{Line: 10, Column: 1},
|
||||
Text: "func Hello() {}",
|
||||
Captures: map[string]CapturedNode{
|
||||
"NAME": {Text: "Hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
maxResults: 100,
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "truncated results",
|
||||
results: []MatchResult{
|
||||
{File: "a.go", Location: protocol.Location{Line: 1}, Text: "func A() {}"},
|
||||
{File: "b.go", Location: protocol.Location{Line: 1}, Text: "func B() {}"},
|
||||
{File: "c.go", Location: protocol.Location{Line: 1}, Text: "func C() {}"},
|
||||
},
|
||||
maxResults: 2,
|
||||
wantEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
output := FormatResults(tt.results, tt.maxResults)
|
||||
|
||||
if tt.wantEmpty {
|
||||
if output != "No matches found." {
|
||||
t.Errorf("expected 'No matches found.', got: %s", output)
|
||||
}
|
||||
} else {
|
||||
if output == "No matches found." {
|
||||
t.Error("expected results, got 'No matches found.'")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryValidation(t *testing.T) {
|
||||
reg := parser.NewRegistry()
|
||||
defer reg.Close()
|
||||
|
||||
matcher := NewMatcher(reg)
|
||||
ctx := context.Background()
|
||||
|
||||
// Parse some valid content
|
||||
content := `package main
|
||||
func main() {}
|
||||
`
|
||||
result, err := reg.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
query *ASTQuery
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty pattern",
|
||||
query: &ASTQuery{Pattern: "", Language: "go"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing language",
|
||||
query: &ASTQuery{Pattern: "func $NAME() {}", Language: ""},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown language",
|
||||
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "unknown"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid query",
|
||||
query: &ASTQuery{Pattern: "func $NAME() {}", Language: "go"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := matcher.Match(ctx, tt.query, result.Tree, []byte(content), "test.go")
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCompileRegexCaching tests that regex compilation is cached.
|
||||
func TestCompileRegexCaching(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
pattern := `^test_\w+$`
|
||||
|
||||
// First compilation
|
||||
re1, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("First compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Second compilation should return cached version
|
||||
re2, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Second compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be the exact same object
|
||||
if re1 != re2 {
|
||||
t.Error("Expected cached regex to be reused, got different objects")
|
||||
}
|
||||
|
||||
// Verify it's in the cache
|
||||
cached, ok := regexCache.Load(pattern)
|
||||
if !ok {
|
||||
t.Error("Pattern not found in cache")
|
||||
}
|
||||
|
||||
if cached.(*regexp.Regexp) != re1 {
|
||||
t.Error("Cached regex doesn't match returned regex")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompileRegexConcurrent tests concurrent regex compilation.
|
||||
func TestCompileRegexConcurrent(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
pattern := `[a-z]+_\d+`
|
||||
const numGoroutines = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
results := make([]*regexp.Regexp, numGoroutines)
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
re, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
results[i] = re
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
t.Errorf("Concurrent compile failed: %v", err)
|
||||
}
|
||||
|
||||
// All results should be the same object (cached)
|
||||
for i := 1; i < numGoroutines; i++ {
|
||||
if results[i] != results[0] {
|
||||
t.Errorf("Result %d is different from result 0 (cache not working)", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompileRegexInvalidPattern tests error handling for invalid patterns.
|
||||
func TestCompileRegexInvalidPattern(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
invalidPattern := `[invalid(`
|
||||
|
||||
_, err := compileRegex(invalidPattern)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid pattern, got nil")
|
||||
}
|
||||
|
||||
// Invalid patterns should not be cached
|
||||
_, ok := regexCache.Load(invalidPattern)
|
||||
if ok {
|
||||
t.Error("Invalid pattern should not be cached")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompileRegexMultiplePatterns tests that different patterns are cached separately.
|
||||
func TestCompileRegexMultiplePatterns(t *testing.T) {
|
||||
// Clear cache before test
|
||||
regexCache = sync.Map{}
|
||||
|
||||
patterns := []string{
|
||||
`^test_\w+$`,
|
||||
`^\d{4}-\d{2}-\d{2}$`,
|
||||
`^[A-Z][a-z]+$`,
|
||||
`\b\w+@\w+\.\w+\b`,
|
||||
}
|
||||
|
||||
compiled := make([]*regexp.Regexp, len(patterns))
|
||||
|
||||
// Compile all patterns
|
||||
for i, pattern := range patterns {
|
||||
re, err := compileRegex(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Compile failed for pattern %s: %v", pattern, err)
|
||||
}
|
||||
compiled[i] = re
|
||||
}
|
||||
|
||||
// Verify all are cached
|
||||
for i, pattern := range patterns {
|
||||
cached, ok := regexCache.Load(pattern)
|
||||
if !ok {
|
||||
t.Errorf("Pattern %s not in cache", pattern)
|
||||
}
|
||||
|
||||
if cached.(*regexp.Regexp) != compiled[i] {
|
||||
t.Errorf("Cached regex for %s doesn't match compiled version", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// All should be different objects
|
||||
for i := 0; i < len(compiled); i++ {
|
||||
for j := i + 1; j < len(compiled); j++ {
|
||||
if compiled[i] == compiled[j] {
|
||||
t.Errorf("Pattern %d and %d have same regex object", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegex_Uncached benchmarks regex compilation without caching.
|
||||
func BenchmarkCompileRegex_Uncached(b *testing.B) {
|
||||
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = regexp.Compile(pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegex_Cached benchmarks regex compilation with caching.
|
||||
func BenchmarkCompileRegex_Cached(b *testing.B) {
|
||||
// Clear cache
|
||||
regexCache = sync.Map{}
|
||||
|
||||
pattern := `^\w+_[0-9]{3,5}_[a-zA-Z]+$`
|
||||
|
||||
// Pre-populate cache
|
||||
_, _ = compileRegex(pattern)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = compileRegex(pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompileRegex_MixedPatterns benchmarks realistic workload with multiple patterns.
|
||||
func BenchmarkCompileRegex_MixedPatterns(b *testing.B) {
|
||||
// Clear cache
|
||||
regexCache = sync.Map{}
|
||||
|
||||
patterns := []string{
|
||||
`^test_\w+$`,
|
||||
`^\d{4}-\d{2}-\d{2}$`,
|
||||
`^[A-Z][a-z]+$`,
|
||||
`\b\w+@\w+\.\w+\b`,
|
||||
`^func\s+\w+\(`,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate realistic access pattern
|
||||
pattern := patterns[i%len(patterns)]
|
||||
_, _ = compileRegex(pattern)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user