From 9217bf35f3d3ac8041f2516fe7b45c0d8ab8645c Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sun, 11 Jan 2026 01:32:35 +0000 Subject: [PATCH] fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation --- internal/mcp/server.go | 27 +++++++++++---------- internal/worker/sdk/processor.go | 40 +++++++++++++++++++++++++++----- internal/worker/service.go | 7 +++++- 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 5864be4..8ba6c70 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -2620,8 +2620,21 @@ func (s *Server) handleExportObservations(ctx context.Context, args json.RawMess } lines = append(lines, string(line)) } - output = fmt.Sprintf(`{"format":"jsonl","count":%d,"data":"%s"}`, - len(observations), escapeJSONString(strings.Join(lines, "\n"))) + // Use proper JSON marshaling to avoid injection issues + jsonlOutput := struct { + Format string `json:"format"` + Data string `json:"data"` + Count int `json:"count"` + }{ + Format: "jsonl", + Count: len(observations), + Data: strings.Join(lines, "\n"), + } + outputBytes, err := json.Marshal(jsonlOutput) + if err != nil { + return "", fmt.Errorf("marshal jsonl output: %w", err) + } + output = string(outputBytes) case "markdown": // Markdown format for human reading @@ -2685,16 +2698,6 @@ func (s *Server) handleExportObservations(ctx context.Context, args json.RawMess return output, nil } -// escapeJSONString escapes a string for use in JSON. -func escapeJSONString(s string) string { - b, _ := json.Marshal(s) - // Remove surrounding quotes - if len(b) >= 2 { - return string(b[1 : len(b)-1]) - } - return s -} - // handleCheckSystemHealth performs comprehensive system health checks. func (s *Server) handleCheckSystemHealth(ctx context.Context) (string, error) { type SubsystemHealth struct { diff --git a/internal/worker/sdk/processor.go b/internal/worker/sdk/processor.go index c83d805..acc6ae7 100644 --- a/internal/worker/sdk/processor.go +++ b/internal/worker/sdk/processor.go @@ -782,6 +782,32 @@ func toJSONString(v any) string { return string(b) } +// safeResolvePath resolves a path relative to cwd and validates it doesn't escape the cwd directory. +// Returns the resolved absolute path and true if valid, or empty string and false if path traversal detected. +func safeResolvePath(path, cwd string) (string, bool) { + if filepath.IsAbs(path) { + // Absolute paths are allowed as-is + return filepath.Clean(path), true + } + if cwd == "" { + return filepath.Clean(path), true + } + + // Clean the cwd first + cleanCwd := filepath.Clean(cwd) + + // Join and clean the path + absPath := filepath.Clean(filepath.Join(cleanCwd, path)) + + // Verify the resolved path is still within cwd (prevents path traversal via ..) + // Use HasPrefix on cleaned paths to detect escapes + if !strings.HasPrefix(absPath, cleanCwd+string(filepath.Separator)) && absPath != cleanCwd { + return "", false + } + + return absPath, true +} + // captureFileMtimes captures current modification times for tracked files. // Returns a map of absolute file paths to their mtime in epoch milliseconds. // For large file lists (>10 files), uses parallel stat calls for better performance. @@ -809,9 +835,10 @@ func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[stri mtimes := make(map[string]int64, len(paths)) for path := range paths { - absPath := path - if !filepath.IsAbs(path) && cwd != "" { - absPath = filepath.Join(cwd, path) + absPath, ok := safeResolvePath(path, cwd) + if !ok { + // Skip paths that attempt directory traversal + continue } info, err := os.Stat(absPath) @@ -841,9 +868,10 @@ func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string sem <- struct{}{} // Acquire defer func() { <-sem }() // Release - absPath := p - if !filepath.IsAbs(p) && cwd != "" { - absPath = filepath.Join(cwd, p) + absPath, ok := safeResolvePath(p, cwd) + if !ok { + // Skip paths that attempt directory traversal + return } info, err := os.Stat(absPath) diff --git a/internal/worker/service.go b/internal/worker/service.go index 2088ee1..6db9a8f 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -1443,7 +1443,12 @@ func (s *Service) getRecentSearchQueries(project string, limit int) []RecentSear } // Filter by project (iterate from newest to oldest) - result := make([]RecentSearchQuery, 0, limit) + // Cap capacity to maxRecentQueries to prevent excessive allocation from user input + capacity := limit + if capacity > maxRecentQueries { + capacity = maxRecentQueries + } + result := make([]RecentSearchQuery, 0, capacity) for i := 0; i < s.recentQueriesLen; i++ { idx := (s.recentQueriesHead + i) % maxRecentQueries q := s.recentQueriesBuf[idx]