mirror of
https://github.com/lukaszraczylo/filepuff-mcp.git
synced 2026-06-10 22:59:01 +00:00
Ho hum.
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// FindNodeAtPosition finds the node at the given line and column.
|
||||
func FindNodeAtPosition(tree *sitter.Tree, line, col int) *sitter.Node {
|
||||
if tree == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert to 0-indexed
|
||||
point := sitter.Point{
|
||||
Row: uint32(line - 1), // #nosec G115 - line numbers are bounded by file size
|
||||
Column: uint32(col - 1), // #nosec G115 - column numbers are bounded by line length
|
||||
}
|
||||
|
||||
return findNodeAtPoint(root, point)
|
||||
}
|
||||
|
||||
// findNodeAtPoint recursively finds the smallest node containing the point.
|
||||
func findNodeAtPoint(node *sitter.Node, point sitter.Point) *sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
startPoint := node.StartPoint()
|
||||
endPoint := node.EndPoint()
|
||||
|
||||
// Check if point is within this node
|
||||
if !pointInRange(point, startPoint, endPoint) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to find a more specific child node
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
child := node.Child(i)
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
if result := findNodeAtPoint(child, point); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// No child contains the point, return this node
|
||||
return node
|
||||
}
|
||||
|
||||
// pointInRange checks if a point is within a range.
|
||||
func pointInRange(point, start, end sitter.Point) bool {
|
||||
// Before start?
|
||||
if point.Row < start.Row || (point.Row == start.Row && point.Column < start.Column) {
|
||||
return false
|
||||
}
|
||||
// After end?
|
||||
if point.Row > end.Row || (point.Row == end.Row && point.Column >= end.Column) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// FindParentOfKind finds the nearest ancestor of the given node type.
|
||||
func FindParentOfKind(node *sitter.Node, kind string) *sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
current := node.Parent()
|
||||
for current != nil {
|
||||
if current.Type() == kind {
|
||||
return current
|
||||
}
|
||||
current = current.Parent()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNodeText returns the text content of a node.
|
||||
func GetNodeText(node *sitter.Node, content []byte) string {
|
||||
if node == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
start := node.StartByte()
|
||||
end := node.EndByte()
|
||||
|
||||
if int(start) >= len(content) || int(end) > len(content) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(content[start:end])
|
||||
}
|
||||
|
||||
// WalkTree walks the tree calling fn for each node.
|
||||
// If fn returns false, the walk stops.
|
||||
func WalkTree(node *sitter.Node, fn func(*sitter.Node) bool) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !fn(node) {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < int(node.ChildCount()); i++ {
|
||||
WalkTree(node.Child(i), fn)
|
||||
}
|
||||
}
|
||||
|
||||
// FindNodesByKind finds all nodes of a given kind.
|
||||
func FindNodesByKind(root *sitter.Node, kind string) []*sitter.Node {
|
||||
var nodes []*sitter.Node
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
if n.Type() == kind {
|
||||
nodes = append(nodes, n)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
// FindNamedChildren returns all named (non-anonymous) children of a node.
|
||||
func FindNamedChildren(node *sitter.Node) []*sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var children []*sitter.Node
|
||||
for i := 0; i < int(node.NamedChildCount()); i++ {
|
||||
if child := node.NamedChild(i); child != nil {
|
||||
children = append(children, child)
|
||||
}
|
||||
}
|
||||
return children
|
||||
}
|
||||
|
||||
// GetChildByFieldName returns the child node with the given field name.
|
||||
func GetChildByFieldName(node *sitter.Node, fieldName string) *sitter.Node {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return node.ChildByFieldName(fieldName)
|
||||
}
|
||||
|
||||
// NodeLocation returns the location of a node.
|
||||
func NodeLocation(node *sitter.Node, filename string) protocol.Location {
|
||||
if node == nil {
|
||||
return protocol.Location{}
|
||||
}
|
||||
|
||||
startPoint := node.StartPoint()
|
||||
return protocol.Location{
|
||||
File: filename,
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// NodeRange returns the range of a node.
|
||||
func NodeRange(node *sitter.Node, filename string) protocol.Range {
|
||||
if node == nil {
|
||||
return protocol.Range{}
|
||||
}
|
||||
|
||||
startPoint := node.StartPoint()
|
||||
endPoint := node.EndPoint()
|
||||
|
||||
return protocol.Range{
|
||||
Start: protocol.Location{
|
||||
File: filename,
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
},
|
||||
End: protocol.Location{
|
||||
File: filename,
|
||||
Line: int(endPoint.Row) + 1,
|
||||
Column: int(endPoint.Column) + 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLRUCacheEviction tests that the LRU cache properly evicts old entries.
|
||||
func TestLRUCacheEviction(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 101 unique Go files (cache size is 100)
|
||||
for i := 0; i < 101; i++ {
|
||||
content := []byte(fmt.Sprintf("package main\n\nfunc test%d() {}\n", i))
|
||||
filename := "test.go"
|
||||
|
||||
_, err := registry.Parse(ctx, filename, content)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed for iteration %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// The LRU cache should have evicted the oldest entry
|
||||
// Verify cache size is capped at 100
|
||||
cacheLen := registry.cache.Len()
|
||||
if cacheLen > 100 {
|
||||
t.Errorf("Cache size %d exceeds max size 100", cacheLen)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheHit tests that repeated parsing of the same content uses cache.
|
||||
func TestCacheHit(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
ctx := context.Background()
|
||||
|
||||
content := []byte("package main\n\nfunc test() {}\n")
|
||||
filename := "test.go"
|
||||
|
||||
// First parse
|
||||
result1, err := registry.Parse(ctx, filename, content)
|
||||
if err != nil {
|
||||
t.Fatalf("First parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Second parse should use cache
|
||||
result2, err := registry.Parse(ctx, filename, content)
|
||||
if err != nil {
|
||||
t.Fatalf("Second parse failed: %v", err)
|
||||
}
|
||||
|
||||
// The tree should be the same object (cached)
|
||||
if result1.Tree != result2.Tree {
|
||||
t.Error("Expected cached tree to be reused, but got different tree objects")
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentHashCollisionResistance tests that different content produces different hashes.
|
||||
func TestContentHashCollisionResistance(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
content1 []byte
|
||||
content2 []byte
|
||||
}{
|
||||
{
|
||||
name: "different content",
|
||||
content1: []byte("package main"),
|
||||
content2: []byte("package test"),
|
||||
},
|
||||
{
|
||||
name: "same prefix different suffix",
|
||||
content1: []byte("package main\nfunc a() {}"),
|
||||
content2: []byte("package main\nfunc b() {}"),
|
||||
},
|
||||
{
|
||||
name: "different length",
|
||||
content1: []byte("short"),
|
||||
content2: []byte("much longer content here"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hash1 := contentHash(tc.content1)
|
||||
hash2 := contentHash(tc.content2)
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("Hash collision: %s == %s for different content", hash1, hash2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentHashConsistency tests that the same content always produces the same hash.
|
||||
func TestContentHashConsistency(t *testing.T) {
|
||||
content := []byte("package main\n\nfunc test() {}\n")
|
||||
|
||||
hash1 := contentHash(content)
|
||||
hash2 := contentHash(content)
|
||||
hash3 := contentHash(content)
|
||||
|
||||
if hash1 != hash2 || hash2 != hash3 {
|
||||
t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkContentHash_xxHash benchmarks the xxHash implementation.
|
||||
func BenchmarkContentHash_xxHash(b *testing.B) {
|
||||
// Typical file content size (10KB)
|
||||
content := make([]byte, 10*1024)
|
||||
for i := range content {
|
||||
content[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = contentHash(content)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCacheHitRate benchmarks cache performance with realistic workload.
|
||||
func BenchmarkCacheHitRate(b *testing.B) {
|
||||
registry := NewRegistry()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a set of common files that get parsed repeatedly
|
||||
files := [][]byte{
|
||||
[]byte("package main\n\nfunc main() {}\n"),
|
||||
[]byte("package test\n\nimport \"testing\"\n"),
|
||||
[]byte("package util\n\nfunc helper() string { return \"\" }\n"),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate realistic access pattern with cache hits
|
||||
content := files[i%len(files)]
|
||||
_, _ = registry.Parse(ctx, "test.go", content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,550 @@
|
||||
// Package parser provides documentation extraction for multiple languages.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// DocComment represents an extracted documentation comment.
|
||||
type DocComment struct {
|
||||
Tags map[string]string
|
||||
Text string
|
||||
Raw string
|
||||
Style CommentStyle
|
||||
StartLine int
|
||||
EndLine int
|
||||
}
|
||||
|
||||
// CommentStyle indicates the type of comment.
|
||||
type CommentStyle string
|
||||
|
||||
const (
|
||||
CommentStyleLine CommentStyle = "line" // // comment
|
||||
CommentStyleBlock CommentStyle = "block" // /* comment */
|
||||
CommentStyleJSDoc CommentStyle = "jsdoc" // /** comment */
|
||||
CommentStyleDoxygen CommentStyle = "doxygen" // /** comment */ or /// comment
|
||||
CommentStyleDocstring CommentStyle = "docstring" // """comment""" or '''comment'''
|
||||
CommentStyleHash CommentStyle = "hash" // # comment (Python)
|
||||
)
|
||||
|
||||
// ExtractDocComment extracts the documentation comment for a node.
|
||||
func ExtractDocComment(n *sitter.Node, content []byte, lang protocol.Language) *DocComment {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return extractGoDocComment(n, content)
|
||||
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||
return extractJSDocComment(n, content)
|
||||
case protocol.LangPython:
|
||||
return extractPythonDocComment(n, content)
|
||||
case protocol.LangC, protocol.LangCpp:
|
||||
return extractCDocComment(n, content)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// extractGoDocComment extracts Go documentation comments.
|
||||
// Go uses // or /* */ comments immediately preceding a declaration.
|
||||
func extractGoDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
cleaned := cleanGoComment(text)
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: detectCommentStyle(raw[0]),
|
||||
Tags: nil, // Go doesn't use JSDoc-style tags
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// extractJSDocComment extracts JSDoc-style documentation comments.
|
||||
func extractJSDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// JSDoc prefers the last comment block if it's a JSDoc comment
|
||||
var jsDocComment *sitter.Node
|
||||
for i := len(comments) - 1; i >= 0; i-- {
|
||||
text := GetNodeText(comments[i], content)
|
||||
if strings.HasPrefix(strings.TrimSpace(text), "/**") {
|
||||
jsDocComment = comments[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if jsDocComment != nil {
|
||||
text := GetNodeText(jsDocComment, content)
|
||||
cleaned, tags := parseJSDoc(text)
|
||||
return &DocComment{
|
||||
Text: cleaned,
|
||||
Raw: text,
|
||||
Style: CommentStyleJSDoc,
|
||||
Tags: tags,
|
||||
StartLine: int(jsDocComment.StartPoint().Row) + 1,
|
||||
EndLine: int(jsDocComment.EndPoint().Row) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to regular comments
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
cleaned := cleanJSComment(text)
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: CommentStyleLine,
|
||||
Tags: nil,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// extractPythonDocComment extracts Python docstrings.
|
||||
// Python docstrings are triple-quoted strings inside the function/class body.
|
||||
func extractPythonDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
// Python docstrings are inside the body, not before
|
||||
body := n.ChildByFieldName("body")
|
||||
if body == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// First statement should be the docstring if present
|
||||
if body.NamedChildCount() > 0 {
|
||||
first := body.NamedChild(0)
|
||||
if first != nil && first.Type() == "expression_statement" {
|
||||
if first.NamedChildCount() > 0 {
|
||||
expr := first.NamedChild(0)
|
||||
if expr != nil && expr.Type() == "string" {
|
||||
text := GetNodeText(expr, content)
|
||||
cleaned := cleanPythonDocstring(text)
|
||||
return &DocComment{
|
||||
Text: cleaned,
|
||||
Raw: text,
|
||||
Style: CommentStyleDocstring,
|
||||
Tags: nil,
|
||||
StartLine: int(expr.StartPoint().Row) + 1,
|
||||
EndLine: int(expr.EndPoint().Row) + 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for # comments before the definition
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
// Clean # comment
|
||||
cleaned := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(text), "#"))
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: CommentStyleHash,
|
||||
Tags: nil,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// extractCDocComment extracts C/C++ documentation comments (Doxygen style).
|
||||
func extractCDocComment(n *sitter.Node, content []byte) *DocComment {
|
||||
comments := collectPrecedingComments(n, content, []string{"comment"})
|
||||
if len(comments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Look for Doxygen-style comment
|
||||
var doxyComment *sitter.Node
|
||||
for i := len(comments) - 1; i >= 0; i-- {
|
||||
text := GetNodeText(comments[i], content)
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if strings.HasPrefix(trimmed, "/**") || strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
|
||||
doxyComment = comments[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if doxyComment != nil {
|
||||
text := GetNodeText(doxyComment, content)
|
||||
cleaned, tags := parseDoxygen(text)
|
||||
return &DocComment{
|
||||
Text: cleaned,
|
||||
Raw: text,
|
||||
Style: CommentStyleDoxygen,
|
||||
Tags: tags,
|
||||
StartLine: int(doxyComment.StartPoint().Row) + 1,
|
||||
EndLine: int(doxyComment.EndPoint().Row) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to regular comments
|
||||
var parts []string
|
||||
var raw []string
|
||||
startLine := -1
|
||||
endLine := -1
|
||||
|
||||
for _, c := range comments {
|
||||
text := GetNodeText(c, content)
|
||||
raw = append(raw, text)
|
||||
|
||||
if startLine == -1 {
|
||||
startLine = int(c.StartPoint().Row) + 1
|
||||
}
|
||||
endLine = int(c.EndPoint().Row) + 1
|
||||
|
||||
cleaned := cleanCComment(text)
|
||||
if cleaned != "" {
|
||||
parts = append(parts, cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DocComment{
|
||||
Text: strings.Join(parts, "\n"),
|
||||
Raw: strings.Join(raw, "\n"),
|
||||
Style: detectCommentStyle(raw[0]),
|
||||
Tags: nil,
|
||||
StartLine: startLine,
|
||||
EndLine: endLine,
|
||||
}
|
||||
}
|
||||
|
||||
// collectPrecedingComments collects all comment nodes immediately before a node.
|
||||
func collectPrecedingComments(n *sitter.Node, _ []byte, commentTypes []string) []*sitter.Node {
|
||||
var comments []*sitter.Node
|
||||
|
||||
// Walk backwards through siblings
|
||||
prev := n.PrevSibling()
|
||||
lastCommentLine := int(n.StartPoint().Row)
|
||||
|
||||
for prev != nil {
|
||||
isComment := false
|
||||
nodeType := prev.Type()
|
||||
for _, ct := range commentTypes {
|
||||
if nodeType == ct {
|
||||
isComment = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isComment {
|
||||
break
|
||||
}
|
||||
|
||||
commentEndLine := int(prev.EndPoint().Row)
|
||||
|
||||
// Check if there's a blank line gap
|
||||
if lastCommentLine-commentEndLine > 1 {
|
||||
break
|
||||
}
|
||||
|
||||
comments = append([]*sitter.Node{prev}, comments...)
|
||||
lastCommentLine = int(prev.StartPoint().Row)
|
||||
prev = prev.PrevSibling()
|
||||
}
|
||||
|
||||
return comments
|
||||
}
|
||||
|
||||
// detectCommentStyle determines the style of a comment.
|
||||
func detectCommentStyle(comment string) CommentStyle {
|
||||
trimmed := strings.TrimSpace(comment)
|
||||
if strings.HasPrefix(trimmed, "/**") {
|
||||
return CommentStyleJSDoc
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
|
||||
return CommentStyleDoxygen
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "/*") {
|
||||
return CommentStyleBlock
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "//") {
|
||||
return CommentStyleLine
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
return CommentStyleHash
|
||||
}
|
||||
if strings.HasPrefix(trimmed, `"""`) || strings.HasPrefix(trimmed, `'''`) {
|
||||
return CommentStyleDocstring
|
||||
}
|
||||
return CommentStyleLine
|
||||
}
|
||||
|
||||
// cleanGoComment cleans a Go comment.
|
||||
func cleanGoComment(comment string) string {
|
||||
comment = strings.TrimSpace(comment)
|
||||
|
||||
// Handle // comments
|
||||
if after, found := strings.CutPrefix(comment, "//"); found {
|
||||
return strings.TrimSpace(after)
|
||||
}
|
||||
|
||||
// Handle /* */ comments
|
||||
if strings.HasPrefix(comment, "/*") && strings.HasSuffix(comment, "*/") {
|
||||
comment = strings.TrimPrefix(comment, "/*")
|
||||
comment = strings.TrimSuffix(comment, "*/")
|
||||
return cleanBlockComment(comment)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(comment)
|
||||
}
|
||||
|
||||
// cleanJSComment cleans a JavaScript/TypeScript comment.
|
||||
func cleanJSComment(comment string) string {
|
||||
return cleanGoComment(comment) // Same rules
|
||||
}
|
||||
|
||||
// cleanCComment cleans a C/C++ comment.
|
||||
func cleanCComment(comment string) string {
|
||||
return cleanGoComment(comment) // Same rules
|
||||
}
|
||||
|
||||
// cleanBlockComment cleans the content of a block comment.
|
||||
func cleanBlockComment(comment string) string {
|
||||
lines := strings.Split(comment, "\n")
|
||||
var cleaned []string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
// Remove leading * from each line (common in block comments)
|
||||
line = strings.TrimPrefix(line, "*")
|
||||
line = strings.TrimSpace(line)
|
||||
cleaned = append(cleaned, line)
|
||||
}
|
||||
|
||||
// Remove empty leading/trailing lines
|
||||
for len(cleaned) > 0 && cleaned[0] == "" {
|
||||
cleaned = cleaned[1:]
|
||||
}
|
||||
for len(cleaned) > 0 && cleaned[len(cleaned)-1] == "" {
|
||||
cleaned = cleaned[:len(cleaned)-1]
|
||||
}
|
||||
|
||||
return strings.Join(cleaned, "\n")
|
||||
}
|
||||
|
||||
// parseJSDoc parses a JSDoc comment and extracts tags.
|
||||
func parseJSDoc(comment string) (string, map[string]string) {
|
||||
comment = strings.TrimSpace(comment)
|
||||
|
||||
// Remove /** and */
|
||||
comment = strings.TrimPrefix(comment, "/**")
|
||||
comment = strings.TrimSuffix(comment, "*/")
|
||||
|
||||
lines := strings.Split(comment, "\n")
|
||||
var descLines []string
|
||||
tags := make(map[string]string)
|
||||
|
||||
// Regex for JSDoc tags
|
||||
tagPattern := regexp.MustCompile(`^\s*\*?\s*@(\w+)\s*(.*)$`)
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimPrefix(line, "*")
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
|
||||
tagName := matches[1]
|
||||
tagValue := strings.TrimSpace(matches[2])
|
||||
if existing, ok := tags[tagName]; ok {
|
||||
tags[tagName] = existing + "\n" + tagValue
|
||||
} else {
|
||||
tags[tagName] = tagValue
|
||||
}
|
||||
} else if line != "" {
|
||||
descLines = append(descLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(descLines, "\n"), tags
|
||||
}
|
||||
|
||||
// parseDoxygen parses a Doxygen comment and extracts tags.
|
||||
func parseDoxygen(comment string) (string, map[string]string) {
|
||||
comment = strings.TrimSpace(comment)
|
||||
|
||||
// Handle /// and //! style comments
|
||||
comment = strings.TrimPrefix(comment, "///")
|
||||
comment = strings.TrimPrefix(comment, "//!")
|
||||
|
||||
// Handle /** */ style comments
|
||||
comment = strings.TrimPrefix(comment, "/**")
|
||||
comment = strings.TrimSuffix(comment, "*/")
|
||||
|
||||
lines := strings.Split(comment, "\n")
|
||||
var descLines []string
|
||||
tags := make(map[string]string)
|
||||
|
||||
// Regex for Doxygen tags (@param, @return, \param, \return, etc.)
|
||||
tagPattern := regexp.MustCompile(`^\s*\*?\s*[@\\](\w+)\s*(.*)$`)
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimPrefix(line, "*")
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
|
||||
tagName := matches[1]
|
||||
tagValue := strings.TrimSpace(matches[2])
|
||||
if existing, ok := tags[tagName]; ok {
|
||||
tags[tagName] = existing + "\n" + tagValue
|
||||
} else {
|
||||
tags[tagName] = tagValue
|
||||
}
|
||||
} else if line != "" {
|
||||
descLines = append(descLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(descLines, "\n"), tags
|
||||
}
|
||||
|
||||
// FormatDocComment formats a DocComment for display.
|
||||
func FormatDocComment(doc *DocComment) string {
|
||||
if doc == nil || doc.Text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(doc.Text)
|
||||
|
||||
if len(doc.Tags) > 0 {
|
||||
sb.WriteString("\n\n")
|
||||
// Order: description, params, returns, other
|
||||
paramOrder := []string{"param", "parameter", "arg", "argument"}
|
||||
returnOrder := []string{"return", "returns", "retval"}
|
||||
|
||||
// Write params first
|
||||
for _, tagName := range paramOrder {
|
||||
if val, ok := doc.Tags[tagName]; ok {
|
||||
for _, line := range strings.Split(val, "\n") {
|
||||
sb.WriteString("@" + tagName + " " + line + "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write returns
|
||||
for _, tagName := range returnOrder {
|
||||
if val, ok := doc.Tags[tagName]; ok {
|
||||
sb.WriteString("@" + tagName + " " + val + "\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Write remaining tags
|
||||
written := make(map[string]bool)
|
||||
for _, t := range paramOrder {
|
||||
written[t] = true
|
||||
}
|
||||
for _, t := range returnOrder {
|
||||
written[t] = true
|
||||
}
|
||||
|
||||
for tagName, val := range doc.Tags {
|
||||
if !written[tagName] {
|
||||
sb.WriteString("@" + tagName + " " + val + "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// cleanPythonDocstring cleans a Python docstring.
|
||||
func cleanPythonDocstring(doc string) string {
|
||||
doc = strings.TrimSpace(doc)
|
||||
|
||||
// Remove triple quotes
|
||||
doc = strings.TrimPrefix(doc, `"""`)
|
||||
doc = strings.TrimSuffix(doc, `"""`)
|
||||
doc = strings.TrimPrefix(doc, `'''`)
|
||||
doc = strings.TrimSuffix(doc, `'''`)
|
||||
|
||||
return strings.TrimSpace(doc)
|
||||
}
|
||||
@@ -0,0 +1,630 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
func TestExtractGoDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "single line comment",
|
||||
code: `package main
|
||||
|
||||
// Hello says hello
|
||||
func Hello() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "Hello says hello",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
{
|
||||
name: "multi-line comments",
|
||||
code: `package main
|
||||
|
||||
// This is a function
|
||||
// that does something
|
||||
// important
|
||||
func DoSomething() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a function\nthat does something\nimportant",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
{
|
||||
name: "block comment",
|
||||
code: `package main
|
||||
|
||||
/* This is a block comment
|
||||
describing the function */
|
||||
func BlockCommented() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a block comment\ndescribing the function",
|
||||
wantStyle: CommentStyleBlock,
|
||||
},
|
||||
{
|
||||
name: "doc comment with asterisks",
|
||||
code: `package main
|
||||
|
||||
/*
|
||||
* This is a properly formatted
|
||||
* block comment with asterisks
|
||||
*/
|
||||
func FormattedBlock() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a properly formatted\nblock comment with asterisks",
|
||||
wantStyle: CommentStyleBlock,
|
||||
},
|
||||
{
|
||||
name: "no comment",
|
||||
code: `package main
|
||||
|
||||
func NoComment() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.go", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Find the target node
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangGo)
|
||||
|
||||
if tt.wantText == "" {
|
||||
if doc != nil && doc.Text != "" {
|
||||
t.Errorf("expected no doc, got %q", doc.Text)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "JSDoc comment",
|
||||
code: `/**
|
||||
* Adds two numbers together.
|
||||
* @param a The first number
|
||||
* @param b The second number
|
||||
* @returns The sum of a and b
|
||||
*/
|
||||
function add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "Adds two numbers together.",
|
||||
wantStyle: CommentStyleJSDoc,
|
||||
wantTags: map[string]string{
|
||||
"param": "a The first number\nb The second number",
|
||||
"returns": "The sum of a and b",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple line comment",
|
||||
code: `// This is a simple function
|
||||
function simple() {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "This is a simple function",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
{
|
||||
name: "JSDoc with types",
|
||||
code: `/**
|
||||
* @param {string} name - The name
|
||||
* @returns {boolean} True if valid
|
||||
*/
|
||||
function validate(name) {}
|
||||
`,
|
||||
nodeKind: "function_declaration",
|
||||
wantText: "",
|
||||
wantStyle: CommentStyleJSDoc,
|
||||
wantTags: map[string]string{
|
||||
"param": "{string} name - The name",
|
||||
"returns": "{boolean} True if valid",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.js", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangJavaScript)
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
|
||||
if tt.wantTags != nil {
|
||||
for k, want := range tt.wantTags {
|
||||
if got := doc.Tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPythonDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "docstring",
|
||||
code: `def greet(name):
|
||||
"""Greet a person by name."""
|
||||
print(f"Hello, {name}!")
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Greet a person by name.",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
{
|
||||
name: "multi-line docstring",
|
||||
code: `def calculate(x, y):
|
||||
"""
|
||||
Calculate the sum of two numbers.
|
||||
|
||||
Args:
|
||||
x: First number
|
||||
y: Second number
|
||||
|
||||
Returns:
|
||||
The sum of x and y
|
||||
"""
|
||||
return x + y
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Calculate the sum of two numbers.\n\n Args:\n x: First number\n y: Second number\n\n Returns:\n The sum of x and y",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
{
|
||||
name: "class docstring",
|
||||
code: `class MyClass:
|
||||
"""This is a class description."""
|
||||
pass
|
||||
`,
|
||||
nodeKind: "class_definition",
|
||||
wantText: "This is a class description.",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
{
|
||||
name: "single quote docstring",
|
||||
code: `def func():
|
||||
'''Single quote docstring'''
|
||||
pass
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Single quote docstring",
|
||||
wantStyle: CommentStyleDocstring,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.py", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangPython)
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCDocComment(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
code string
|
||||
nodeKind string
|
||||
wantText string
|
||||
wantStyle CommentStyle
|
||||
}{
|
||||
{
|
||||
name: "Doxygen block comment",
|
||||
code: `/**
|
||||
* Adds two numbers.
|
||||
* @param a First number
|
||||
* @param b Second number
|
||||
* @return Sum of a and b
|
||||
*/
|
||||
int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Adds two numbers.",
|
||||
wantStyle: CommentStyleDoxygen,
|
||||
wantTags: map[string]string{
|
||||
"param": "a First number\nb Second number",
|
||||
"return": "Sum of a and b",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "regular block comment",
|
||||
code: `/* This is a regular comment */
|
||||
int regular() { return 0; }
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "This is a regular comment",
|
||||
wantStyle: CommentStyleBlock,
|
||||
},
|
||||
{
|
||||
name: "line comment",
|
||||
code: `// Simple function
|
||||
int simple() { return 1; }
|
||||
`,
|
||||
nodeKind: "function_definition",
|
||||
wantText: "Simple function",
|
||||
wantStyle: CommentStyleLine,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.Parse(context.Background(), "test.c", []byte(tt.code))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
|
||||
if targetNode == nil {
|
||||
t.Fatalf("could not find node of type %s", tt.nodeKind)
|
||||
}
|
||||
|
||||
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangC)
|
||||
|
||||
if doc == nil {
|
||||
t.Fatal("expected doc, got nil")
|
||||
}
|
||||
|
||||
if doc.Text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
|
||||
}
|
||||
|
||||
if doc.Style != tt.wantStyle {
|
||||
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
|
||||
}
|
||||
|
||||
if tt.wantTags != nil {
|
||||
for k, want := range tt.wantTags {
|
||||
if got := doc.Tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSDoc(t *testing.T) {
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
input string
|
||||
wantText string
|
||||
}{
|
||||
{
|
||||
name: "complete jsdoc",
|
||||
input: `/**
|
||||
* This is a description.
|
||||
* Multiple lines.
|
||||
* @param {string} name The name
|
||||
* @returns {boolean} Result
|
||||
*/`,
|
||||
wantText: "This is a description.\nMultiple lines.",
|
||||
wantTags: map[string]string{
|
||||
"param": "{string} name The name",
|
||||
"returns": "{boolean} Result",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty jsdoc",
|
||||
input: `/** */`,
|
||||
wantText: "",
|
||||
wantTags: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "only description",
|
||||
input: `/** Simple description */`,
|
||||
wantText: "Simple description",
|
||||
wantTags: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
text, tags := parseJSDoc(tt.input)
|
||||
|
||||
if text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
|
||||
}
|
||||
|
||||
if len(tags) != len(tt.wantTags) {
|
||||
t.Errorf("tag count mismatch: got %d, want %d", len(tags), len(tt.wantTags))
|
||||
}
|
||||
|
||||
for k, want := range tt.wantTags {
|
||||
if got := tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDoxygen(t *testing.T) {
|
||||
tests := []struct {
|
||||
wantTags map[string]string
|
||||
name string
|
||||
input string
|
||||
wantText string
|
||||
}{
|
||||
{
|
||||
name: "doxygen with @ tags",
|
||||
input: `/**
|
||||
* Brief description.
|
||||
* @param x Value
|
||||
* @return Result
|
||||
*/`,
|
||||
wantText: "Brief description.",
|
||||
wantTags: map[string]string{
|
||||
"param": "x Value",
|
||||
"return": "Result",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "doxygen with backslash tags",
|
||||
input: `/**
|
||||
* Description.
|
||||
* \param y Input
|
||||
* \retval Output value
|
||||
*/`,
|
||||
wantText: "Description.",
|
||||
wantTags: map[string]string{
|
||||
"param": "y Input",
|
||||
"retval": "Output value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "triple slash",
|
||||
input: `/// Simple description`,
|
||||
wantText: "Simple description",
|
||||
wantTags: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
text, tags := parseDoxygen(tt.input)
|
||||
|
||||
if text != tt.wantText {
|
||||
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
|
||||
}
|
||||
|
||||
for k, want := range tt.wantTags {
|
||||
if got := tags[k]; got != want {
|
||||
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatDocComment(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
doc *DocComment
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "with tags",
|
||||
doc: &DocComment{
|
||||
Text: "This is a function.",
|
||||
Tags: map[string]string{
|
||||
"param": "x The value",
|
||||
"returns": "The result",
|
||||
},
|
||||
},
|
||||
want: "This is a function.\n\n@param x The value\n@returns The result",
|
||||
},
|
||||
{
|
||||
name: "no tags",
|
||||
doc: &DocComment{
|
||||
Text: "Simple description.",
|
||||
Tags: nil,
|
||||
},
|
||||
want: "Simple description.",
|
||||
},
|
||||
{
|
||||
name: "nil doc",
|
||||
doc: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty text",
|
||||
doc: &DocComment{
|
||||
Text: "",
|
||||
Tags: nil,
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FormatDocComment(tt.doc)
|
||||
if got != tt.want {
|
||||
t.Errorf("mismatch:\ngot: %q\nwant: %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectCommentStyle(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want CommentStyle
|
||||
}{
|
||||
{"/** JSDoc */", CommentStyleJSDoc},
|
||||
{"/// Doxygen", CommentStyleDoxygen},
|
||||
{"//! Doxygen", CommentStyleDoxygen},
|
||||
{"/* block */", CommentStyleBlock},
|
||||
{"// line", CommentStyleLine},
|
||||
{"# hash", CommentStyleHash},
|
||||
{`"""docstring"""`, CommentStyleDocstring},
|
||||
{`'''docstring'''`, CommentStyleDocstring},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := detectCommentStyle(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// findNodeByKind finds the first node of the given kind.
|
||||
func findNodeByKind(root *sitter.Node, nodeType string) *sitter.Node {
|
||||
if root == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result *sitter.Node
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
if n.Type() == nodeType {
|
||||
result = n
|
||||
return false // stop walking
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func TestCleanBlockComment(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
input: "\n * Line 1\n * Line 2\n ",
|
||||
want: "Line 1\nLine 2",
|
||||
},
|
||||
{
|
||||
input: "Simple",
|
||||
want: "Simple",
|
||||
},
|
||||
{
|
||||
input: "\n\nWith blank lines\n\n",
|
||||
want: "With blank lines",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input[:min(10, len(tt.input))], func(t *testing.T) {
|
||||
got := cleanBlockComment(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
// Package parser provides Tree-sitter based parsing for multiple languages.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
"github.com/smacker/go-tree-sitter/c"
|
||||
"github.com/smacker/go-tree-sitter/cpp"
|
||||
"github.com/smacker/go-tree-sitter/golang"
|
||||
"github.com/smacker/go-tree-sitter/html"
|
||||
"github.com/smacker/go-tree-sitter/javascript"
|
||||
"github.com/smacker/go-tree-sitter/python"
|
||||
"github.com/smacker/go-tree-sitter/typescript/typescript"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
// MaxFileSize is the maximum file size we'll parse (10MB).
|
||||
const MaxFileSize = 10 * 1024 * 1024
|
||||
|
||||
// Registry manages Tree-sitter parsers for different languages.
|
||||
type Registry struct {
|
||||
parsers map[protocol.Language]*sitter.Parser
|
||||
cache *lru.Cache[string, *CachedTree]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// CachedTree stores a parsed tree with its metadata.
|
||||
// Content is not stored to reduce memory usage.
|
||||
type CachedTree struct {
|
||||
Tree *sitter.Tree
|
||||
Language protocol.Language
|
||||
}
|
||||
|
||||
// ParseResult contains the result of parsing a file.
|
||||
type ParseResult struct {
|
||||
Tree *sitter.Tree
|
||||
Language protocol.Language
|
||||
Errors []SyntaxError
|
||||
Content []byte
|
||||
}
|
||||
|
||||
// SyntaxError represents a syntax error found during parsing.
|
||||
type SyntaxError struct {
|
||||
Message string
|
||||
NodeType string
|
||||
Location protocol.Location
|
||||
}
|
||||
|
||||
// NewRegistry creates a new parser registry.
|
||||
func NewRegistry() *Registry {
|
||||
// Create LRU cache with capacity of 100 trees
|
||||
cache, err := lru.New[string, *CachedTree](100)
|
||||
if err != nil {
|
||||
// LRU.New only errors if size <= 0, which won't happen here
|
||||
panic(fmt.Sprintf("failed to create LRU cache: %v", err))
|
||||
}
|
||||
|
||||
return &Registry{
|
||||
parsers: make(map[protocol.Language]*sitter.Parser),
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// getLanguage returns the Tree-sitter language for a given language.
|
||||
func getLanguage(lang protocol.Language) (*sitter.Language, error) {
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return golang.GetLanguage(), nil
|
||||
case protocol.LangTypeScript:
|
||||
return typescript.GetLanguage(), nil
|
||||
case protocol.LangJavaScript:
|
||||
return javascript.GetLanguage(), nil
|
||||
case protocol.LangPython:
|
||||
return python.GetLanguage(), nil
|
||||
case protocol.LangC:
|
||||
return c.GetLanguage(), nil
|
||||
case protocol.LangCpp:
|
||||
return cpp.GetLanguage(), nil
|
||||
case protocol.LangHTML:
|
||||
return html.GetLanguage(), nil
|
||||
case protocol.LangVue:
|
||||
// Vue SFC files use HTML-like template syntax, so we use the HTML parser
|
||||
return html.GetLanguage(), nil
|
||||
default:
|
||||
return nil, errors.New(errors.ErrInvalidLanguage, fmt.Sprintf("language %s is not supported", lang)).
|
||||
WithContext("language", string(lang)).
|
||||
WithRemediation("Supported languages: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue")
|
||||
}
|
||||
}
|
||||
|
||||
// GetParser returns a parser for the given language.
|
||||
func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
|
||||
r.mu.RLock()
|
||||
if p, ok := r.parsers[lang]; ok {
|
||||
r.mu.RUnlock()
|
||||
return p, nil
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
// Create new parser
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if p, ok := r.parsers[lang]; ok {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
sitterLang, err := getLanguage(lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parser := sitter.NewParser()
|
||||
parser.SetLanguage(sitterLang)
|
||||
r.parsers[lang] = parser
|
||||
|
||||
return parser, nil
|
||||
}
|
||||
|
||||
// Parse parses the given content for the specified language.
|
||||
func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
}
|
||||
|
||||
// Detect binary files
|
||||
if isBinary(content) {
|
||||
return nil, errors.New(errors.ErrParseFailed, "binary file detected").
|
||||
WithContext("file", filename).
|
||||
WithRemediation("This appears to be a binary file and cannot be parsed as source code")
|
||||
}
|
||||
|
||||
// Detect language
|
||||
lang := protocol.DetectLanguage(filename)
|
||||
if lang == protocol.LangUnknown {
|
||||
return nil, errors.New(errors.ErrInvalidLanguage, "could not detect language from filename").
|
||||
WithContext("file", filename).
|
||||
WithRemediation("Ensure file has a recognized extension (e.g., .go, .ts, .py, .c, .cpp, .html, .vue, .json, .yaml)")
|
||||
}
|
||||
|
||||
// Handle YAML and JSON separately (they don't use tree-sitter)
|
||||
switch lang {
|
||||
case protocol.LangYAML:
|
||||
return r.ParseYAML(ctx, filename, content)
|
||||
case protocol.LangJSON:
|
||||
return r.ParseJSON(ctx, filename, content)
|
||||
}
|
||||
|
||||
// Check cache (LRU cache is thread-safe)
|
||||
hash := contentHash(content)
|
||||
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
|
||||
errors := extractErrors(cached.Tree.RootNode(), content)
|
||||
return &ParseResult{
|
||||
Tree: cached.Tree,
|
||||
Language: lang,
|
||||
Errors: errors,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get parser
|
||||
parser, err := r.GetParser(lang)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse content - tree-sitter parsers are not thread-safe,
|
||||
// so we need to hold the lock during parsing
|
||||
r.mu.Lock()
|
||||
tree, err := parser.ParseCtx(ctx, nil, content)
|
||||
r.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.NewParseError(string(lang), filename, err)
|
||||
}
|
||||
|
||||
// Extract syntax errors
|
||||
errors := extractErrors(tree.RootNode(), content)
|
||||
|
||||
// Cache result (LRU cache handles eviction automatically)
|
||||
r.cache.Add(hash, &CachedTree{
|
||||
Tree: tree,
|
||||
Language: lang,
|
||||
})
|
||||
|
||||
return &ParseResult{
|
||||
Tree: tree,
|
||||
Language: lang,
|
||||
Errors: errors,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractErrors finds all error nodes in the tree.
|
||||
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
|
||||
var errors []SyntaxError
|
||||
|
||||
var walk func(n *sitter.Node)
|
||||
walk = func(n *sitter.Node) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if n.IsError() || n.IsMissing() {
|
||||
startPoint := n.StartPoint()
|
||||
nodeType := "ERROR"
|
||||
if n.IsMissing() {
|
||||
nodeType = "MISSING"
|
||||
}
|
||||
|
||||
errors = append(errors, SyntaxError{
|
||||
Location: protocol.Location{
|
||||
Line: int(startPoint.Row) + 1,
|
||||
Column: int(startPoint.Column) + 1,
|
||||
},
|
||||
Message: fmt.Sprintf("syntax error: unexpected %s", n.Type()),
|
||||
NodeType: nodeType,
|
||||
})
|
||||
}
|
||||
|
||||
for i := 0; i < int(n.ChildCount()); i++ {
|
||||
walk(n.Child(i))
|
||||
}
|
||||
}
|
||||
|
||||
walk(node)
|
||||
return errors
|
||||
}
|
||||
|
||||
// contentHash returns a fast hash of the content for caching.
|
||||
// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes.
|
||||
func contentHash(content []byte) string {
|
||||
h := xxhash.Sum64(content)
|
||||
return fmt.Sprintf("%016x", h)
|
||||
}
|
||||
|
||||
// isBinary checks if content appears to be binary.
|
||||
func isBinary(content []byte) bool {
|
||||
// Check first 8000 bytes for null bytes
|
||||
checkLen := min(8000, len(content))
|
||||
|
||||
for i := range checkLen {
|
||||
if content[i] == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Close closes all parsers and clears the cache.
|
||||
func (r *Registry) Close() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for _, p := range r.parsers {
|
||||
p.Close()
|
||||
}
|
||||
r.parsers = make(map[protocol.Language]*sitter.Parser)
|
||||
|
||||
// Purge LRU cache
|
||||
r.cache.Purge()
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
if r == nil {
|
||||
t.Fatal("expected non-nil registry")
|
||||
}
|
||||
defer r.Close()
|
||||
}
|
||||
|
||||
func TestGetParser(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
lang protocol.Language
|
||||
wantErr bool
|
||||
}{
|
||||
{protocol.LangGo, false},
|
||||
{protocol.LangTypeScript, false},
|
||||
{protocol.LangJavaScript, false},
|
||||
{protocol.LangPython, false},
|
||||
{protocol.LangC, false},
|
||||
{protocol.LangCpp, false},
|
||||
{protocol.LangHTML, false},
|
||||
{protocol.LangVue, false},
|
||||
{protocol.LangUnknown, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.lang), func(t *testing.T) {
|
||||
parser, err := r.GetParser(tt.lang)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if parser == nil {
|
||||
t.Error("expected non-nil parser")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
content string
|
||||
wantLang protocol.Language
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "go file",
|
||||
filename: "test.go",
|
||||
content: "package main\n\nfunc main() {}\n",
|
||||
wantLang: protocol.LangGo,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "typescript file",
|
||||
filename: "test.ts",
|
||||
content: "function hello(): void {}\n",
|
||||
wantLang: protocol.LangTypeScript,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "react tsx file",
|
||||
filename: "Component.tsx",
|
||||
content: `import React from 'react';\n\nexport const Button: React.FC = () => <button className="btn">Click</button>;`,
|
||||
wantLang: protocol.LangTypeScript,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "react jsx file",
|
||||
filename: "Component.jsx",
|
||||
content: `import React from 'react';\n\nexport const Button = () => <button className="btn">Click</button>;`,
|
||||
wantLang: protocol.LangJavaScript,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "python file",
|
||||
filename: "test.py",
|
||||
content: "def hello():\n pass\n",
|
||||
wantLang: protocol.LangPython,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "html file",
|
||||
filename: "test.html",
|
||||
content: `<!DOCTYPE html><html><head><title>Test</title></head><body><h1 class="text-xl">Hello</h1></body></html>`,
|
||||
wantLang: protocol.LangHTML,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "vue file",
|
||||
filename: "Component.vue",
|
||||
content: `<template><div class="container"><h1>{{ title }}</h1></div></template>`,
|
||||
wantLang: protocol.LangVue,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unknown file",
|
||||
filename: "test.txt",
|
||||
content: "hello world",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, tt.filename, []byte(tt.content))
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Language != tt.wantLang {
|
||||
t.Errorf("expected language %s, got %s", tt.wantLang, result.Language)
|
||||
}
|
||||
|
||||
if result.Tree == nil {
|
||||
t.Error("expected non-nil tree")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWithSyntaxErrors(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
// Invalid Go code
|
||||
content := "package main\n\nfunc main( {}\n" // Missing closing paren
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := r.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should have parsed (tree-sitter is error-tolerant)
|
||||
if result.Tree == nil {
|
||||
t.Error("expected non-nil tree")
|
||||
}
|
||||
|
||||
// Should have detected errors
|
||||
if len(result.Errors) == 0 {
|
||||
t.Error("expected syntax errors to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBinary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "text file",
|
||||
content: []byte("hello world"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "binary with null byte",
|
||||
content: []byte{0x68, 0x65, 0x6c, 0x00, 0x6f},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
content: []byte{},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isBinary(tt.content); got != tt.want {
|
||||
t.Errorf("isBinary() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := []byte("package main\n\nfunc main() {}\n")
|
||||
ctx := context.Background()
|
||||
|
||||
// Parse once
|
||||
result1, err := r.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Parse again with same content
|
||||
result2, err := r.Parse(ctx, "test.go", content)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
// Should return cached tree (same pointer)
|
||||
if result1.Tree != result2.Tree {
|
||||
t.Error("expected cached tree to be returned")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// ExtractSymbols extracts symbols from a parsed tree.
|
||||
func ExtractSymbols(tree *sitter.Tree, content []byte, lang protocol.Language, filename string) []protocol.Symbol {
|
||||
if tree == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
root := tree.RootNode()
|
||||
if root == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch lang {
|
||||
case protocol.LangGo:
|
||||
return extractGoSymbols(root, content, filename)
|
||||
case protocol.LangTypeScript, protocol.LangJavaScript:
|
||||
return extractJSSymbols(root, content, filename)
|
||||
case protocol.LangPython:
|
||||
return extractPythonSymbols(root, content, filename)
|
||||
case protocol.LangC, protocol.LangCpp:
|
||||
return extractCSymbols(root, content, filename)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// extractGoSymbols extracts symbols from Go code.
|
||||
func extractGoSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_declaration":
|
||||
symbol = extractGoFunction(n, content, filename)
|
||||
case "method_declaration":
|
||||
symbol = extractGoMethod(n, content, filename)
|
||||
case "type_declaration":
|
||||
symbol = extractGoType(n, content, filename)
|
||||
case "const_declaration", "var_declaration":
|
||||
syms := extractGoVarConst(n, content, filename)
|
||||
symbols = append(symbols, syms...)
|
||||
return true
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangGo); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractGoFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractGoMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get receiver type
|
||||
receiver := n.ChildByFieldName("receiver")
|
||||
receiverType := ""
|
||||
if receiver != nil {
|
||||
// Find the type in the receiver
|
||||
WalkTree(receiver, func(node *sitter.Node) bool {
|
||||
if node.Type() == "type_identifier" {
|
||||
receiverType = GetNodeText(node, content)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
name := GetNodeText(nameNode, content)
|
||||
if receiverType != "" {
|
||||
name = "(" + receiverType + ")." + name
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: name,
|
||||
Kind: protocol.SymbolMethod,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractGoType(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
// Find type_spec child
|
||||
for i := 0; i < int(n.NamedChildCount()); i++ {
|
||||
child := n.NamedChild(i)
|
||||
if child != nil && child.Type() == "type_spec" {
|
||||
nameNode := child.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kind := protocol.SymbolType
|
||||
typeNode := child.ChildByFieldName("type")
|
||||
if typeNode != nil {
|
||||
switch typeNode.Type() {
|
||||
case "struct_type":
|
||||
kind = protocol.SymbolStruct
|
||||
case "interface_type":
|
||||
kind = protocol.SymbolInterface
|
||||
}
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: kind,
|
||||
Location: NodeLocation(child, filename),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractGoVarConst(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
kind := protocol.SymbolVariable
|
||||
if n.Type() == "const_declaration" {
|
||||
kind = protocol.SymbolConstant
|
||||
}
|
||||
|
||||
WalkTree(n, func(node *sitter.Node) bool {
|
||||
if node.Type() == "const_spec" || node.Type() == "var_spec" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode != nil {
|
||||
symbols = append(symbols, protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: kind,
|
||||
Location: NodeLocation(node, filename),
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
// extractJSSymbols extracts symbols from JavaScript/TypeScript code.
|
||||
func extractJSSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_declaration":
|
||||
symbol = extractJSFunction(n, content, filename)
|
||||
case "class_declaration":
|
||||
symbol = extractJSClass(n, content, filename)
|
||||
case "method_definition":
|
||||
symbol = extractJSMethod(n, content, filename)
|
||||
case "lexical_declaration", "variable_declaration":
|
||||
syms := extractJSVariable(n, content, filename)
|
||||
symbols = append(symbols, syms...)
|
||||
return true
|
||||
case "interface_declaration":
|
||||
symbol = extractTSInterface(n, content, filename)
|
||||
case "type_alias_declaration":
|
||||
symbol = extractTSTypeAlias(n, content, filename)
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangJavaScript); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractJSFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolClass,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolMethod,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSVariable(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(n, func(node *sitter.Node) bool {
|
||||
if node.Type() == "variable_declarator" {
|
||||
nameNode := node.ChildByFieldName("name")
|
||||
if nameNode != nil {
|
||||
symbols = append(symbols, protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolVariable,
|
||||
Location: NodeLocation(node, filename),
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractTSInterface(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolInterface,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractTSTypeAlias(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolType,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
// extractPythonSymbols extracts symbols from Python code.
|
||||
func extractPythonSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_definition":
|
||||
symbol = extractPythonFunction(n, content, filename)
|
||||
case "class_definition":
|
||||
symbol = extractPythonClass(n, content, filename)
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangPython); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractPythonFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this is a method (inside a class)
|
||||
parent := n.Parent()
|
||||
kind := protocol.SymbolFunction
|
||||
if parent != nil && parent.Type() == "block" {
|
||||
grandparent := parent.Parent()
|
||||
if grandparent != nil && grandparent.Type() == "class_definition" {
|
||||
kind = protocol.SymbolMethod
|
||||
}
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: kind,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractPythonClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolClass,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
// extractCSymbols extracts symbols from C/C++ code.
|
||||
func extractCSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
|
||||
var symbols []protocol.Symbol
|
||||
|
||||
WalkTree(root, func(n *sitter.Node) bool {
|
||||
var symbol *protocol.Symbol
|
||||
|
||||
switch n.Type() {
|
||||
case "function_definition":
|
||||
symbol = extractCFunction(n, content, filename)
|
||||
case "struct_specifier":
|
||||
symbol = extractCStruct(n, content, filename)
|
||||
case "class_specifier":
|
||||
symbol = extractCppClass(n, content, filename)
|
||||
case "declaration":
|
||||
// Could be function declaration or variable
|
||||
if hasFunctionDeclarator(n) {
|
||||
symbol = extractCFunctionDecl(n, content, filename)
|
||||
}
|
||||
}
|
||||
|
||||
if symbol != nil {
|
||||
if doc := ExtractDocComment(n, content, protocol.LangC); doc != nil {
|
||||
symbol.Doc = FormatDocComment(doc)
|
||||
}
|
||||
symbols = append(symbols, *symbol)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return symbols
|
||||
}
|
||||
|
||||
func extractCFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
declarator := n.ChildByFieldName("declarator")
|
||||
if declarator == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find the function name within the declarator
|
||||
var name string
|
||||
WalkTree(declarator, func(node *sitter.Node) bool {
|
||||
if node.Type() == "identifier" {
|
||||
name = GetNodeText(node, content)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: name,
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractCStruct(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolStruct,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractCppClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
nameNode := n.ChildByFieldName("name")
|
||||
if nameNode == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: GetNodeText(nameNode, content),
|
||||
Kind: protocol.SymbolClass,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func extractCFunctionDecl(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
|
||||
declarator := n.ChildByFieldName("declarator")
|
||||
if declarator == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var name string
|
||||
WalkTree(declarator, func(node *sitter.Node) bool {
|
||||
if node.Type() == "identifier" {
|
||||
name = GetNodeText(node, content)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &protocol.Symbol{
|
||||
Name: name,
|
||||
Kind: protocol.SymbolFunction,
|
||||
Location: NodeLocation(n, filename),
|
||||
}
|
||||
}
|
||||
|
||||
func hasFunctionDeclarator(n *sitter.Node) bool {
|
||||
found := false
|
||||
WalkTree(n, func(node *sitter.Node) bool {
|
||||
if node.Type() == "function_declarator" {
|
||||
found = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestExtractGoSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `package main
|
||||
|
||||
// Hello prints a greeting
|
||||
func Hello() {
|
||||
println("hello")
|
||||
}
|
||||
|
||||
// Server handles requests
|
||||
type Server struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const MaxConnections = 100
|
||||
var globalVar = "test"
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.go", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangGo, "test.go")
|
||||
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"Hello": protocol.SymbolFunction,
|
||||
"Server": protocol.SymbolStruct,
|
||||
"(Server).Start": protocol.SymbolMethod,
|
||||
"MaxConnections": protocol.SymbolConstant,
|
||||
"globalVar": protocol.SymbolVariable,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `
|
||||
function greet(name) {
|
||||
console.log("Hello, " + name);
|
||||
}
|
||||
|
||||
class User {
|
||||
constructor(name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
getName() {
|
||||
return this.name;
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_USERS = 100;
|
||||
let currentUser = null;
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.js", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangJavaScript, "test.js")
|
||||
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"greet": protocol.SymbolFunction,
|
||||
"User": protocol.SymbolClass,
|
||||
"MAX_USERS": protocol.SymbolVariable,
|
||||
"currentUser": protocol.SymbolVariable,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPythonSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `
|
||||
def greet(name):
|
||||
"""Greet a person by name."""
|
||||
print(f"Hello, {name}")
|
||||
|
||||
class User:
|
||||
"""Represents a user."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.py", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangPython, "test.py")
|
||||
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"greet": protocol.SymbolFunction,
|
||||
"User": protocol.SymbolClass,
|
||||
"__init__": protocol.SymbolMethod,
|
||||
"get_name": protocol.SymbolMethod,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCSymbols(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer r.Close()
|
||||
|
||||
content := `
|
||||
#include <stdio.h>
|
||||
|
||||
struct Point {
|
||||
int x;
|
||||
int y;
|
||||
};
|
||||
|
||||
void print_point(struct Point p) {
|
||||
printf("(%d, %d)\n", p.x, p.y);
|
||||
}
|
||||
|
||||
int main() {
|
||||
struct Point p = {1, 2};
|
||||
print_point(p);
|
||||
return 0;
|
||||
}
|
||||
`
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := r.Parse(ctx, "test.c", []byte(content))
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangC, "test.c")
|
||||
|
||||
// Note: C symbol extraction is complex, checking for at least main and Point
|
||||
expectedSymbols := map[string]protocol.SymbolKind{
|
||||
"Point": protocol.SymbolStruct,
|
||||
"main": protocol.SymbolFunction,
|
||||
}
|
||||
|
||||
found := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
|
||||
found[sym.Name] = true
|
||||
if sym.Kind != expectedKind {
|
||||
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name := range expectedSymbols {
|
||||
if !found[name] {
|
||||
t.Errorf("expected to find symbol %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
// Package parser provides YAML and JSON parsing with AST support.
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
sitter "github.com/smacker/go-tree-sitter"
|
||||
)
|
||||
|
||||
// YAMLNode wraps yaml.Node to provide tree-sitter-like interface
|
||||
type YAMLNode struct {
|
||||
*yaml.Node
|
||||
Content []byte
|
||||
}
|
||||
|
||||
// JSONNode represents a JSON AST node
|
||||
type JSONNode struct {
|
||||
Value any
|
||||
Type string
|
||||
Children []*JSONNode
|
||||
Line int
|
||||
Column int
|
||||
}
|
||||
|
||||
// ParseYAML parses YAML content and returns a tree-sitter-compatible result
|
||||
func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
}
|
||||
|
||||
// Parse YAML
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||
return nil, errors.NewParseError("yaml", filename, err)
|
||||
}
|
||||
|
||||
// Extract syntax errors from YAML parse
|
||||
syntaxErrors := extractYAMLErrors()
|
||||
|
||||
// Create a pseudo tree-sitter tree for compatibility
|
||||
// We'll use nil for the tree since YAML doesn't use tree-sitter
|
||||
return &ParseResult{
|
||||
Tree: nil, // YAML uses yaml.Node instead
|
||||
Language: protocol.LangYAML,
|
||||
Errors: syntaxErrors,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseJSON parses JSON content and returns a tree-sitter-compatible result
|
||||
func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
|
||||
// Check file size
|
||||
if len(content) > MaxFileSize {
|
||||
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
|
||||
}
|
||||
|
||||
// Parse JSON to validate syntax
|
||||
var jsonData any
|
||||
if err := json.Unmarshal(content, &jsonData); err != nil {
|
||||
return nil, errors.NewParseError("json", filename, err)
|
||||
}
|
||||
|
||||
// JSON parsing succeeded, no syntax errors
|
||||
return &ParseResult{
|
||||
Tree: nil, // JSON uses native Go structures
|
||||
Language: protocol.LangJSON,
|
||||
Errors: []SyntaxError{},
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractYAMLErrors extracts errors from YAML nodes
|
||||
func extractYAMLErrors() []SyntaxError {
|
||||
// YAML parser already validates during unmarshal
|
||||
// If we got here, there are no syntax errors
|
||||
// However, we could add semantic validation here in the future
|
||||
return []SyntaxError{}
|
||||
}
|
||||
|
||||
// WalkYAML walks a YAML AST and calls fn for each node
|
||||
func WalkYAML(node *yaml.Node, fn func(*yaml.Node) bool) {
|
||||
if node == nil || !fn(node) {
|
||||
return
|
||||
}
|
||||
|
||||
for _, child := range node.Content {
|
||||
WalkYAML(child, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// GetYAMLNodeText returns the text representation of a YAML node
|
||||
func GetYAMLNodeText(node *yaml.Node) string {
|
||||
if node == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case yaml.DocumentNode:
|
||||
if len(node.Content) > 0 {
|
||||
return GetYAMLNodeText(node.Content[0])
|
||||
}
|
||||
return ""
|
||||
case yaml.MappingNode:
|
||||
return node.Value
|
||||
case yaml.SequenceNode:
|
||||
return node.Value
|
||||
case yaml.ScalarNode:
|
||||
return node.Value
|
||||
case yaml.AliasNode:
|
||||
return node.Value
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetYAMLNodeLocation returns the location of a YAML node
|
||||
func GetYAMLNodeLocation(node *yaml.Node) protocol.Location {
|
||||
if node == nil {
|
||||
return protocol.Location{Line: 1, Column: 1}
|
||||
}
|
||||
|
||||
return protocol.Location{
|
||||
Line: node.Line,
|
||||
Column: node.Column,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryYAML performs a simple query on YAML content
|
||||
// Example: "$.metadata.name" to find the name field in metadata
|
||||
func QueryYAML(content []byte, query string) ([]*yaml.Node, error) {
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML: %w", err)
|
||||
}
|
||||
|
||||
// Simple path-based query implementation
|
||||
// This is a basic implementation - can be extended with more sophisticated queries
|
||||
var results []*yaml.Node
|
||||
|
||||
WalkYAML(&root, func(node *yaml.Node) bool {
|
||||
if node.Value == query || node.Tag == query {
|
||||
results = append(results, node)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// QueryJSON performs a simple query on JSON content
|
||||
func QueryJSON(content []byte, query string) ([]any, error) {
|
||||
var data any
|
||||
if err := json.Unmarshal(content, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON: %w", err)
|
||||
}
|
||||
|
||||
// Basic implementation - can be extended with JSONPath support
|
||||
var results []any
|
||||
|
||||
// For now, just validate that it's valid JSON
|
||||
results = append(results, data)
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// ValidateYAML validates YAML content without parsing to full AST
|
||||
func ValidateYAML(content []byte) error {
|
||||
var node yaml.Node
|
||||
if err := yaml.Unmarshal(content, &node); err != nil {
|
||||
return fmt.Errorf("YAML validation failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateJSON validates JSON content
|
||||
func ValidateJSON(content []byte) error {
|
||||
var data any
|
||||
if err := json.Unmarshal(content, &data); err != nil {
|
||||
return fmt.Errorf("JSON validation failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToSitterTree is a placeholder that returns nil for YAML/JSON
|
||||
// These formats don't use tree-sitter, but we keep this for interface compatibility
|
||||
func (yn *YAMLNode) ToSitterTree() *sitter.Tree {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestParseYAML(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid simple YAML",
|
||||
content: `name: test
|
||||
version: 1.0.0
|
||||
enabled: true`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid nested YAML",
|
||||
content: `metadata:
|
||||
name: test-app
|
||||
namespace: default
|
||||
spec:
|
||||
replicas: 3
|
||||
selector:
|
||||
matchLabels:
|
||||
app: test`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid list YAML",
|
||||
content: `items:
|
||||
- name: item1
|
||||
value: 100
|
||||
- name: item2
|
||||
value: 200`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid YAML - bad syntax",
|
||||
content: `name: test\n bad: indent\n wrong: [unclosed`,
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.ParseYAML(context.Background(), "test.yaml", []byte(tt.content))
|
||||
|
||||
if tt.shouldError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Error("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if result.Language != protocol.LangYAML {
|
||||
t.Errorf("expected language YAML, got %s", result.Language)
|
||||
}
|
||||
|
||||
if len(result.Errors) > 0 {
|
||||
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSON(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid simple JSON",
|
||||
content: `{"name": "test", "version": "1.0.0", "enabled": true}`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid nested JSON",
|
||||
content: `{
|
||||
"metadata": {
|
||||
"name": "test-app",
|
||||
"namespace": "default"
|
||||
},
|
||||
"spec": {
|
||||
"replicas": 3,
|
||||
"selector": {
|
||||
"matchLabels": {
|
||||
"app": "test"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "valid array JSON",
|
||||
content: `[{"name": "item1", "value": 100}, {"name": "item2", "value": 200}]`,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON - unclosed brace",
|
||||
content: `{"name": "test", "value": 100`,
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON - trailing comma",
|
||||
content: `{"name": "test", "value": 100,}`,
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := registry.ParseJSON(context.Background(), "test.json", []byte(tt.content))
|
||||
|
||||
if tt.shouldError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Error("expected result but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if result.Language != protocol.LangJSON {
|
||||
t.Errorf("expected language JSON, got %s", result.Language)
|
||||
}
|
||||
|
||||
if len(result.Errors) > 0 {
|
||||
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryParse_YAML_JSON(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
defer registry.Close()
|
||||
|
||||
yamlContent := []byte(`name: test
|
||||
version: 1.0.0`)
|
||||
|
||||
jsonContent := []byte(`{"name": "test", "version": "1.0.0"}`)
|
||||
|
||||
// Test YAML through main Parse method
|
||||
yamlResult, err := registry.Parse(context.Background(), "config.yaml", yamlContent)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse YAML: %v", err)
|
||||
}
|
||||
if yamlResult.Language != protocol.LangYAML {
|
||||
t.Errorf("expected YAML language, got %s", yamlResult.Language)
|
||||
}
|
||||
|
||||
// Test JSON through main Parse method
|
||||
jsonResult, err := registry.Parse(context.Background(), "config.json", jsonContent)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse JSON: %v", err)
|
||||
}
|
||||
if jsonResult.Language != protocol.LangJSON {
|
||||
t.Errorf("expected JSON language, got %s", jsonResult.Language)
|
||||
}
|
||||
|
||||
// Test .yml extension
|
||||
ymlResult, err := registry.Parse(context.Background(), "config.yml", yamlContent)
|
||||
if err != nil {
|
||||
t.Errorf("failed to parse .yml: %v", err)
|
||||
}
|
||||
if ymlResult.Language != protocol.LangYAML {
|
||||
t.Errorf("expected YAML language for .yml extension, got %s", ymlResult.Language)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalkYAML(t *testing.T) {
|
||||
content := []byte(`metadata:
|
||||
name: test
|
||||
labels:
|
||||
app: myapp
|
||||
env: prod`)
|
||||
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(content, &root); err != nil {
|
||||
t.Fatalf("failed to parse YAML: %v", err)
|
||||
}
|
||||
|
||||
nodeCount := 0
|
||||
WalkYAML(&root, func(node *yaml.Node) bool {
|
||||
nodeCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if nodeCount == 0 {
|
||||
t.Error("expected to visit nodes, but count is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid YAML",
|
||||
content: []byte("name: test\nvalue: 100"),
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid YAML",
|
||||
content: []byte("name: test\n bad:\n[unclosed"),
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateYAML(tt.content)
|
||||
if (err != nil) != tt.shouldError {
|
||||
t.Errorf("ValidateYAML() error = %v, shouldError = %v", err, tt.shouldError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content []byte
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid JSON",
|
||||
content: []byte(`{"name": "test", "value": 100}`),
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
content: []byte(`{"name": "test", "value": 100`),
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateJSON(tt.content)
|
||||
if (err != nil) != tt.shouldError {
|
||||
t.Errorf("ValidateJSON() error = %v, shouldError = %v", err, tt.shouldError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user