From 982c2c8b44928a154d807c444f11ac6dab995cac Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sun, 22 Feb 2026 14:03:54 +0000 Subject: [PATCH] fixup! Update, bugfixes on diff and edit handling --- internal/config/config.go | 19 +++- internal/config/config_test.go | 57 +++++++++- internal/config/security_test.go | 2 +- internal/edit/edit.go | 30 ++++-- internal/edit/edit_test.go | 7 +- internal/lsp/client.go | 23 ++++ internal/lsp/manager.go | 69 +++++++----- internal/lsp/manager_test.go | 26 ++++- internal/parser/cache_test.go | 14 +-- internal/parser/parser.go | 70 +----------- internal/parser/parser_bench_test.go | 5 +- internal/parser/yaml_json.go | 8 +- internal/query/query.go | 14 ++- internal/query/query_test.go | 149 ++++++++++++++++++++++++++ internal/search/search.go | 21 ++++ internal/search/search_test.go | 32 ++++-- internal/server/handlers_ast.go | 39 ++++--- internal/server/handlers_edit.go | 10 ++ internal/server/handlers_file.go | 31 ++++-- internal/server/handlers_lsp.go | 10 +- internal/server/server.go | 2 +- internal/server/server_test.go | 153 +++++++++++++++++++++++++++ internal/util/regex_cache.go | 58 +++++----- 23 files changed, 655 insertions(+), 194 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 904d7e9..38f27da 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -69,13 +69,26 @@ func Load(workspaceRoot string) (*Config, error) { cfg.WorkspaceRoot = cwd } - // Try to load from config file in workspace root + // Try to load from config file in workspace root. + // Save WorkspaceRoot before loading config file so it cannot be overridden. + savedRoot := cfg.WorkspaceRoot configPath := filepath.Join(cfg.WorkspaceRoot, ".mcp-filepuff.json") if data, err := os.ReadFile(configPath); err == nil { if err := json.Unmarshal(data, cfg); err != nil { return nil, err } } + // Restore WorkspaceRoot — config file must not override path guards. + cfg.WorkspaceRoot = savedRoot + + // Clamp size limits to prevent config file from requesting excessive memory. + const maxAllowedSize int64 = 100 * 1024 * 1024 // 100 MB + if cfg.MaxFileSize > maxAllowedSize { + cfg.MaxFileSize = maxAllowedSize + } + if cfg.MaxParseSize > maxAllowedSize { + cfg.MaxParseSize = maxAllowedSize + } // Override from environment variables cfg.loadFromEnv() @@ -173,8 +186,8 @@ func (c *Config) IsPathAllowed(path string) bool { // Check if the path is within workspace (doesn't start with ..) // This prevents both "../" attacks and symlink bypasses - // Also reject empty relative path (which means it's the workspace root itself) - return rel != "." && !strings.HasPrefix(rel, "..") + // The workspace root itself (rel == ".") is a valid, allowed path + return !strings.HasPrefix(rel, "..") } // Validate validates the configuration and returns an error if invalid. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 592b931..6abd86b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -389,8 +389,8 @@ func TestIsPathAllowedEdgeCases(t *testing.T) { { name: "workspace_root_itself", path: tmpDir, - allowed: false, - desc: "workspace root itself should not be allowed", + allowed: true, + desc: "workspace root itself should be allowed", }, { name: "dot_relative", @@ -546,6 +546,59 @@ func TestConfigFileLoadingErrors(t *testing.T) { } } +// TestIsPathAllowed_SymlinkOutsideWorkspace verifies that symlinks pointing +// outside the workspace are rejected (T-01). +func TestIsPathAllowed_SymlinkOutsideWorkspace(t *testing.T) { + // Create two separate temp dirs: one as workspace, one as outside target + workspace, err := os.MkdirTemp("", "mcp-workspace-*") + if err != nil { + t.Fatalf("failed to create workspace dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(workspace) }) + + outside, err := os.MkdirTemp("", "mcp-outside-*") + if err != nil { + t.Fatalf("failed to create outside dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(outside) }) + + // Create a file outside the workspace + outsideFile := filepath.Join(outside, "secret.txt") + if err := os.WriteFile(outsideFile, []byte("secret"), 0o600); err != nil { + t.Fatalf("failed to write outside file: %v", err) + } + + // Create a symlink inside the workspace pointing outside + symlinkPath := filepath.Join(workspace, "escape-link") + if err := os.Symlink(outsideFile, symlinkPath); err != nil { + t.Skip("symlink creation not supported on this system") + } + + cfg := Default() + cfg.WorkspaceRoot = workspace + + // The symlink resolves to a file outside workspace — must be rejected + if cfg.IsPathAllowed(symlinkPath) { + t.Error("symlink pointing outside workspace should NOT be allowed") + } + + // Direct access to the outside file should also be rejected + if cfg.IsPathAllowed(outsideFile) { + t.Error("file outside workspace should NOT be allowed") + } + + // File inside workspace should still be allowed + insideFile := filepath.Join(workspace, "safe.txt") + if !cfg.IsPathAllowed(insideFile) { + t.Error("file inside workspace should be allowed") + } + + // Workspace root itself should be allowed (C-08 fix) + if !cfg.IsPathAllowed(workspace) { + t.Error("workspace root itself should be allowed") + } +} + // Helper function to check if a string contains a substring. func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(substr) == 0 || diff --git a/internal/config/security_test.go b/internal/config/security_test.go index 5bdfe58..a01c6ce 100644 --- a/internal/config/security_test.go +++ b/internal/config/security_test.go @@ -126,7 +126,7 @@ func TestIsPathAllowed_BasicCases(t *testing.T) { { name: "workspace root itself", path: tmpDir, - expected: false, // Empty relative path + expected: true, // Workspace root is a valid, allowed path (needed for ast_query) }, } diff --git a/internal/edit/edit.go b/internal/edit/edit.go index 531bddf..571b2ff 100644 --- a/internal/edit/edit.go +++ b/internal/edit/edit.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "os" + "slices" "strings" "sync" @@ -60,6 +61,7 @@ type EditResult struct { // Engine performs AST-aware edits. type Engine struct { registry *parser.Registry + dmp *diffmatchpatch.DiffMatchPatch fileLocks sync.Map // map[string]*sync.Mutex for per-file locking } @@ -67,6 +69,7 @@ type Engine struct { func NewEngine(registry *parser.Registry) *Engine { return &Engine{ registry: registry, + dmp: diffmatchpatch.New(), fileLocks: sync.Map{}, } } @@ -166,7 +169,7 @@ func (e *Engine) performASTEdit(ctx context.Context, edit *ASTEdit, apply bool) } // Generate diff - diff := generateDiff(string(content), string(newContent), edit.File) + diff := e.generateDiff(string(content), string(newContent), edit.File) result := &EditResult{ Success: true, @@ -225,7 +228,7 @@ func (e *Engine) performTextEdit(_ context.Context, edit *ASTEdit, apply bool) ( } // Generate diff - diff := generateDiff(string(content), string(newContent), edit.File) + diff := e.generateDiff(string(content), string(newContent), edit.File) result := &EditResult{ Success: true, @@ -369,17 +372,18 @@ func sortBySpecificity(nodes []*sitter.Node) []*sitter.Node { return nodes } - // Sort by specificity: named nodes first, then by size (smallest first) result := make([]*sitter.Node, len(nodes)) copy(result, nodes) - for i := 0; i < len(result)-1; i++ { - for j := i + 1; j < len(result); j++ { - if shouldPrefer(result[j], result[i]) { - result[i], result[j] = result[j], result[i] - } + slices.SortFunc(result, func(a, b *sitter.Node) int { + if shouldPrefer(a, b) { + return -1 } - } + if shouldPrefer(b, a) { + return 1 + } + return 0 + }) return result } @@ -566,8 +570,8 @@ func indentContent(content string, indent string) string { // generateDiff creates a unified diff between original and modified content. // Uses line-level Myers diff algorithm for accurate and readable diffs. -func generateDiff(original, modified, filename string) string { - dmp := diffmatchpatch.New() +func (e *Engine) generateDiff(original, modified, filename string) string { + dmp := e.dmp // Use line-level diffing: encode each line as a single character, // diff the encoded strings, then decode back to real lines. @@ -692,6 +696,10 @@ func (e *Engine) findLineRange(content []byte, lineStart, lineEnd int) (start, e } lines := bytes.Split(content, []byte("\n")) + // Trim phantom empty element from trailing newline + if len(lines) > 0 && len(lines[len(lines)-1]) == 0 { + lines = lines[:len(lines)-1] + } totalLines := len(lines) // Convert to 0-indexed diff --git a/internal/edit/edit_test.go b/internal/edit/edit_test.go index af5fef5..8137256 100644 --- a/internal/edit/edit_test.go +++ b/internal/edit/edit_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/lukaszraczylo/mcp-filepuff/internal/parser" + "github.com/sergi/go-diff/diffmatchpatch" sitter "github.com/smacker/go-tree-sitter" ) @@ -394,7 +395,7 @@ func TestGenerateDiff(t *testing.T) { modified := "line1\nmodified\nline3" filename := "test.txt" - diff := generateDiff(original, modified, filename) + diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, filename) if !strings.Contains(diff, "---") { t.Error("diff should contain --- header") @@ -420,7 +421,7 @@ func TestGenerateDiffLineLevelAccuracy(t *testing.T) { original := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello\")\n}\n" modified := "package main\n\nfunc hello() {\n\tfmt.Println(\"hello world\")\n}\n" - diff := generateDiff(original, modified, "test.go") + diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, "test.go") // The diff must show whole-line removals and additions if !strings.Contains(diff, "-\tfmt.Println(\"hello\")\n") { @@ -453,7 +454,7 @@ func TestGenerateDiffNoPhantomChanges(t *testing.T) { original := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\n" modified := "line1\nREPLACED\nline3\nline4\nline5\nline6\nline7\nline8\n" - diff := generateDiff(original, modified, "test.txt") + diff := (&Engine{dmp: diffmatchpatch.New()}).generateDiff(original, modified, "test.txt") // Count changed lines (excluding headers) addCount := 0 diff --git a/internal/lsp/client.go b/internal/lsp/client.go index eecf8a3..5874ac6 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -268,7 +268,11 @@ func (c *Client) send(msg interface{}) error { } // readLoop reads and dispatches messages from the server. +// On exit (for any reason), it drains all pending Call waiters with a +// synthetic error so that goroutines blocked in Call are unblocked. func (c *Client) readLoop() { + defer c.drainPending() + reader := bufio.NewReader(c.stdout) for { @@ -329,6 +333,25 @@ func (c *Client) readLoop() { } } +// drainPending sends a synthetic error response to every pending Call waiter +// so that goroutines blocked in Call are unblocked when readLoop exits. +func (c *Client) drainPending() { + c.mu.Lock() + defer c.mu.Unlock() + + for id, ch := range c.pending { + ch <- &Response{ + JSONRPC: "2.0", + ID: id, + Error: &ResponseError{ + Code: -32603, // InternalError + Message: "LSP client readLoop terminated", + }, + } + delete(c.pending, id) + } +} + // IsRunning returns whether the client is running. func (c *Client) IsRunning() bool { c.runningMu.RLock() diff --git a/internal/lsp/manager.go b/internal/lsp/manager.go index 8a7a7a5..5494553 100644 --- a/internal/lsp/manager.go +++ b/internal/lsp/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "net/url" "os" "os/exec" "path/filepath" @@ -28,6 +29,9 @@ const ( ) // Manager manages LSP servers for different languages. +// +// Lock ordering: m.mu must always be acquired before srv.mu. +// Never acquire m.mu while holding srv.mu. type Manager struct { servers map[protocol.Language]*ManagedServer logger *slog.Logger @@ -36,6 +40,7 @@ type Manager struct { timeout time.Duration idleTimeout time.Duration mu sync.RWMutex + closeOnce sync.Once stopped bool } @@ -163,9 +168,10 @@ func (m *Manager) GetServer(ctx context.Context, lang protocol.Language) (*Manag WithRemediation("Only whitelisted LSP server binaries are allowed for security reasons") } - // Create command + // Create command — use exec.Command (not CommandContext) so the LSP + // subprocess is not killed when the request-scoped context expires. args := append(config.Command[1:], config.Args...) - cmd := exec.CommandContext(ctx, cmdPath, args...) + cmd := exec.Command(cmdPath, args...) cmd.Env = os.Environ() cmd.Dir = m.workspaceRoot // Create client @@ -412,9 +418,13 @@ func (m *Manager) References(ctx context.Context, file string, line, col int, in } // ensureDocumentOpen opens a document if not already open. -func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, file string) error { +// It reads the file content outside the lock (to avoid holding the lock during I/O), +// then holds srv.mu for the entire check-and-send sequence to prevent duplicate didOpen +// notifications from concurrent goroutines. +func (m *Manager) ensureDocumentOpen(_ context.Context, srv *ManagedServer, file string) error { uri := fileToURI(file) + // Quick check under lock — common fast path. srv.mu.Lock() if _, ok := srv.openDocs[uri]; ok { srv.mu.Unlock() @@ -422,15 +432,23 @@ func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, fi } srv.mu.Unlock() - // Read file content + // Read file content outside the lock to avoid holding it during I/O. content, err := os.ReadFile(file) if err != nil { return fmt.Errorf("failed to read file: %w", err) } - // Get language ID langID := languageToLSPID(srv.language) + // Re-acquire lock and re-check to prevent TOCTOU race: two goroutines could + // both pass the fast-path check above and both try to send didOpen. + srv.mu.Lock() + defer srv.mu.Unlock() + + if _, ok := srv.openDocs[uri]; ok { + return nil + } + params := DidOpenTextDocumentParams{ TextDocument: TextDocumentItem{ URI: uri, @@ -444,10 +462,7 @@ func (m *Manager) ensureDocumentOpen(ctx context.Context, srv *ManagedServer, fi return fmt.Errorf("didOpen failed: %w", err) } - srv.mu.Lock() srv.openDocs[uri] = 1 - srv.mu.Unlock() - return nil } @@ -519,26 +534,28 @@ func (m *Manager) reapIdleServers() { } } -// Close shuts down all LSP servers. +// Close shuts down all LSP servers. It is safe to call multiple times. func (m *Manager) Close() error { - close(m.stopReaper) + m.closeOnce.Do(func() { + close(m.stopReaper) - m.mu.Lock() - defer m.mu.Unlock() + m.mu.Lock() + defer m.mu.Unlock() - m.stopped = true + m.stopped = true - for lang, srv := range m.servers { - m.logger.Info("shutting down LSP server", "language", lang) - // Try graceful shutdown - ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout) - _, _ = srv.client.Call(ctx, "shutdown", nil) - cancel() - _ = srv.client.Notify("exit", nil) - _ = srv.client.Close() - } + for lang, srv := range m.servers { + m.logger.Info("shutting down LSP server", "language", lang) + // Try graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout) + _, _ = srv.client.Call(ctx, "shutdown", nil) + cancel() + _ = srv.client.Notify("exit", nil) + _ = srv.client.Close() + } - m.servers = make(map[protocol.Language]*ManagedServer) + m.servers = make(map[protocol.Language]*ManagedServer) + }) return nil } @@ -553,13 +570,13 @@ func (m *Manager) IsAvailable(lang protocol.Language) bool { return err == nil } -// fileToURI converts a file path to a file URI. +// fileToURI converts a file path to a properly percent-encoded file URI. func fileToURI(file string) string { absPath, err := filepath.Abs(file) if err != nil { - return "file://" + file + absPath = file } - return "file://" + absPath + return (&url.URL{Scheme: "file", Path: absPath}).String() } // URIToFile converts a file URI to a file path. diff --git a/internal/lsp/manager_test.go b/internal/lsp/manager_test.go index d2c68d5..e13cee8 100644 --- a/internal/lsp/manager_test.go +++ b/internal/lsp/manager_test.go @@ -189,10 +189,30 @@ func TestManagerGracefulShutdown(t *testing.T) { if !manager.stopped { t.Error("manager should be marked as stopped after Close()") } +} - // Note: We don't test multiple Close() calls because the implementation - // closes the stopReaper channel which can't be closed twice. - // In production, Close() should only be called once during shutdown. +// TestManagerDoubleClose verifies that calling Close() twice does not panic (T-05, C-02). +func TestManagerDoubleClose(t *testing.T) { + tmpDir := t.TempDir() + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + + manager := NewManager(tmpDir, logger) + + // First close should succeed + err := manager.Close() + if err != nil { + t.Errorf("first Close() returned error: %v", err) + } + + // Second close must not panic (C-02 fix wraps close in sync.Once) + err = manager.Close() + if err != nil { + t.Errorf("second Close() returned error: %v", err) + } + + if !manager.stopped { + t.Error("manager should be marked as stopped after double Close()") + } } // TestManagerIdleReaper tests the idle server cleanup mechanism. diff --git a/internal/parser/cache_test.go b/internal/parser/cache_test.go index 79e790f..928380d 100644 --- a/internal/parser/cache_test.go +++ b/internal/parser/cache_test.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "testing" + + "github.com/cespare/xxhash/v2" ) // TestLRUCacheEviction tests that the LRU cache properly evicts old entries. @@ -82,8 +84,8 @@ func TestContentHashCollisionResistance(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - hash1 := contentHash(tc.content1) - hash2 := contentHash(tc.content2) + hash1 := fmt.Sprintf("%016x", xxhash.Sum64(tc.content1)) + hash2 := fmt.Sprintf("%016x", xxhash.Sum64(tc.content2)) if hash1 == hash2 { t.Errorf("Hash collision: %s == %s for different content", hash1, hash2) @@ -96,9 +98,9 @@ func TestContentHashCollisionResistance(t *testing.T) { func TestContentHashConsistency(t *testing.T) { content := []byte("package main\n\nfunc test() {}\n") - hash1 := contentHash(content) - hash2 := contentHash(content) - hash3 := contentHash(content) + hash1 := fmt.Sprintf("%016x", xxhash.Sum64(content)) + hash2 := fmt.Sprintf("%016x", xxhash.Sum64(content)) + hash3 := fmt.Sprintf("%016x", xxhash.Sum64(content)) if hash1 != hash2 || hash2 != hash3 { t.Errorf("Hash inconsistency: %s, %s, %s", hash1, hash2, hash3) @@ -115,7 +117,7 @@ func BenchmarkContentHash_xxHash(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = contentHash(content) + _ = fmt.Sprintf("%016x", xxhash.Sum64(content)) } } diff --git a/internal/parser/parser.go b/internal/parser/parser.go index b2790c0..359abb8 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -24,9 +24,8 @@ import ( "github.com/lukaszraczylo/mcp-filepuff/pkg/protocol" ) -// MaxFileSize is the default maximum file size we'll parse (10MB). -// Deprecated: Use Registry.maxParseSize instead. -const MaxFileSize = 10 * 1024 * 1024 +// maxFileSize is the default maximum file size we'll parse (10MB). +const maxFileSize = 10 * 1024 * 1024 // Registry manages Tree-sitter parsers for different languages. type Registry struct { @@ -69,18 +68,6 @@ type SyntaxError struct { Location protocol.Location } -// CacheStatsResult contains cache statistics. -type CacheStatsResult struct { - Hits int64 `json:"hits"` - Misses int64 `json:"misses"` - HitRate float64 `json:"hit_rate"` - Size int `json:"size"` - TotalParseTime int64 `json:"total_parse_time_ns"` - ParseCount int64 `json:"parse_count"` - AvgParseTime int64 `json:"avg_parse_time_ns"` - LastParseTime int64 `json:"last_parse_time_ns"` -} - // NewRegistry creates a new parser registry with the default max parse size. // For custom max parse size, use NewRegistryWithSize. func NewRegistry() *Registry { @@ -98,7 +85,7 @@ func NewRegistryWithSize(maxParseSize int64) *Registry { } if maxParseSize <= 0 { - maxParseSize = MaxFileSize + maxParseSize = maxFileSize } return &Registry{ @@ -266,50 +253,6 @@ func (r *Registry) Parse(ctx context.Context, filename string, content []byte) ( }, nil } -// CacheStats returns cache hit/miss statistics. -func (r *Registry) CacheStats() (hits, misses int64) { - return r.cacheHits.Load(), r.cacheMisses.Load() -} - -// CacheStatsDetailed returns detailed cache and parse statistics. -func (r *Registry) CacheStatsDetailed() CacheStatsResult { - hits := r.cacheHits.Load() - misses := r.cacheMisses.Load() - totalParseTime := r.totalParseTime.Load() - parseCount := r.parseCount.Load() - - var hitRate float64 - total := hits + misses - if total > 0 { - hitRate = float64(hits) / float64(total) - } - - var avgParseTime int64 - if parseCount > 0 { - avgParseTime = totalParseTime / parseCount - } - - return CacheStatsResult{ - Hits: hits, - Misses: misses, - HitRate: hitRate, - Size: r.cache.Len(), - TotalParseTime: totalParseTime, - ParseCount: parseCount, - AvgParseTime: avgParseTime, - LastParseTime: r.lastParseDuration.Load(), - } -} - -// ResetStats resets all cache and parse statistics. -func (r *Registry) ResetStats() { - r.cacheHits.Store(0) - r.cacheMisses.Store(0) - r.totalParseTime.Store(0) - r.parseCount.Store(0) - r.lastParseDuration.Store(0) -} - // extractErrors finds all error nodes in the tree. func extractErrors(node *sitter.Node, _ []byte) []SyntaxError { var errors []SyntaxError @@ -346,13 +289,6 @@ func extractErrors(node *sitter.Node, _ []byte) []SyntaxError { return errors } -// contentHash returns a fast hash of the content for caching. -// Uses xxHash which is 5-10x faster than SHA256 for non-cryptographic purposes. -func contentHash(content []byte) string { - h := xxhash.Sum64(content) - return fmt.Sprintf("%016x", h) -} - // isBinary checks if content appears to be binary. func isBinary(content []byte) bool { // Check first 8000 bytes for null bytes diff --git a/internal/parser/parser_bench_test.go b/internal/parser/parser_bench_test.go index cf9ece9..6590b85 100644 --- a/internal/parser/parser_bench_test.go +++ b/internal/parser/parser_bench_test.go @@ -2,8 +2,11 @@ package parser import ( "context" + "fmt" "strings" "testing" + + "github.com/cespare/xxhash/v2" ) // BenchmarkParse benchmarks parsing files of various sizes. @@ -194,7 +197,7 @@ func BenchmarkContentHash(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - _ = contentHash(content) + _ = fmt.Sprintf("%016x", xxhash.Sum64(content)) } }) } diff --git a/internal/parser/yaml_json.go b/internal/parser/yaml_json.go index 88a5532..053a5f8 100644 --- a/internal/parser/yaml_json.go +++ b/internal/parser/yaml_json.go @@ -31,8 +31,8 @@ type JSONNode struct { // ParseYAML parses YAML content and returns a tree-sitter-compatible result func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byte) (*ParseResult, error) { // Check file size - if len(content) > MaxFileSize { - return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize) + if len(content) > maxFileSize { + return nil, errors.NewFileTooLarge(filename, int64(len(content)), maxFileSize) } // Parse YAML @@ -57,8 +57,8 @@ func (r *Registry) ParseYAML(ctx context.Context, filename string, content []byt // ParseJSON parses JSON content and returns a tree-sitter-compatible result func (r *Registry) ParseJSON(ctx context.Context, filename string, content []byte) (*ParseResult, error) { // Check file size - if len(content) > MaxFileSize { - return nil, errors.NewFileTooLarge(filename, int64(len(content)), MaxFileSize) + if len(content) > maxFileSize { + return nil, errors.NewFileTooLarge(filename, int64(len(content)), maxFileSize) } // Parse JSON to validate syntax diff --git a/internal/query/query.go b/internal/query/query.go index bbdca27..dc931df 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -22,13 +22,11 @@ type ASTQuery struct { // QueryFilters provide additional filtering criteria. type QueryFilters struct { - HasChild *ASTQuery `json:"has_child,omitempty"` - HasParent *ASTQuery `json:"has_parent,omitempty"` - NameMatches string `json:"name_matches,omitempty"` - NameExact string `json:"name_exact,omitempty"` - InFile string `json:"in_file,omitempty"` - NotInFile string `json:"not_in_file,omitempty"` - KindIn []string `json:"kind_in,omitempty"` + NameMatches string `json:"name_matches,omitempty"` + NameExact string `json:"name_exact,omitempty"` + InFile string `json:"in_file,omitempty"` + NotInFile string `json:"not_in_file,omitempty"` + KindIn []string `json:"kind_in,omitempty"` } // MatchResult represents a single match from a query. @@ -259,7 +257,7 @@ func matchPatternHeuristic(node *sitter.Node, pattern *ParsedPattern, content [] } // Match struct patterns (Go, C, C++) - if strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ") && strings.Contains(patternLower, "struct") { + if (strings.Contains(patternLower, "struct ") || strings.Contains(patternLower, "type ")) && strings.Contains(patternLower, "struct") { if nodeType != "type_declaration" && nodeType != "struct_specifier" { return false } diff --git a/internal/query/query_test.go b/internal/query/query_test.go index e38980b..6543cb1 100644 --- a/internal/query/query_test.go +++ b/internal/query/query_test.go @@ -499,6 +499,155 @@ func TestFormatResults(t *testing.T) { } } +// TestMatchStructOperatorPrecedence verifies the C-07 operator precedence fix. +// Before the fix, patterns like "struct Foo" would match because +// strings.Contains(p, "struct ") short-circuited the entire condition. +// After the fix, both "struct" must be present for the struct branch to match. +func TestMatchStructOperatorPrecedence(t *testing.T) { + reg := parser.NewRegistry() + defer reg.Close() + + matcher := NewMatcher(reg) + + content := `package main + +type Server struct { + Port int +} + +func main() {} +` + + ctx := context.Background() + result, err := reg.Parse(ctx, "test.go", []byte(content)) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + tests := []struct { + name string + pattern string + wantMatches int + }{ + { + name: "type struct pattern should match", + pattern: "type $NAME struct { $$$FIELDS }", + wantMatches: 1, // Server + }, + { + name: "struct keyword alone should match", + pattern: "struct $NAME { $$$FIELDS }", + wantMatches: 1, // Server + }, + { + name: "func pattern should not match struct branch", + pattern: "func $NAME() {}", + wantMatches: 1, // main (matches function branch, not struct) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := &ASTQuery{ + Pattern: tt.pattern, + Language: "go", + } + results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go") + if err != nil { + t.Fatalf("match failed: %v", err) + } + if len(results) != tt.wantMatches { + t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results)) + for i, r := range results { + t.Logf("match %d: type=%s, text=%q", i, r.Node.Type(), truncateForLog(r.Text, 80)) + } + } + }) + } +} + +func truncateForLog(s string, max int) string { + if len(s) <= max { + return s + } + return s[:max] + "..." +} + +// TestPassesFilters_AllBranches tests passesFilters for each filter type. +func TestPassesFilters_AllBranches(t *testing.T) { + reg := parser.NewRegistry() + defer reg.Close() + + content := `package main + +func Alpha() {} +func Beta() {} +func Gamma() {} +` + + ctx := context.Background() + result, err := reg.Parse(ctx, "test.go", []byte(content)) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + matcher := NewMatcher(reg) + + tests := []struct { + name string + filters QueryFilters + wantMatches int + }{ + { + name: "no filters matches all", + filters: QueryFilters{}, + wantMatches: 3, + }, + { + name: "name_exact filter", + filters: QueryFilters{NameExact: "Alpha"}, + wantMatches: 1, + }, + { + name: "name_matches regex filter", + filters: QueryFilters{NameMatches: "^[AB]"}, + wantMatches: 2, + }, + { + name: "kind_in filter", + filters: QueryFilters{KindIn: []string{"function_declaration"}}, + wantMatches: 3, + }, + { + name: "kind_in filter excludes non-matching kinds", + filters: QueryFilters{KindIn: []string{"class_declaration"}}, + wantMatches: 0, + }, + { + name: "combined name_exact and kind_in", + filters: QueryFilters{NameExact: "Beta", KindIn: []string{"function_declaration"}}, + wantMatches: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := &ASTQuery{ + Pattern: "func $NAME() {}", + Language: "go", + Filters: tt.filters, + } + results, err := matcher.Match(ctx, query, result.Tree, []byte(content), "test.go") + if err != nil { + t.Fatalf("match failed: %v", err) + } + if len(results) != tt.wantMatches { + t.Errorf("expected %d matches, got %d", tt.wantMatches, len(results)) + } + }) + } +} + func TestQueryValidation(t *testing.T) { reg := parser.NewRegistry() defer reg.Close() diff --git a/internal/search/search.go b/internal/search/search.go index eea7cdb..903681b 100644 --- a/internal/search/search.go +++ b/internal/search/search.go @@ -131,6 +131,11 @@ func (s *Searcher) Search(ctx context.Context, req *Request) (*SearchResults, er WithRemediation("Provide a non-empty search pattern") } + // Validate that at least one provided path is allowed + if err := s.validatePaths(req.Paths); err != nil { + return nil, err + } + // Build ripgrep command args := s.buildArgs(req) @@ -238,6 +243,22 @@ func (s *Searcher) buildArgs(req *Request) []string { return args } +// validatePaths checks that at least one caller-provided path is allowed. +// Returns an error if paths were provided but none passed IsPathAllowed. +func (s *Searcher) validatePaths(paths []string) error { + if len(paths) == 0 { + return nil // no explicit paths — will default to workspace root + } + for _, p := range paths { + if s.cfg.IsPathAllowed(p) { + return nil + } + } + return errors.New(errors.ErrPathNotAllowed, "all provided search paths are outside the workspace root"). + WithContext("paths", fmt.Sprintf("%v", paths)). + WithRemediation("Provide paths within the workspace root") +} + // parseOutput parses ripgrep JSON output. func (s *Searcher) parseOutput(output *bytes.Buffer, maxResults int) (*SearchResults, error) { results := &SearchResults{ diff --git a/internal/search/search_test.go b/internal/search/search_test.go index ed58528..73f2a06 100644 --- a/internal/search/search_test.go +++ b/internal/search/search_test.go @@ -43,9 +43,10 @@ func TestBuildArgs(t *testing.T) { } tests := []struct { - name string - req *Request - expected []string + name string + req *Request + expected []string + notExpected []string // T-06: verify absence of unexpected flags }{ { name: "basic search", @@ -54,7 +55,8 @@ func TestBuildArgs(t *testing.T) { ContextLines: 2, Regex: true, }, - expected: []string{"--json", "--context=2", "--", "test", "."}, + expected: []string{"--json", "--context=2", "--", "test", "."}, + notExpected: []string{"--ignore-case", "--fixed-strings", "--max-total-count=0"}, }, { name: "ignore case", @@ -63,7 +65,8 @@ func TestBuildArgs(t *testing.T) { IgnoreCase: true, Regex: true, }, - expected: []string{"--json", "--ignore-case", "--", "test", "."}, + expected: []string{"--json", "--ignore-case", "--", "test", "."}, + notExpected: []string{"--fixed-strings"}, }, { name: "fixed strings", @@ -71,7 +74,8 @@ func TestBuildArgs(t *testing.T) { Pattern: "test", Regex: false, }, - expected: []string{"--json", "--fixed-strings", "--", "test", "."}, + expected: []string{"--json", "--fixed-strings", "--", "test", "."}, + notExpected: []string{"--ignore-case"}, }, { name: "with file types", @@ -80,7 +84,8 @@ func TestBuildArgs(t *testing.T) { FileTypes: []string{"go", "ts"}, Regex: true, }, - expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."}, + expected: []string{"--json", "--type", "go", "--type", "ts", "--", "test", "."}, + notExpected: []string{"--ignore-case", "--fixed-strings"}, }, { name: "with max results", @@ -89,7 +94,8 @@ func TestBuildArgs(t *testing.T) { MaxResults: 10, Regex: true, }, - expected: []string{"--json", "--max-total-count=10", "--", "test", "."}, + expected: []string{"--json", "--max-total-count=10", "--", "test", "."}, + notExpected: []string{"--ignore-case", "--fixed-strings"}, }, } @@ -110,6 +116,16 @@ func TestBuildArgs(t *testing.T) { t.Errorf("expected arg %q not found in %v", exp, args) } } + + // T-06: Check that unexpected args are absent + for _, notExp := range tt.notExpected { + for _, arg := range args { + if arg == notExp { + t.Errorf("unexpected arg %q found in %v", notExp, args) + break + } + } + } }) } } diff --git a/internal/server/handlers_ast.go b/internal/server/handlers_ast.go index 160195e..92f9f33 100644 --- a/internal/server/handlers_ast.go +++ b/internal/server/handlers_ast.go @@ -54,9 +54,9 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest } // Find files to search based on language - ext := languageToExtension(language) - if ext == "" { - return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s", language)), nil + exts := languageToExtensions(language) + if len(exts) == 0 { + return mcp.NewToolResultError(fmt.Sprintf("unsupported language: %s (supported: go, typescript, javascript, python, c, cpp, html, vue, elixir)", language)), nil } var allResults []query.MatchResult @@ -89,7 +89,14 @@ func (s *Server) handleASTQuery(ctx context.Context, request mcp.CallToolRequest } // Check file extension matches language - if !strings.HasSuffix(path, ext) { + matched := false + for _, ext := range exts { + if strings.HasSuffix(path, ext) { + matched = true + break + } + } + if !matched { return nil } @@ -203,24 +210,28 @@ func symbolKindIcon(kind protocol.SymbolKind) string { } } -// languageToExtension maps language names to file extensions. -func languageToExtension(language string) string { +// languageToExtensions maps language names to file extensions. +func languageToExtensions(language string) []string { switch strings.ToLower(language) { case "go": - return ".go" + return []string{".go"} case "typescript": - return ".ts" + return []string{".ts"} case "javascript": - return ".js" + return []string{".js"} case "python": - return ".py" + return []string{".py"} case "c": - return ".c" + return []string{".c"} case "cpp", "c++": - return ".cpp" + return []string{".cpp"} + case "html": + return []string{".html", ".htm"} + case "vue": + return []string{".vue"} case "elixir": - return ".ex" + return []string{".ex", ".exs"} default: - return "" + return nil } } diff --git a/internal/server/handlers_edit.go b/internal/server/handlers_edit.go index edd702d..2549f88 100644 --- a/internal/server/handlers_edit.go +++ b/internal/server/handlers_edit.go @@ -33,6 +33,16 @@ func (s *Server) handleEdit(ctx context.Context, request mcp.CallToolRequest, ap return mcp.NewToolResultError("operation is required"), nil } + // Validate operation against known values + switch edit.EditOperation(operation) { + case edit.EditReplace, edit.EditInsertBefore, edit.EditInsertAfter, edit.EditDelete: + // valid + default: + return mcp.NewToolResultError(fmt.Sprintf( + "invalid operation %q: must be one of: replace, insert_before, insert_after, delete", operation, + )), nil + } + // Validate path if !s.cfg.IsPathAllowed(file) { return mcp.NewToolResultError("file is outside workspace root"), nil diff --git a/internal/server/handlers_file.go b/internal/server/handlers_file.go index c63bf79..d7df5b9 100644 --- a/internal/server/handlers_file.go +++ b/internal/server/handlers_file.go @@ -81,8 +81,8 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest return mcp.NewToolResultError("path is outside workspace root"), nil } - // Read file - content, err := os.ReadFile(path) + // Check file size before reading to avoid loading huge files into memory + info, err := os.Stat(path) if err != nil { if os.IsNotExist(err) { return mcp.NewToolResultError(fmt.Sprintf("file not found: %s", path)), nil @@ -90,13 +90,21 @@ func (s *Server) handleFileRead(ctx context.Context, request mcp.CallToolRequest if os.IsPermission(err) { return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil } - s.logger.Warn("file read error", "path", path, "error", err) - return mcp.NewToolResultError("error reading file"), nil + s.logger.Warn("file stat error", "path", path, "error", err) + return mcp.NewToolResultError("error accessing file"), nil + } + if info.Size() > s.cfg.MaxFileSize { + return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", info.Size(), s.cfg.MaxFileSize)), nil } - // Check file size - if int64(len(content)) > s.cfg.MaxFileSize { - return mcp.NewToolResultError(fmt.Sprintf("file too large (%d bytes, max %d)", len(content), s.cfg.MaxFileSize)), nil + // Read file + content, err := os.ReadFile(path) + if err != nil { + if os.IsPermission(err) { + return mcp.NewToolResultError(fmt.Sprintf("permission denied: %s", path)), nil + } + s.logger.Warn("file read error", "path", path, "error", err) + return mcp.NewToolResultError("error reading file"), nil } // Handle line range @@ -167,13 +175,18 @@ func splitLines(s string) []string { const largeSizeThreshold = 1024 * 1024 // 1MB if len(s) > largeSizeThreshold { - // Use scanner for large files + // Use scanner for large files with increased buffer for long lines scanner := bufio.NewScanner(strings.NewReader(s)) + scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), 1024*1024) // up to 1MB per line var lines []string for scanner.Scan() { lines = append(lines, scanner.Text()) } - // Handle potential error and add empty line if string ended with newline + if err := scanner.Err(); err != nil { + // If scanning fails (e.g. line exceeds buffer), fall back to strings.Split + return strings.Split(s, "\n") + } + // Add empty line if string ended with newline if len(s) > 0 && s[len(s)-1] == '\n' { lines = append(lines, "") } diff --git a/internal/server/handlers_lsp.go b/internal/server/handlers_lsp.go index abd9c7d..4067d08 100644 --- a/internal/server/handlers_lsp.go +++ b/internal/server/handlers_lsp.go @@ -117,7 +117,7 @@ func (s *Server) handleFindDefinition(ctx context.Context, request mcp.CallToolR output.WriteString(fmt.Sprintf("**%s:%d:%d**\n", filePath, loc.Range.Start.Line+1, loc.Range.Start.Character+1)) // Try to read a preview snippet - preview := readFilePreview(filePath, loc.Range.Start.Line+1, 3) + preview := s.readFilePreview(filePath, loc.Range.Start.Line+1, 3) if preview != "" { output.WriteString("```\n") output.WriteString(preview) @@ -184,7 +184,13 @@ func (s *Server) handleFindReferences(ctx context.Context, request mcp.CallToolR } // readFilePreview reads a few lines from a file around the given line. -func readFilePreview(file string, line, contextLines int) string { +// It validates that the file path is within the allowed workspace before reading. +func (s *Server) readFilePreview(file string, line, contextLines int) string { + if !s.cfg.IsPathAllowed(file) { + s.logger.Warn("readFilePreview: path not allowed", "path", file) + return "" + } + content, err := os.ReadFile(file) if err != nil { return "" diff --git a/internal/server/server.go b/internal/server/server.go index 4ae29ea..df3852e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -169,7 +169,7 @@ func (s *Server) registerTools() { ), mcp.WithString("language", mcp.Required(), - mcp.Description("Target language: go, typescript, javascript, python, c, cpp"), + mcp.Description("Target language: go, typescript, javascript, python, c, cpp, html, vue, elixir"), ), mcp.WithArray("paths", mcp.Description("Paths to search in (defaults to workspace root)"), diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ff15327..2df72af 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -5,6 +5,7 @@ import ( "log/slog" "os" "path/filepath" + "strings" "testing" "github.com/lukaszraczylo/mcp-filepuff/internal/config" @@ -381,3 +382,155 @@ func Hello() { t.Error("handleEdit(apply) should modify the file") } } + +// TestHandleFileReadMaxFileSize verifies that handleFileRead rejects files +// exceeding MaxFileSize via os.Stat before loading them into memory (T-03, S-01). +func TestHandleFileReadMaxFileSize(t *testing.T) { + tmpDir := t.TempDir() + + // Create a test file + testFile := filepath.Join(tmpDir, "big.txt") + content := make([]byte, 1024) // 1KB file + for i := range content { + content[i] = 'A' + } + if err := os.WriteFile(testFile, content, 0600); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + // Set MaxFileSize smaller than the file + cfg := &config.Config{ + WorkspaceRoot: tmpDir, + EnableLSP: false, + MaxFileSize: 512, // 512 bytes — smaller than our 1KB file + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + ctx := context.Background() + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{ + "path": testFile, + } + + result, err := srv.handleFileRead(ctx, req) + if err != nil { + t.Fatalf("handleFileRead() returned Go error: %v", err) + } + + // The result should indicate an error (file too large) + if result == nil { + t.Fatal("handleFileRead() returned nil result") + } + if !result.IsError { + t.Error("expected IsError=true for file exceeding MaxFileSize") + } + contents := result.Content + if len(contents) == 0 { + t.Fatal("expected non-empty content with error message") + } + textContent, ok := contents[0].(mcp.TextContent) + if !ok { + t.Fatal("expected text content") + } + if !strings.Contains(textContent.Text, "too large") { + t.Errorf("expected 'too large' error message, got: %s", textContent.Text) + } +} + +// TestSplitLinesLargeFile tests the splitLines function with a large file (>1MB) +// to exercise the bufio.Scanner path including the scanner.Err() check (T-07, C-05). +func TestSplitLinesLargeFile(t *testing.T) { + // Build a string >1MB with known line count + lineCount := 20000 + var sb strings.Builder + for i := 0; i < lineCount; i++ { + sb.WriteString(strings.Repeat("x", 60)) + sb.WriteByte('\n') + } + largeContent := sb.String() + + // Verify it's large enough to trigger the scanner path + if len(largeContent) <= 1024*1024 { + t.Fatalf("test content too small: %d bytes, need >1MB", len(largeContent)) + } + + lines := splitLines(largeContent) + + // String ends with \n, so splitLines adds an empty trailing element + // (matching the behavior of strings.Split for the small-file path) + expectedLines := lineCount + 1 // lineCount lines + 1 trailing empty + if len(lines) != expectedLines { + t.Errorf("splitLines returned %d lines, expected %d", len(lines), expectedLines) + } + + // Check first and last actual lines + if lines[0] != strings.Repeat("x", 60) { + t.Errorf("first line mismatch: got %q", lines[0][:20]) + } + if lines[lineCount-1] != strings.Repeat("x", 60) { + t.Errorf("last content line mismatch: got %q", lines[lineCount-1][:20]) + } + if lines[lineCount] != "" { + t.Errorf("expected empty trailing line, got %q", lines[lineCount]) + } +} + +// TestSplitLinesLargeFileNoTrailingNewline verifies splitLines for large files +// without a trailing newline. +func TestSplitLinesLargeFileNoTrailingNewline(t *testing.T) { + lineCount := 20000 + var sb strings.Builder + for i := 0; i < lineCount; i++ { + if i > 0 { + sb.WriteByte('\n') + } + sb.WriteString(strings.Repeat("y", 60)) + } + largeContent := sb.String() + + if len(largeContent) <= 1024*1024 { + t.Fatalf("test content too small: %d bytes", len(largeContent)) + } + + lines := splitLines(largeContent) + if len(lines) != lineCount { + t.Errorf("splitLines returned %d lines, expected %d", len(lines), lineCount) + } +} + +// TestSplitLinesLongLine verifies the scanner gracefully handles very long lines +// (up to the 1MB buffer limit set in C-05 fix). +func TestSplitLinesLongLine(t *testing.T) { + // Create content with one very long line (500KB) embedded in a >1MB file + shortLines := strings.Repeat("short line content\n", 60000) // ~60KB * ~1 = ~1.08MB + longLine := strings.Repeat("L", 500*1024) // 500KB line + largeContent := shortLines + longLine + "\n" + + if len(largeContent) <= 1024*1024 { + t.Fatalf("test content too small: %d bytes", len(largeContent)) + } + + lines := splitLines(largeContent) + + // Should not crash and should return some lines + if len(lines) == 0 { + t.Fatal("splitLines returned no lines for valid content") + } + + // The long line should be present somewhere in the output + foundLong := false + for _, line := range lines { + if len(line) >= 500*1024 { + foundLong = true + break + } + } + if !foundLong { + t.Error("the 500KB long line was not found in splitLines output") + } +} diff --git a/internal/util/regex_cache.go b/internal/util/regex_cache.go index e0e08bf..2c92406 100644 --- a/internal/util/regex_cache.go +++ b/internal/util/regex_cache.go @@ -5,7 +5,6 @@ import ( "fmt" "regexp" "sync" - "sync/atomic" ) const ( @@ -19,10 +18,11 @@ const ( ) // regexCache is a global thread-safe cache for compiled regular expressions. -// Caching regex compilation provides 10-50x speedup for repeated patterns. +// Uses sync.RWMutex with a regular map so that ClearRegexCache can atomically +// clear the map and reset the count in a single lock acquisition. var ( - regexCache sync.Map // string -> *regexp.Regexp - cacheSize atomic.Int64 + cacheMu sync.RWMutex + regexCache = make(map[string]*regexp.Regexp) ) // RegexError represents an error during regex compilation or validation. @@ -62,7 +62,7 @@ func ValidatePattern(pattern string) error { } // CompileRegex compiles a regex pattern with caching and validation for security. -// Thread-safe: uses LoadOrStore to prevent race conditions. +// Thread-safe: uses RWMutex to prevent race conditions. // Returns the compiled regex or an error if the pattern is invalid or unsafe. func CompileRegex(pattern string) (*regexp.Regexp, error) { // Validate pattern first @@ -70,12 +70,15 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) { return nil, err } - // Check cache first - if cached, ok := regexCache.Load(pattern); ok { - return cached.(*regexp.Regexp), nil + // Check cache first (read lock) + cacheMu.RLock() + if cached, ok := regexCache[pattern]; ok { + cacheMu.RUnlock() + return cached, nil } + cacheMu.RUnlock() - // Compile regex + // Compile regex outside the lock to avoid holding it during compilation re, err := regexp.Compile(pattern) if err != nil { return nil, &RegexError{ @@ -85,18 +88,22 @@ func CompileRegex(pattern string) (*regexp.Regexp, error) { } } - // Check cache size and clear if too large - if cacheSize.Load() >= MaxCacheSize { - ClearRegexCache() + // Write lock to store in cache + cacheMu.Lock() + // Re-check in case another goroutine stored it while we were compiling + if cached, ok := regexCache[pattern]; ok { + cacheMu.Unlock() + return cached, nil } - // Try to store - if another goroutine already stored it, use theirs - // This prevents race conditions where multiple goroutines compile the same pattern - actual, loaded := regexCache.LoadOrStore(pattern, re) - if !loaded { - cacheSize.Add(1) + // Check cache size and clear if too large + if len(regexCache) >= MaxCacheSize { + regexCache = make(map[string]*regexp.Regexp) } - return actual.(*regexp.Regexp), nil + + regexCache[pattern] = re + cacheMu.Unlock() + return re, nil } // CompileRegexUncached compiles a regex pattern without caching. @@ -118,18 +125,19 @@ func CompileRegexUncached(pattern string) (*regexp.Regexp, error) { } // ClearRegexCache clears all cached compiled regular expressions. -// Useful for testing or when memory usage needs to be reduced. +// Atomically replaces the map under a single write lock. func ClearRegexCache() { - regexCache.Range(func(key, _ interface{}) bool { - regexCache.Delete(key) - return true - }) - cacheSize.Store(0) + cacheMu.Lock() + regexCache = make(map[string]*regexp.Regexp) + cacheMu.Unlock() } // CacheStats returns the current number of cached patterns. func CacheStats() int64 { - return cacheSize.Load() + cacheMu.RLock() + n := int64(len(regexCache)) + cacheMu.RUnlock() + return n } // truncatePattern truncates a pattern for display in error messages.