Files
filepuff-mcp/internal/server/handlers_ast.go
T
2026-02-22 15:24:48 +00:00

244 lines
5.9 KiB
Go

// Package server implements the MCP server for file operations.
package server
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/lukaszraczylo/mcp-filepuff/internal/parser"
"github.com/lukaszraczylo/mcp-filepuff/internal/query"
"github.com/lukaszraczylo/mcp-filepuff/pkg/protocol"
"github.com/mark3labs/mcp-go/mcp"
)
// handleASTQuery handles the ast_query tool.
func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Acquire semaphore to limit concurrent queries (prevents CPU exhaustion)
select {
case s.querySem <- struct{}{}:
defer func() { <-s.querySem }()
case <-ctx.Done():
return mcp.NewToolResultError("request cancelled"), nil
}
pattern, err := request.RequireString("pattern")
if err != nil {
return mcp.NewToolResultError("pattern is required"), nil
}
language, err := request.RequireString("language")
if err != nil {
return mcp.NewToolResultError("language is required"), nil
}
// Build query
astQuery := &query.ASTQuery{
Pattern: pattern,
Language: language,
Filters: query.QueryFilters{
NameMatches: request.GetString("name_matches", ""),
NameExact: request.GetString("name_exact", ""),
KindIn: request.GetStringSlice("kind_in", nil),
},
}
maxResults := request.GetInt("max_results", 100)
paths := request.GetStringSlice("paths", nil)
// Default to workspace root if no paths specified
if len(paths) == 0 {
paths = []string{s.cfg.WorkspaceRoot}
}
// Find files to search based on language
exts := languageToExtensions(language)
if len(exts) == 0 {
return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir, rust)", language)), nil
}
var allResults []query.MatchResult
// Walk through paths and find matching files
for _, searchPath := range paths {
// Validate path is within workspace
if !s.cfg.IsPathAllowed(searchPath) {
continue
}
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
// Check for context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err != nil {
return nil // Skip files with errors
}
if info.IsDir() {
// Skip hidden directories
if strings.HasPrefix(info.Name(), ".") {
return filepath.SkipDir
}
return nil
}
// Check file extension matches language
matched := false
for _, ext := range exts {
if strings.HasSuffix(path, ext) {
matched = true
break
}
}
if !matched {
return nil
}
// Read and parse file
content, err := os.ReadFile(path)
if err != nil {
return nil // Skip unreadable files
}
// Check file size
if int64(len(content)) > s.cfg.MaxFileSize {
return nil // Skip large files
}
// Parse file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return nil // Skip unparseable files
}
// Run query
matches, err := s.matcher.Match(ctx, astQuery, result.Tree, content, path)
if err != nil {
return nil // Skip on error
}
allResults = append(allResults, matches...)
// Stop if we have enough results
if maxResults > 0 && len(allResults) >= maxResults {
return filepath.SkipAll
}
return nil
})
if err != nil {
s.logger.Warn("error walking path", "path", searchPath, "error", err)
}
}
// Format and return results
output := query.FormatResults(allResults, maxResults)
return mcp.NewToolResultText(output), nil
}
// generateASTSummary generates a summary of symbols in the file.
func (s *Server) generateASTSummary(ctx context.Context, path string, content []byte) string {
// Parse the file
result, err := s.parser.Parse(ctx, path, content)
if err != nil {
return "" // Silently skip AST if parsing fails
}
// Extract symbols
lang := protocol.DetectLanguage(path)
symbols := parser.ExtractSymbols(result.Tree, content, lang, path)
if len(symbols) == 0 {
return ""
}
var sb strings.Builder
// Get relative path
relPath := path
if absPath, err := filepath.Abs(path); err == nil {
if rel, err := filepath.Rel(s.cfg.WorkspaceRoot, absPath); err == nil && !strings.HasPrefix(rel, "..") {
relPath = rel
}
}
sb.WriteString(fmt.Sprintf("**%s** (%d lines, %s)\n\n", relPath, len(splitLines(string(content))), lang))
sb.WriteString("Symbols:\n")
for _, sym := range symbols {
kindStr := symbolKindIcon(sym.Kind)
sb.WriteString(fmt.Sprintf(" %s %s L%d\n", kindStr, sym.Name, sym.Location.Line))
}
return sb.String()
}
// symbolKindIcon returns an icon/prefix for a symbol kind.
func symbolKindIcon(kind protocol.SymbolKind) string {
switch kind {
case protocol.SymbolFunction:
return "func"
case protocol.SymbolMethod:
return "meth"
case protocol.SymbolClass:
return "class"
case protocol.SymbolStruct:
return "struct"
case protocol.SymbolInterface:
return "iface"
case protocol.SymbolVariable:
return "var"
case protocol.SymbolConstant:
return "const"
case protocol.SymbolType:
return "type"
case protocol.SymbolField:
return "field"
case protocol.SymbolProperty:
return "prop"
case protocol.SymbolModule:
return "mod"
case protocol.SymbolPackage:
return "pkg"
case protocol.SymbolEnum:
return "enum"
case protocol.SymbolTrait:
return "trait"
default:
return "sym"
}
}
// languageToExtensions maps language names to file extensions.
func languageToExtensions(language string) []string {
switch strings.ToLower(language) {
case "go":
return []string{".go"}
case "typescript":
return []string{".ts"}
case "javascript":
return []string{".js"}
case "python":
return []string{".py"}
case "c":
return []string{".c"}
case "cpp", "c++":
return []string{".cpp"}
case "html":
return []string{".html", ".htm"}
case "vue":
return []string{".vue"}
case "elixir":
return []string{".ex", ".exs"}
case "rust":
return []string{".rs"}
default:
return nil
}
}