This commit is contained in:
2026-01-18 18:40:26 +00:00
commit 185e73da47
51 changed files with 14073 additions and 0 deletions
+190
View File
@@ -0,0 +1,190 @@
package parser
import (
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// FindNodeAtPosition finds the node at the given line and column.
func FindNodeAtPosition(tree *sitter.Tree, line, col int) *sitter.Node {
if tree == nil {
return nil
}
root := tree.RootNode()
if root == nil {
return nil
}
// Convert to 0-indexed
point := sitter.Point{
Row: uint32(line - 1), // #nosec G115 - line numbers are bounded by file size
Column: uint32(col - 1), // #nosec G115 - column numbers are bounded by line length
}
return findNodeAtPoint(root, point)
}
// findNodeAtPoint recursively finds the smallest node containing the point.
func findNodeAtPoint(node *sitter.Node, point sitter.Point) *sitter.Node {
if node == nil {
return nil
}
startPoint := node.StartPoint()
endPoint := node.EndPoint()
// Check if point is within this node
if !pointInRange(point, startPoint, endPoint) {
return nil
}
// Try to find a more specific child node
for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
if child == nil {
continue
}
if result := findNodeAtPoint(child, point); result != nil {
return result
}
}
// No child contains the point, return this node
return node
}
// pointInRange checks if a point is within a range.
func pointInRange(point, start, end sitter.Point) bool {
// Before start?
if point.Row < start.Row || (point.Row == start.Row && point.Column < start.Column) {
return false
}
// After end?
if point.Row > end.Row || (point.Row == end.Row && point.Column >= end.Column) {
return false
}
return true
}
// FindParentOfKind finds the nearest ancestor of the given node type.
func FindParentOfKind(node *sitter.Node, kind string) *sitter.Node {
if node == nil {
return nil
}
current := node.Parent()
for current != nil {
if current.Type() == kind {
return current
}
current = current.Parent()
}
return nil
}
// GetNodeText returns the text content of a node.
func GetNodeText(node *sitter.Node, content []byte) string {
if node == nil {
return ""
}
start := node.StartByte()
end := node.EndByte()
if int(start) >= len(content) || int(end) > len(content) {
return ""
}
return string(content[start:end])
}
// WalkTree walks the tree calling fn for each node.
// If fn returns false, the walk stops.
func WalkTree(node *sitter.Node, fn func(*sitter.Node) bool) {
if node == nil {
return
}
if !fn(node) {
return
}
for i := 0; i < int(node.ChildCount()); i++ {
WalkTree(node.Child(i), fn)
}
}
// FindNodesByKind finds all nodes of a given kind.
func FindNodesByKind(root *sitter.Node, kind string) []*sitter.Node {
var nodes []*sitter.Node
WalkTree(root, func(n *sitter.Node) bool {
if n.Type() == kind {
nodes = append(nodes, n)
}
return true
})
return nodes
}
// FindNamedChildren returns all named (non-anonymous) children of a node.
func FindNamedChildren(node *sitter.Node) []*sitter.Node {
if node == nil {
return nil
}
var children []*sitter.Node
for i := 0; i < int(node.NamedChildCount()); i++ {
if child := node.NamedChild(i); child != nil {
children = append(children, child)
}
}
return children
}
// GetChildByFieldName returns the child node with the given field name.
func GetChildByFieldName(node *sitter.Node, fieldName string) *sitter.Node {
if node == nil {
return nil
}
return node.ChildByFieldName(fieldName)
}
// NodeLocation returns the location of a node.
func NodeLocation(node *sitter.Node, filename string) protocol.Location {
if node == nil {
return protocol.Location{}
}
startPoint := node.StartPoint()
return protocol.Location{
File: filename,
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
}
}
// NodeRange returns the range of a node.
func NodeRange(node *sitter.Node, filename string) protocol.Range {
if node == nil {
return protocol.Range{}
}
startPoint := node.StartPoint()
endPoint := node.EndPoint()
return protocol.Range{
Start: protocol.Location{
File: filename,
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
},
End: protocol.Location{
File: filename,
Line: int(endPoint.Row) + 1,
Column: int(endPoint.Column) + 1,
},
}
}
+140
View File
@@ -0,0 +1,140 @@
package parser
import (
"context"
"fmt"
"testing"
)
// TestLRUCacheEviction tests that the LRU cache properly evicts old entries.
func TestLRUCacheEviction(t *testing.T) {
registry := NewRegistry()
ctx := context.Background()
// Create 101 unique Go files (cache size is 100)
for i := 0; i < 101; i++ {
content := []byte(fmt.Sprintf("package main\n\nfunc test%d() {}\n", i))
filename := "test.go"
_, err := registry.Parse(ctx, filename, content)
if err != nil {
t.Fatalf("Parse failed for iteration %d: %v", i, err)
}
}
// The LRU cache should have evicted the oldest entry
// Verify cache size is capped at 100
cacheLen := registry.cache.Len()
if cacheLen > 100 {
t.Errorf("Cache size %d exceeds max size 100", cacheLen)
}
}
// TestCacheHit tests that repeated parsing of the same content uses cache.
func TestCacheHit(t *testing.T) {
registry := NewRegistry()
ctx := context.Background()
content := []byte("package main\n\nfunc test() {}\n")
filename := "test.go"
// First parse
result1, err := registry.Parse(ctx, filename, content)
if err != nil {
t.Fatalf("First parse failed: %v", err)
}
// Second parse should use cache
result2, err := registry.Parse(ctx, filename, content)
if err != nil {
t.Fatalf("Second parse failed: %v", err)
}
// The tree should be the same object (cached)
if result1.Tree != result2.Tree {
t.Error("Expected cached tree to be reused, but got different tree objects")
}
}
// TestContentHashCollisionResistance tests that different content produces different hashes.
func TestContentHashCollisionResistance(t *testing.T) {
testCases := []struct {
name string
content1 []byte
content2 []byte
}{
{
name: "different content",
content1: []byte("package main"),
content2: []byte("package test"),
},
{
name: "same prefix different suffix",
content1: []byte("package main\nfunc a() {}"),
content2: []byte("package main\nfunc b() {}"),
},
{
name: "different length",
content1: []byte("short"),
content2: []byte("much longer content here"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hash1 := contentHash(tc.content1)
hash2 := contentHash(tc.content2)
if hash1 == hash2 {
t.Errorf("Hash collision: %s == %s for different content", hash1, hash2)
}
})
}
}
// TestContentHashConsistency tests that the same content always produces the same hash.
func TestContentHashConsistency(t *testing.T) {
content := []byte("package main\n\nfunc test() {}\n")
hash1 := contentHash(content)
hash2 := contentHash(content)
hash3 := contentHash(content)
if hash1 != hash2 || hash2 != hash3 {
t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3)
}
}
// BenchmarkContentHash_xxHash benchmarks the xxHash implementation.
func BenchmarkContentHash_xxHash(b *testing.B) {
// Typical file content size (10KB)
content := make([]byte, 10*1024)
for i := range content {
content[i] = byte(i % 256)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = contentHash(content)
}
}
// BenchmarkCacheHitRate benchmarks cache performance with realistic workload.
func BenchmarkCacheHitRate(b *testing.B) {
registry := NewRegistry()
ctx := context.Background()
// Create a set of common files that get parsed repeatedly
files := [][]byte{
[]byte("package main\n\nfunc main() {}\n"),
[]byte("package test\n\nimport \"testing\"\n"),
[]byte("package util\n\nfunc helper() string { return \"\" }\n"),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Simulate realistic access pattern with cache hits
content := files[i%len(files)]
_, _ = registry.Parse(ctx, "test.go", content)
}
}
+550
View File
@@ -0,0 +1,550 @@
// Package parser provides documentation extraction for multiple languages.
package parser
import (
"regexp"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// DocComment represents an extracted documentation comment.
type DocComment struct {
Tags map[string]string
Text string
Raw string
Style CommentStyle
StartLine int
EndLine int
}
// CommentStyle indicates the type of comment.
type CommentStyle string
const (
CommentStyleLine CommentStyle = "line" // // comment
CommentStyleBlock CommentStyle = "block" // /* comment */
CommentStyleJSDoc CommentStyle = "jsdoc" // /** comment */
CommentStyleDoxygen CommentStyle = "doxygen" // /** comment */ or /// comment
CommentStyleDocstring CommentStyle = "docstring" // """comment""" or '''comment'''
CommentStyleHash CommentStyle = "hash" // # comment (Python)
)
// ExtractDocComment extracts the documentation comment for a node.
func ExtractDocComment(n *sitter.Node, content []byte, lang protocol.Language) *DocComment {
if n == nil {
return nil
}
switch lang {
case protocol.LangGo:
return extractGoDocComment(n, content)
case protocol.LangTypeScript, protocol.LangJavaScript:
return extractJSDocComment(n, content)
case protocol.LangPython:
return extractPythonDocComment(n, content)
case protocol.LangC, protocol.LangCpp:
return extractCDocComment(n, content)
default:
return nil
}
}
// extractGoDocComment extracts Go documentation comments.
// Go uses // or /* */ comments immediately preceding a declaration.
func extractGoDocComment(n *sitter.Node, content []byte) *DocComment {
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
cleaned := cleanGoComment(text)
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: detectCommentStyle(raw[0]),
Tags: nil, // Go doesn't use JSDoc-style tags
StartLine: startLine,
EndLine: endLine,
}
}
// extractJSDocComment extracts JSDoc-style documentation comments.
func extractJSDocComment(n *sitter.Node, content []byte) *DocComment {
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
// JSDoc prefers the last comment block if it's a JSDoc comment
var jsDocComment *sitter.Node
for i := len(comments) - 1; i >= 0; i-- {
text := GetNodeText(comments[i], content)
if strings.HasPrefix(strings.TrimSpace(text), "/**") {
jsDocComment = comments[i]
break
}
}
if jsDocComment != nil {
text := GetNodeText(jsDocComment, content)
cleaned, tags := parseJSDoc(text)
return &DocComment{
Text: cleaned,
Raw: text,
Style: CommentStyleJSDoc,
Tags: tags,
StartLine: int(jsDocComment.StartPoint().Row) + 1,
EndLine: int(jsDocComment.EndPoint().Row) + 1,
}
}
// Fall back to regular comments
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
cleaned := cleanJSComment(text)
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: CommentStyleLine,
Tags: nil,
StartLine: startLine,
EndLine: endLine,
}
}
// extractPythonDocComment extracts Python docstrings.
// Python docstrings are triple-quoted strings inside the function/class body.
func extractPythonDocComment(n *sitter.Node, content []byte) *DocComment {
// Python docstrings are inside the body, not before
body := n.ChildByFieldName("body")
if body == nil {
return nil
}
// First statement should be the docstring if present
if body.NamedChildCount() > 0 {
first := body.NamedChild(0)
if first != nil && first.Type() == "expression_statement" {
if first.NamedChildCount() > 0 {
expr := first.NamedChild(0)
if expr != nil && expr.Type() == "string" {
text := GetNodeText(expr, content)
cleaned := cleanPythonDocstring(text)
return &DocComment{
Text: cleaned,
Raw: text,
Style: CommentStyleDocstring,
Tags: nil,
StartLine: int(expr.StartPoint().Row) + 1,
EndLine: int(expr.EndPoint().Row) + 1,
}
}
}
}
}
// Also check for # comments before the definition
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
// Clean # comment
cleaned := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(text), "#"))
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: CommentStyleHash,
Tags: nil,
StartLine: startLine,
EndLine: endLine,
}
}
// extractCDocComment extracts C/C++ documentation comments (Doxygen style).
func extractCDocComment(n *sitter.Node, content []byte) *DocComment {
comments := collectPrecedingComments(n, content, []string{"comment"})
if len(comments) == 0 {
return nil
}
// Look for Doxygen-style comment
var doxyComment *sitter.Node
for i := len(comments) - 1; i >= 0; i-- {
text := GetNodeText(comments[i], content)
trimmed := strings.TrimSpace(text)
if strings.HasPrefix(trimmed, "/**") || strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
doxyComment = comments[i]
break
}
}
if doxyComment != nil {
text := GetNodeText(doxyComment, content)
cleaned, tags := parseDoxygen(text)
return &DocComment{
Text: cleaned,
Raw: text,
Style: CommentStyleDoxygen,
Tags: tags,
StartLine: int(doxyComment.StartPoint().Row) + 1,
EndLine: int(doxyComment.EndPoint().Row) + 1,
}
}
// Fall back to regular comments
var parts []string
var raw []string
startLine := -1
endLine := -1
for _, c := range comments {
text := GetNodeText(c, content)
raw = append(raw, text)
if startLine == -1 {
startLine = int(c.StartPoint().Row) + 1
}
endLine = int(c.EndPoint().Row) + 1
cleaned := cleanCComment(text)
if cleaned != "" {
parts = append(parts, cleaned)
}
}
if len(parts) == 0 {
return nil
}
return &DocComment{
Text: strings.Join(parts, "\n"),
Raw: strings.Join(raw, "\n"),
Style: detectCommentStyle(raw[0]),
Tags: nil,
StartLine: startLine,
EndLine: endLine,
}
}
// collectPrecedingComments collects all comment nodes immediately before a node.
func collectPrecedingComments(n *sitter.Node, _ []byte, commentTypes []string) []*sitter.Node {
var comments []*sitter.Node
// Walk backwards through siblings
prev := n.PrevSibling()
lastCommentLine := int(n.StartPoint().Row)
for prev != nil {
isComment := false
nodeType := prev.Type()
for _, ct := range commentTypes {
if nodeType == ct {
isComment = true
break
}
}
if !isComment {
break
}
commentEndLine := int(prev.EndPoint().Row)
// Check if there's a blank line gap
if lastCommentLine-commentEndLine > 1 {
break
}
comments = append([]*sitter.Node{prev}, comments...)
lastCommentLine = int(prev.StartPoint().Row)
prev = prev.PrevSibling()
}
return comments
}
// detectCommentStyle determines the style of a comment.
func detectCommentStyle(comment string) CommentStyle {
trimmed := strings.TrimSpace(comment)
if strings.HasPrefix(trimmed, "/**") {
return CommentStyleJSDoc
}
if strings.HasPrefix(trimmed, "///") || strings.HasPrefix(trimmed, "//!") {
return CommentStyleDoxygen
}
if strings.HasPrefix(trimmed, "/*") {
return CommentStyleBlock
}
if strings.HasPrefix(trimmed, "//") {
return CommentStyleLine
}
if strings.HasPrefix(trimmed, "#") {
return CommentStyleHash
}
if strings.HasPrefix(trimmed, `"""`) || strings.HasPrefix(trimmed, `'''`) {
return CommentStyleDocstring
}
return CommentStyleLine
}
// cleanGoComment cleans a Go comment.
func cleanGoComment(comment string) string {
comment = strings.TrimSpace(comment)
// Handle // comments
if after, found := strings.CutPrefix(comment, "//"); found {
return strings.TrimSpace(after)
}
// Handle /* */ comments
if strings.HasPrefix(comment, "/*") && strings.HasSuffix(comment, "*/") {
comment = strings.TrimPrefix(comment, "/*")
comment = strings.TrimSuffix(comment, "*/")
return cleanBlockComment(comment)
}
return strings.TrimSpace(comment)
}
// cleanJSComment cleans a JavaScript/TypeScript comment.
func cleanJSComment(comment string) string {
return cleanGoComment(comment) // Same rules
}
// cleanCComment cleans a C/C++ comment.
func cleanCComment(comment string) string {
return cleanGoComment(comment) // Same rules
}
// cleanBlockComment cleans the content of a block comment.
func cleanBlockComment(comment string) string {
lines := strings.Split(comment, "\n")
var cleaned []string
for _, line := range lines {
line = strings.TrimSpace(line)
// Remove leading * from each line (common in block comments)
line = strings.TrimPrefix(line, "*")
line = strings.TrimSpace(line)
cleaned = append(cleaned, line)
}
// Remove empty leading/trailing lines
for len(cleaned) > 0 && cleaned[0] == "" {
cleaned = cleaned[1:]
}
for len(cleaned) > 0 && cleaned[len(cleaned)-1] == "" {
cleaned = cleaned[:len(cleaned)-1]
}
return strings.Join(cleaned, "\n")
}
// parseJSDoc parses a JSDoc comment and extracts tags.
func parseJSDoc(comment string) (string, map[string]string) {
comment = strings.TrimSpace(comment)
// Remove /** and */
comment = strings.TrimPrefix(comment, "/**")
comment = strings.TrimSuffix(comment, "*/")
lines := strings.Split(comment, "\n")
var descLines []string
tags := make(map[string]string)
// Regex for JSDoc tags
tagPattern := regexp.MustCompile(`^\s*\*?\s*@(\w+)\s*(.*)$`)
for _, line := range lines {
line = strings.TrimSpace(line)
line = strings.TrimPrefix(line, "*")
line = strings.TrimSpace(line)
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
tagName := matches[1]
tagValue := strings.TrimSpace(matches[2])
if existing, ok := tags[tagName]; ok {
tags[tagName] = existing + "\n" + tagValue
} else {
tags[tagName] = tagValue
}
} else if line != "" {
descLines = append(descLines, line)
}
}
return strings.Join(descLines, "\n"), tags
}
// parseDoxygen parses a Doxygen comment and extracts tags.
func parseDoxygen(comment string) (string, map[string]string) {
comment = strings.TrimSpace(comment)
// Handle /// and //! style comments
comment = strings.TrimPrefix(comment, "///")
comment = strings.TrimPrefix(comment, "//!")
// Handle /** */ style comments
comment = strings.TrimPrefix(comment, "/**")
comment = strings.TrimSuffix(comment, "*/")
lines := strings.Split(comment, "\n")
var descLines []string
tags := make(map[string]string)
// Regex for Doxygen tags (@param, @return, \param, \return, etc.)
tagPattern := regexp.MustCompile(`^\s*\*?\s*[@\\](\w+)\s*(.*)$`)
for _, line := range lines {
line = strings.TrimSpace(line)
line = strings.TrimPrefix(line, "*")
line = strings.TrimSpace(line)
if matches := tagPattern.FindStringSubmatch(line); matches != nil {
tagName := matches[1]
tagValue := strings.TrimSpace(matches[2])
if existing, ok := tags[tagName]; ok {
tags[tagName] = existing + "\n" + tagValue
} else {
tags[tagName] = tagValue
}
} else if line != "" {
descLines = append(descLines, line)
}
}
return strings.Join(descLines, "\n"), tags
}
// FormatDocComment formats a DocComment for display.
func FormatDocComment(doc *DocComment) string {
if doc == nil || doc.Text == "" {
return ""
}
var sb strings.Builder
sb.WriteString(doc.Text)
if len(doc.Tags) > 0 {
sb.WriteString("\n\n")
// Order: description, params, returns, other
paramOrder := []string{"param", "parameter", "arg", "argument"}
returnOrder := []string{"return", "returns", "retval"}
// Write params first
for _, tagName := range paramOrder {
if val, ok := doc.Tags[tagName]; ok {
for _, line := range strings.Split(val, "\n") {
sb.WriteString("@" + tagName + " " + line + "\n")
}
}
}
// Write returns
for _, tagName := range returnOrder {
if val, ok := doc.Tags[tagName]; ok {
sb.WriteString("@" + tagName + " " + val + "\n")
}
}
// Write remaining tags
written := make(map[string]bool)
for _, t := range paramOrder {
written[t] = true
}
for _, t := range returnOrder {
written[t] = true
}
for tagName, val := range doc.Tags {
if !written[tagName] {
sb.WriteString("@" + tagName + " " + val + "\n")
}
}
}
return strings.TrimSpace(sb.String())
}
// cleanPythonDocstring cleans a Python docstring.
func cleanPythonDocstring(doc string) string {
doc = strings.TrimSpace(doc)
// Remove triple quotes
doc = strings.TrimPrefix(doc, `"""`)
doc = strings.TrimSuffix(doc, `"""`)
doc = strings.TrimPrefix(doc, `'''`)
doc = strings.TrimSuffix(doc, `'''`)
return strings.TrimSpace(doc)
}
+630
View File
@@ -0,0 +1,630 @@
package parser
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
func TestExtractGoDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "single line comment",
code: `package main
// Hello says hello
func Hello() {}
`,
nodeKind: "function_declaration",
wantText: "Hello says hello",
wantStyle: CommentStyleLine,
},
{
name: "multi-line comments",
code: `package main
// This is a function
// that does something
// important
func DoSomething() {}
`,
nodeKind: "function_declaration",
wantText: "This is a function\nthat does something\nimportant",
wantStyle: CommentStyleLine,
},
{
name: "block comment",
code: `package main
/* This is a block comment
describing the function */
func BlockCommented() {}
`,
nodeKind: "function_declaration",
wantText: "This is a block comment\ndescribing the function",
wantStyle: CommentStyleBlock,
},
{
name: "doc comment with asterisks",
code: `package main
/*
* This is a properly formatted
* block comment with asterisks
*/
func FormattedBlock() {}
`,
nodeKind: "function_declaration",
wantText: "This is a properly formatted\nblock comment with asterisks",
wantStyle: CommentStyleBlock,
},
{
name: "no comment",
code: `package main
func NoComment() {}
`,
nodeKind: "function_declaration",
wantText: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.go", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
// Find the target node
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangGo)
if tt.wantText == "" {
if doc != nil && doc.Text != "" {
t.Errorf("expected no doc, got %q", doc.Text)
}
return
}
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
})
}
}
func TestExtractJSDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
wantTags map[string]string
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "JSDoc comment",
code: `/**
* Adds two numbers together.
* @param a The first number
* @param b The second number
* @returns The sum of a and b
*/
function add(a, b) {
return a + b;
}
`,
nodeKind: "function_declaration",
wantText: "Adds two numbers together.",
wantStyle: CommentStyleJSDoc,
wantTags: map[string]string{
"param": "a The first number\nb The second number",
"returns": "The sum of a and b",
},
},
{
name: "simple line comment",
code: `// This is a simple function
function simple() {}
`,
nodeKind: "function_declaration",
wantText: "This is a simple function",
wantStyle: CommentStyleLine,
},
{
name: "JSDoc with types",
code: `/**
* @param {string} name - The name
* @returns {boolean} True if valid
*/
function validate(name) {}
`,
nodeKind: "function_declaration",
wantText: "",
wantStyle: CommentStyleJSDoc,
wantTags: map[string]string{
"param": "{string} name - The name",
"returns": "{boolean} True if valid",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.js", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangJavaScript)
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
if tt.wantTags != nil {
for k, want := range tt.wantTags {
if got := doc.Tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
}
})
}
}
func TestExtractPythonDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "docstring",
code: `def greet(name):
"""Greet a person by name."""
print(f"Hello, {name}!")
`,
nodeKind: "function_definition",
wantText: "Greet a person by name.",
wantStyle: CommentStyleDocstring,
},
{
name: "multi-line docstring",
code: `def calculate(x, y):
"""
Calculate the sum of two numbers.
Args:
x: First number
y: Second number
Returns:
The sum of x and y
"""
return x + y
`,
nodeKind: "function_definition",
wantText: "Calculate the sum of two numbers.\n\n Args:\n x: First number\n y: Second number\n\n Returns:\n The sum of x and y",
wantStyle: CommentStyleDocstring,
},
{
name: "class docstring",
code: `class MyClass:
"""This is a class description."""
pass
`,
nodeKind: "class_definition",
wantText: "This is a class description.",
wantStyle: CommentStyleDocstring,
},
{
name: "single quote docstring",
code: `def func():
'''Single quote docstring'''
pass
`,
nodeKind: "function_definition",
wantText: "Single quote docstring",
wantStyle: CommentStyleDocstring,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.py", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangPython)
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
})
}
}
func TestExtractCDocComment(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
wantTags map[string]string
name string
code string
nodeKind string
wantText string
wantStyle CommentStyle
}{
{
name: "Doxygen block comment",
code: `/**
* Adds two numbers.
* @param a First number
* @param b Second number
* @return Sum of a and b
*/
int add(int a, int b) {
return a + b;
}
`,
nodeKind: "function_definition",
wantText: "Adds two numbers.",
wantStyle: CommentStyleDoxygen,
wantTags: map[string]string{
"param": "a First number\nb Second number",
"return": "Sum of a and b",
},
},
{
name: "regular block comment",
code: `/* This is a regular comment */
int regular() { return 0; }
`,
nodeKind: "function_definition",
wantText: "This is a regular comment",
wantStyle: CommentStyleBlock,
},
{
name: "line comment",
code: `// Simple function
int simple() { return 1; }
`,
nodeKind: "function_definition",
wantText: "Simple function",
wantStyle: CommentStyleLine,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.Parse(context.Background(), "test.c", []byte(tt.code))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
targetNode := findNodeByKind(result.Tree.RootNode(), tt.nodeKind)
if targetNode == nil {
t.Fatalf("could not find node of type %s", tt.nodeKind)
}
doc := ExtractDocComment(targetNode, []byte(tt.code), protocol.LangC)
if doc == nil {
t.Fatal("expected doc, got nil")
}
if doc.Text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", doc.Text, tt.wantText)
}
if doc.Style != tt.wantStyle {
t.Errorf("style mismatch: got %v, want %v", doc.Style, tt.wantStyle)
}
if tt.wantTags != nil {
for k, want := range tt.wantTags {
if got := doc.Tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
}
})
}
}
func TestParseJSDoc(t *testing.T) {
tests := []struct {
wantTags map[string]string
name string
input string
wantText string
}{
{
name: "complete jsdoc",
input: `/**
* This is a description.
* Multiple lines.
* @param {string} name The name
* @returns {boolean} Result
*/`,
wantText: "This is a description.\nMultiple lines.",
wantTags: map[string]string{
"param": "{string} name The name",
"returns": "{boolean} Result",
},
},
{
name: "empty jsdoc",
input: `/** */`,
wantText: "",
wantTags: map[string]string{},
},
{
name: "only description",
input: `/** Simple description */`,
wantText: "Simple description",
wantTags: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, tags := parseJSDoc(tt.input)
if text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
}
if len(tags) != len(tt.wantTags) {
t.Errorf("tag count mismatch: got %d, want %d", len(tags), len(tt.wantTags))
}
for k, want := range tt.wantTags {
if got := tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
})
}
}
func TestParseDoxygen(t *testing.T) {
tests := []struct {
wantTags map[string]string
name string
input string
wantText string
}{
{
name: "doxygen with @ tags",
input: `/**
* Brief description.
* @param x Value
* @return Result
*/`,
wantText: "Brief description.",
wantTags: map[string]string{
"param": "x Value",
"return": "Result",
},
},
{
name: "doxygen with backslash tags",
input: `/**
* Description.
* \param y Input
* \retval Output value
*/`,
wantText: "Description.",
wantTags: map[string]string{
"param": "y Input",
"retval": "Output value",
},
},
{
name: "triple slash",
input: `/// Simple description`,
wantText: "Simple description",
wantTags: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text, tags := parseDoxygen(tt.input)
if text != tt.wantText {
t.Errorf("text mismatch:\ngot: %q\nwant: %q", text, tt.wantText)
}
for k, want := range tt.wantTags {
if got := tags[k]; got != want {
t.Errorf("tag %q mismatch:\ngot: %q\nwant: %q", k, got, want)
}
}
})
}
}
func TestFormatDocComment(t *testing.T) {
tests := []struct {
name string
doc *DocComment
want string
}{
{
name: "with tags",
doc: &DocComment{
Text: "This is a function.",
Tags: map[string]string{
"param": "x The value",
"returns": "The result",
},
},
want: "This is a function.\n\n@param x The value\n@returns The result",
},
{
name: "no tags",
doc: &DocComment{
Text: "Simple description.",
Tags: nil,
},
want: "Simple description.",
},
{
name: "nil doc",
doc: nil,
want: "",
},
{
name: "empty text",
doc: &DocComment{
Text: "",
Tags: nil,
},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatDocComment(tt.doc)
if got != tt.want {
t.Errorf("mismatch:\ngot: %q\nwant: %q", got, tt.want)
}
})
}
}
func TestDetectCommentStyle(t *testing.T) {
tests := []struct {
input string
want CommentStyle
}{
{"/** JSDoc */", CommentStyleJSDoc},
{"/// Doxygen", CommentStyleDoxygen},
{"//! Doxygen", CommentStyleDoxygen},
{"/* block */", CommentStyleBlock},
{"// line", CommentStyleLine},
{"# hash", CommentStyleHash},
{`"""docstring"""`, CommentStyleDocstring},
{`'''docstring'''`, CommentStyleDocstring},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := detectCommentStyle(tt.input)
if got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
// findNodeByKind finds the first node of the given kind.
func findNodeByKind(root *sitter.Node, nodeType string) *sitter.Node {
if root == nil {
return nil
}
var result *sitter.Node
WalkTree(root, func(n *sitter.Node) bool {
if n.Type() == nodeType {
result = n
return false // stop walking
}
return true
})
return result
}
func TestCleanBlockComment(t *testing.T) {
tests := []struct {
input string
want string
}{
{
input: "\n * Line 1\n * Line 2\n ",
want: "Line 1\nLine 2",
},
{
input: "Simple",
want: "Simple",
},
{
input: "\n\nWith blank lines\n\n",
want: "With blank lines",
},
}
for _, tt := range tests {
t.Run(tt.input[:min(10, len(tt.input))], func(t *testing.T) {
got := cleanBlockComment(tt.input)
if got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
+271
View File
@@ -0,0 +1,271 @@
// Package parser provides Tree-sitter based parsing for multiple languages.
package parser
import (
"context"
"fmt"
"sync"
"github.com/cespare/xxhash/v2"
lru "github.com/hashicorp/golang-lru/v2"
sitter "github.com/smacker/go-tree-sitter"
"github.com/smacker/go-tree-sitter/c"
"github.com/smacker/go-tree-sitter/cpp"
"github.com/smacker/go-tree-sitter/golang"
"github.com/smacker/go-tree-sitter/html"
"github.com/smacker/go-tree-sitter/javascript"
"github.com/smacker/go-tree-sitter/python"
"github.com/smacker/go-tree-sitter/typescript/typescript"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
// MaxFileSize is the maximum file size we'll parse (10MB).
const MaxFileSize = 10 * 1024 * 1024
// Registry manages Tree-sitter parsers for different languages.
type Registry struct {
parsers map[protocol.Language]*sitter.Parser
cache *lru.Cache[string, *CachedTree]
mu sync.RWMutex
}
// CachedTree stores a parsed tree with its metadata.
// Content is not stored to reduce memory usage.
type CachedTree struct {
Tree *sitter.Tree
Language protocol.Language
}
// ParseResult contains the result of parsing a file.
type ParseResult struct {
Tree *sitter.Tree
Language protocol.Language
Errors []SyntaxError
Content []byte
}
// SyntaxError represents a syntax error found during parsing.
type SyntaxError struct {
Message string
NodeType string
Location protocol.Location
}
// NewRegistry creates a new parser registry.
func NewRegistry() *Registry {
// Create LRU cache with capacity of 100 trees
cache, err := lru.New[string, *CachedTree](100)
if err != nil {
// LRU.New only errors if size <= 0, which won't happen here
panic(fmt.Sprintf("failed to create LRU cache: %v", err))
}
return &Registry{
parsers: make(map[protocol.Language]*sitter.Parser),
cache: cache,
}
}
// getLanguage returns the Tree-sitter language for a given language.
func getLanguage(lang protocol.Language) (*sitter.Language, error) {
switch lang {
case protocol.LangGo:
return golang.GetLanguage(), nil
case protocol.LangTypeScript:
return typescript.GetLanguage(), nil
case protocol.LangJavaScript:
return javascript.GetLanguage(), nil
case protocol.LangPython:
return python.GetLanguage(), nil
case protocol.LangC:
return c.GetLanguage(), nil
case protocol.LangCpp:
return cpp.GetLanguage(), nil
case protocol.LangHTML:
return html.GetLanguage(), nil
case protocol.LangVue:
// Vue SFC files use HTML-like template syntax, so we use the HTML parser
return html.GetLanguage(), nil
default:
return nil, errors.New(errors.ErrInvalidLanguage, fmt.Sprintf("language %s is not supported", lang)).
WithContext("language", string(lang)).
WithRemediation("Supported languages: Go, TypeScript, JavaScript, Python, C, C++, HTML, Vue")
}
}
// GetParser returns a parser for the given language.
func (r *Registry) GetParser(lang protocol.Language) (*sitter.Parser, error) {
r.mu.RLock()
if p, ok := r.parsers[lang]; ok {
r.mu.RUnlock()
return p, nil
}
r.mu.RUnlock()
// Create new parser
r.mu.Lock()
defer r.mu.Unlock()
// Double-check after acquiring write lock
if p, ok := r.parsers[lang]; ok {
return p, nil
}
sitterLang, err := getLanguage(lang)
if err != nil {
return nil, err
}
parser := sitter.NewParser()
parser.SetLanguage(sitterLang)
r.parsers[lang] = parser
return parser, nil
}
// Parse parses the given content for the specified language.
func (r *Registry) Parse(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
}
// Detect binary files
if isBinary(content) {
return nil, errors.New(errors.ErrParseFailed, "binary file detected").
WithContext("file", filename).
WithRemediation("This appears to be a binary file and cannot be parsed as source code")
}
// Detect language
lang := protocol.DetectLanguage(filename)
if lang == protocol.LangUnknown {
return nil, errors.New(errors.ErrInvalidLanguage, "could not detect language from filename").
WithContext("file", filename).
WithRemediation("Ensure file has a recognized extension (e.g., .go, .ts, .py, .c, .cpp, .html, .vue, .json, .yaml)")
}
// Handle YAML and JSON separately (they don't use tree-sitter)
switch lang {
case protocol.LangYAML:
return r.ParseYAML(ctx, filename, content)
case protocol.LangJSON:
return r.ParseJSON(ctx, filename, content)
}
// Check cache (LRU cache is thread-safe)
hash := contentHash(content)
if cached, ok := r.cache.Get(hash); ok && cached.Language == lang {
errors := extractErrors(cached.Tree.RootNode(), content)
return &ParseResult{
Tree: cached.Tree,
Language: lang,
Errors: errors,
Content: content,
}, nil
}
// Get parser
parser, err := r.GetParser(lang)
if err != nil {
return nil, err
}
// Parse content - tree-sitter parsers are not thread-safe,
// so we need to hold the lock during parsing
r.mu.Lock()
tree, err := parser.ParseCtx(ctx, nil, content)
r.mu.Unlock()
if err != nil {
return nil, errors.NewParseError(string(lang), filename, err)
}
// Extract syntax errors
errors := extractErrors(tree.RootNode(), content)
// Cache result (LRU cache handles eviction automatically)
r.cache.Add(hash, &CachedTree{
Tree: tree,
Language: lang,
})
return &ParseResult{
Tree: tree,
Language: lang,
Errors: errors,
Content: content,
}, nil
}
// extractErrors finds all error nodes in the tree.
func extractErrors(node *sitter.Node, _ []byte) []SyntaxError {
var errors []SyntaxError
var walk func(n *sitter.Node)
walk = func(n *sitter.Node) {
if n == nil {
return
}
if n.IsError() || n.IsMissing() {
startPoint := n.StartPoint()
nodeType := "ERROR"
if n.IsMissing() {
nodeType = "MISSING"
}
errors = append(errors, SyntaxError{
Location: protocol.Location{
Line: int(startPoint.Row) + 1,
Column: int(startPoint.Column) + 1,
},
Message: fmt.Sprintf("syntax error: unexpected %s", n.Type()),
NodeType: nodeType,
})
}
for i := 0; i < int(n.ChildCount()); i++ {
walk(n.Child(i))
}
}
walk(node)
return errors
}
// contentHash returns a fast hash of the content for caching.
// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes.
func contentHash(content []byte) string {
h := xxhash.Sum64(content)
return fmt.Sprintf("%016x", h)
}
// isBinary checks if content appears to be binary.
func isBinary(content []byte) bool {
// Check first 8000 bytes for null bytes
checkLen := min(8000, len(content))
for i := range checkLen {
if content[i] == 0 {
return true
}
}
return false
}
// Close closes all parsers and clears the cache.
func (r *Registry) Close() {
r.mu.Lock()
defer r.mu.Unlock()
for _, p := range r.parsers {
p.Close()
}
r.parsers = make(map[protocol.Language]*sitter.Parser)
// Purge LRU cache
r.cache.Purge()
}
+230
View File
@@ -0,0 +1,230 @@
package parser
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestNewRegistry(t *testing.T) {
r := NewRegistry()
if r == nil {
t.Fatal("expected non-nil registry")
}
defer r.Close()
}
func TestGetParser(t *testing.T) {
r := NewRegistry()
defer r.Close()
tests := []struct {
lang protocol.Language
wantErr bool
}{
{protocol.LangGo, false},
{protocol.LangTypeScript, false},
{protocol.LangJavaScript, false},
{protocol.LangPython, false},
{protocol.LangC, false},
{protocol.LangCpp, false},
{protocol.LangHTML, false},
{protocol.LangVue, false},
{protocol.LangUnknown, true},
}
for _, tt := range tests {
t.Run(string(tt.lang), func(t *testing.T) {
parser, err := r.GetParser(tt.lang)
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if parser == nil {
t.Error("expected non-nil parser")
}
}
})
}
}
func TestParse(t *testing.T) {
r := NewRegistry()
defer r.Close()
tests := []struct {
name string
filename string
content string
wantLang protocol.Language
wantErr bool
}{
{
name: "go file",
filename: "test.go",
content: "package main\n\nfunc main() {}\n",
wantLang: protocol.LangGo,
wantErr: false,
},
{
name: "typescript file",
filename: "test.ts",
content: "function hello(): void {}\n",
wantLang: protocol.LangTypeScript,
wantErr: false,
},
{
name: "react tsx file",
filename: "Component.tsx",
content: `import React from 'react';\n\nexport const Button: React.FC = () => <button className="btn">Click</button>;`,
wantLang: protocol.LangTypeScript,
wantErr: false,
},
{
name: "react jsx file",
filename: "Component.jsx",
content: `import React from 'react';\n\nexport const Button = () => <button className="btn">Click</button>;`,
wantLang: protocol.LangJavaScript,
wantErr: false,
},
{
name: "python file",
filename: "test.py",
content: "def hello():\n pass\n",
wantLang: protocol.LangPython,
wantErr: false,
},
{
name: "html file",
filename: "test.html",
content: `<!DOCTYPE html><html><head><title>Test</title></head><body><h1 class="text-xl">Hello</h1></body></html>`,
wantLang: protocol.LangHTML,
wantErr: false,
},
{
name: "vue file",
filename: "Component.vue",
content: `<template><div class="container"><h1>{{ title }}</h1></div></template>`,
wantLang: protocol.LangVue,
wantErr: false,
},
{
name: "unknown file",
filename: "test.txt",
content: "hello world",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
result, err := r.Parse(ctx, tt.filename, []byte(tt.content))
if tt.wantErr {
if err == nil {
t.Error("expected error")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Language != tt.wantLang {
t.Errorf("expected language %s, got %s", tt.wantLang, result.Language)
}
if result.Tree == nil {
t.Error("expected non-nil tree")
}
})
}
}
func TestParseWithSyntaxErrors(t *testing.T) {
r := NewRegistry()
defer r.Close()
// Invalid Go code
content := "package main\n\nfunc main( {}\n" // Missing closing paren
ctx := context.Background()
result, err := r.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have parsed (tree-sitter is error-tolerant)
if result.Tree == nil {
t.Error("expected non-nil tree")
}
// Should have detected errors
if len(result.Errors) == 0 {
t.Error("expected syntax errors to be detected")
}
}
func TestIsBinary(t *testing.T) {
tests := []struct {
name string
content []byte
want bool
}{
{
name: "text file",
content: []byte("hello world"),
want: false,
},
{
name: "binary with null byte",
content: []byte{0x68, 0x65, 0x6c, 0x00, 0x6f},
want: true,
},
{
name: "empty file",
content: []byte{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isBinary(tt.content); got != tt.want {
t.Errorf("isBinary() = %v, want %v", got, tt.want)
}
})
}
}
func TestCaching(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := []byte("package main\n\nfunc main() {}\n")
ctx := context.Background()
// Parse once
result1, err := r.Parse(ctx, "test.go", content)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
// Parse again with same content
result2, err := r.Parse(ctx, "test.go", content)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
// Should return cached tree (same pointer)
if result1.Tree != result2.Tree {
t.Error("expected cached tree to be returned")
}
}
+474
View File
@@ -0,0 +1,474 @@
package parser
import (
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// ExtractSymbols extracts symbols from a parsed tree.
func ExtractSymbols(tree *sitter.Tree, content []byte, lang protocol.Language, filename string) []protocol.Symbol {
if tree == nil {
return nil
}
root := tree.RootNode()
if root == nil {
return nil
}
switch lang {
case protocol.LangGo:
return extractGoSymbols(root, content, filename)
case protocol.LangTypeScript, protocol.LangJavaScript:
return extractJSSymbols(root, content, filename)
case protocol.LangPython:
return extractPythonSymbols(root, content, filename)
case protocol.LangC, protocol.LangCpp:
return extractCSymbols(root, content, filename)
default:
return nil
}
}
// extractGoSymbols extracts symbols from Go code.
func extractGoSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_declaration":
symbol = extractGoFunction(n, content, filename)
case "method_declaration":
symbol = extractGoMethod(n, content, filename)
case "type_declaration":
symbol = extractGoType(n, content, filename)
case "const_declaration", "var_declaration":
syms := extractGoVarConst(n, content, filename)
symbols = append(symbols, syms...)
return true
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangGo); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractGoFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func extractGoMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
// Get receiver type
receiver := n.ChildByFieldName("receiver")
receiverType := ""
if receiver != nil {
// Find the type in the receiver
WalkTree(receiver, func(node *sitter.Node) bool {
if node.Type() == "type_identifier" {
receiverType = GetNodeText(node, content)
return false
}
return true
})
}
name := GetNodeText(nameNode, content)
if receiverType != "" {
name = "(" + receiverType + ")." + name
}
return &protocol.Symbol{
Name: name,
Kind: protocol.SymbolMethod,
Location: NodeLocation(n, filename),
}
}
func extractGoType(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
// Find type_spec child
for i := 0; i < int(n.NamedChildCount()); i++ {
child := n.NamedChild(i)
if child != nil && child.Type() == "type_spec" {
nameNode := child.ChildByFieldName("name")
if nameNode == nil {
continue
}
kind := protocol.SymbolType
typeNode := child.ChildByFieldName("type")
if typeNode != nil {
switch typeNode.Type() {
case "struct_type":
kind = protocol.SymbolStruct
case "interface_type":
kind = protocol.SymbolInterface
}
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: kind,
Location: NodeLocation(child, filename),
}
}
}
return nil
}
func extractGoVarConst(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
kind := protocol.SymbolVariable
if n.Type() == "const_declaration" {
kind = protocol.SymbolConstant
}
WalkTree(n, func(node *sitter.Node) bool {
if node.Type() == "const_spec" || node.Type() == "var_spec" {
nameNode := node.ChildByFieldName("name")
if nameNode != nil {
symbols = append(symbols, protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: kind,
Location: NodeLocation(node, filename),
})
}
}
return true
})
return symbols
}
// extractJSSymbols extracts symbols from JavaScript/TypeScript code.
func extractJSSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_declaration":
symbol = extractJSFunction(n, content, filename)
case "class_declaration":
symbol = extractJSClass(n, content, filename)
case "method_definition":
symbol = extractJSMethod(n, content, filename)
case "lexical_declaration", "variable_declaration":
syms := extractJSVariable(n, content, filename)
symbols = append(symbols, syms...)
return true
case "interface_declaration":
symbol = extractTSInterface(n, content, filename)
case "type_alias_declaration":
symbol = extractTSTypeAlias(n, content, filename)
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangJavaScript); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractJSFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func extractJSClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolClass,
Location: NodeLocation(n, filename),
}
}
func extractJSMethod(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolMethod,
Location: NodeLocation(n, filename),
}
}
func extractJSVariable(n *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(n, func(node *sitter.Node) bool {
if node.Type() == "variable_declarator" {
nameNode := node.ChildByFieldName("name")
if nameNode != nil {
symbols = append(symbols, protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolVariable,
Location: NodeLocation(node, filename),
})
}
}
return true
})
return symbols
}
func extractTSInterface(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolInterface,
Location: NodeLocation(n, filename),
}
}
func extractTSTypeAlias(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolType,
Location: NodeLocation(n, filename),
}
}
// extractPythonSymbols extracts symbols from Python code.
func extractPythonSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_definition":
symbol = extractPythonFunction(n, content, filename)
case "class_definition":
symbol = extractPythonClass(n, content, filename)
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangPython); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractPythonFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
// Check if this is a method (inside a class)
parent := n.Parent()
kind := protocol.SymbolFunction
if parent != nil && parent.Type() == "block" {
grandparent := parent.Parent()
if grandparent != nil && grandparent.Type() == "class_definition" {
kind = protocol.SymbolMethod
}
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: kind,
Location: NodeLocation(n, filename),
}
}
func extractPythonClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolClass,
Location: NodeLocation(n, filename),
}
}
// extractCSymbols extracts symbols from C/C++ code.
func extractCSymbols(root *sitter.Node, content []byte, filename string) []protocol.Symbol {
var symbols []protocol.Symbol
WalkTree(root, func(n *sitter.Node) bool {
var symbol *protocol.Symbol
switch n.Type() {
case "function_definition":
symbol = extractCFunction(n, content, filename)
case "struct_specifier":
symbol = extractCStruct(n, content, filename)
case "class_specifier":
symbol = extractCppClass(n, content, filename)
case "declaration":
// Could be function declaration or variable
if hasFunctionDeclarator(n) {
symbol = extractCFunctionDecl(n, content, filename)
}
}
if symbol != nil {
if doc := ExtractDocComment(n, content, protocol.LangC); doc != nil {
symbol.Doc = FormatDocComment(doc)
}
symbols = append(symbols, *symbol)
}
return true
})
return symbols
}
func extractCFunction(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
declarator := n.ChildByFieldName("declarator")
if declarator == nil {
return nil
}
// Find the function name within the declarator
var name string
WalkTree(declarator, func(node *sitter.Node) bool {
if node.Type() == "identifier" {
name = GetNodeText(node, content)
return false
}
return true
})
if name == "" {
return nil
}
return &protocol.Symbol{
Name: name,
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func extractCStruct(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolStruct,
Location: NodeLocation(n, filename),
}
}
func extractCppClass(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
nameNode := n.ChildByFieldName("name")
if nameNode == nil {
return nil
}
return &protocol.Symbol{
Name: GetNodeText(nameNode, content),
Kind: protocol.SymbolClass,
Location: NodeLocation(n, filename),
}
}
func extractCFunctionDecl(n *sitter.Node, content []byte, filename string) *protocol.Symbol {
declarator := n.ChildByFieldName("declarator")
if declarator == nil {
return nil
}
var name string
WalkTree(declarator, func(node *sitter.Node) bool {
if node.Type() == "identifier" {
name = GetNodeText(node, content)
return false
}
return true
})
if name == "" {
return nil
}
return &protocol.Symbol{
Name: name,
Kind: protocol.SymbolFunction,
Location: NodeLocation(n, filename),
}
}
func hasFunctionDeclarator(n *sitter.Node) bool {
found := false
WalkTree(n, func(node *sitter.Node) bool {
if node.Type() == "function_declarator" {
found = true
return false
}
return true
})
return found
}
+226
View File
@@ -0,0 +1,226 @@
package parser
import (
"context"
"testing"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestExtractGoSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `package main
// Hello prints a greeting
func Hello() {
println("hello")
}
// Server handles requests
type Server struct {
Port int
}
// Start starts the server
func (s *Server) Start() error {
return nil
}
const MaxConnections = 100
var globalVar = "test"
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.go", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangGo, "test.go")
expectedSymbols := map[string]protocol.SymbolKind{
"Hello": protocol.SymbolFunction,
"Server": protocol.SymbolStruct,
"(Server).Start": protocol.SymbolMethod,
"MaxConnections": protocol.SymbolConstant,
"globalVar": protocol.SymbolVariable,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
func TestExtractJSSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `
function greet(name) {
console.log("Hello, " + name);
}
class User {
constructor(name) {
this.name = name;
}
getName() {
return this.name;
}
}
const MAX_USERS = 100;
let currentUser = null;
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.js", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangJavaScript, "test.js")
expectedSymbols := map[string]protocol.SymbolKind{
"greet": protocol.SymbolFunction,
"User": protocol.SymbolClass,
"MAX_USERS": protocol.SymbolVariable,
"currentUser": protocol.SymbolVariable,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
func TestExtractPythonSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `
def greet(name):
"""Greet a person by name."""
print(f"Hello, {name}")
class User:
"""Represents a user."""
def __init__(self, name):
self.name = name
def get_name(self):
return self.name
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.py", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangPython, "test.py")
expectedSymbols := map[string]protocol.SymbolKind{
"greet": protocol.SymbolFunction,
"User": protocol.SymbolClass,
"__init__": protocol.SymbolMethod,
"get_name": protocol.SymbolMethod,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
func TestExtractCSymbols(t *testing.T) {
r := NewRegistry()
defer r.Close()
content := `
#include <stdio.h>
struct Point {
int x;
int y;
};
void print_point(struct Point p) {
printf("(%d, %d)\n", p.x, p.y);
}
int main() {
struct Point p = {1, 2};
print_point(p);
return 0;
}
`
ctx := context.Background()
result, err := r.Parse(ctx, "test.c", []byte(content))
if err != nil {
t.Fatalf("parse failed: %v", err)
}
symbols := ExtractSymbols(result.Tree, []byte(content), protocol.LangC, "test.c")
// Note: C symbol extraction is complex, checking for at least main and Point
expectedSymbols := map[string]protocol.SymbolKind{
"Point": protocol.SymbolStruct,
"main": protocol.SymbolFunction,
}
found := make(map[string]bool)
for _, sym := range symbols {
if expectedKind, ok := expectedSymbols[sym.Name]; ok {
found[sym.Name] = true
if sym.Kind != expectedKind {
t.Errorf("symbol %s: expected kind %s, got %s", sym.Name, expectedKind, sym.Kind)
}
}
}
for name := range expectedSymbols {
if !found[name] {
t.Errorf("expected to find symbol %s", name)
}
}
}
+195
View File
@@ -0,0 +1,195 @@
// Package parser provides YAML and JSON parsing with AST support.
package parser
import (
"context"
"encoding/json"
"fmt"
"gopkg.in/yaml.v3"
"github.com/lukaszraczylo/mcp-filepuff/pkg/errors"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
sitter "github.com/smacker/go-tree-sitter"
)
// YAMLNode wraps yaml.Node to provide tree-sitter-like interface
type YAMLNode struct {
*yaml.Node
Content []byte
}
// JSONNode represents a JSON AST node
type JSONNode struct {
Value any
Type string
Children []*JSONNode
Line int
Column int
}
// ParseYAML parses YAML content and returns a tree-sitter-compatible result
func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
}
// Parse YAML
var root yaml.Node
if err := yaml.Unmarshal(content, &root); err != nil {
return nil, errors.NewParseError("yaml", filename, err)
}
// Extract syntax errors from YAML parse
syntaxErrors := extractYAMLErrors()
// Create a pseudo tree-sitter tree for compatibility
// We'll use nil for the tree since YAML doesn't use tree-sitter
return &ParseResult{
Tree: nil, // YAML uses yaml.Node instead
Language: protocol.LangYAML,
Errors: syntaxErrors,
Content: content,
}, nil
}
// ParseJSON parses JSON content and returns a tree-sitter-compatible result
func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) {
// Check file size
if len(content) > MaxFileSize {
return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize)
}
// Parse JSON to validate syntax
var jsonData any
if err := json.Unmarshal(content, &jsonData); err != nil {
return nil, errors.NewParseError("json", filename, err)
}
// JSON parsing succeeded, no syntax errors
return &ParseResult{
Tree: nil, // JSON uses native Go structures
Language: protocol.LangJSON,
Errors: []SyntaxError{},
Content: content,
}, nil
}
// extractYAMLErrors extracts errors from YAML nodes
func extractYAMLErrors() []SyntaxError {
// YAML parser already validates during unmarshal
// If we got here, there are no syntax errors
// However, we could add semantic validation here in the future
return []SyntaxError{}
}
// WalkYAML walks a YAML AST and calls fn for each node
func WalkYAML(node *yaml.Node, fn func(*yaml.Node) bool) {
if node == nil || !fn(node) {
return
}
for _, child := range node.Content {
WalkYAML(child, fn)
}
}
// GetYAMLNodeText returns the text representation of a YAML node
func GetYAMLNodeText(node *yaml.Node) string {
if node == nil {
return ""
}
switch node.Kind {
case yaml.DocumentNode:
if len(node.Content) > 0 {
return GetYAMLNodeText(node.Content[0])
}
return ""
case yaml.MappingNode:
return node.Value
case yaml.SequenceNode:
return node.Value
case yaml.ScalarNode:
return node.Value
case yaml.AliasNode:
return node.Value
default:
return ""
}
}
// GetYAMLNodeLocation returns the location of a YAML node
func GetYAMLNodeLocation(node *yaml.Node) protocol.Location {
if node == nil {
return protocol.Location{Line: 1, Column: 1}
}
return protocol.Location{
Line: node.Line,
Column: node.Column,
}
}
// QueryYAML performs a simple query on YAML content
// Example: "$.metadata.name" to find the name field in metadata
func QueryYAML(content []byte, query string) ([]*yaml.Node, error) {
var root yaml.Node
if err := yaml.Unmarshal(content, &root); err != nil {
return nil, fmt.Errorf("failed to parse YAML: %w", err)
}
// Simple path-based query implementation
// This is a basic implementation - can be extended with more sophisticated queries
var results []*yaml.Node
WalkYAML(&root, func(node *yaml.Node) bool {
if node.Value == query || node.Tag == query {
results = append(results, node)
}
return true
})
return results, nil
}
// QueryJSON performs a simple query on JSON content
func QueryJSON(content []byte, query string) ([]any, error) {
var data any
if err := json.Unmarshal(content, &data); err != nil {
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
// Basic implementation - can be extended with JSONPath support
var results []any
// For now, just validate that it's valid JSON
results = append(results, data)
return results, nil
}
// ValidateYAML validates YAML content without parsing to full AST
func ValidateYAML(content []byte) error {
var node yaml.Node
if err := yaml.Unmarshal(content, &node); err != nil {
return fmt.Errorf("YAML validation failed: %w", err)
}
return nil
}
// ValidateJSON validates JSON content
func ValidateJSON(content []byte) error {
var data any
if err := json.Unmarshal(content, &data); err != nil {
return fmt.Errorf("JSON validation failed: %w", err)
}
return nil
}
// ToSitterTree is a placeholder that returns nil for YAML/JSON
// These formats don't use tree-sitter, but we keep this for interface compatibility
func (yn *YAMLNode) ToSitterTree() *sitter.Tree {
return nil
}
+283
View File
@@ -0,0 +1,283 @@
package parser
import (
"context"
"testing"
"gopkg.in/yaml.v3"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
)
func TestParseYAML(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
content string
shouldError bool
}{
{
name: "valid simple YAML",
content: `name: test
version: 1.0.0
enabled: true`,
shouldError: false,
},
{
name: "valid nested YAML",
content: `metadata:
name: test-app
namespace: default
spec:
replicas: 3
selector:
matchLabels:
app: test`,
shouldError: false,
},
{
name: "valid list YAML",
content: `items:
- name: item1
value: 100
- name: item2
value: 200`,
shouldError: false,
},
{
name: "invalid YAML - bad syntax",
content: `name: test\n bad: indent\n wrong: [unclosed`,
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.ParseYAML(context.Background(), "test.yaml", []byte(tt.content))
if tt.shouldError {
if err == nil {
t.Error("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result == nil {
t.Error("expected result but got nil")
return
}
if result.Language != protocol.LangYAML {
t.Errorf("expected language YAML, got %s", result.Language)
}
if len(result.Errors) > 0 {
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
}
})
}
}
func TestParseJSON(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
tests := []struct {
name string
content string
shouldError bool
}{
{
name: "valid simple JSON",
content: `{"name": "test", "version": "1.0.0", "enabled": true}`,
shouldError: false,
},
{
name: "valid nested JSON",
content: `{
"metadata": {
"name": "test-app",
"namespace": "default"
},
"spec": {
"replicas": 3,
"selector": {
"matchLabels": {
"app": "test"
}
}
}
}`,
shouldError: false,
},
{
name: "valid array JSON",
content: `[{"name": "item1", "value": 100}, {"name": "item2", "value": 200}]`,
shouldError: false,
},
{
name: "invalid JSON - unclosed brace",
content: `{"name": "test", "value": 100`,
shouldError: true,
},
{
name: "invalid JSON - trailing comma",
content: `{"name": "test", "value": 100,}`,
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := registry.ParseJSON(context.Background(), "test.json", []byte(tt.content))
if tt.shouldError {
if err == nil {
t.Error("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result == nil {
t.Error("expected result but got nil")
return
}
if result.Language != protocol.LangJSON {
t.Errorf("expected language JSON, got %s", result.Language)
}
if len(result.Errors) > 0 {
t.Errorf("expected no syntax errors, got %d", len(result.Errors))
}
})
}
}
func TestRegistryParse_YAML_JSON(t *testing.T) {
registry := NewRegistry()
defer registry.Close()
yamlContent := []byte(`name: test
version: 1.0.0`)
jsonContent := []byte(`{"name": "test", "version": "1.0.0"}`)
// Test YAML through main Parse method
yamlResult, err := registry.Parse(context.Background(), "config.yaml", yamlContent)
if err != nil {
t.Errorf("failed to parse YAML: %v", err)
}
if yamlResult.Language != protocol.LangYAML {
t.Errorf("expected YAML language, got %s", yamlResult.Language)
}
// Test JSON through main Parse method
jsonResult, err := registry.Parse(context.Background(), "config.json", jsonContent)
if err != nil {
t.Errorf("failed to parse JSON: %v", err)
}
if jsonResult.Language != protocol.LangJSON {
t.Errorf("expected JSON language, got %s", jsonResult.Language)
}
// Test .yml extension
ymlResult, err := registry.Parse(context.Background(), "config.yml", yamlContent)
if err != nil {
t.Errorf("failed to parse .yml: %v", err)
}
if ymlResult.Language != protocol.LangYAML {
t.Errorf("expected YAML language for .yml extension, got %s", ymlResult.Language)
}
}
func TestWalkYAML(t *testing.T) {
content := []byte(`metadata:
name: test
labels:
app: myapp
env: prod`)
var root yaml.Node
if err := yaml.Unmarshal(content, &root); err != nil {
t.Fatalf("failed to parse YAML: %v", err)
}
nodeCount := 0
WalkYAML(&root, func(node *yaml.Node) bool {
nodeCount++
return true
})
if nodeCount == 0 {
t.Error("expected to visit nodes, but count is 0")
}
}
func TestValidateYAML(t *testing.T) {
tests := []struct {
name string
content []byte
shouldError bool
}{
{
name: "valid YAML",
content: []byte("name: test\nvalue: 100"),
shouldError: false,
},
{
name: "invalid YAML",
content: []byte("name: test\n bad:\n[unclosed"),
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateYAML(tt.content)
if (err != nil) != tt.shouldError {
t.Errorf("ValidateYAML() error = %v, shouldError = %v", err, tt.shouldError)
}
})
}
}
func TestValidateJSON(t *testing.T) {
tests := []struct {
name string
content []byte
shouldError bool
}{
{
name: "valid JSON",
content: []byte(`{"name": "test", "value": 100}`),
shouldError: false,
},
{
name: "invalid JSON",
content: []byte(`{"name": "test", "value": 100`),
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateJSON(tt.content)
if (err != nil) != tt.shouldError {
t.Errorf("ValidateJSON() error = %v, shouldError = %v", err, tt.shouldError)
}
})
}
}