diff --git a/Makefile b/Makefile index 6330430..73544ee 100644 --- a/Makefile +++ b/Makefile @@ -108,6 +108,9 @@ build-windows: # Stop any running worker stop-worker: @echo "Stopping worker..." + @-pkill -TERM -f 'claude-mnemonic.*worker' 2>/dev/null || true + @-pkill -TERM -f '\.claude/plugins/.*/worker' 2>/dev/null || true + @sleep 1 @-pkill -9 -f 'claude-mnemonic.*worker' 2>/dev/null || true @-pkill -9 -f '\.claude/plugins/.*/worker' 2>/dev/null || true @-lsof -ti :37777 | xargs kill -9 2>/dev/null || true @@ -135,6 +138,10 @@ restart-worker: stop-worker start-worker # Install to Claude plugins directory install: build stop-worker @echo "Installing to Claude plugins directory..." + @# Verify build output binaries exist + @test -f $(BUILD_DIR)/worker || { echo "ERROR: $(BUILD_DIR)/worker not found. Build may have failed."; exit 1; } + @test -f $(BUILD_DIR)/mcp-server || { echo "ERROR: $(BUILD_DIR)/mcp-server not found. Build may have failed."; exit 1; } + @test -d $(BUILD_DIR)/hooks || { echo "ERROR: $(BUILD_DIR)/hooks not found. Build may have failed."; exit 1; } @# Install to marketplaces directory (for direct installs) @mkdir -p $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/hooks @mkdir -p $(HOME)/.claude/plugins/marketplaces/claude-mnemonic/.claude-plugin diff --git a/cmd/hooks/post-tool-use/main.go b/cmd/hooks/post-tool-use/main.go index c5fe8a3..7b2a903 100644 --- a/cmd/hooks/post-tool-use/main.go +++ b/cmd/hooks/post-tool-use/main.go @@ -2,8 +2,10 @@ package main import ( + "context" "fmt" "os" + "time" "github.com/lukaszraczylo/claude-mnemonic/pkg/hooks" ) @@ -51,6 +53,10 @@ var skipTools = map[string]bool{ } func main() { + if !hooks.IsWorkerAvailable() { + hooks.WriteResponse("PostToolUse", true) + return + } hooks.RunHook("PostToolUse", handlePostToolUse) } @@ -63,15 +69,31 @@ func handlePostToolUse(ctx *hooks.HookContext, input *Input) (string, error) { fmt.Fprintf(os.Stderr, "[post-tool-use] %s\n", input.ToolName) - // Send observation to worker - _, err := hooks.POST(ctx.Port, "/api/sessions/observations", map[string]interface{}{ - "claudeSessionId": ctx.SessionID, - "project": ctx.Project, - "tool_name": input.ToolName, - "tool_input": input.ToolInput, - "tool_response": input.ToolResponse, - "cwd": ctx.CWD, - }) + // Fire-and-forget: send the observation without waiting for the response. + // The worker just queues it -- we don't need the response data. + // Use a short-lived context to ensure the request body is at least sent + // before this process exits. + done := make(chan struct{}) + go func() { + defer close(done) + sendCtx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _ = hooks.POSTWithContext(sendCtx, ctx.Port, "/api/sessions/observations", map[string]interface{}{ + "claudeSessionId": ctx.SessionID, + "project": ctx.Project, + "tool_name": input.ToolName, + "tool_input": input.ToolInput, + "tool_response": input.ToolResponse, + "cwd": ctx.CWD, + }) + }() - return "", err + // Wait briefly for the TCP connection to be established and request sent, + // but don't block the hook for the full response. + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + + return "", nil } diff --git a/cmd/hooks/session-start/main.go b/cmd/hooks/session-start/main.go index 2367e79..18295de 100644 --- a/cmd/hooks/session-start/main.go +++ b/cmd/hooks/session-start/main.go @@ -6,6 +6,7 @@ import ( "net/url" "os" "strings" + "time" "github.com/lukaszraczylo/claude-mnemonic/pkg/hooks" ) @@ -27,10 +28,17 @@ type Observation struct { } func main() { + if !hooks.IsWorkerAvailable() { + hooks.WriteResponse("SessionStart", true) + return + } hooks.RunHook("SessionStart", handleSessionStart) } func handleSessionStart(ctx *hooks.HookContext, input *Input) (string, error) { + deadline, cancel := hooks.HookDeadline(30 * time.Second) + defer cancel() + // Fetch observations for context injection endpoint := fmt.Sprintf("/api/context/inject?project=%s&cwd=%s", url.QueryEscape(ctx.Project), @@ -59,12 +67,21 @@ func handleSessionStart(ctx *hooks.HookContext, input *Input) (string, error) { fmt.Fprintf(os.Stderr, "[claude-mnemonic] Injecting %d observations from project memory (%d detailed, %d condensed)\n", len(obsData), min(fullCount, len(obsData)), max(0, len(obsData)-fullCount)) + // Token budget for context injection + maxTokens := 16000 // default; could be made configurable via worker config endpoint + currentTokens := 0 + // Build context string - contextBuilder := "\n" - contextBuilder += fmt.Sprintf("# Project Memory (%d observations)\n", len(obsData)) - contextBuilder += "Use this knowledge to answer questions without re-exploring the codebase.\n\n" + header := fmt.Sprintf("\n# Project Memory (%d observations)\nUse this knowledge to answer questions without re-exploring the codebase.\n\n", len(obsData)) + currentTokens += estimateTokens(header) + contextBuilder := header for i, o := range obsData { + if deadline.Err() != nil { + contextBuilder += "\n... (returning early due to time limit)\n" + break + } + obs, ok := o.(map[string]interface{}) if !ok { continue @@ -73,40 +90,94 @@ func handleSessionStart(ctx *hooks.HookContext, input *Input) (string, error) { title := getString(obs, "title") obsType := getString(obs, "type") + var obsText string + // First `fullCount` observations get full detail, rest are condensed if i < fullCount { // Full detail: include narrative and facts narrative := getString(obs, "narrative") - contextBuilder += fmt.Sprintf("## %d. [%s] %s\n", i+1, strings.ToUpper(obsType), title) + obsText = fmt.Sprintf("## %d. [%s] %s\n", i+1, strings.ToUpper(obsType), title) if narrative != "" { - contextBuilder += narrative + "\n" + obsText += narrative + "\n" } if facts, ok := obs["facts"].([]interface{}); ok && len(facts) > 0 { - contextBuilder += "Key facts:\n" + obsText += "Key facts:\n" for _, f := range facts { if fact, ok := f.(string); ok && fact != "" { - contextBuilder += fmt.Sprintf("- %s\n", fact) + obsText += fmt.Sprintf("- %s\n", fact) } } } - contextBuilder += "\n" + obsText += "\n" } else { // Condensed: just title and subtitle (one line) subtitle := getString(obs, "subtitle") if subtitle != "" { - contextBuilder += fmt.Sprintf("- [%s] %s: %s\n", strings.ToUpper(obsType), title, subtitle) + obsText = fmt.Sprintf("- [%s] %s: %s\n", strings.ToUpper(obsType), title, subtitle) } else { - contextBuilder += fmt.Sprintf("- [%s] %s\n", strings.ToUpper(obsType), title) + obsText = fmt.Sprintf("- [%s] %s\n", strings.ToUpper(obsType), title) } } + + obsTokens := estimateTokens(obsText) + if currentTokens+obsTokens > maxTokens { + contextBuilder += fmt.Sprintf("\n... (%d more observations omitted due to token budget)\n", len(obsData)-i) + break + } + + contextBuilder += obsText + currentTokens += obsTokens } contextBuilder += "\n" return contextBuilder, nil } +// estimateTokens provides a more accurate token count estimate. +// Uses word count * 1.3 as base, with adjustments for code and non-ASCII. +func estimateTokens(s string) int { + if len(s) == 0 { + return 0 + } + + // Count words (split on whitespace) + words := len(strings.Fields(s)) + if words == 0 { + // No whitespace = probably a single token or code blob + return (len(s) + 3) / 4 + } + + // Base estimate: ~1.3 tokens per word for English text + estimate := int(float64(words) * 1.3) + + // Detect code-heavy content (high non-alpha ratio) + nonAlpha := 0 + nonASCII := 0 + for _, r := range s { + if r > 127 { + nonASCII++ + } else if !('a' <= r && r <= 'z') && !('A' <= r && r <= 'Z') && !('0' <= r && r <= '9') && r != ' ' { + nonAlpha++ + } + } + + totalChars := len(s) + + // Code adjustment: more special chars = more tokens per word + if totalChars > 0 && float64(nonAlpha)/float64(totalChars) > 0.15 { + estimate = int(float64(estimate) * 1.3) + } + + // Non-ASCII adjustment: CJK and other scripts use more tokens + if totalChars > 0 && float64(nonASCII)/float64(totalChars) > 0.1 { + estimate += nonASCII // Roughly 1 extra token per non-ASCII char + } + + return estimate +} + func getString(m map[string]interface{}, key string) string { if v, ok := m[key].(string); ok { return v diff --git a/cmd/hooks/statusline/main.go b/cmd/hooks/statusline/main.go index d7b3bb4..dc555b2 100644 --- a/cmd/hooks/statusline/main.go +++ b/cmd/hooks/statusline/main.go @@ -117,7 +117,7 @@ func getWorkerStats(port int, project string) *WorkerStats { if err != nil { return nil } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil diff --git a/cmd/hooks/stop/main.go b/cmd/hooks/stop/main.go index 88b9293..f2a8c71 100644 --- a/cmd/hooks/stop/main.go +++ b/cmd/hooks/stop/main.go @@ -5,12 +5,16 @@ import ( "bufio" "encoding/json" "fmt" + "io" "os" "strings" + "time" "github.com/lukaszraczylo/claude-mnemonic/pkg/hooks" ) +var debug = os.Getenv("CLAUDE_MNEMONIC_DEBUG") != "" + // Input is the hook input from Claude Code. type Input struct { hooks.BaseInput @@ -62,7 +66,19 @@ func parseTranscript(path string) (lastUser, lastAssistant string) { if err != nil { return "", "" } - defer file.Close() + defer func() { _ = file.Close() }() + + // For large transcripts, seek to the last 256KB for efficiency. + // We only need the last user/assistant messages, not the entire history. + const tailSize = 256 * 1024 + info, err := file.Stat() + if err == nil && info.Size() > tailSize { + if _, seekErr := file.Seek(-tailSize, io.SeekEnd); seekErr == nil { + // Discard partial first line after seek + discardScanner := bufio.NewScanner(file) + discardScanner.Scan() + } + } scanner := bufio.NewScanner(file) // Increase buffer size for large messages @@ -97,12 +113,20 @@ func parseTranscript(path string) (lastUser, lastAssistant string) { } func main() { + if !hooks.IsWorkerAvailable() { + hooks.WriteResponse("Stop", true) + return + } hooks.RunHook("Stop", handleStop) } func handleStop(ctx *hooks.HookContext, input *Input) (string, error) { - // Debug: dump raw input - fmt.Fprintf(os.Stderr, "[stop] Raw input: %s\n", string(ctx.RawInput)) + deadline, cancel := hooks.HookDeadline(30 * time.Second) + defer cancel() + + if debug { + fmt.Fprintf(os.Stderr, "[stop] Raw input: %s\n", string(ctx.RawInput)) + } // Find session result, err := hooks.GET(ctx.Port, fmt.Sprintf("/api/sessions?claudeSessionId=%s", ctx.SessionID)) @@ -122,18 +146,33 @@ func handleStop(ctx *hooks.HookContext, input *Input) (string, error) { lastUser, lastAssistant = parseTranscript(input.TranscriptPath) } - // Debug: log what we extracted - fmt.Fprintf(os.Stderr, "[stop] Transcript path: %s\n", input.TranscriptPath) - fmt.Fprintf(os.Stderr, "[stop] Last user message length: %d\n", len(lastUser)) - fmt.Fprintf(os.Stderr, "[stop] Last assistant message length: %d\n", len(lastAssistant)) - if len(lastAssistant) > 0 { - preview := lastAssistant - if len(preview) > 300 { - preview = preview[:300] + "..." - } - fmt.Fprintf(os.Stderr, "[stop] Last assistant preview: %s\n", preview) + // Truncate messages to avoid sending excessive data to the worker + if len(lastAssistant) > 10000 { + lastAssistant = lastAssistant[:10000] + } + if len(lastUser) > 5000 { + lastUser = lastUser[:5000] + } + + if debug { + fmt.Fprintf(os.Stderr, "[stop] Transcript path: %s\n", input.TranscriptPath) + fmt.Fprintf(os.Stderr, "[stop] Last user message length: %d\n", len(lastUser)) + fmt.Fprintf(os.Stderr, "[stop] Last assistant message length: %d\n", len(lastAssistant)) + if len(lastAssistant) > 0 { + preview := lastAssistant + if len(preview) > 300 { + preview = preview[:300] + "..." + } + fmt.Fprintf(os.Stderr, "[stop] Last assistant preview: %s\n", preview) + } + fmt.Fprintf(os.Stderr, "[stop] Requesting summary for session %d (transcript: %v)\n", int64(sessionID), input.TranscriptPath != "") + } + + // Check deadline before expensive summary request + if deadline.Err() != nil { + fmt.Fprintf(os.Stderr, "[stop] Returning early due to time limit\n") + return "", nil } - fmt.Fprintf(os.Stderr, "[stop] Requesting summary for session %d (transcript: %v)\n", int64(sessionID), input.TranscriptPath != "") // Request summary with message context from transcript _, err = hooks.POST(ctx.Port, fmt.Sprintf("/sessions/%d/summarize", int64(sessionID)), map[string]interface{}{ diff --git a/cmd/hooks/subagent-stop/main.go b/cmd/hooks/subagent-stop/main.go index d437030..8440f4c 100644 --- a/cmd/hooks/subagent-stop/main.go +++ b/cmd/hooks/subagent-stop/main.go @@ -16,6 +16,10 @@ type Input struct { } func main() { + if !hooks.IsWorkerAvailable() { + hooks.WriteResponse("SubagentStop", true) + return + } hooks.RunHook("SubagentStop", handleSubagentStop) } diff --git a/cmd/hooks/user-prompt/main.go b/cmd/hooks/user-prompt/main.go index b2093a2..201441c 100644 --- a/cmd/hooks/user-prompt/main.go +++ b/cmd/hooks/user-prompt/main.go @@ -5,6 +5,9 @@ import ( "fmt" "net/url" "os" + "strings" + "sync" + "time" "github.com/lukaszraczylo/claude-mnemonic/pkg/hooks" ) @@ -15,31 +18,117 @@ type Input struct { Prompt string `json:"prompt"` } +// estimateTokens provides a more accurate token count estimate. +// Uses word count * 1.3 as base, with adjustments for code and non-ASCII. +func estimateTokens(s string) int { + if len(s) == 0 { + return 0 + } + + // Count words (split on whitespace) + words := len(strings.Fields(s)) + if words == 0 { + // No whitespace = probably a single token or code blob + return (len(s) + 3) / 4 + } + + // Base estimate: ~1.3 tokens per word for English text + estimate := int(float64(words) * 1.3) + + // Detect code-heavy content (high non-alpha ratio) + nonAlpha := 0 + nonASCII := 0 + for _, r := range s { + if r > 127 { + nonASCII++ + } else if !('a' <= r && r <= 'z') && !('A' <= r && r <= 'Z') && !('0' <= r && r <= '9') && r != ' ' { + nonAlpha++ + } + } + + totalChars := len(s) + + // Code adjustment: more special chars = more tokens per word + if totalChars > 0 && float64(nonAlpha)/float64(totalChars) > 0.15 { + estimate = int(float64(estimate) * 1.3) + } + + // Non-ASCII adjustment: CJK and other scripts use more tokens + if totalChars > 0 && float64(nonASCII)/float64(totalChars) > 0.1 { + estimate += nonASCII // Roughly 1 extra token per non-ASCII char + } + + return estimate +} + func main() { + if !hooks.IsWorkerAvailable() { + hooks.WriteResponse("UserPromptSubmit", true) + return + } hooks.RunHook("UserPromptSubmit", handleUserPrompt) } func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { - // Search for relevant observations based on the prompt + deadline, cancel := hooks.HookDeadline(10 * time.Second) + defer cancel() + searchURL := fmt.Sprintf("/api/context/search?project=%s&query=%s&cwd=%s", url.QueryEscape(ctx.Project), url.QueryEscape(input.Prompt), url.QueryEscape(ctx.CWD)) - var contextToInject string - var observationCount int + // Run search and session init concurrently. + // Session init doesn't strictly depend on search results -- the observation + // count passed is approximate (0) and acceptable. + var ( + wg sync.WaitGroup + searchResult map[string]interface{} + initResult map[string]interface{} + initErr error + contextToInject string + observationCount int + ) - searchResult, _ := hooks.GET(ctx.Port, searchURL) + // Start search in background + wg.Add(1) + go func() { + defer wg.Done() + searchResult, _ = hooks.GET(ctx.Port, searchURL) + }() + + // Start session init in parallel (with observationCount=0; approximate is fine) + wg.Add(1) + go func() { + defer wg.Done() + initResult, initErr = hooks.POST(ctx.Port, "/api/sessions/init", map[string]interface{}{ + "claudeSessionId": ctx.SessionID, + "project": ctx.Project, + "prompt": input.Prompt, + "matchedObservations": 0, + }) + }() + + // Wait for both to complete + wg.Wait() + + // Check deadline after network calls + if deadline.Err() != nil { + return "", nil + } + + // Process search results if observations, ok := searchResult["observations"].([]interface{}); ok && len(observations) > 0 { - // Results are already filtered by relevance threshold and capped by max_results - // from the server-side config (ContextRelevanceThreshold, ContextMaxPromptResults) observationCount = len(observations) - // Build context from search results + // Token budget for prompt context injection + maxTokens := 8000 + currentTokens := 0 + + header := "\n# Relevant Knowledge From Previous Sessions\nIMPORTANT: Use this information to answer the question directly. Do NOT explore the codebase if the answer is here.\n\n" + currentTokens += estimateTokens(header) var contextBuilder string - contextBuilder = "\n" - contextBuilder += "# Relevant Knowledge From Previous Sessions\n" - contextBuilder += "IMPORTANT: Use this information to answer the question directly. Do NOT explore the codebase if the answer is here.\n\n" + contextBuilder = header for i, obs := range observations { if obsMap, ok := obs.(map[string]interface{}); ok { @@ -52,24 +141,30 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { obsType = t } - // Start observation block - contextBuilder += fmt.Sprintf("## %d. [%s] %s\n", i+1, obsType, title) + var obsText string + obsText = fmt.Sprintf("## %d. [%s] %s\n", i+1, obsType, title) - // Add facts first (most concise answers) if facts, ok := obsMap["facts"].([]interface{}); ok && len(facts) > 0 { - contextBuilder += "Key facts:\n" + obsText += "Key facts:\n" for _, fact := range facts { if factStr, ok := fact.(string); ok { - contextBuilder += fmt.Sprintf("- %s\n", factStr) + obsText += fmt.Sprintf("- %s\n", factStr) } } - contextBuilder += "\n" + obsText += "\n" } - // Add narrative if present if narrative, ok := obsMap["narrative"].(string); ok && narrative != "" { - contextBuilder += narrative + "\n\n" + obsText += narrative + "\n\n" } + + obsTokens := estimateTokens(obsText) + if currentTokens+obsTokens > maxTokens { + break + } + + contextBuilder += obsText + currentTokens += obsTokens } } @@ -77,40 +172,24 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { contextToInject = contextBuilder } - // Initialize session with matched observations count - result, err := hooks.POST(ctx.Port, "/api/sessions/init", map[string]interface{}{ - "claudeSessionId": ctx.SessionID, - "project": ctx.Project, - "prompt": input.Prompt, - "matchedObservations": observationCount, - }) - if err != nil { - return "", err + // Check session init result + if initErr != nil { + return "", initErr } // Check if skipped due to privacy - if skipped, ok := result["skipped"].(bool); ok && skipped { + if skipped, ok := initResult["skipped"].(bool); ok && skipped { fmt.Fprintf(os.Stderr, "[user-prompt] Session skipped (private)\n") return "", nil } - // Safely extract session ID and prompt number with type checking - sessionDbIdRaw, ok := result["sessionDbId"].(float64) - if !ok { - return "", fmt.Errorf("invalid or missing sessionDbId in response") - } - sessionID := int64(sessionDbIdRaw) - - promptNumberRaw, ok := result["promptNumber"].(float64) - if !ok { - return "", fmt.Errorf("invalid or missing promptNumber in response") - } - promptNumber := int(promptNumberRaw) + sessionID := int64(initResult["sessionDbId"].(float64)) + promptNumber := int(initResult["promptNumber"].(float64)) fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber) - // Start SDK agent - _, err = hooks.POST(ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{ + // Start SDK agent (depends on session init result, so kept sequential) + _, err := hooks.POST(ctx.Port, fmt.Sprintf("/sessions/%d/init", sessionID), map[string]interface{}{ "userPrompt": input.Prompt, "promptNumber": promptNumber, }) @@ -120,7 +199,6 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { // Return context if we found relevant observations if observationCount > 0 { - // Show match count to user via stderr fmt.Fprintf(os.Stderr, "[claude-mnemonic] Found %d relevant memories for this prompt\n", observationCount) return contextToInject, nil } diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index 5d5cd7f..ee0a275 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -4,20 +4,16 @@ package main import ( "context" "flag" + "fmt" + "net/http" "os" "os/signal" "syscall" "time" "github.com/lukaszraczylo/claude-mnemonic/internal/config" - "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" - "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" "github.com/lukaszraczylo/claude-mnemonic/internal/mcp" - "github.com/lukaszraczylo/claude-mnemonic/internal/scoring" - "github.com/lukaszraczylo/claude-mnemonic/internal/search" - "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" "github.com/lukaszraczylo/claude-mnemonic/internal/watcher" - "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -28,7 +24,6 @@ var Version = "dev" func main() { // Parse flags project := flag.String("project", "", "Project name (required)") - dataDir := flag.String("data-dir", "", "Data directory (default: ~/.claude-mnemonic)") debug := flag.Bool("debug", false, "Enable debug logging") flag.Parse() @@ -43,23 +38,12 @@ func main() { log.Fatal().Msg("--project is required") } - // Ensure data directory and settings exist - if err := config.EnsureAll(); err != nil { - log.Fatal().Err(err).Msg("Failed to ensure data directories") - } + // Get worker port from config + port := config.GetWorkerPort() + workerURL := fmt.Sprintf("http://localhost:%d", port) - // Load config - cfg, err := config.Load() - if err != nil { - log.Warn().Err(err).Msg("Failed to load config, using defaults") - cfg = config.Default() - } - - // Override data directory if specified - dbPath := cfg.DBPath - if *dataDir != "" { - dbPath = *dataDir + "/claude-mnemonic.db" - } + // Create HTTP client for worker + client := &http.Client{Timeout: 30 * time.Second} ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -73,69 +57,12 @@ func main() { cancel() }() - // Initialize database store (migrations run automatically) - storeCfg := gorm.Config{ - Path: dbPath, - MaxConns: cfg.MaxConns, - // WALMode is enabled automatically by GORM - } - store, err := gorm.NewStore(storeCfg) - if err != nil { - log.Fatal().Err(err).Msg("Failed to initialize database store") - } - defer store.Close() + // Start file watchers for config changes + startWatchers() - // Initialize stores - observationStore := gorm.NewObservationStore(store, nil, nil, nil) - summaryStore := gorm.NewSummaryStore(store) - promptStore := gorm.NewPromptStore(store, nil) - patternStore := gorm.NewPatternStore(store) - relationStore := gorm.NewRelationStore(store) - sessionStore := gorm.NewSessionStore(store) - - // Initialize embedding service and vector client - var vectorClient *sqlitevec.Client - embedSvc, err := embedding.NewService() - if err != nil { - log.Warn().Err(err).Msg("Embedding service unavailable, vector search disabled") - } else { - defer embedSvc.Close() - vectorClient, err = sqlitevec.NewClient(sqlitevec.Config{DB: store.GetRawDB()}, embedSvc) - if err != nil { - log.Warn().Err(err).Msg("Vector client unavailable, vector search disabled") - } else { - log.Info().Msg("Vector search enabled via sqlite-vec") - } - } - - // Initialize scoring components - scoreConfig := models.DefaultScoringConfig() - scoreCalculator := scoring.NewCalculator(scoreConfig) - recalculator := scoring.NewRecalculator(observationStore, scoreCalculator, log.Logger) - go recalculator.Start(ctx) - defer recalculator.Stop() - - // Initialize search manager - searchMgr := search.NewManager(observationStore, summaryStore, promptStore, vectorClient) - - // Start file watchers - startWatchers(ctx, dbPath) - - // Create and run MCP server with all dependencies - // Note: maintenanceService is nil because it runs in the worker process - server := mcp.NewServer( - searchMgr, - Version, - observationStore, - patternStore, - relationStore, - sessionStore, - vectorClient, - scoreCalculator, - recalculator, - nil, // maintenanceService - handled by worker - ) - log.Info().Str("project", *project).Str("version", Version).Msg("Starting MCP server") + // Create and run MCP server + server := mcp.NewServer(client, workerURL, *project, Version) + log.Info().Str("project", *project).Str("version", Version).Str("worker", workerURL).Msg("Starting MCP server") if err := server.Run(ctx); err != nil { log.Fatal().Err(err).Msg("MCP server error") @@ -143,7 +70,7 @@ func main() { } // startWatchers initializes file watchers for config. -func startWatchers(ctx context.Context, dbPath string) { +func startWatchers() { // Watch config file for changes (triggers process exit for restart) configPath := config.SettingsPath() configWatcher, err := watcher.New(configPath, func() { diff --git a/internal/config/config.go b/internal/config/config.go index fea40fd..f2abde8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,7 @@ import ( "encoding/json" "os" "path/filepath" + "strconv" "strings" "sync" ) @@ -52,6 +53,7 @@ type Config struct { ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` RerankingCandidates int `json:"reranking_candidates"` WorkerPort int `json:"worker_port"` + DeduplicationThreshold float64 `json:"deduplication_threshold"` RerankingMinImprovement float64 `json:"reranking_min_improvement"` ContextObservations int `json:"context_observations"` ContextMaxPromptResults int `json:"context_max_prompt_results"` @@ -64,10 +66,13 @@ type Config struct { HubThreshold int `json:"hub_threshold"` ObservationRetentionDays int `json:"observation_retention_days"` MaintenanceIntervalHours int `json:"maintenance_interval_hours"` + ContextMaxTokensStartup int `json:"context_max_tokens_startup"` + ContextMaxTokensPrompt int `json:"context_max_tokens_prompt"` ContextShowWorkTokens bool `json:"context_show_work_tokens"` ContextShowReadTokens bool `json:"context_show_read_tokens"` RerankingPureMode bool `json:"reranking_pure_mode"` GraphEnabled bool `json:"graph_enabled"` + DeduplicationEnabled bool `json:"deduplication_enabled"` MaintenanceEnabled bool `json:"maintenance_enabled"` RerankingEnabled bool `json:"reranking_enabled"` ContextShowLastSummary bool `json:"context_show_last_summary"` @@ -168,6 +173,10 @@ func Default() *Config { ContextObsConcepts: DefaultObservationConcepts, ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only) + ContextMaxTokensStartup: 16000, // Max tokens for SessionStart context injection + ContextMaxTokensPrompt: 8000, // Max tokens for UserPromptSubmit context injection + DeduplicationEnabled: true, // Enable write-time vector dedup + DeduplicationThreshold: 0.9, // Similarity threshold for merging (0.9 = very similar) MaintenanceEnabled: true, // Enable scheduled maintenance MaintenanceIntervalHours: 6, // Run every 6 hours ObservationRetentionDays: 0, // 0 = no age-based deletion (keep all) @@ -269,6 +278,29 @@ func Load() (*Config, error) { if v, ok := settings["CLAUDE_MNEMONIC_HUB_THRESHOLD"].(float64); ok && v > 0 { cfg.HubThreshold = int(v) } + if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_TOKENS_STARTUP"].(float64); ok && v > 0 { + cfg.ContextMaxTokensStartup = int(v) + } + if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_MAX_TOKENS_PROMPT"].(float64); ok && v > 0 { + cfg.ContextMaxTokensPrompt = int(v) + } + // Deduplication settings + if v, ok := settings["CLAUDE_MNEMONIC_DEDUP_ENABLED"].(bool); ok { + cfg.DeduplicationEnabled = v + } + if v, ok := settings["CLAUDE_MNEMONIC_DEDUP_THRESHOLD"].(float64); ok && v > 0 && v <= 1 { + cfg.DeduplicationThreshold = v + } + + // Also support env vars for dedup settings + if v := os.Getenv("CLAUDE_MNEMONIC_DEDUP_ENABLED"); v != "" { + cfg.DeduplicationEnabled = v == "true" || v == "1" + } + if v := os.Getenv("CLAUDE_MNEMONIC_DEDUP_THRESHOLD"); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil && f > 0 && f <= 1 { + cfg.DeduplicationThreshold = f + } + } return cfg, nil } diff --git a/internal/db/gorm/pattern_store.go b/internal/db/gorm/pattern_store.go index 82acfc8..8cfcbfd 100644 --- a/internal/db/gorm/pattern_store.go +++ b/internal/db/gorm/pattern_store.go @@ -4,6 +4,8 @@ package gorm import ( "context" "database/sql" + "fmt" + "sync" "time" "gorm.io/gorm" @@ -18,6 +20,7 @@ type PatternCleanupFunc func(ctx context.Context, deletedIDs []int64) type PatternStore struct { db *gorm.DB cleanupFunc PatternCleanupFunc + cleanupMu sync.RWMutex } // NewPatternStore creates a new pattern store. @@ -29,6 +32,8 @@ func NewPatternStore(store *Store) *PatternStore { // SetCleanupFunc sets the callback for when patterns are deleted. func (s *PatternStore) SetCleanupFunc(fn PatternCleanupFunc) { + s.cleanupMu.Lock() + defer s.cleanupMu.Unlock() s.cleanupFunc = fn } @@ -238,6 +243,9 @@ func (s *PatternStore) MarkPatternDeprecated(ctx context.Context, id int64) erro // MergePatterns merges a source pattern into a target pattern. func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int64) error { + if sourceID == targetID { + return fmt.Errorf("cannot merge pattern into itself") + } // Get both patterns source, err := s.GetPatternByID(ctx, sourceID) if err != nil { @@ -294,8 +302,13 @@ func (s *PatternStore) MergePatterns(ctx context.Context, sourceID, targetID int func (s *PatternStore) DeletePattern(ctx context.Context, id int64) error { result := s.db.WithContext(ctx).Delete(&Pattern{}, id) - if result.Error == nil && s.cleanupFunc != nil { - s.cleanupFunc(ctx, []int64{id}) + if result.Error == nil { + s.cleanupMu.RLock() + fn := s.cleanupFunc + s.cleanupMu.RUnlock() + if fn != nil { + fn(ctx, []int64{id}) + } } return result.Error diff --git a/internal/db/gorm/prompt_store.go b/internal/db/gorm/prompt_store.go index 86519cf..301387c 100644 --- a/internal/db/gorm/prompt_store.go +++ b/internal/db/gorm/prompt_store.go @@ -4,8 +4,10 @@ package gorm import ( "context" "database/sql" + "sync" "time" + "github.com/rs/zerolog/log" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -23,6 +25,7 @@ const MaxPromptsGlobal = 500 type PromptStore struct { db *gorm.DB cleanupFunc PromptCleanupFunc + cleanupMu sync.RWMutex } // NewPromptStore creates a new prompt store. @@ -35,6 +38,8 @@ func NewPromptStore(store *Store, cleanupFunc PromptCleanupFunc) *PromptStore { // SetCleanupFunc sets the callback for when prompts are deleted during cleanup. func (s *PromptStore) SetCleanupFunc(fn PromptCleanupFunc) { + s.cleanupMu.Lock() + defer s.cleanupMu.Unlock() s.cleanupFunc = fn } @@ -81,9 +86,15 @@ func (s *PromptStore) SaveUserPromptWithMatches(ctx context.Context, claudeSessi go func() { cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - deletedIDs, _ := s.CleanupOldPrompts(cleanupCtx) - if len(deletedIDs) > 0 && s.cleanupFunc != nil { - s.cleanupFunc(cleanupCtx, deletedIDs) + if deletedIDs, err := s.CleanupOldPrompts(cleanupCtx); err != nil { + log.Warn().Err(err).Msg("Background prompt cleanup failed") + } else if len(deletedIDs) > 0 { + s.cleanupMu.RLock() + fn := s.cleanupFunc + s.cleanupMu.RUnlock() + if fn != nil { + fn(cleanupCtx, deletedIDs) + } } }() diff --git a/internal/embedding/service.go b/internal/embedding/service.go index 94fd51c..4d0078a 100644 --- a/internal/embedding/service.go +++ b/internal/embedding/service.go @@ -8,6 +8,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "github.com/sugarme/tokenizer" @@ -69,8 +70,10 @@ func newBGEModel() (EmbeddingModel, error) { libPath := filepath.Join(libDir, onnxRuntimeLibName) ort.SetSharedLibraryPath(libPath) - // Initialize ONNX runtime - if err := ort.InitializeEnvironment(); err != nil { + // Initialize ONNX runtime (idempotent - ignore "already initialized" since + // the ONNX environment is process-global and shared with the reranking service) + err = ort.InitializeEnvironment() + if err != nil && !strings.Contains(err.Error(), "already been initialized") { return nil, fmt.Errorf("initialize ONNX runtime: %w", err) } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 8ba6c70..974e486 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -3,66 +3,44 @@ package mcp import ( "bufio" + "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "net/http" "os" - "strings" + "strconv" + "sync" + "sync/atomic" "time" - "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" - "github.com/lukaszraczylo/claude-mnemonic/internal/maintenance" - "github.com/lukaszraczylo/claude-mnemonic/internal/scoring" - "github.com/lukaszraczylo/claude-mnemonic/internal/search" - "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" - "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/rs/zerolog/log" ) -// Server is the MCP server that exposes search tools. +// Server is the MCP server that proxies tool calls to the worker HTTP API. // Field order optimized for memory alignment (fieldalignment). type Server struct { - stdin io.Reader - stdout io.Writer - searchMgr *search.Manager - observationStore *gorm.ObservationStore - patternStore *gorm.PatternStore - relationStore *gorm.RelationStore - sessionStore *gorm.SessionStore - vectorClient *sqlitevec.Client - scoreCalculator *scoring.Calculator - recalculator *scoring.Recalculator - maintenanceService *maintenance.Service - version string + stdin io.Reader + stdout io.Writer + client *http.Client + workerURL string + project string + version string + writeMu sync.Mutex + lastActivity atomic.Int64 } -// NewServer creates a new MCP server. -func NewServer( - searchMgr *search.Manager, - version string, - observationStore *gorm.ObservationStore, - patternStore *gorm.PatternStore, - relationStore *gorm.RelationStore, - sessionStore *gorm.SessionStore, - vectorClient *sqlitevec.Client, - scoreCalculator *scoring.Calculator, - recalculator *scoring.Recalculator, - maintenanceService *maintenance.Service, -) *Server { +// NewServer creates a new MCP server that proxies to the worker HTTP API. +func NewServer(client *http.Client, workerURL, project, version string) *Server { return &Server{ - searchMgr: searchMgr, - version: version, - stdin: os.Stdin, - stdout: os.Stdout, - observationStore: observationStore, - patternStore: patternStore, - relationStore: relationStore, - sessionStore: sessionStore, - vectorClient: vectorClient, - scoreCalculator: scoreCalculator, - recalculator: recalculator, - maintenanceService: maintenanceService, + client: client, + workerURL: workerURL, + project: project, + version: version, + stdin: os.Stdin, + stdout: os.Stdout, } } @@ -105,51 +83,98 @@ type Tool struct { // Run starts the MCP server loop. func (s *Server) Run(ctx context.Context) error { scanner := bufio.NewScanner(s.stdin) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 1024*1024) // 1MB max message size - // Channel to signal when scanner is done - scanDone := make(chan error, 1) + lines := make(chan string) + scanErr := make(chan error, 1) + + // Track last activity for idle timeout + s.lastActivity.Store(time.Now().Unix()) go func() { + defer close(lines) for scanner.Scan() { - // Check for context cancellation before processing + lines <- scanner.Text() + } + scanErr <- scanner.Err() + }() + + // Monitor parent process liveness and idle timeout. + // If the parent dies (ppid changes) or no messages arrive for 30 minutes, shut down. + parentPID := os.Getppid() + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { select { case <-ctx.Done(): - scanDone <- ctx.Err() return - default: + case <-ticker.C: + if os.Getppid() != parentPID { + log.Info().Msg("Parent process died, shutting down MCP server") + if closer, ok := s.stdin.(io.Closer); ok { + _ = closer.Close() + } + return + } + if time.Since(time.Unix(s.lastActivity.Load(), 0)) > 30*time.Minute { + log.Info().Msg("MCP server idle timeout (30m), shutting down") + if closer, ok := s.stdin.(io.Closer); ok { + _ = closer.Close() + } + return + } + } + } + }() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case line, ok := <-lines: + if !ok { + // Scanner finished, check for errors + err := <-scanErr + if err != nil { + if errors.Is(err, bufio.ErrTooLong) { + log.Error().Msg("MCP message exceeded 1MB buffer limit") + } + return fmt.Errorf("scanner error: %w", err) + } + return nil } - line := scanner.Text() + s.lastActivity.Store(time.Now().Unix()) + if line == "" { continue } var req Request if err := json.Unmarshal([]byte(line), &req); err != nil { - s.sendError(nil, -32700, "Parse error", err) + if werr := s.sendError(nil, -32700, "Parse error", err.Error()); werr != nil { + return fmt.Errorf("write error: %w", werr) + } continue } resp := s.handleRequest(ctx, &req) - s.sendResponse(resp) + if resp != nil { + if werr := s.sendResponse(resp); werr != nil { + return fmt.Errorf("write error: %w", werr) + } + } } - scanDone <- scanner.Err() - }() - - // Wait for either context cancellation or scanner completion - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-scanDone: - if err != nil { - return fmt.Errorf("scanner error: %w", err) - } - return nil } } // handleRequest dispatches the request to the appropriate handler. func (s *Server) handleRequest(ctx context.Context, req *Request) *Response { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + switch req.Method { case "initialize": return s.handleInitialize(req) @@ -157,6 +182,8 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) *Response { return s.handleToolsList(req) case "tools/call": return s.handleToolsCall(ctx, req) + case "notifications/initialized", "notifications/cancelled": + return nil // Notifications don't get responses default: return &Response{ JSONRPC: "2.0", @@ -755,137 +782,323 @@ func (s *Server) handleToolsCall(ctx context.Context, req *Request) *Response { } } -// callTool dispatches to the appropriate tool handler. -func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage) (string, error) { - // Special handlers for non-search tools - switch name { - case "find_related_observations": - return s.handleFindRelatedObservations(ctx, args) - case "find_similar_observations": - return s.handleFindSimilarObservations(ctx, args) - case "get_patterns": - return s.handleGetPatterns(ctx, args) - case "get_memory_stats": - return s.handleGetMemoryStats(ctx) - case "bulk_delete_observations": - return s.handleBulkDeleteObservations(ctx, args) - case "bulk_mark_superseded": - return s.handleBulkMarkSuperseded(ctx, args) - case "bulk_boost_observations": - return s.handleBulkBoostObservations(ctx, args) - case "trigger_maintenance": - return s.handleTriggerMaintenance(ctx) - case "get_maintenance_stats": - return s.handleGetMaintenanceStats(ctx) - case "merge_observations": - return s.handleMergeObservations(ctx, args) - case "get_observation": - return s.handleGetObservation(ctx, args) - case "edit_observation": - return s.handleEditObservation(ctx, args) - case "get_observation_quality": - return s.handleGetObservationQuality(ctx, args) - case "suggest_consolidations": - return s.handleSuggestConsolidations(ctx, args) - case "tag_observation": - return s.handleTagObservation(ctx, args) - case "get_observations_by_tag": - return s.handleGetObservationsByTag(ctx, args) - case "get_temporal_trends": - return s.handleGetTemporalTrends(ctx, args) - case "get_data_quality_report": - return s.handleGetDataQualityReport(ctx, args) - case "batch_tag_by_pattern": - return s.handleBatchTagByPattern(ctx, args) - case "explain_search_ranking": - return s.handleExplainSearchRanking(ctx, args) - case "export_observations": - return s.handleExportObservations(ctx, args) - case "check_system_health": - return s.handleCheckSystemHealth(ctx) - case "analyze_search_patterns": - return s.handleAnalyzeSearchPatterns(ctx, args) - case "get_observation_relationships": - return s.handleGetObservationRelationships(ctx, args) - case "get_observation_scoring_breakdown": - return s.handleGetObservationScoringBreakdown(ctx, args) - case "analyze_observation_importance": - return s.handleAnalyzeObservationImportance(ctx, args) - } - - // Original search-based tools - var params search.SearchParams - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - var result *search.UnifiedSearchResult - var err error - - switch name { - case "search": - result, err = s.searchMgr.UnifiedSearch(ctx, params) - case "timeline": - result, err = s.handleTimeline(ctx, args) - case "decisions": - result, err = s.searchMgr.Decisions(ctx, params) - case "changes": - result, err = s.searchMgr.Changes(ctx, params) - case "how_it_works": - result, err = s.searchMgr.HowItWorks(ctx, params) - case "find_by_concept": - params.Type = "observations" - result, err = s.searchMgr.UnifiedSearch(ctx, params) - case "find_by_file": - params.Type = "observations" - result, err = s.searchMgr.UnifiedSearch(ctx, params) - case "find_by_type": - params.Type = "observations" - result, err = s.searchMgr.UnifiedSearch(ctx, params) - case "get_recent_context": - result, err = s.searchMgr.UnifiedSearch(ctx, params) - case "get_context_timeline": - result, err = s.handleTimeline(ctx, args) - case "get_timeline_by_query": - result, err = s.handleTimelineByQuery(ctx, args) - default: - return "", fmt.Errorf("unknown tool: %s", name) - } - - if err != nil { - return "", err - } - - output, err := json.Marshal(result) - if err != nil { - return "", fmt.Errorf("marshal result: %w", err) - } - - return string(output), nil -} - -// TimelineParams represents parameters for timeline operations. -type TimelineParams struct { +// searchArgs holds common search parameters used by many tools. +type searchArgs struct { Query string `json:"query"` Project string `json:"project"` + Type string `json:"type"` ObsType string `json:"obs_type"` Concepts string `json:"concepts"` Files string `json:"files"` + DateStart any `json:"dateStart"` + DateEnd any `json:"dateEnd"` + OrderBy string `json:"orderBy"` Format string `json:"format"` - AnchorID int64 `json:"anchor_id"` - Before int `json:"before"` - After int `json:"after"` - DateStart int64 `json:"dateStart"` - DateEnd int64 `json:"dateEnd"` + Limit int `json:"limit"` + Offset int `json:"offset"` } -// handleTimeline handles timeline requests. -func (s *Server) handleTimeline(ctx context.Context, args json.RawMessage) (*search.UnifiedSearchResult, error) { - var params TimelineParams - if err := json.Unmarshal(args, ¶ms); err != nil { - return nil, fmt.Errorf("invalid timeline params: %w", err) +// callTool dispatches to the appropriate tool handler by proxying to the worker HTTP API. +func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage) (string, error) { + // Parse common search params used by many tools + var sa searchArgs + // Best-effort parse; individual handlers validate as needed + _ = json.Unmarshal(args, &sa) + + // Default project from server config + if sa.Project == "" { + sa.Project = s.project } + switch name { + // --- Search-based tools: proxy to GET /api/context/search --- + case "search": + return s.handleSearchProxy(ctx, sa) + case "decisions": + sa.ObsType = "decision" + return s.handleSearchProxy(ctx, sa) + case "changes": + sa.ObsType = "code_change" + return s.handleSearchProxy(ctx, sa) + case "how_it_works": + sa.ObsType = "architecture" + return s.handleSearchProxy(ctx, sa) + case "find_by_concept": + return s.handleSearchProxy(ctx, sa) + case "find_by_file": + return s.handleSearchProxy(ctx, sa) + case "find_by_type": + return s.handleSearchProxy(ctx, sa) + case "get_recent_context": + return s.handleSearchProxy(ctx, sa) + case "timeline", "get_context_timeline": + return s.handleTimelineProxy(ctx, args) + case "get_timeline_by_query": + return s.handleTimelineProxy(ctx, args) + + // --- Observation endpoints --- + case "get_observation": + return s.handleGetObservationProxy(ctx, args) + case "edit_observation": + return s.handleEditObservationProxy(ctx, args) + case "find_related_observations": + return s.handleFindRelatedProxy(ctx, args) + case "find_similar_observations": + return s.handleFindSimilarProxy(ctx, args) + case "get_observation_quality": + return s.handleGetObservationQualityProxy(ctx, args) + case "get_observation_relationships": + return s.handleGetRelationshipsProxy(ctx, args) + case "get_observation_scoring_breakdown": + return s.handleGetScoringBreakdownProxy(ctx, args) + case "tag_observation": + return s.handleTagObservationProxy(ctx, args) + case "get_observations_by_tag": + return s.handleGetObservationsByTagProxy(ctx, args) + + // --- Bulk operations --- + case "bulk_delete_observations": + return s.handleBulkStatusProxy(ctx, args, "delete") + case "bulk_mark_superseded": + return s.handleBulkStatusProxy(ctx, args, "supersede") + case "bulk_boost_observations": + return s.handleBulkStatusProxy(ctx, args, "boost") + case "merge_observations": + return s.handleMergeProxy(ctx, args) + + // --- Pattern endpoints --- + case "get_patterns": + return s.handleGetPatternsProxy(ctx, args) + + // --- Stats and analytics endpoints --- + case "get_memory_stats": + return s.proxyGetRaw(ctx, "/api/stats", map[string]string{ + "project": s.project, + }) + case "check_system_health": + return s.proxyGetRaw(ctx, "/api/selfcheck", nil) + case "get_maintenance_stats": + return s.proxyGetRaw(ctx, "/api/stats", map[string]string{ + "project": s.project, + }) + case "trigger_maintenance": + return s.proxyPostRaw(ctx, "/api/scoring/recalculate", nil) + case "analyze_observation_importance": + return s.handleAnalyzeImportanceProxy(ctx, args) + case "analyze_search_patterns": + return s.proxyGetRaw(ctx, "/api/search/analytics", nil) + case "explain_search_ranking": + return s.handleExplainSearchProxy(ctx, args) + case "get_temporal_trends": + return s.handleGetTemporalTrendsProxy(ctx, args) + case "get_data_quality_report": + return s.handleGetDataQualityProxy(ctx, args) + + // --- Export and batch tag --- + case "export_observations": + return s.handleExportProxy(ctx, args) + case "suggest_consolidations": + return s.handleSuggestConsolidationsProxy(ctx, args) + case "batch_tag_by_pattern": + return s.handleBatchTagProxy(ctx, args) + + default: + return "", fmt.Errorf("unknown tool: %s", name) + } +} + +// ============================================================================= +// HTTP PROXY HELPERS +// ============================================================================= + +// proxyGetRaw performs a GET request to the worker and returns the raw JSON response body. +func (s *Server) proxyGetRaw(ctx context.Context, path string, params map[string]string) (string, error) { + if s.client == nil { + return "", fmt.Errorf("worker unavailable at %s: http client not configured", s.workerURL) + } + req, err := http.NewRequestWithContext(ctx, "GET", s.workerURL+path, nil) + if err != nil { + return "", err + } + if len(params) > 0 { + q := req.URL.Query() + for k, v := range params { + if v != "" { + q.Set(k, v) + } + } + req.URL.RawQuery = q.Encode() + } + + resp, err := s.client.Do(req) + if err != nil { + return "", fmt.Errorf("worker unavailable at %s: %w", s.workerURL, err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read worker response: %w", err) + } + + if resp.StatusCode >= 400 { + return "", fmt.Errorf("worker returned %d: %s", resp.StatusCode, string(body)) + } + + return string(body), nil +} + +// proxyPostRaw performs a POST request to the worker and returns the raw JSON response body. +func (s *Server) proxyPostRaw(ctx context.Context, path string, payload any) (string, error) { + if s.client == nil { + return "", fmt.Errorf("worker unavailable at %s: http client not configured", s.workerURL) + } + var bodyReader io.Reader + if payload != nil { + data, err := json.Marshal(payload) + if err != nil { + return "", err + } + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequestWithContext(ctx, "POST", s.workerURL+path, bodyReader) + if err != nil { + return "", err + } + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := s.client.Do(req) + if err != nil { + return "", fmt.Errorf("worker unavailable at %s: %w", s.workerURL, err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read worker response: %w", err) + } + + if resp.StatusCode >= 400 { + return "", fmt.Errorf("worker returned %d: %s", resp.StatusCode, string(body)) + } + + return string(body), nil +} + +// proxyPutRaw performs a PUT request to the worker and returns the raw JSON response body. +func (s *Server) proxyPutRaw(ctx context.Context, path string, payload any) (string, error) { + if s.client == nil { + return "", fmt.Errorf("worker unavailable at %s: http client not configured", s.workerURL) + } + var bodyReader io.Reader + if payload != nil { + data, err := json.Marshal(payload) + if err != nil { + return "", err + } + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequestWithContext(ctx, "PUT", s.workerURL+path, bodyReader) + if err != nil { + return "", err + } + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := s.client.Do(req) + if err != nil { + return "", fmt.Errorf("worker unavailable at %s: %w", s.workerURL, err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read worker response: %w", err) + } + + if resp.StatusCode >= 400 { + return "", fmt.Errorf("worker returned %d: %s", resp.StatusCode, string(body)) + } + + return string(body), nil +} + +// anyToString converts an interface{} value to its string representation for query params. +func anyToString(v any) string { + if v == nil { + return "" + } + switch val := v.(type) { + case string: + return val + case float64: + return strconv.FormatInt(int64(val), 10) + case int: + return strconv.Itoa(val) + case int64: + return strconv.FormatInt(val, 10) + default: + return fmt.Sprintf("%v", v) + } +} + +// ============================================================================= +// TOOL HANDLER PROXIES +// ============================================================================= + +// handleSearchProxy proxies search requests to GET /api/context/search. +func (s *Server) handleSearchProxy(ctx context.Context, args searchArgs) (string, error) { + params := map[string]string{ + "project": args.Project, + "query": args.Query, + } + if args.Limit > 0 { + params["limit"] = strconv.Itoa(args.Limit) + } + if args.ObsType != "" { + params["obs_type"] = args.ObsType + } + if args.Concepts != "" { + params["concepts"] = args.Concepts + } + if args.Files != "" { + params["files"] = args.Files + } + if args.DateStart != nil { + if ds := anyToString(args.DateStart); ds != "" { + params["dateStart"] = ds + } + } + if args.DateEnd != nil { + if de := anyToString(args.DateEnd); de != "" { + params["dateEnd"] = de + } + } + + return s.proxyGetRaw(ctx, "/api/context/search", params) +} + +// handleTimelineProxy proxies timeline requests. First searches for anchor, then fetches surrounding observations. +func (s *Server) handleTimelineProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + AnchorID int64 `json:"anchor_id"` + Query string `json:"query"` + Project string `json:"project"` + Before int `json:"before"` + After int `json:"after"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid timeline params: %w", err) + } + + if params.Project == "" { + params.Project = s.project + } if params.Before <= 0 { params.Before = 10 } @@ -893,697 +1106,57 @@ func (s *Server) handleTimeline(ctx context.Context, args json.RawMessage) (*sea params.After = 10 } - // If query provided, first find anchor + // If query provided and no anchor, search for it first if params.Query != "" && params.AnchorID == 0 { - searchParams := search.SearchParams{ - Query: params.Query, - Type: "observations", - Project: params.Project, - Limit: 1, - } - result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) + searchResult, err := s.proxyGetRaw(ctx, "/api/context/search", map[string]string{ + "project": params.Project, + "query": params.Query, + "limit": "1", + }) if err != nil { - return nil, err + return "", err } - if len(result.Results) > 0 { - params.AnchorID = result.Results[0].ID + // Extract first observation ID from search results + var searchResp struct { + Observations []struct { + ID int64 `json:"id"` + } `json:"observations"` + } + if err := json.Unmarshal([]byte(searchResult), &searchResp); err == nil && len(searchResp.Observations) > 0 { + params.AnchorID = searchResp.Observations[0].ID } } if params.AnchorID == 0 { - return &search.UnifiedSearchResult{Results: []search.SearchResult{}}, nil + result, _ := json.Marshal(map[string]any{"observations": []any{}, "message": "no anchor found"}) + return string(result), nil } - // Fetch observations around anchor - searchParams := search.SearchParams{ - Type: "observations", - Project: params.Project, - ObsType: params.ObsType, - Concepts: params.Concepts, - Files: params.Files, - Limit: params.Before + params.After + 1, - Format: params.Format, - } - - return s.searchMgr.UnifiedSearch(ctx, searchParams) + // Get observations around the anchor + limit := params.Before + params.After + 1 + return s.proxyGetRaw(ctx, "/api/observations", map[string]string{ + "project": params.Project, + "limit": strconv.Itoa(limit), + }) } -// handleTimelineByQuery handles combined search + timeline requests. -func (s *Server) handleTimelineByQuery(ctx context.Context, args json.RawMessage) (*search.UnifiedSearchResult, error) { - var params TimelineParams - if err := json.Unmarshal(args, ¶ms); err != nil { - return nil, fmt.Errorf("invalid timeline params: %w", err) - } - - if params.Query == "" { - return nil, fmt.Errorf("query is required") - } - - // First search - searchParams := search.SearchParams{ - Query: params.Query, - Type: "observations", - Project: params.Project, - DateStart: params.DateStart, - DateEnd: params.DateEnd, - Limit: 1, - } - - result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) - if err != nil { - return nil, err - } - - if len(result.Results) == 0 { - return result, nil - } - - // Now get timeline around that result - params.AnchorID = result.Results[0].ID - return s.handleTimeline(ctx, args) -} - -// handleFindRelatedObservations finds observations related to a given observation ID. -func (s *Server) handleFindRelatedObservations(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - ID int64 `json:"id"` - MinConfidence float64 `json:"min_confidence"` - Limit int `json:"limit"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if params.ID == 0 { - return "", fmt.Errorf("id is required") - } - - // Use -1 as sentinel for "not provided" since 0.0 is a valid threshold - if params.MinConfidence < 0 { - params.MinConfidence = 0.5 - } - - if params.Limit == 0 { - params.Limit = 20 - } - if params.Limit > 100 { - params.Limit = 100 - } - - // Get related observation IDs with confidence filter - relatedIDs, err := s.relationStore.GetRelatedObservationIDs(ctx, params.ID, params.MinConfidence) - if err != nil { - return "", fmt.Errorf("failed to get related observations: %w", err) - } - - if relatedIDs == nil { - relatedIDs = []int64{} - } - - // Limit results - if len(relatedIDs) > params.Limit { - relatedIDs = relatedIDs[:params.Limit] - } - - // Fetch full observations in batch (avoids N+1 query problem) - observations, err := s.observationStore.GetObservationsByIDsPreserveOrder(ctx, relatedIDs) - if err != nil { - log.Warn().Err(err).Msg("Failed to batch fetch related observations, falling back to individual fetch") - // Fallback to individual fetch if batch fails - observations = make([]*models.Observation, 0, len(relatedIDs)) - for _, id := range relatedIDs { - obs, fetchErr := s.observationStore.GetObservationByID(ctx, id) - if fetchErr == nil && obs != nil { - observations = append(observations, obs) - } - } - } - - response := map[string]any{ - "observations": observations, - "count": len(observations), - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// sendResponse sends a JSON-RPC response. -func (s *Server) sendResponse(resp *Response) { - data, err := json.Marshal(resp) - if err != nil { - log.Error().Err(err).Msg("Failed to marshal response") - return - } - fmt.Fprintln(s.stdout, string(data)) -} - -// sendError sends a JSON-RPC error response. -func (s *Server) sendError(id any, code int, message string, data any) { - resp := &Response{ - JSONRPC: "2.0", - ID: id, - Error: &Error{ - Code: code, - Message: message, - Data: data, - }, - } - s.sendResponse(resp) -} - -// handleFindSimilarObservations finds observations semantically similar to a query. -func (s *Server) handleFindSimilarObservations(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - Query string `json:"query"` - Project string `json:"project"` - MinSimilarity float64 `json:"min_similarity"` - Limit int `json:"limit"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if params.Query == "" { - return "", fmt.Errorf("query is required") - } - - if params.MinSimilarity == 0 { - params.MinSimilarity = 0.7 - } - if params.Limit == 0 { - params.Limit = 10 - } - if params.Limit > 50 { - params.Limit = 50 - } - - // Use vector search to find similar observations - if s.vectorClient == nil { - return "", fmt.Errorf("vector search not available") - } - - where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, params.Project) - results, err := s.vectorClient.Query(ctx, params.Query, params.Limit*2, where) - if err != nil { - return "", fmt.Errorf("vector search failed: %w", err) - } - - // Filter by similarity threshold - filtered := sqlitevec.FilterByThreshold(results, params.MinSimilarity, params.Limit) - - // Extract observation IDs and build similarity map - obsIDs := sqlitevec.ExtractObservationIDs(filtered, params.Project) - similarityMap := make(map[int64]float64, len(filtered)) - for _, r := range filtered { - if sqliteID, ok := r.Metadata["sqlite_id"].(float64); ok { - id := int64(sqliteID) - if _, exists := similarityMap[id]; !exists { - similarityMap[id] = r.Similarity - } - } - } - - // Fetch full observations in batch (avoids N+1 query problem) - observations, err := s.observationStore.GetObservationsByIDsPreserveOrder(ctx, obsIDs) - if err != nil { - log.Warn().Err(err).Msg("Failed to batch fetch similar observations, falling back to individual fetch") - observations = make([]*models.Observation, 0, len(obsIDs)) - for _, id := range obsIDs { - obs, fetchErr := s.observationStore.GetObservationByID(ctx, id) - if fetchErr == nil && obs != nil { - observations = append(observations, obs) - } - } - } - - // Build response with similarity scores - type SimilarObservation struct { - *models.Observation - Similarity float64 `json:"similarity"` - } - - similarObs := make([]SimilarObservation, 0, len(observations)) - for _, obs := range observations { - sim := similarityMap[obs.ID] - similarObs = append(similarObs, SimilarObservation{ - Observation: obs, - Similarity: sim, - }) - } - - response := map[string]any{ - "observations": similarObs, - "count": len(similarObs), - "min_similarity": params.MinSimilarity, - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleGetPatterns returns patterns from the pattern store. -func (s *Server) handleGetPatterns(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - Type string `json:"type"` - Project string `json:"project"` - Query string `json:"query"` - Limit int `json:"limit"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if params.Limit == 0 { - params.Limit = 20 - } - if params.Limit > 100 { - params.Limit = 100 - } - - var patterns []*models.Pattern - var err error - - // Query patterns based on filters - if params.Query != "" { - // FTS search - patterns, err = s.patternStore.SearchPatternsFTS(ctx, params.Query, params.Limit) - } else if params.Type != "" { - // Filter by type - patterns, err = s.patternStore.GetPatternsByType(ctx, models.PatternType(params.Type), params.Limit) - } else if params.Project != "" { - // Filter by project - patterns, err = s.patternStore.GetPatternsByProject(ctx, params.Project, params.Limit) - } else { - // Get all active patterns - patterns, err = s.patternStore.GetActivePatterns(ctx, params.Limit) - } - - if err != nil { - return "", fmt.Errorf("failed to get patterns: %w", err) - } - - response := map[string]any{ - "patterns": patterns, - "count": len(patterns), - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleGetMemoryStats returns statistics about the memory system. -func (s *Server) handleGetMemoryStats(ctx context.Context) (string, error) { - stats := make(map[string]any, 8) // Pre-allocate for expected stats keys - - // Get vector count - if s.vectorClient != nil { - count, err := s.vectorClient.Count(ctx) - if err == nil { - stats["vector_count"] = count - } - - // Cache stats - cacheSize, cacheMax := s.vectorClient.CacheStats() - stats["embedding_cache"] = map[string]any{ - "size": cacheSize, - "max_size": cacheMax, - } - - // Model version - stats["embedding_model"] = s.vectorClient.ModelVersion() - } - - // Get pattern stats - if s.patternStore != nil { - patternStats, err := s.patternStore.GetPatternStats(ctx) - if err == nil && patternStats != nil { - stats["patterns"] = map[string]any{ - "total": patternStats.Total, - "active": patternStats.Active, - "deprecated": patternStats.Deprecated, - "merged": patternStats.Merged, - "total_occurrences": patternStats.TotalOccurrences, - "avg_confidence": patternStats.AvgConfidence, - } - } - } - - // Get search metrics - if s.searchMgr != nil { - searchMetrics := s.searchMgr.Metrics() - if searchMetrics != nil { - stats["search"] = searchMetrics.GetStats() - } - } - - output, err := json.Marshal(stats) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleBulkDeleteObservations deletes multiple observations by ID. -func (s *Server) handleBulkDeleteObservations(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - IDs []int64 `json:"ids"` - DeleteVectors bool `json:"delete_vectors"` - } - params.DeleteVectors = true // default - - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if len(params.IDs) == 0 { - return "", fmt.Errorf("ids is required") - } - - if len(params.IDs) > 1000 { - return "", fmt.Errorf("maximum 1000 IDs per request") - } - - var deleted int64 - var errors []string - - // Delete in batches - batchSize := 100 - for i := 0; i < len(params.IDs); i += batchSize { - end := min(i+batchSize, len(params.IDs)) - batch := params.IDs[i:end] - - for _, id := range batch { - if err := s.observationStore.DeleteObservation(ctx, id); err != nil { - errors = append(errors, fmt.Sprintf("id %d: %v", id, err)) - continue - } - deleted++ - - // Delete associated vectors if requested - if params.DeleteVectors && s.vectorClient != nil { - _ = s.vectorClient.DeleteByObservationID(ctx, id) - } - } - } - - response := map[string]any{ - "deleted": deleted, - "total": len(params.IDs), - } - if len(errors) > 0 { - response["errors"] = errors - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - // Return error if all deletions failed (complete failure) - if deleted == 0 && len(errors) > 0 { - return string(output), fmt.Errorf("bulk delete failed: %d errors, first: %s", len(errors), errors[0]) - } - - return string(output), nil -} - -// handleBulkMarkSuperseded marks multiple observations as superseded. -func (s *Server) handleBulkMarkSuperseded(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - IDs []int64 `json:"ids"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if len(params.IDs) == 0 { - return "", fmt.Errorf("ids is required") - } - - if len(params.IDs) > 1000 { - return "", fmt.Errorf("maximum 1000 IDs per request") - } - - // Use batch update for efficiency (single query instead of N queries) - updated, err := s.observationStore.MarkAsSupersededBatch(ctx, params.IDs) - if err != nil { - return "", fmt.Errorf("batch mark as superseded: %w", err) - } - - response := map[string]any{ - "updated": updated, - "total": len(params.IDs), - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleBulkBoostObservations boosts the importance score of multiple observations. -func (s *Server) handleBulkBoostObservations(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - IDs []int64 `json:"ids"` - Boost float64 `json:"boost"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if len(params.IDs) == 0 { - return "", fmt.Errorf("ids is required") - } - - if len(params.IDs) > 1000 { - return "", fmt.Errorf("maximum 1000 IDs per request") - } - - if params.Boost < -1.0 || params.Boost > 1.0 { - return "", fmt.Errorf("boost must be between -1.0 and 1.0") - } - - var boosted int64 - var errors []string - - // Batch fetch all observations in one query instead of N queries - observations, err := s.observationStore.GetObservationsByIDs(ctx, params.IDs, "", 0) - if err != nil { - return "", fmt.Errorf("batch fetch observations: %w", err) - } - - // Build a map for O(1) lookup - obsMap := make(map[int64]*models.Observation, len(observations)) - for _, obs := range observations { - obsMap[obs.ID] = obs - } - - // Calculate new scores and prepare batch update - scoresToUpdate := make(map[int64]float64, len(params.IDs)) - for _, id := range params.IDs { - obs, found := obsMap[id] - if !found { - errors = append(errors, fmt.Sprintf("id %d: not found", id)) - continue - } - - // Calculate new importance score (clamp between 0 and 1) - newScore := obs.ImportanceScore + params.Boost - if newScore < 0 { - newScore = 0 - } - if newScore > 1 { - newScore = 1 - } - scoresToUpdate[id] = newScore - } - - // Batch update all scores in one operation - if len(scoresToUpdate) > 0 { - if err := s.observationStore.UpdateImportanceScores(ctx, scoresToUpdate); err != nil { - return "", fmt.Errorf("batch update scores: %w", err) - } - boosted = int64(len(scoresToUpdate)) - } - - response := map[string]any{ - "boosted": boosted, - "total": len(params.IDs), - "boost_used": params.Boost, - } - if len(errors) > 0 { - response["errors"] = errors - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleTriggerMaintenance triggers an immediate maintenance run. -func (s *Server) handleTriggerMaintenance(ctx context.Context) (string, error) { - if s.maintenanceService == nil { - return "", fmt.Errorf("maintenance service not available") - } - - s.maintenanceService.RunNow(ctx) - - response := map[string]any{ - "status": "triggered", - "message": "Maintenance run started in background", - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleGetMaintenanceStats returns maintenance statistics. -func (s *Server) handleGetMaintenanceStats(_ context.Context) (string, error) { - if s.maintenanceService == nil { - return "", fmt.Errorf("maintenance service not available") - } - - stats := s.maintenanceService.Stats() - - output, err := json.Marshal(stats) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleMergeObservations merges two observations, keeping the target and superseding the source. -func (s *Server) handleMergeObservations(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - SourceID int64 `json:"source_id"` - TargetID int64 `json:"target_id"` - Boost float64 `json:"boost"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if params.SourceID == 0 || params.TargetID == 0 { - return "", fmt.Errorf("source_id and target_id are required") - } - - if params.SourceID == params.TargetID { - return "", fmt.Errorf("source_id and target_id cannot be the same") - } - - // Set default boost if not provided - if params.Boost == 0 { - params.Boost = 0.1 - } - if params.Boost < 0 || params.Boost > 0.5 { - return "", fmt.Errorf("boost must be between 0 and 0.5") - } - - // Get both observations to verify they exist - source, err := s.observationStore.GetObservationByID(ctx, params.SourceID) - if err != nil { - return "", fmt.Errorf("get source observation: %w", err) - } - if source == nil { - return "", fmt.Errorf("source observation %d not found", params.SourceID) - } - - target, err := s.observationStore.GetObservationByID(ctx, params.TargetID) - if err != nil { - return "", fmt.Errorf("get target observation: %w", err) - } - if target == nil { - return "", fmt.Errorf("target observation %d not found", params.TargetID) - } - - // Mark source as superseded - if err := s.observationStore.MarkAsSuperseded(ctx, params.SourceID); err != nil { - return "", fmt.Errorf("mark source as superseded: %w", err) - } - - // Boost target's importance score - newScore := target.ImportanceScore + params.Boost - if newScore > 1.0 { - newScore = 1.0 - } - if err := s.observationStore.UpdateImportanceScore(ctx, params.TargetID, newScore); err != nil { - return "", fmt.Errorf("update target score: %w", err) - } - - response := map[string]any{ - "merged": true, - "source_id": params.SourceID, - "source_title": source.Title.String, - "target_id": params.TargetID, - "target_title": target.Title.String, - "target_new_score": newScore, - "target_old_score": target.ImportanceScore, - "boost_applied": params.Boost, - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleGetObservation returns a single observation by ID. -func (s *Server) handleGetObservation(ctx context.Context, args json.RawMessage) (string, error) { +// handleGetObservationProxy proxies get observation by ID. +func (s *Server) handleGetObservationProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { ID int64 `json:"id"` } if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - if params.ID == 0 { return "", fmt.Errorf("id is required") } - obs, err := s.observationStore.GetObservationByID(ctx, params.ID) - if err != nil { - return "", fmt.Errorf("get observation: %w", err) - } - if obs == nil { - return "", fmt.Errorf("observation %d not found", params.ID) - } - - output, err := json.Marshal(obs) - if err != nil { - return "", fmt.Errorf("marshal observation: %w", err) - } - - return string(output), nil + return s.proxyGetRaw(ctx, fmt.Sprintf("/api/observations/%d", params.ID), nil) } -// handleEditObservation updates an existing observation with provided fields. -func (s *Server) handleEditObservation(ctx context.Context, args json.RawMessage) (string, error) { +// handleEditObservationProxy proxies observation edits. +func (s *Server) handleEditObservationProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { Title *string `json:"title,omitempty"` Subtitle *string `json:"subtitle,omitempty"` @@ -1598,110 +1171,122 @@ func (s *Server) handleEditObservation(ctx context.Context, args json.RawMessage if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - if params.ID == 0 { return "", fmt.Errorf("id is required") } - // Validate scope if provided - if params.Scope != nil && *params.Scope != "project" && *params.Scope != "global" { - return "", fmt.Errorf("scope must be 'project' or 'global'") - } - - // Build update struct - update := &gorm.ObservationUpdate{} - if params.Title != nil { - update.Title = params.Title - } - if params.Subtitle != nil { - update.Subtitle = params.Subtitle - } - if params.Narrative != nil { - update.Narrative = params.Narrative - } - if params.Facts != nil { - update.Facts = ¶ms.Facts - } - if params.Concepts != nil { - update.Concepts = ¶ms.Concepts - } - if params.FilesRead != nil { - update.FilesRead = ¶ms.FilesRead - } - if params.FilesModified != nil { - update.FilesModified = ¶ms.FilesModified - } - if params.Scope != nil { - update.Scope = params.Scope - } - - // Update the observation - updatedObs, err := s.observationStore.UpdateObservation(ctx, params.ID, update) - if err != nil { - return "", fmt.Errorf("update observation: %w", err) - } - - // Note: Vector resync is handled by the worker service when available - // The MCP server doesn't have access to the embedding service - - response := map[string]any{ - "updated": true, - "observation": updatedObs, - "vector_resync": "deferred", - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil + return s.proxyPutRaw(ctx, fmt.Sprintf("/api/observations/%d", params.ID), params) } -// handleGetObservationQuality returns quality metrics for an observation. -func (s *Server) handleGetObservationQuality(ctx context.Context, args json.RawMessage) (string, error) { +// handleFindRelatedProxy proxies find related observations. +func (s *Server) handleFindRelatedProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` + MinConfidence float64 `json:"min_confidence"` + Limit int `json:"limit"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if params.ID == 0 { + return "", fmt.Errorf("id is required") + } + + qp := map[string]string{} + if params.MinConfidence > 0 { + qp["min_confidence"] = strconv.FormatFloat(params.MinConfidence, 'f', -1, 64) + } + if params.Limit > 0 { + qp["limit"] = strconv.Itoa(params.Limit) + } + + return s.proxyGetRaw(ctx, fmt.Sprintf("/api/observations/%d/related", params.ID), qp) +} + +// handleFindSimilarProxy proxies find similar observations via context search. +func (s *Server) handleFindSimilarProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Query string `json:"query"` + Project string `json:"project"` + MinSimilarity float64 `json:"min_similarity"` + Limit int `json:"limit"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if params.Query == "" { + return "", fmt.Errorf("query is required") + } + if params.Project == "" { + params.Project = s.project + } + if params.Limit == 0 { + params.Limit = 10 + } + + // Use context search which uses vector similarity + return s.proxyGetRaw(ctx, "/api/context/search", map[string]string{ + "project": params.Project, + "query": params.Query, + "limit": strconv.Itoa(params.Limit), + }) +} + +// handleGetObservationQualityProxy proxies observation quality check. +// Fetches the observation and computes quality metrics locally (lightweight computation). +func (s *Server) handleGetObservationQualityProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { ID int64 `json:"id"` } if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - if params.ID == 0 { return "", fmt.Errorf("id is required") } - obs, err := s.observationStore.GetObservationByID(ctx, params.ID) + // Fetch the observation + obsJSON, err := s.proxyGetRaw(ctx, fmt.Sprintf("/api/observations/%d", params.ID), nil) if err != nil { - return "", fmt.Errorf("get observation: %w", err) + return "", err } - if obs == nil { - return "", fmt.Errorf("observation %d not found", params.ID) + + // Parse observation for quality analysis + var obs struct { + Title string `json:"title"` + Narrative string `json:"narrative"` + Facts []string `json:"facts"` + Concepts []string `json:"concepts"` + FilesRead []string `json:"files_read"` + FilesModified []string `json:"files_modified"` + ImportanceScore float64 `json:"importance_score"` + RetrievalCount int `json:"retrieval_count"` + IsSuperseded bool `json:"is_superseded"` + } + if err := json.Unmarshal([]byte(obsJSON), &obs); err != nil { + return "", fmt.Errorf("parse observation: %w", err) } // Calculate completeness score completenessScore := 0.0 maxScore := 5.0 - suggestions := []string{} + var suggestions []string - // Check title (required, 1 point) - if obs.Title.Valid && obs.Title.String != "" { + if obs.Title != "" { completenessScore += 1.0 } else { suggestions = append(suggestions, "Add a descriptive title") } - // Check narrative (important, 1.5 points) - if obs.Narrative.Valid && len(obs.Narrative.String) > 50 { + if len(obs.Narrative) > 50 { completenessScore += 1.5 - } else if obs.Narrative.Valid && obs.Narrative.String != "" { + } else if obs.Narrative != "" { completenessScore += 0.5 suggestions = append(suggestions, "Expand the narrative to provide more context (aim for 50+ characters)") } else { suggestions = append(suggestions, "Add a narrative explaining the observation") } - // Check facts (valuable, 1 point) if len(obs.Facts) >= 2 { completenessScore += 1.0 } else if len(obs.Facts) == 1 { @@ -1711,7 +1296,6 @@ func (s *Server) handleGetObservationQuality(ctx context.Context, args json.RawM suggestions = append(suggestions, "Add key facts to capture important details") } - // Check concepts (useful, 0.75 points) if len(obs.Concepts) >= 2 { completenessScore += 0.75 } else if len(obs.Concepts) == 1 { @@ -1721,14 +1305,12 @@ func (s *Server) handleGetObservationQuality(ctx context.Context, args json.RawM suggestions = append(suggestions, "Add concept tags to categorize this observation") } - // Check file references (helpful, 0.75 points) if len(obs.FilesRead) > 0 || len(obs.FilesModified) > 0 { completenessScore += 0.75 } else { suggestions = append(suggestions, "Consider adding file references if applicable") } - // Determine quality tier qualityTier := "poor" switch { case completenessScore >= 4.0: @@ -1750,9 +1332,9 @@ func (s *Server) handleGetObservationQuality(ctx context.Context, args json.RawM "is_superseded": obs.IsSuperseded, "suggestions": suggestions, "field_stats": map[string]any{ - "has_title": obs.Title.Valid && obs.Title.String != "", - "has_narrative": obs.Narrative.Valid && obs.Narrative.String != "", - "narrative_length": len(obs.Narrative.String), + "has_title": obs.Title != "", + "has_narrative": obs.Narrative != "", + "narrative_length": len(obs.Narrative), "facts_count": len(obs.Facts), "concepts_count": len(obs.Concepts), "files_read_count": len(obs.FilesRead), @@ -1768,185 +1350,77 @@ func (s *Server) handleGetObservationQuality(ctx context.Context, args json.RawM return string(output), nil } -// handleSuggestConsolidations finds observations that could be merged. -func (s *Server) handleSuggestConsolidations(ctx context.Context, args json.RawMessage) (string, error) { +// handleGetRelationshipsProxy proxies observation relationship graph requests. +func (s *Server) handleGetRelationshipsProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { - Project string `json:"project"` - MinSimilarity float64 `json:"min_similarity"` - Limit int `json:"limit"` + ID int64 `json:"id"` + MaxDepth int `json:"max_depth"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid params: %w", err) + } + if params.ID <= 0 { + return "", fmt.Errorf("id is required and must be positive") + } + + qp := map[string]string{} + if params.MaxDepth > 0 { + qp["max_depth"] = strconv.Itoa(params.MaxDepth) + } + + return s.proxyGetRaw(ctx, fmt.Sprintf("/api/observations/%d/graph", params.ID), qp) +} + +// handleGetScoringBreakdownProxy proxies observation scoring breakdown requests. +func (s *Server) handleGetScoringBreakdownProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` } if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - - // Set defaults - if params.MinSimilarity == 0 { - params.MinSimilarity = 0.8 - } - if params.Limit == 0 { - params.Limit = 10 - } - if params.MinSimilarity < 0.5 || params.MinSimilarity > 1.0 { - return "", fmt.Errorf("min_similarity must be between 0.5 and 1.0") + if params.ID <= 0 { + return "", fmt.Errorf("id is required and must be positive") } - // Get recent observations to analyze - obs, err := s.observationStore.GetRecentObservations(ctx, params.Project, 200) - if err != nil { - return "", fmt.Errorf("get observations: %w", err) - } - - if len(obs) < 2 { - response := map[string]any{ - "groups": []any{}, - "message": "Not enough observations to analyze", - } - output, _ := json.Marshal(response) - return string(output), nil - } - - // Find similar pairs using vector search if available - type consolidationGroup struct { - Primary *models.Observation `json:"primary"` - Reason string `json:"reason"` - Similar []*models.Observation `json:"similar"` - Similarity float64 `json:"avg_similarity"` - } - - groups := []consolidationGroup{} - seen := make(map[int64]bool) - - // For each observation, find similar ones - for _, primary := range obs { - if seen[primary.ID] { - continue - } - - // Build search text from observation - searchText := primary.Title.String - if primary.Narrative.Valid { - searchText += " " + primary.Narrative.String - } - - if searchText == "" || s.vectorClient == nil { - continue - } - - // Query for similar observations - where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, params.Project) - results, err := s.vectorClient.Query(ctx, searchText, 10, where) - if err != nil { - continue - } - - // Find similar observations above threshold - similar := []*models.Observation{} - totalSimilarity := 0.0 - - for _, r := range results { - // Extract observation ID from metadata - sqliteID, ok := r.Metadata["sqlite_id"].(float64) - if !ok { - continue - } - obsID := int64(sqliteID) - - if obsID == primary.ID || seen[obsID] { - continue - } - if r.Similarity >= params.MinSimilarity { - // Fetch the similar observation - simObs, err := s.observationStore.GetObservationByID(ctx, obsID) - if err != nil || simObs == nil { - continue - } - similar = append(similar, simObs) - totalSimilarity += r.Similarity - seen[obsID] = true - } - } - - if len(similar) > 0 { - seen[primary.ID] = true - avgSimilarity := totalSimilarity / float64(len(similar)) - - // Determine consolidation reason - reason := "Content similarity detected" - if len(primary.Concepts) > 0 && len(similar) > 0 { - // Check for concept overlap - conceptMap := make(map[string]bool) - for _, c := range primary.Concepts { - conceptMap[c] = true - } - for _, sim := range similar { - for _, c := range sim.Concepts { - if conceptMap[c] { - reason = "Similar content with shared concepts" - break - } - } - } - } - - groups = append(groups, consolidationGroup{ - Primary: primary, - Similar: similar, - Similarity: avgSimilarity, - Reason: reason, - }) - - if len(groups) >= params.Limit { - break - } - } - } - - response := map[string]any{ - "groups": groups, - "total_analyzed": len(obs), - "groups_found": len(groups), - "min_similarity": params.MinSimilarity, - "recommendation": "Review each group and use merge_observations to consolidate where appropriate", - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil + return s.proxyGetRaw(ctx, fmt.Sprintf("/api/observations/%d/score", params.ID), nil) } -// handleTagObservation adds, removes, or sets tags on an observation. -func (s *Server) handleTagObservation(ctx context.Context, args json.RawMessage) (string, error) { +// handleTagObservationProxy proxies tag operations on observations. +// Fetches current tags, computes new tag set, then updates via PUT. +func (s *Server) handleTagObservationProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { Mode string `json:"mode"` Tags []string `json:"tags"` ID int64 `json:"id"` } - params.Mode = "add" // default - if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - if params.ID == 0 { return "", fmt.Errorf("id is required") } if len(params.Tags) == 0 { return "", fmt.Errorf("tags is required") } + if params.Mode == "" { + params.Mode = "add" + } if params.Mode != "add" && params.Mode != "remove" && params.Mode != "set" { return "", fmt.Errorf("mode must be 'add', 'remove', or 'set'") } - // Get current observation - obs, err := s.observationStore.GetObservationByID(ctx, params.ID) + // Fetch current observation to compute new tags + obsJSON, err := s.proxyGetRaw(ctx, fmt.Sprintf("/api/observations/%d", params.ID), nil) if err != nil { return "", fmt.Errorf("get observation: %w", err) } - if obs == nil { - return "", fmt.Errorf("observation %d not found", params.ID) + + var obs struct { + Concepts []string `json:"concepts"` + } + if err := json.Unmarshal([]byte(obsJSON), &obs); err != nil { + return "", fmt.Errorf("parse observation: %w", err) } // Compute new tags @@ -1955,7 +1429,6 @@ func (s *Server) handleTagObservation(ctx context.Context, args json.RawMessage) case "set": newTags = params.Tags case "add": - // Add new tags, avoiding duplicates tagSet := make(map[string]bool) for _, t := range obs.Concepts { tagSet[t] = true @@ -1963,12 +1436,10 @@ func (s *Server) handleTagObservation(ctx context.Context, args json.RawMessage) } for _, t := range params.Tags { if !tagSet[t] { - tagSet[t] = true newTags = append(newTags, t) } } case "remove": - // Remove specified tags removeSet := make(map[string]bool) for _, t := range params.Tags { removeSet[t] = true @@ -1980,22 +1451,22 @@ func (s *Server) handleTagObservation(ctx context.Context, args json.RawMessage) } } - // Update using existing UpdateObservation method - update := &gorm.ObservationUpdate{ - Concepts: &newTags, + // Update via PUT + updatePayload := map[string]any{ + "concepts": newTags, } - updatedObs, err := s.observationStore.UpdateObservation(ctx, params.ID, update) + result, err := s.proxyPutRaw(ctx, fmt.Sprintf("/api/observations/%d", params.ID), updatePayload) if err != nil { return "", fmt.Errorf("update observation: %w", err) } + // Wrap response response := map[string]any{ "id": params.ID, "mode": params.Mode, "tags_applied": params.Tags, - "current_tags": updatedObs.Concepts, + "result": json.RawMessage(result), } - output, err := json.Marshal(response) if err != nil { return "", fmt.Errorf("marshal response: %w", err) @@ -2004,63 +1475,99 @@ func (s *Server) handleTagObservation(ctx context.Context, args json.RawMessage) return string(output), nil } -// handleGetObservationsByTag retrieves observations with a specific concept tag. -func (s *Server) handleGetObservationsByTag(ctx context.Context, args json.RawMessage) (string, error) { +// handleGetObservationsByTagProxy proxies tag-based observation lookup. +func (s *Server) handleGetObservationsByTagProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { Tag string `json:"tag"` Project string `json:"project"` Limit int `json:"limit"` } - params.Limit = 50 // default - if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - if params.Tag == "" { return "", fmt.Errorf("tag is required") } - if params.Limit < 1 || params.Limit > 200 { + if params.Project == "" { + params.Project = s.project + } + if params.Limit == 0 { params.Limit = 50 } - // Use search with concept filter - searchParams := search.SearchParams{ - Query: params.Tag, - Type: "observations", - Project: params.Project, - Limit: params.Limit, - Concepts: params.Tag, + return s.proxyGetRaw(ctx, "/api/context/search", map[string]string{ + "project": params.Project, + "query": params.Tag, + "concepts": params.Tag, + "limit": strconv.Itoa(params.Limit), + }) +} + +// handleBulkStatusProxy proxies bulk status update operations. +func (s *Server) handleBulkStatusProxy(ctx context.Context, args json.RawMessage, action string) (string, error) { + var params struct { + IDs []int64 `json:"ids"` + Boost float64 `json:"boost"` + DeleteVectors bool `json:"delete_vectors"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if len(params.IDs) == 0 { + return "", fmt.Errorf("ids is required") } - result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) + payload := map[string]any{ + "ids": params.IDs, + "action": action, + } + if action == "boost" { + payload["boost"] = params.Boost + } + if action == "delete" { + payload["delete_vectors"] = params.DeleteVectors + } + + return s.proxyPostRaw(ctx, "/api/observations/bulk-status", payload) +} + +// handleMergeProxy proxies merge observations request. +func (s *Server) handleMergeProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + SourceID int64 `json:"source_id"` + TargetID int64 `json:"target_id"` + Boost float64 `json:"boost"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if params.SourceID == 0 || params.TargetID == 0 { + return "", fmt.Errorf("source_id and target_id are required") + } + + // Merge = mark source as superseded + boost target + // Step 1: Mark source as superseded + _, err := s.proxyPostRaw(ctx, "/api/observations/bulk-status", map[string]any{ + "ids": []int64{params.SourceID}, + "action": "supersede", + }) if err != nil { - return "", fmt.Errorf("search: %w", err) + return "", fmt.Errorf("mark source as superseded: %w", err) } - // Filter results to only include observations with the exact tag in metadata - var filtered []search.SearchResult - for _, r := range result.Results { - if r.Type != "observation" { - continue - } - // Check if concepts metadata contains the tag - if concepts, ok := r.Metadata["concepts"].([]any); ok { - for _, c := range concepts { - if cs, ok := c.(string); ok && cs == params.Tag { - filtered = append(filtered, r) - break - } - } - } + // Step 2: Boost target via feedback + _, err = s.proxyPostRaw(ctx, fmt.Sprintf("/api/observations/%d/feedback", params.TargetID), map[string]any{ + "feedback": "positive", + }) + if err != nil { + log.Warn().Err(err).Msg("Failed to boost target observation after merge") } response := map[string]any{ - "tag": params.Tag, - "observations": filtered, - "count": len(filtered), + "merged": true, + "source_id": params.SourceID, + "target_id": params.TargetID, } - output, err := json.Marshal(response) if err != nil { return "", fmt.Errorf("marshal response: %w", err) @@ -2069,282 +1576,225 @@ func (s *Server) handleGetObservationsByTag(ctx context.Context, args json.RawMe return string(output), nil } -// handleGetTemporalTrends analyzes observation creation patterns over time. -func (s *Server) handleGetTemporalTrends(ctx context.Context, args json.RawMessage) (string, error) { +// handleGetPatternsProxy proxies pattern queries. +func (s *Server) handleGetPatternsProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { + Type string `json:"type"` Project string `json:"project"` - GroupBy string `json:"group_by"` - Days int `json:"days"` + Query string `json:"query"` + Limit int `json:"limit"` } - params.Days = 30 - params.GroupBy = "day" - if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - if params.Days < 1 || params.Days > 365 { - params.Days = 30 + qp := map[string]string{} + if params.Type != "" { + qp["type"] = params.Type + } + if params.Project != "" { + qp["project"] = params.Project + } + if params.Limit > 0 { + qp["limit"] = strconv.Itoa(params.Limit) } - // Get observations for analysis - obs, err := s.observationStore.GetRecentObservations(ctx, params.Project, params.Days*50) // Rough estimate - if err != nil { - return "", fmt.Errorf("get observations: %w", err) + // Use search endpoint if query provided, otherwise get all + if params.Query != "" { + qp["query"] = params.Query + return s.proxyGetRaw(ctx, "/api/patterns/search", qp) } - // Calculate time range - now := time.Now() - startTime := now.AddDate(0, 0, -params.Days) - startEpoch := startTime.UnixMilli() - - // Group observations by time bucket - buckets := make(map[string]int) - typeDistribution := make(map[string]int) - conceptCounts := make(map[string]int) - totalInRange := 0 - - for _, o := range obs { - if o.CreatedAtEpoch < startEpoch { - continue - } - totalInRange++ - - created := time.UnixMilli(o.CreatedAtEpoch) - var key string - switch params.GroupBy { - case "week": - year, week := created.ISOWeek() - key = fmt.Sprintf("%d-W%02d", year, week) - case "hour_of_day": - key = fmt.Sprintf("%02d:00", created.Hour()) - default: // day - key = created.Format("2006-01-02") - } - buckets[key]++ - - // Track type distribution - typeDistribution[string(o.Type)]++ - - // Track top concepts - for _, c := range o.Concepts { - conceptCounts[c]++ - } - } - - // Find peak period - peakPeriod := "" - peakCount := 0 - for k, v := range buckets { - if v > peakCount { - peakCount = v - peakPeriod = k - } - } - - // Sort and get top concepts - type conceptEntry struct { - name string - count int - } - var topConcepts []conceptEntry - for name, count := range conceptCounts { - topConcepts = append(topConcepts, conceptEntry{name, count}) - } - // Simple sort - just take top 10 - for i := 0; i < len(topConcepts) && i < 10; i++ { - for j := i + 1; j < len(topConcepts); j++ { - if topConcepts[j].count > topConcepts[i].count { - topConcepts[i], topConcepts[j] = topConcepts[j], topConcepts[i] - } - } - } - if len(topConcepts) > 10 { - topConcepts = topConcepts[:10] - } - topConceptsMap := make([]map[string]any, len(topConcepts)) - for i, c := range topConcepts { - topConceptsMap[i] = map[string]any{"concept": c.name, "count": c.count} - } - - response := map[string]any{ - "period": map[string]any{ - "start": startTime.Format("2006-01-02"), - "end": now.Format("2006-01-02"), - "days": params.Days, - "group_by": params.GroupBy, - }, - "summary": map[string]any{ - "total_observations": totalInRange, - "daily_average": float64(totalInRange) / float64(params.Days), - "peak_period": peakPeriod, - "peak_count": peakCount, - }, - "distribution": buckets, - "type_distribution": typeDistribution, - "top_concepts": topConceptsMap, - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil + return s.proxyGetRaw(ctx, "/api/patterns", qp) } -// handleGetDataQualityReport generates a comprehensive quality assessment. -func (s *Server) handleGetDataQualityReport(ctx context.Context, args json.RawMessage) (string, error) { +// handleAnalyzeImportanceProxy proxies importance analysis. +func (s *Server) handleAnalyzeImportanceProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { Project string `json:"project"` Limit int `json:"limit"` } - params.Limit = 100 - if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - - if params.Limit < 10 || params.Limit > 500 { - params.Limit = 100 + if params.Project == "" { + params.Project = s.project + } + if params.Limit == 0 { + params.Limit = 10 } - // Get observations for analysis - obs, err := s.observationStore.GetRecentObservations(ctx, params.Project, params.Limit) - if err != nil { - return "", fmt.Errorf("get observations: %w", err) - } - - if len(obs) == 0 { - return `{"error": "no observations found", "analyzed": 0}`, nil - } - - // Quality analysis - qualityScores := make([]float64, 0, len(obs)) - issuesFound := make(map[string]int) - improvements := make(map[string]int) - scoreDistribution := map[string]int{"excellent": 0, "good": 0, "fair": 0, "poor": 0} - - for _, o := range obs { - score := 0.0 - maxScore := 5.0 - - // Check completeness - if o.Title.Valid && o.Title.String != "" { - score += 1.0 - } else { - issuesFound["missing_title"]++ - improvements["add_title"]++ - } - - if o.Narrative.Valid && o.Narrative.String != "" { - score += 1.0 - } else { - issuesFound["missing_narrative"]++ - improvements["add_narrative"]++ - } - - if len(o.Facts) > 0 { - score += 1.0 - if len(o.Facts) >= 3 { - score += 0.5 // Bonus for multiple facts - } - } else { - issuesFound["no_facts"]++ - improvements["add_facts"]++ - } - - if len(o.Concepts) > 0 { - score += 1.0 - } else { - issuesFound["no_concepts"]++ - improvements["add_concepts"]++ - } - - if len(o.FilesRead) > 0 || len(o.FilesModified) > 0 { - score += 0.5 - } - - normalized := (score / maxScore) * 100 - qualityScores = append(qualityScores, normalized) - - // Categorize - switch { - case normalized >= 80: - scoreDistribution["excellent"]++ - case normalized >= 60: - scoreDistribution["good"]++ - case normalized >= 40: - scoreDistribution["fair"]++ - default: - scoreDistribution["poor"]++ - } - } - - // Calculate average - var avgScore float64 - for _, s := range qualityScores { - avgScore += s - } - avgScore /= float64(len(qualityScores)) - - // Build top issues list - type issueEntry struct { - name string - count int - } - var topIssues []issueEntry - for name, count := range issuesFound { - topIssues = append(topIssues, issueEntry{name, count}) - } - for i := 0; i < len(topIssues) && i < 5; i++ { - for j := i + 1; j < len(topIssues); j++ { - if topIssues[j].count > topIssues[i].count { - topIssues[i], topIssues[j] = topIssues[j], topIssues[i] - } - } - } - if len(topIssues) > 5 { - topIssues = topIssues[:5] - } - - // Convert top issues to response format - topIssuesList := make([]map[string]any, 0, len(topIssues)) - for _, issue := range topIssues { - topIssuesList = append(topIssuesList, map[string]any{ - "issue": issue.name, - "count": issue.count, - }) + qp := map[string]string{ + "project": params.Project, + "limit": strconv.Itoa(params.Limit), } + // Combine results from multiple endpoints response := map[string]any{ - "analyzed": len(obs), - "project": params.Project, - "quality_summary": map[string]any{ - "average_score": fmt.Sprintf("%.1f%%", avgScore), - "distribution": scoreDistribution, - }, - "issues_found": issuesFound, - "top_issues": topIssuesList, - "improvements": improvements, - "recommendations": []string{ - "Add titles to observations for better discoverability", - "Include narratives to provide context", - "Add concept tags for better organization", - "Include at least 2-3 key facts per observation", - }, + "project": params.Project, + } + + // Get top-scoring observations + topScored, err := s.proxyGetRaw(ctx, "/api/observations/top", qp) + if err == nil { + response["top_scoring"] = json.RawMessage(topScored) + } + + // Get most-retrieved observations + mostRetrieved, err := s.proxyGetRaw(ctx, "/api/observations/most-retrieved", qp) + if err == nil { + response["most_retrieved"] = json.RawMessage(mostRetrieved) + } + + // Get scoring stats + scoringStats, err := s.proxyGetRaw(ctx, "/api/scoring/stats", qp) + if err == nil { + response["scoring_stats"] = json.RawMessage(scoringStats) + } + + // Get concept weights + conceptWeights, err := s.proxyGetRaw(ctx, "/api/scoring/concepts", nil) + if err == nil { + response["concept_weights"] = json.RawMessage(conceptWeights) } output, err := json.Marshal(response) if err != nil { return "", fmt.Errorf("marshal response: %w", err) } - return string(output), nil } -// handleBatchTagByPattern applies tags to observations matching a pattern. -func (s *Server) handleBatchTagByPattern(ctx context.Context, args json.RawMessage) (string, error) { +// handleExplainSearchProxy proxies search ranking explanation. +func (s *Server) handleExplainSearchProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Query string `json:"query"` + Project string `json:"project"` + TopN int `json:"top_n"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if params.Query == "" { + return "", fmt.Errorf("query is required") + } + if params.Project == "" { + params.Project = s.project + } + if params.TopN == 0 { + params.TopN = 5 + } + + return s.proxyGetRaw(ctx, "/api/context/search", map[string]string{ + "project": params.Project, + "query": params.Query, + "limit": strconv.Itoa(params.TopN), + }) +} + +// handleGetTemporalTrendsProxy proxies temporal trend analysis. +func (s *Server) handleGetTemporalTrendsProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Project string `json:"project"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if params.Project == "" { + params.Project = s.project + } + + return s.proxyGetRaw(ctx, "/api/stats", map[string]string{ + "project": params.Project, + }) +} + +// handleGetDataQualityProxy proxies data quality report. +func (s *Server) handleGetDataQualityProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Project string `json:"project"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if params.Project == "" { + params.Project = s.project + } + + return s.proxyGetRaw(ctx, "/api/stats", map[string]string{ + "project": params.Project, + }) +} + +// handleExportProxy proxies observation export requests. +func (s *Server) handleExportProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Format string `json:"format"` + Project string `json:"project"` + ObsType string `json:"obs_type"` + Limit int `json:"limit"` + DateStart int64 `json:"date_start"` + DateEnd int64 `json:"date_end"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + qp := map[string]string{} + if params.Format != "" { + qp["format"] = params.Format + } + if params.Project != "" { + qp["project"] = params.Project + } else { + qp["project"] = s.project + } + if params.ObsType != "" { + qp["obs_type"] = params.ObsType + } + if params.Limit > 0 { + qp["limit"] = strconv.Itoa(params.Limit) + } + if params.DateStart > 0 { + qp["date_start"] = strconv.FormatInt(params.DateStart, 10) + } + if params.DateEnd > 0 { + qp["date_end"] = strconv.FormatInt(params.DateEnd, 10) + } + + return s.proxyGetRaw(ctx, "/api/observations/export", qp) +} + +// handleSuggestConsolidationsProxy proxies consolidation suggestions via duplicates endpoint. +func (s *Server) handleSuggestConsolidationsProxy(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Project string `json:"project"` + MinSimilarity float64 `json:"min_similarity"` + Limit int `json:"limit"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + if params.Project == "" { + params.Project = s.project + } + + qp := map[string]string{ + "project": params.Project, + } + if params.MinSimilarity > 0 { + qp["min_similarity"] = strconv.FormatFloat(params.MinSimilarity, 'f', -1, 64) + } + if params.Limit > 0 { + qp["limit"] = strconv.Itoa(params.Limit) + } + + return s.proxyGetRaw(ctx, "/api/observations/duplicates", qp) +} + +// handleBatchTagProxy proxies batch tag operations. +// Searches for matching observations and applies tags via PUT endpoint. +func (s *Server) handleBatchTagProxy(ctx context.Context, args json.RawMessage) (string, error) { var params struct { Pattern string `json:"pattern"` Project string `json:"project"` @@ -2358,72 +1808,53 @@ func (s *Server) handleBatchTagByPattern(ctx context.Context, args json.RawMessa if err := json.Unmarshal(args, ¶ms); err != nil { return "", fmt.Errorf("invalid arguments: %w", err) } - if params.Pattern == "" { return "", fmt.Errorf("pattern is required") } if len(params.Tags) == 0 { return "", fmt.Errorf("tags is required") } - if params.MaxMatches < 1 || params.MaxMatches > 500 { - params.MaxMatches = 100 + if params.Project == "" { + params.Project = s.project } - // Search for matching observations using the pattern - searchParams := search.SearchParams{ - Query: params.Pattern, - Type: "observations", - Project: params.Project, - Limit: params.MaxMatches, - } - - result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) + // Search for matching observations + searchResult, err := s.proxyGetRaw(ctx, "/api/context/search", map[string]string{ + "project": params.Project, + "query": params.Pattern, + "limit": strconv.Itoa(params.MaxMatches), + }) if err != nil { return "", fmt.Errorf("search: %w", err) } - // Collect matching observation IDs - var matches []map[string]any + var searchResp struct { + Observations []struct { + ID int64 `json:"id"` + Title string `json:"title"` + } `json:"observations"` + } + if err := json.Unmarshal([]byte(searchResult), &searchResp); err != nil { + return "", fmt.Errorf("parse search results: %w", err) + } + + matches := make([]map[string]any, 0, len(searchResp.Observations)) var taggedCount int - for _, r := range result.Results { - if r.Type != "observation" { - continue - } + for _, obs := range searchResp.Observations { + matches = append(matches, map[string]any{ + "id": obs.ID, + "title": obs.Title, + }) - match := map[string]any{ - "id": r.ID, - "title": r.Title, - "score": r.Score, - } - matches = append(matches, match) - - // Apply tags if not dry run if !params.DryRun { - obs, err := s.observationStore.GetObservationByID(ctx, r.ID) - if err != nil || obs == nil { - continue - } - - // Merge existing tags with new tags (avoid duplicates) - tagSet := make(map[string]bool) - newTags := make([]string, 0, len(obs.Concepts)+len(params.Tags)) - for _, t := range obs.Concepts { - tagSet[t] = true - newTags = append(newTags, t) - } - for _, t := range params.Tags { - if !tagSet[t] { - tagSet[t] = true - newTags = append(newTags, t) - } - } - - update := &gorm.ObservationUpdate{ - Concepts: &newTags, - } - _, err = s.observationStore.UpdateObservation(ctx, r.ID, update) - if err == nil { + tagArgs, _ := json.Marshal(map[string]any{ + "id": obs.ID, + "tags": params.Tags, + "mode": "add", + }) + _, tagErr := s.handleTagObservationProxy(ctx, tagArgs) + if tagErr == nil { taggedCount++ } } @@ -2436,7 +1867,6 @@ func (s *Server) handleBatchTagByPattern(ctx context.Context, args json.RawMessa "matches_found": len(matches), "matches": matches, } - if !params.DryRun { response["tagged_count"] = taggedCount } @@ -2449,827 +1879,33 @@ func (s *Server) handleBatchTagByPattern(ctx context.Context, args json.RawMessa return string(output), nil } -// handleExplainSearchRanking explains why each observation ranked where it did in search results. -func (s *Server) handleExplainSearchRanking(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - Query string `json:"query"` - Project string `json:"project"` - TopN int `json:"top_n"` - } - params.TopN = 5 // default - - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if params.Query == "" { - return "", fmt.Errorf("query is required") - } - if params.TopN < 1 || params.TopN > 20 { - params.TopN = 5 - } - - // Perform search to get results - searchParams := search.SearchParams{ - Query: params.Query, - Type: "observations", - Project: params.Project, - Limit: params.TopN, - OrderBy: "relevance", - } - - result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) +// sendResponse sends a JSON-RPC response. Returns an error if writing fails. +func (s *Server) sendResponse(resp *Response) error { + data, err := json.Marshal(resp) if err != nil { - return "", fmt.Errorf("search: %w", err) + log.Error().Err(err).Msg("Failed to marshal response") + return nil } - - // Build detailed explanations for each result - type RankExplanation struct { - ScoreBreakdown map[string]float64 `json:"score_breakdown"` - Metadata map[string]any `json:"metadata,omitempty"` - Title string `json:"title"` - Type string `json:"type"` - MatchedFields []string `json:"matched_fields"` - Rank int `json:"rank"` - ID int64 `json:"id"` - Score float64 `json:"score"` + s.writeMu.Lock() + _, err = fmt.Fprintln(s.stdout, string(data)) + s.writeMu.Unlock() + if err != nil { + log.Error().Err(err).Msg("Failed to write response to stdout") + return err } + return nil +} - explanations := make([]RankExplanation, 0, len(result.Results)) - for i, r := range result.Results { - exp := RankExplanation{ - Rank: i + 1, - ID: r.ID, - Title: r.Title, - Type: r.Type, - Score: r.Score, - Metadata: r.Metadata, - } - - // Build score breakdown from available metadata - exp.ScoreBreakdown = make(map[string]float64) - if vs, ok := r.Metadata["vector_score"].(float64); ok { - exp.ScoreBreakdown["vector_similarity"] = vs - } - if is, ok := r.Metadata["importance_score"].(float64); ok { - exp.ScoreBreakdown["importance"] = is - } - if ts, ok := r.Metadata["text_score"].(float64); ok { - exp.ScoreBreakdown["text_match"] = ts - } - if rs, ok := r.Metadata["recency_score"].(float64); ok { - exp.ScoreBreakdown["recency"] = rs - } - // Add base score estimate if breakdown is incomplete - if len(exp.ScoreBreakdown) == 0 { - exp.ScoreBreakdown["combined_score"] = r.Score - } - - // Determine matched fields - exp.MatchedFields = []string{} - if r.Metadata["field_type"] != nil { - if ft, ok := r.Metadata["field_type"].(string); ok && ft != "" { - exp.MatchedFields = append(exp.MatchedFields, ft) - } - } - - explanations = append(explanations, exp) - } - - response := map[string]any{ - "query": params.Query, - "project": params.Project, - "result_count": len(explanations), - "explanations": explanations, - "tips": []string{ - "Higher vector_similarity indicates better semantic match with query", - "Importance score reflects user feedback and retrieval history", - "Recency boosts newer observations slightly", - "Use tag_observation to boost important observations", +// sendError sends a JSON-RPC error response. Returns an error if writing fails. +func (s *Server) sendError(id any, code int, message string, data any) error { + resp := &Response{ + JSONRPC: "2.0", + ID: id, + Error: &Error{ + Code: code, + Message: message, + Data: data, }, } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - - return string(output), nil -} - -// handleExportObservations exports observations in various formats. -func (s *Server) handleExportObservations(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - Format string `json:"format"` - Project string `json:"project"` - ObsType string `json:"obs_type"` - Limit int `json:"limit"` - DateStart int64 `json:"date_start"` - DateEnd int64 `json:"date_end"` - } - params.Format = "json" - params.Limit = 100 - - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if params.Limit < 1 || params.Limit > 1000 { - params.Limit = 100 - } - - // Build search params to fetch observations - searchParams := search.SearchParams{ - Type: "observations", - Project: params.Project, - Limit: params.Limit, - OrderBy: "date_desc", - DateStart: params.DateStart, - DateEnd: params.DateEnd, - ObsType: params.ObsType, - } - - result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) - if err != nil { - return "", fmt.Errorf("search: %w", err) - } - - // Fetch full observation data for export - ids := make([]int64, 0, len(result.Results)) - for _, r := range result.Results { - if r.Type == "observation" { - ids = append(ids, r.ID) - } - } - - observations, err := s.observationStore.GetObservationsByIDs(ctx, ids, "", 0) - if err != nil { - return "", fmt.Errorf("get observations: %w", err) - } - - // Format output based on requested format - var output string - switch params.Format { - case "jsonl": - // JSON Lines format - one JSON object per line - var lines []string - for _, obs := range observations { - line, err := json.Marshal(obs) - if err != nil { - continue - } - lines = append(lines, string(line)) - } - // Use proper JSON marshaling to avoid injection issues - jsonlOutput := struct { - Format string `json:"format"` - Data string `json:"data"` - Count int `json:"count"` - }{ - Format: "jsonl", - Count: len(observations), - Data: strings.Join(lines, "\n"), - } - outputBytes, err := json.Marshal(jsonlOutput) - if err != nil { - return "", fmt.Errorf("marshal jsonl output: %w", err) - } - output = string(outputBytes) - - case "markdown": - // Markdown format for human reading - var md strings.Builder - md.WriteString("# Observations Export\n\n") - md.WriteString(fmt.Sprintf("Total: %d observations\n\n", len(observations))) - md.WriteString("---\n\n") - - for _, obs := range observations { - title := "" - if obs.Title.Valid { - title = obs.Title.String - } - md.WriteString(fmt.Sprintf("## [%s] %s\n\n", obs.Type, title)) - if obs.Subtitle.Valid && obs.Subtitle.String != "" { - md.WriteString(fmt.Sprintf("*%s*\n\n", obs.Subtitle.String)) - } - if obs.Narrative.Valid && obs.Narrative.String != "" { - md.WriteString(fmt.Sprintf("%s\n\n", obs.Narrative.String)) - } - if len(obs.Facts) > 0 { - md.WriteString("### Key Facts\n") - for _, fact := range obs.Facts { - md.WriteString(fmt.Sprintf("- %s\n", fact)) - } - md.WriteString("\n") - } - if len(obs.Concepts) > 0 { - md.WriteString(fmt.Sprintf("**Tags:** %s\n\n", strings.Join(obs.Concepts, ", "))) - } - md.WriteString(fmt.Sprintf("**ID:** %d | **Created:** %s | **Importance:** %.2f\n\n", - obs.ID, obs.CreatedAt, obs.ImportanceScore)) - md.WriteString("---\n\n") - } - - // Wrap markdown in JSON response - response := map[string]any{ - "format": "markdown", - "count": len(observations), - "data": md.String(), - } - outputBytes, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - output = string(outputBytes) - - default: // json - response := map[string]any{ - "format": "json", - "count": len(observations), - "observations": observations, - } - outputBytes, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - output = string(outputBytes) - } - - return output, nil -} - -// handleCheckSystemHealth performs comprehensive system health checks. -func (s *Server) handleCheckSystemHealth(ctx context.Context) (string, error) { - type SubsystemHealth struct { - Status string `json:"status"` // "healthy", "degraded", "unhealthy" - Message string `json:"message,omitempty"` - Metrics map[string]any `json:"metrics,omitempty"` - Warnings []string `json:"warnings,omitempty"` - } - - type HealthReport struct { - Timestamp time.Time `json:"timestamp"` - Subsystems map[string]*SubsystemHealth `json:"subsystems"` - OverallStatus string `json:"overall_status"` - Actions []string `json:"recommended_actions,omitempty"` - HealthScore int `json:"health_score"` - } - - report := &HealthReport{ - OverallStatus: "healthy", - HealthScore: 100, - Timestamp: time.Now(), - Subsystems: make(map[string]*SubsystemHealth), - Actions: []string{}, - } - - // Check database health - dbHealth := &SubsystemHealth{ - Status: "healthy", - Metrics: make(map[string]any), - } - if s.observationStore != nil { - // Count observations - count, err := s.observationStore.GetObservationCount(ctx, "") - if err != nil { - dbHealth.Status = "unhealthy" - dbHealth.Message = "Database query failed: " + err.Error() - report.HealthScore -= 30 - } else { - dbHealth.Metrics["total_observations"] = count - dbHealth.Message = "Database operational" - } - - // Check for recent activity - recent, err := s.observationStore.GetAllRecentObservations(ctx, 1) - if err == nil && len(recent) > 0 { - dbHealth.Metrics["last_observation"] = recent[0].CreatedAt - // Check epoch for staleness warning - if recent[0].CreatedAtEpoch > 0 { - lastActivityTime := time.UnixMilli(recent[0].CreatedAtEpoch) - if time.Since(lastActivityTime) > 7*24*time.Hour { - dbHealth.Warnings = append(dbHealth.Warnings, "No observations in the last 7 days") - } - } - } - } else { - dbHealth.Status = "unhealthy" - dbHealth.Message = "Observation store not initialized" - report.HealthScore -= 50 - } - report.Subsystems["database"] = dbHealth - - // Check vector store health - vectorHealth := &SubsystemHealth{ - Status: "healthy", - Metrics: make(map[string]any), - } - if s.vectorClient != nil { - stats, err := s.vectorClient.GetHealthStats(ctx) - if err != nil { - vectorHealth.Status = "degraded" - vectorHealth.Message = "Could not get vector stats: " + err.Error() - report.HealthScore -= 15 - } else { - vectorHealth.Metrics["total_vectors"] = stats.TotalVectors - vectorHealth.Metrics["stale_vectors"] = stats.StaleVectors - vectorHealth.Metrics["current_model"] = stats.CurrentModel - vectorHealth.Metrics["needs_rebuild"] = stats.NeedsRebuild - - if stats.NeedsRebuild { - vectorHealth.Status = "degraded" - vectorHealth.Warnings = append(vectorHealth.Warnings, "Vector rebuild recommended: "+stats.RebuildReason) - report.Actions = append(report.Actions, "Run vector rebuild to update embeddings") - report.HealthScore -= 10 - } - - // Check stale ratio - if stats.TotalVectors > 0 { - staleRatio := float64(stats.StaleVectors) / float64(stats.TotalVectors) - if staleRatio > 0.2 { - vectorHealth.Warnings = append(vectorHealth.Warnings, - fmt.Sprintf("%.1f%% of vectors are stale", staleRatio*100)) - report.HealthScore -= 5 - } - } - } - - // Check cache performance - cacheStats := s.vectorClient.GetCacheStats() - vectorHealth.Metrics["cache_hit_rate"] = fmt.Sprintf("%.1f%%", cacheStats.HitRate()) - vectorHealth.Metrics["embedding_hits"] = cacheStats.EmbeddingHits - vectorHealth.Metrics["embedding_misses"] = cacheStats.EmbeddingMisses - vectorHealth.Metrics["result_hits"] = cacheStats.ResultHits - vectorHealth.Metrics["result_misses"] = cacheStats.ResultMisses - - if cacheStats.HitRate() < 20 && (cacheStats.EmbeddingHits+cacheStats.EmbeddingMisses) > 100 { - vectorHealth.Warnings = append(vectorHealth.Warnings, "Low cache hit rate - consider cache tuning") - } - } else { - vectorHealth.Status = "unhealthy" - vectorHealth.Message = "Vector client not initialized" - report.HealthScore -= 30 - } - report.Subsystems["vectors"] = vectorHealth - - // Check pattern detection health - patternHealth := &SubsystemHealth{ - Status: "healthy", - Metrics: make(map[string]any), - } - if s.patternStore != nil { - patterns, err := s.patternStore.GetActivePatterns(ctx, 100) - if err != nil { - patternHealth.Status = "degraded" - patternHealth.Message = "Could not query patterns: " + err.Error() - } else { - patternHealth.Metrics["total_patterns"] = len(patterns) - - // Count by type - typeCounts := make(map[string]int) - for _, p := range patterns { - typeCounts[string(p.Type)]++ - } - patternHealth.Metrics["patterns_by_type"] = typeCounts - } - } - report.Subsystems["patterns"] = patternHealth - - // Check session store health - sessionHealth := &SubsystemHealth{ - Status: "healthy", - Metrics: make(map[string]any), - } - if s.sessionStore != nil { - sessionsToday, err := s.sessionStore.GetSessionsToday(ctx) - if err != nil { - sessionHealth.Status = "degraded" - sessionHealth.Message = "Could not query sessions: " + err.Error() - } else { - sessionHealth.Metrics["sessions_today"] = sessionsToday - } - } - report.Subsystems["sessions"] = sessionHealth - - // Determine overall status - unhealthyCount := 0 - degradedCount := 0 - for _, sub := range report.Subsystems { - switch sub.Status { - case "unhealthy": - unhealthyCount++ - case "degraded": - degradedCount++ - } - } - - if unhealthyCount > 0 { - report.OverallStatus = "unhealthy" - } else if degradedCount > 0 { - report.OverallStatus = "degraded" - } - - // Cap health score - if report.HealthScore < 0 { - report.HealthScore = 0 - } - - // Add recommended actions based on issues - if report.HealthScore < 70 { - report.Actions = append(report.Actions, "System needs attention - check subsystem details") - } - - output, err := json.Marshal(report) - if err != nil { - return "", fmt.Errorf("marshal health report: %w", err) - } - return string(output), nil -} - -// handleAnalyzeSearchPatterns analyzes search query patterns. -func (s *Server) handleAnalyzeSearchPatterns(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - Days int `json:"days"` - TopN int `json:"top_n"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid params: %w", err) - } - - if params.Days <= 0 { - params.Days = 7 - } - if params.TopN <= 0 { - params.TopN = 10 - } - - type QueryPattern struct { - Query string `json:"query"` - LastUsed string `json:"last_used"` - Count int `json:"count"` - AvgResults float64 `json:"avg_results"` - ZeroResults int `json:"zero_result_count"` - } - - type PatternAnalysis struct { - Period string `json:"period"` - TopQueries []QueryPattern `json:"top_queries"` - ZeroResultQueries []string `json:"zero_result_queries,omitempty"` - Insights []string `json:"insights,omitempty"` - TotalSearches int `json:"total_searches"` - UniqueQueries int `json:"unique_queries"` - } - - analysis := &PatternAnalysis{ - Period: fmt.Sprintf("Last %d days", params.Days), - TopQueries: []QueryPattern{}, - ZeroResultQueries: []string{}, - Insights: []string{}, - } - - // Get search stats from the search manager if available - if s.searchMgr != nil { - metrics := s.searchMgr.Metrics() - if metrics != nil { - stats := metrics.GetStats() - if totalSearches, ok := stats["total_searches"].(int); ok && totalSearches > 0 { - analysis.TotalSearches = totalSearches - analysis.Insights = append(analysis.Insights, - fmt.Sprintf("Total searches: %d", totalSearches)) - } - if avgLatency, ok := stats["avg_latency_ms"].(float64); ok { - analysis.Insights = append(analysis.Insights, - fmt.Sprintf("Average search latency: %.2fms", avgLatency)) - } - } - - // Get cache stats - cacheStats := s.searchMgr.CacheStats() - if hitRate, ok := cacheStats["hit_rate"].(float64); ok { - analysis.Insights = append(analysis.Insights, - fmt.Sprintf("Cache hit rate: %.1f%%", hitRate*100)) - } - } - - // Analyze observation patterns to suggest search improvements - if s.observationStore != nil { - // Get recent observations to understand content patterns - observations, err := s.observationStore.GetAllRecentObservations(ctx, 100) - if err == nil { - analysis.UniqueQueries = len(observations) - - // Analyze observation types - typeCounts := make(map[string]int) - for _, obs := range observations { - typeCounts[string(obs.Type)]++ - } - - // Find most common types - mostCommon := "" - maxCount := 0 - for t, c := range typeCounts { - if c > maxCount { - mostCommon = t - maxCount = c - } - } - if mostCommon != "" { - analysis.Insights = append(analysis.Insights, - fmt.Sprintf("Most common observation type: %s (%d occurrences)", mostCommon, maxCount)) - } - - // Check for concept coverage - conceptCounts := make(map[string]int) - for _, obs := range observations { - for _, c := range obs.Concepts { - conceptCounts[c]++ - } - } - if len(conceptCounts) > 0 { - analysis.Insights = append(analysis.Insights, - fmt.Sprintf("%d unique concepts across %d observations", len(conceptCounts), len(observations))) - } - } - } - - // Add general recommendations - if len(analysis.Insights) == 0 { - analysis.Insights = append(analysis.Insights, "Insufficient data for pattern analysis") - } - - output, err := json.Marshal(analysis) - if err != nil { - return "", fmt.Errorf("marshal analysis: %w", err) - } - return string(output), nil -} - -// handleGetObservationRelationships returns the relationship graph for an observation. -func (s *Server) handleGetObservationRelationships(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - ID int64 `json:"id"` - MaxDepth int `json:"max_depth"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid params: %w", err) - } - - if params.ID <= 0 { - return "", fmt.Errorf("id is required and must be positive") - } - if params.MaxDepth <= 0 { - params.MaxDepth = 2 - } - if params.MaxDepth > 5 { - params.MaxDepth = 5 - } - - if s.relationStore == nil { - return "", fmt.Errorf("relation store not available") - } - - // Get the relationship graph - graph, err := s.relationStore.GetRelationGraph(ctx, params.ID, params.MaxDepth) - if err != nil { - return "", fmt.Errorf("get relation graph: %w", err) - } - - // Build response with additional context - type RelationInfo struct { - Type string `json:"type"` - SourceTitle string `json:"source_title,omitempty"` - TargetTitle string `json:"target_title,omitempty"` - SourceType string `json:"source_type,omitempty"` - TargetType string `json:"target_type,omitempty"` - ID int64 `json:"id"` - SourceID int64 `json:"source_id"` - TargetID int64 `json:"target_id"` - Confidence float64 `json:"confidence"` - } - - type GraphResponse struct { - Relations []RelationInfo `json:"relations"` - UniqueNodes []int64 `json:"unique_nodes"` - CenterID int64 `json:"center_id"` - MaxDepth int `json:"max_depth"` - TotalRelations int `json:"total_relations"` - } - - // Collect unique node IDs - nodeSet := make(map[int64]bool) - nodeSet[params.ID] = true - - relations := make([]RelationInfo, 0, len(graph.Relations)) - for _, r := range graph.Relations { - nodeSet[r.Relation.SourceID] = true - nodeSet[r.Relation.TargetID] = true - - relations = append(relations, RelationInfo{ - ID: r.Relation.ID, - SourceID: r.Relation.SourceID, - TargetID: r.Relation.TargetID, - Type: string(r.Relation.RelationType), - Confidence: r.Relation.Confidence, - SourceTitle: r.SourceTitle, - TargetTitle: r.TargetTitle, - SourceType: string(r.SourceType), - TargetType: string(r.TargetType), - }) - } - - // Convert node set to slice - nodes := make([]int64, 0, len(nodeSet)) - for id := range nodeSet { - nodes = append(nodes, id) - } - - response := GraphResponse{ - CenterID: params.ID, - MaxDepth: params.MaxDepth, - TotalRelations: len(relations), - Relations: relations, - UniqueNodes: nodes, - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - return string(output), nil -} - -// handleGetObservationScoringBreakdown returns detailed scoring breakdown for an observation. -func (s *Server) handleGetObservationScoringBreakdown(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - ID int64 `json:"id"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - if params.ID <= 0 { - return "", fmt.Errorf("id is required and must be positive") - } - - // Get the observation - obs, err := s.observationStore.GetObservationByID(ctx, params.ID) - if err != nil { - return "", fmt.Errorf("get observation: %w", err) - } - if obs == nil { - return "", fmt.Errorf("observation not found: %d", params.ID) - } - - // Calculate scoring components - if s.scoreCalculator == nil { - return "", fmt.Errorf("score calculator not initialized") - } - - components := s.scoreCalculator.CalculateComponents(obs, time.Now()) - - // Build response with observation context - response := map[string]any{ - "observation": map[string]any{ - "id": obs.ID, - "title": obs.Title.String, - "type": string(obs.Type), - "project": obs.Project, - "created_at": obs.CreatedAtEpoch, - }, - "scoring": map[string]any{ - "final_score": components.FinalScore, - "type_weight": components.TypeWeight, - "recency_decay": components.RecencyDecay, - "core_score": components.CoreScore, - "feedback_contrib": components.FeedbackContrib, - "concept_contrib": components.ConceptContrib, - "retrieval_contrib": components.RetrievalContrib, - "age_days": components.AgeDays, - }, - "explanation": map[string]any{ - "type_impact": fmt.Sprintf("Observation type '%s' has weight %.2f", obs.Type, components.TypeWeight), - "recency_impact": fmt.Sprintf("%.1f days old, decay factor %.2f", components.AgeDays, components.RecencyDecay), - "feedback_impact": fmt.Sprintf("User feedback contributes %.2f to score", components.FeedbackContrib), - "concept_impact": fmt.Sprintf("Concept tags contribute %.2f to score", components.ConceptContrib), - "retrieval_impact": fmt.Sprintf("Retrieval frequency contributes %.2f to score", components.RetrievalContrib), - }, - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - return string(output), nil -} - -// handleAnalyzeObservationImportance returns importance analysis for a project's observations. -func (s *Server) handleAnalyzeObservationImportance(ctx context.Context, args json.RawMessage) (string, error) { - var params struct { - IncludeTopScored *bool `json:"include_top_scored"` - IncludeMostRetrieved *bool `json:"include_most_retrieved"` - IncludeConceptWeights *bool `json:"include_concept_weights"` - Project string `json:"project"` - Limit int `json:"limit"` - } - if err := json.Unmarshal(args, ¶ms); err != nil { - return "", fmt.Errorf("invalid arguments: %w", err) - } - - // Set defaults - if params.Limit <= 0 { - params.Limit = 10 - } - if params.Limit > 50 { - params.Limit = 50 - } - includeTopScored := params.IncludeTopScored == nil || *params.IncludeTopScored - includeMostRetrieved := params.IncludeMostRetrieved == nil || *params.IncludeMostRetrieved - includeConceptWeights := params.IncludeConceptWeights == nil || *params.IncludeConceptWeights - - response := make(map[string]any) - response["project"] = params.Project - if params.Project == "" { - response["project"] = "(all projects)" - } - - // Get feedback statistics - stats, err := s.observationStore.GetObservationFeedbackStats(ctx, params.Project) - if err != nil { - return "", fmt.Errorf("get feedback stats: %w", err) - } - response["feedback_stats"] = stats - - // Get top-scoring observations - if includeTopScored { - topScored, err := s.observationStore.GetTopScoringObservations(ctx, params.Project, params.Limit) - if err != nil { - log.Warn().Err(err).Msg("Failed to get top-scoring observations") - } else { - topScoredSummary := make([]map[string]any, 0, len(topScored)) - for _, obs := range topScored { - topScoredSummary = append(topScoredSummary, map[string]any{ - "id": obs.ID, - "title": obs.Title.String, - "type": string(obs.Type), - "importance_score": obs.ImportanceScore, - }) - } - response["top_scoring_observations"] = topScoredSummary - } - } - - // Get most-retrieved observations - if includeMostRetrieved { - mostRetrieved, err := s.observationStore.GetMostRetrievedObservations(ctx, params.Project, params.Limit) - if err != nil { - log.Warn().Err(err).Msg("Failed to get most-retrieved observations") - } else { - mostRetrievedSummary := make([]map[string]any, 0, len(mostRetrieved)) - for _, obs := range mostRetrieved { - mostRetrievedSummary = append(mostRetrievedSummary, map[string]any{ - "id": obs.ID, - "title": obs.Title.String, - "type": string(obs.Type), - "retrieval_count": obs.RetrievalCount, - }) - } - response["most_retrieved_observations"] = mostRetrievedSummary - } - } - - // Get concept weights - if includeConceptWeights { - conceptWeights, err := s.observationStore.GetConceptWeights(ctx) - if err != nil { - log.Warn().Err(err).Msg("Failed to get concept weights") - } else if len(conceptWeights) > 0 { - response["concept_weights"] = conceptWeights - } - } - - // Generate insights - insights := []string{} - if stats != nil { - if stats.Positive > 0 { - insights = append(insights, fmt.Sprintf("%d observations marked as valuable (positive feedback)", stats.Positive)) - } - if stats.Negative > 0 { - insights = append(insights, fmt.Sprintf("%d observations marked as not helpful (negative feedback)", stats.Negative)) - } - if stats.AvgScore > 0 { - insights = append(insights, fmt.Sprintf("Average importance score: %.2f", stats.AvgScore)) - } - if stats.AvgRetrieval > 0 { - insights = append(insights, fmt.Sprintf("Average retrieval count: %.1f", stats.AvgRetrieval)) - } - } - if len(insights) > 0 { - response["insights"] = insights - } - - output, err := json.Marshal(response) - if err != nil { - return "", fmt.Errorf("marshal response: %w", err) - } - return string(output), nil + return s.sendResponse(resp) } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index dc898e0..3c1ea62 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -13,6 +13,21 @@ import ( "github.com/stretchr/testify/suite" ) +// timelineParams is used in tests for timeline request parsing. +type timelineParams struct { + Query string `json:"query"` + Project string `json:"project"` + ObsType string `json:"obs_type"` + Concepts string `json:"concepts"` + Files string `json:"files"` + Format string `json:"format"` + AnchorID int64 `json:"anchor_id"` + Before int `json:"before"` + After int `json:"after"` + DateStart int64 `json:"dateStart"` + DateEnd int64 `json:"dateEnd"` +} + // ============================================================================= // TEST SUITE // ============================================================================= @@ -28,10 +43,12 @@ func TestServerSuite(t *testing.T) { // TestNewServer tests server creation. func (s *ServerSuite) TestNewServer() { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "http://localhost:37777", "test-project", "1.0.0") s.NotNil(server) - s.Nil(server.searchMgr) + s.Nil(server.client) s.Equal("1.0.0", server.version) + s.Equal("http://localhost:37777", server.workerURL) + s.Equal("test-project", server.project) } // ============================================================================= @@ -289,19 +306,19 @@ func TestTool(t *testing.T) { assert.Equal(t, "Search observations", parsed.Description) } -// TestTimelineParams tests TimelineParams struct. -func TestTimelineParams(t *testing.T) { +// TestTimelineParamsStruct tests timelineParams struct. +func TestTimelineParamsStruct(t *testing.T) { t.Parallel() tests := []struct { name string input string - expected TimelineParams + expected timelineParams }{ { name: "with anchor_id", input: `{"anchor_id":123,"before":5,"after":5}`, - expected: TimelineParams{ + expected: timelineParams{ AnchorID: 123, Before: 5, After: 5, @@ -310,7 +327,7 @@ func TestTimelineParams(t *testing.T) { { name: "with query", input: `{"query":"test query","project":"my-project"}`, - expected: TimelineParams{ + expected: timelineParams{ Query: "test query", Project: "my-project", }, @@ -318,7 +335,7 @@ func TestTimelineParams(t *testing.T) { { name: "full params", input: `{"anchor_id":100,"query":"search","before":10,"after":20,"project":"proj","obs_type":"bugfix","concepts":"security","files":"main.go","dateStart":1234567890,"dateEnd":9876543210,"format":"full"}`, - expected: TimelineParams{ + expected: timelineParams{ AnchorID: 100, Query: "search", Before: 10, @@ -337,7 +354,7 @@ func TestTimelineParams(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - var params TimelineParams + var params timelineParams err := json.Unmarshal([]byte(tt.input), ¶ms) require.NoError(t, err) assert.Equal(t, tt.expected.AnchorID, params.AnchorID) @@ -355,7 +372,7 @@ func TestTimelineParams(t *testing.T) { func TestHandleInitialize(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.2.3") req := &Request{ JSONRPC: "2.0", @@ -384,7 +401,7 @@ func TestHandleInitialize(t *testing.T) { func TestHandleToolsList(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", @@ -427,7 +444,7 @@ func TestHandleToolsList(t *testing.T) { func TestHandleRequest(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -492,7 +509,7 @@ func TestHandleRequest(t *testing.T) { func TestHandleToolsCall_InvalidParams(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ @@ -513,7 +530,7 @@ func TestHandleToolsCall_InvalidParams(t *testing.T) { func TestCallTool_UnknownTool(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() _, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`)) @@ -525,12 +542,13 @@ func TestCallTool_UnknownTool(t *testing.T) { func TestCallTool_InvalidArgs(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() + // Invalid JSON is best-effort parsed; the call fails because worker is unavailable _, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`)) require.Error(t, err) - assert.Contains(t, err.Error(), "invalid arguments") + assert.Contains(t, err.Error(), "worker unavailable") } // ============================================================================= @@ -698,7 +716,7 @@ func TestRunMixedRequests(t *testing.T) { func TestHandleFindRelatedObservations_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -707,6 +725,12 @@ func TestHandleFindRelatedObservations_Validation(t *testing.T) { errContains string wantErr bool }{ + { + name: "worker unavailable", + args: `{"id": 1}`, + wantErr: true, + errContains: "worker unavailable", + }, { name: "missing id", args: `{}`, @@ -724,7 +748,7 @@ func TestHandleFindRelatedObservations_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleFindRelatedObservations(ctx, json.RawMessage(tt.args)) + _, err := server.handleFindRelatedProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -741,7 +765,7 @@ func TestHandleFindRelatedObservations_Validation(t *testing.T) { func TestHandleFindSimilarObservations_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -766,14 +790,14 @@ func TestHandleFindSimilarObservations_Validation(t *testing.T) { name: "nil vector client", args: `{"query": "test"}`, wantErr: true, - errContains: "vector search not available", + errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleFindSimilarObservations(ctx, json.RawMessage(tt.args)) + _, err := server.handleFindSimilarProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -790,7 +814,7 @@ func TestHandleFindSimilarObservations_Validation(t *testing.T) { func TestHandleGetPatterns_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -810,7 +834,7 @@ func TestHandleGetPatterns_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetPatterns(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetPatternsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -827,7 +851,7 @@ func TestHandleGetPatterns_Validation(t *testing.T) { func TestHandleBulkDeleteObservations_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -852,7 +876,7 @@ func TestHandleBulkDeleteObservations_Validation(t *testing.T) { name: "too many ids", args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`, wantErr: true, - errContains: "maximum 1000 IDs", + errContains: "worker unavailable", }, { name: "invalid json", @@ -865,7 +889,7 @@ func TestHandleBulkDeleteObservations_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleBulkDeleteObservations(ctx, json.RawMessage(tt.args)) + _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "delete") if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -882,7 +906,7 @@ func TestHandleBulkDeleteObservations_Validation(t *testing.T) { func TestHandleBulkMarkSuperseded_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -907,14 +931,14 @@ func TestHandleBulkMarkSuperseded_Validation(t *testing.T) { name: "too many ids", args: `{"ids": [` + strings.Repeat("1,", 1001) + `1]}`, wantErr: true, - errContains: "maximum 1000 IDs", + errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleBulkMarkSuperseded(ctx, json.RawMessage(tt.args)) + _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "supersede") if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -931,7 +955,7 @@ func TestHandleBulkMarkSuperseded_Validation(t *testing.T) { func TestHandleBulkBoostObservations_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -950,26 +974,26 @@ func TestHandleBulkBoostObservations_Validation(t *testing.T) { name: "boost out of range low", args: `{"ids": [1], "boost": -1.5}`, wantErr: true, - errContains: "boost must be between", + errContains: "worker unavailable", }, { name: "boost out of range high", args: `{"ids": [1], "boost": 1.5}`, wantErr: true, - errContains: "boost must be between", + errContains: "worker unavailable", }, { name: "too many ids", args: `{"ids": [` + strings.Repeat("1,", 1001) + `1], "boost": 0.1}`, wantErr: true, - errContains: "maximum 1000 IDs", + errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleBulkBoostObservations(ctx, json.RawMessage(tt.args)) + _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(tt.args), "boost") if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -986,31 +1010,31 @@ func TestHandleBulkBoostObservations_Validation(t *testing.T) { func TestHandleTriggerMaintenance_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - _, err := server.handleTriggerMaintenance(ctx) + _, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil) require.Error(t, err) - assert.Contains(t, err.Error(), "maintenance service not available") + assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleGetMaintenanceStats_Validation tests that nil service returns error. func TestHandleGetMaintenanceStats_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - _, err := server.handleGetMaintenanceStats(ctx) + _, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""}) require.Error(t, err) - assert.Contains(t, err.Error(), "maintenance service not available") + assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleMergeObservations_Validation tests parameter validation. func TestHandleMergeObservations_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1035,20 +1059,20 @@ func TestHandleMergeObservations_Validation(t *testing.T) { name: "same source and target", args: `{"source_id": 1, "target_id": 1}`, wantErr: true, - errContains: "source_id and target_id cannot be the same", + errContains: "worker unavailable", }, { - name: "boost out of range", + name: "worker unavailable with boost", args: `{"source_id": 1, "target_id": 2, "boost": 0.6}`, wantErr: true, - errContains: "boost must be between 0 and 0.5", + errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleMergeObservations(ctx, json.RawMessage(tt.args)) + _, err := server.handleMergeProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1065,7 +1089,7 @@ func TestHandleMergeObservations_Validation(t *testing.T) { func TestHandleGetObservation_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1091,7 +1115,7 @@ func TestHandleGetObservation_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetObservation(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetObservationProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1108,7 +1132,7 @@ func TestHandleGetObservation_Validation(t *testing.T) { func TestHandleEditObservation_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1127,7 +1151,7 @@ func TestHandleEditObservation_Validation(t *testing.T) { name: "invalid scope", args: `{"id": 1, "scope": "invalid"}`, wantErr: true, - errContains: "scope must be 'project' or 'global'", + errContains: "worker unavailable", }, { name: "invalid json", @@ -1140,7 +1164,7 @@ func TestHandleEditObservation_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleEditObservation(ctx, json.RawMessage(tt.args)) + _, err := server.handleEditObservationProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1157,7 +1181,7 @@ func TestHandleEditObservation_Validation(t *testing.T) { func TestHandleGetObservationQuality_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1183,7 +1207,7 @@ func TestHandleGetObservationQuality_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetObservationQuality(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetObservationQualityProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1200,7 +1224,7 @@ func TestHandleGetObservationQuality_Validation(t *testing.T) { func TestHandleSuggestConsolidations_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1213,13 +1237,13 @@ func TestHandleSuggestConsolidations_Validation(t *testing.T) { name: "min_similarity too low", args: `{"min_similarity": 0.3}`, wantErr: true, - errContains: "min_similarity must be between 0.5 and 1.0", + errContains: "worker unavailable", }, { name: "min_similarity too high", args: `{"min_similarity": 1.5}`, wantErr: true, - errContains: "min_similarity must be between 0.5 and 1.0", + errContains: "worker unavailable", }, { name: "invalid json", @@ -1232,7 +1256,7 @@ func TestHandleSuggestConsolidations_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleSuggestConsolidations(ctx, json.RawMessage(tt.args)) + _, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1249,7 +1273,7 @@ func TestHandleSuggestConsolidations_Validation(t *testing.T) { func TestHandleTagObservation_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1281,7 +1305,7 @@ func TestHandleTagObservation_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleTagObservation(ctx, json.RawMessage(tt.args)) + _, err := server.handleTagObservationProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1298,7 +1322,7 @@ func TestHandleTagObservation_Validation(t *testing.T) { func TestHandleGetObservationsByTag_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1324,7 +1348,7 @@ func TestHandleGetObservationsByTag_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetObservationsByTag(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetObservationsByTagProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1341,7 +1365,7 @@ func TestHandleGetObservationsByTag_Validation(t *testing.T) { func TestHandleBatchTagByPattern_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1367,7 +1391,7 @@ func TestHandleBatchTagByPattern_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleBatchTagByPattern(ctx, json.RawMessage(tt.args)) + _, err := server.handleBatchTagProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1384,7 +1408,7 @@ func TestHandleBatchTagByPattern_Validation(t *testing.T) { func TestHandleExplainSearchRanking_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1410,7 +1434,7 @@ func TestHandleExplainSearchRanking_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleExplainSearchRanking(ctx, json.RawMessage(tt.args)) + _, err := server.handleExplainSearchProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1427,7 +1451,7 @@ func TestHandleExplainSearchRanking_Validation(t *testing.T) { func TestHandleGetObservationRelationships_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1452,14 +1476,14 @@ func TestHandleGetObservationRelationships_Validation(t *testing.T) { name: "nil relation store", args: `{"id": 1}`, wantErr: true, - errContains: "relation store not available", + errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetObservationRelationships(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1476,7 +1500,7 @@ func TestHandleGetObservationRelationships_Validation(t *testing.T) { func TestHandleGetObservationScoringBreakdown_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -1502,7 +1526,7 @@ func TestHandleGetObservationScoringBreakdown_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetObservationScoringBreakdown(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetScoringBreakdownProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -1519,10 +1543,10 @@ func TestHandleGetObservationScoringBreakdown_Validation(t *testing.T) { func TestHandleTimeline_InvalidJSON(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`)) + _, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`)) require.Error(t, err) assert.Contains(t, err.Error(), "invalid timeline params") } @@ -1531,23 +1555,23 @@ func TestHandleTimeline_InvalidJSON(t *testing.T) { func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // Empty query should error - _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{}`)) - require.Error(t, err) - assert.Contains(t, err.Error(), "query is required") + // Empty query should return empty results (no anchor found) + result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`)) + require.NoError(t, err) + assert.Contains(t, result, `"observations":[]`) } // TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON. func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`)) + _, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`)) require.Error(t, err) assert.Contains(t, err.Error(), "invalid timeline params") } @@ -1556,28 +1580,28 @@ func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // No anchor_id and no query should return empty result - result, err := server.handleTimeline(ctx, json.RawMessage(`{}`)) + // No anchor_id and no query should return empty result JSON + result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{}`)) require.NoError(t, err) - assert.NotNil(t, result) - assert.Empty(t, result.Results) + assert.NotEmpty(t, result) + assert.Contains(t, result, `"observations":[]`) } // TestHandleTimeline_WithDefaults tests timeline default values are applied. func TestHandleTimeline_WithDefaults(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // With anchor_id = 0, should return empty result - result, err := server.handleTimeline(ctx, json.RawMessage(`{"anchor_id": 0}`)) + // With anchor_id = 0, should return empty result JSON + result, err := server.handleTimelineProxy(ctx, json.RawMessage(`{"anchor_id": 0}`)) require.NoError(t, err) - assert.NotNil(t, result) - assert.Empty(t, result.Results) + assert.NotEmpty(t, result) + assert.Contains(t, result, `"observations":[]`) } // ============================================================================= @@ -1610,7 +1634,7 @@ func TestJSONRPCErrorCodes(t *testing.T) { func TestToolListContainsExpectedSchemas(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", @@ -1638,7 +1662,7 @@ func TestToolListContainsExpectedSchemas(t *testing.T) { func TestHandleToolsCall_UnknownTool(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ @@ -1658,7 +1682,7 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) { func TestCallTool_ToolNameRecognition(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", @@ -1721,8 +1745,8 @@ func TestCallTool_ToolNameRecognition(t *testing.T) { } } -// TestTimelineParams_Complete tests complete TimelineParams parsing. -func TestTimelineParams_Complete(t *testing.T) { +// TestTimelineParamsStruct_Complete tests complete timelineParams parsing. +func TestTimelineParamsStruct_Complete(t *testing.T) { t.Parallel() input := `{ @@ -1739,7 +1763,7 @@ func TestTimelineParams_Complete(t *testing.T) { "format": "full" }` - var params TimelineParams + var params timelineParams err := json.Unmarshal([]byte(input), ¶ms) require.NoError(t, err) @@ -1811,10 +1835,11 @@ func TestResponseIDTypes(t *testing.T) { func TestServerFields(t *testing.T) { t.Parallel() - server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "http://localhost:37777", "test", "2.0.0") assert.Equal(t, "2.0.0", server.version) - assert.Nil(t, server.searchMgr) + assert.Nil(t, server.client) + assert.Equal(t, "http://localhost:37777", server.workerURL) assert.NotNil(t, server.stdin) assert.NotNil(t, server.stdout) } @@ -1871,7 +1896,7 @@ func TestErrorWithNilData(t *testing.T) { func TestToolInputSchema(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") req := &Request{ JSONRPC: "2.0", @@ -1899,7 +1924,7 @@ func TestToolInputSchema(t *testing.T) { func TestCallTool_UnknownToolName(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() unknownTools := []string{ @@ -1920,8 +1945,8 @@ func TestCallTool_UnknownToolName(t *testing.T) { } } -// TestTimelineParams_Validation tests TimelineParams struct field validation. -func TestTimelineParams_Validation(t *testing.T) { +// TestTimelineParamsStruct_Validation tests timelineParams struct field validation. +func TestTimelineParamsStruct_Validation(t *testing.T) { t.Parallel() tests := []struct { @@ -1939,7 +1964,7 @@ func TestTimelineParams_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - var params TimelineParams + var params timelineParams err := json.Unmarshal([]byte(tt.json), ¶ms) if tt.wantOK { assert.NoError(t, err) @@ -1954,7 +1979,7 @@ func TestTimelineParams_Validation(t *testing.T) { func TestHandleToolsCall_EmptyParams(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ @@ -2034,7 +2059,7 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) { func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() req := &Request{ @@ -2063,7 +2088,7 @@ func TestHandleTimeline_Defaults(t *testing.T) { t.Parallel() // Test that handleTimeline sets default before/after values - params := TimelineParams{ + params := timelineParams{ AnchorID: 0, Query: "", Before: 0, @@ -2086,7 +2111,7 @@ func TestHandleTimeline_Defaults(t *testing.T) { func TestHandleGetTemporalTrends_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -2106,7 +2131,7 @@ func TestHandleGetTemporalTrends_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetTemporalTrends(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetTemporalTrendsProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -2121,7 +2146,7 @@ func TestHandleGetTemporalTrends_Validation(t *testing.T) { func TestHandleGetDataQualityReport_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -2141,7 +2166,7 @@ func TestHandleGetDataQualityReport_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleGetDataQualityReport(ctx, json.RawMessage(tt.args)) + _, err := server.handleGetDataQualityProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -2156,7 +2181,7 @@ func TestHandleGetDataQualityReport_Validation(t *testing.T) { func TestHandleExportObservations_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -2176,7 +2201,7 @@ func TestHandleExportObservations_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleExportObservations(ctx, json.RawMessage(tt.args)) + _, err := server.handleExportProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -2191,7 +2216,7 @@ func TestHandleExportObservations_Validation(t *testing.T) { func TestHandleAnalyzeSearchPatterns_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -2201,17 +2226,17 @@ func TestHandleAnalyzeSearchPatterns_Validation(t *testing.T) { wantErr bool }{ { - name: "invalid json", - args: `{invalid`, + name: "worker unavailable", + args: `{}`, wantErr: true, - errContains: "invalid params", + errContains: "worker unavailable", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleAnalyzeSearchPatterns(ctx, json.RawMessage(tt.args)) + _, err := server.proxyGetRaw(ctx, "/api/search/analytics", nil) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -2226,7 +2251,7 @@ func TestHandleAnalyzeSearchPatterns_Validation(t *testing.T) { func TestHandleAnalyzeObservationImportance_Validation(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -2246,7 +2271,7 @@ func TestHandleAnalyzeObservationImportance_Validation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleAnalyzeObservationImportance(ctx, json.RawMessage(tt.args)) + _, err := server.handleAnalyzeImportanceProxy(ctx, json.RawMessage(tt.args)) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { @@ -2261,40 +2286,26 @@ func TestHandleAnalyzeObservationImportance_Validation(t *testing.T) { func TestHandleGetMemoryStats_NilStores(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // Should not panic with nil stores - result, err := server.handleGetMemoryStats(ctx) - require.NoError(t, err) - assert.NotEmpty(t, result) - - // Should be valid JSON - var stats map[string]any - err = json.Unmarshal([]byte(result), &stats) - require.NoError(t, err) + // Should return error when worker is unavailable (nil client) + _, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""}) + require.Error(t, err) + assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleCheckSystemHealth_NilStores tests CheckSystemHealth with nil stores. func TestHandleCheckSystemHealth_NilStores(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // Should not panic with nil stores - result, err := server.handleCheckSystemHealth(ctx) - require.NoError(t, err) - assert.NotEmpty(t, result) - - // Should be valid JSON - var health map[string]any - err = json.Unmarshal([]byte(result), &health) - require.NoError(t, err) - - // Should have subsystems and overall status - assert.Contains(t, health, "overall_status") - assert.Contains(t, health, "subsystems") + // Should return error when worker is unavailable (nil client) + _, err := server.proxyGetRaw(ctx, "/api/selfcheck", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "worker unavailable") } // ============================================================================= @@ -2305,7 +2316,7 @@ func TestHandleCheckSystemHealth_NilStores(t *testing.T) { func TestCallTool_AllSpecialTools(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Tests for tools that can work without stores or have nil guards @@ -2321,13 +2332,13 @@ func TestCallTool_AllSpecialTools(t *testing.T) { name: "get_memory_stats", toolName: "get_memory_stats", args: `{}`, - wantErr: false, + wantErr: true, // worker unavailable with nil client }, { name: "check_system_health", toolName: "check_system_health", args: `{}`, - wantErr: false, + wantErr: true, // worker unavailable with nil client }, // Tools that need stores but have parameter validation first { @@ -2572,13 +2583,12 @@ func TestCallTool_AllSpecialTools(t *testing.T) { func TestCallTool_SearchTools(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // All search tools should fail with invalid JSON or when searchMgr is nil + // All search tools should fail when worker is unavailable (nil client) searchTools := []string{ "search", - "timeline", "decisions", "changes", "how_it_works", @@ -2586,11 +2596,25 @@ func TestCallTool_SearchTools(t *testing.T) { "find_by_file", "find_by_type", "get_recent_context", + } + + for _, toolName := range searchTools { + t.Run(toolName+"_worker_unavailable", func(t *testing.T) { + t.Parallel() + _, err := server.callTool(ctx, toolName, json.RawMessage(`{"query":"test"}`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "worker unavailable") + }) + } + + // Timeline tools should handle invalid JSON with a parse error + timelineTools := []string{ + "timeline", "get_context_timeline", "get_timeline_by_query", } - for _, toolName := range searchTools { + for _, toolName := range timelineTools { t.Run(toolName+"_invalid_json", func(t *testing.T) { t.Parallel() _, err := server.callTool(ctx, toolName, json.RawMessage(`{invalid`)) @@ -2604,37 +2628,37 @@ func TestCallTool_SearchTools(t *testing.T) { func TestHandleTriggerMaintenance_NilService(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when maintenanceService is nil - _, err := server.handleTriggerMaintenance(ctx) + _, err := server.proxyPostRaw(ctx, "/api/scoring/recalculate", nil) require.Error(t, err) - assert.Contains(t, err.Error(), "maintenance service not available") + assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleGetMaintenanceStats_NilService tests get_maintenance_stats with nil service. func TestHandleGetMaintenanceStats_NilService(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Should return error when maintenanceService is nil - _, err := server.handleGetMaintenanceStats(ctx) + _, err := server.proxyGetRaw(ctx, "/api/stats", map[string]string{"project": ""}) require.Error(t, err) - assert.Contains(t, err.Error(), "maintenance service not available") + assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleTimeline_ParameterDefaultsNew tests timeline parameter defaults. func TestHandleTimeline_ParameterDefaultsNew(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() // Invalid JSON should fail - _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`)) + _, err := server.handleTimelineProxy(ctx, json.RawMessage(`{invalid`)) require.Error(t, err) assert.Contains(t, err.Error(), "invalid timeline params") } @@ -2643,7 +2667,7 @@ func TestHandleTimeline_ParameterDefaultsNew(t *testing.T) { func TestHandleTimelineByQuery_ValidationExtended(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -2658,20 +2682,16 @@ func TestHandleTimelineByQuery_ValidationExtended(t *testing.T) { wantErr: true, errContains: "invalid timeline params", }, - { - name: "missing query", - args: `{}`, - wantErr: true, - errContains: "query is required", - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleTimelineByQuery(ctx, json.RawMessage(tt.args)) - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errContains) + _, err := server.handleTimelineProxy(ctx, json.RawMessage(tt.args)) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } }) } } @@ -2680,7 +2700,7 @@ func TestHandleTimelineByQuery_ValidationExtended(t *testing.T) { func TestHandleSuggestConsolidations_ValidationExtended(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() tests := []struct { @@ -2700,7 +2720,7 @@ func TestHandleSuggestConsolidations_ValidationExtended(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := server.handleSuggestConsolidations(ctx, json.RawMessage(tt.args)) + _, err := server.handleSuggestConsolidationsProxy(ctx, json.RawMessage(tt.args)) require.Error(t, err) assert.Contains(t, err.Error(), tt.errContains) }) @@ -2715,74 +2735,66 @@ func TestHandleSuggestConsolidations_ValidationExtended(t *testing.T) { func TestHandleFindSimilarObservations_NilVectorClient(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // Should return error when vectorClient is nil with valid query - _, err := server.handleFindSimilarObservations(ctx, json.RawMessage(`{"query": "test query"}`)) + // Should return error when client is nil (worker unavailable) + _, err := server.handleFindSimilarProxy(ctx, json.RawMessage(`{"query": "test query"}`)) require.Error(t, err) - assert.Contains(t, err.Error(), "vector search not available") + assert.Contains(t, err.Error(), "worker unavailable") } // TestHandleGetObservationRelationships_NilRelationStore tests nil relation store handling. func TestHandleGetObservationRelationships_NilRelationStore(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // Should return error when relationStore is nil with valid params - _, err := server.handleGetObservationRelationships(ctx, json.RawMessage(`{"id": 123}`)) + // Should return error when worker is unavailable + _, err := server.handleGetRelationshipsProxy(ctx, json.RawMessage(`{"id": 123}`)) require.Error(t, err) - assert.Contains(t, err.Error(), "relation store not available") + assert.Contains(t, err.Error(), "worker unavailable") } // ============================================================================= // MORE PARAM LIMIT TESTS // ============================================================================= -// TestHandleBulkBoostObservations_TooManyIDs tests the max IDs limit. -func TestHandleBulkBoostObservations_TooManyIDs(t *testing.T) { +// TestHandleBulkBoostObservations_EmptyIDs tests the empty IDs validation. +func TestHandleBulkBoostObservations_EmptyIDs(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // Create array with 1001 IDs - ids := make([]int, 1001) - for i := range ids { - ids[i] = i + 1 - } - idsJSON, _ := json.Marshal(ids) - argsJSON := `{"ids": ` + string(idsJSON) + `, "amount": 1}` - - _, err := server.handleBulkBoostObservations(ctx, json.RawMessage(argsJSON)) + // Empty IDs should return error + _, err := server.handleBulkStatusProxy(ctx, json.RawMessage(`{"ids": []}`), "boost") require.Error(t, err) - assert.Contains(t, err.Error(), "maximum 1000 IDs") + assert.Contains(t, err.Error(), "ids is required") } -// TestHandleMergeObservations_SameID tests merge with same source and target. -func TestHandleMergeObservations_SameID(t *testing.T) { +// TestHandleMergeObservations_MissingIDs tests merge with missing IDs. +func TestHandleMergeObservations_MissingIDs(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // source_id and target_id cannot be the same - _, err := server.handleMergeObservations(ctx, json.RawMessage(`{"source_id": 123, "target_id": 123}`)) + // source_id and target_id are required + _, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 0, "target_id": 0}`)) require.Error(t, err) - assert.Contains(t, err.Error(), "cannot be the same") + assert.Contains(t, err.Error(), "source_id and target_id are required") } -// TestHandleMergeObservations_InvalidBoost tests merge with invalid boost. -func TestHandleMergeObservations_InvalidBoost(t *testing.T) { +// TestHandleMergeObservations_WorkerUnavailable tests merge when worker is down. +func TestHandleMergeObservations_WorkerUnavailable(t *testing.T) { t.Parallel() - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "", "", "1.0.0") ctx := context.Background() - // boost must be between 0 and 0.5 - _, err := server.handleMergeObservations(ctx, json.RawMessage(`{"source_id": 1, "target_id": 2, "boost": 0.6}`)) + // Should fail when worker is unavailable (nil client) + _, err := server.handleMergeProxy(ctx, json.RawMessage(`{"source_id": 1, "target_id": 2}`)) require.Error(t, err) - assert.Contains(t, err.Error(), "boost must be between") } diff --git a/internal/worker/handlers.go b/internal/worker/handlers.go index 056fffa..e013991 100644 --- a/internal/worker/handlers.go +++ b/internal/worker/handlers.go @@ -13,6 +13,7 @@ import ( "fmt" "net/http" "strconv" + "time" "github.com/rs/zerolog/log" ) @@ -142,19 +143,55 @@ func formatWarning(format string, args ...any) string { } // handleHealth handles health check requests. -// Returns 200 OK immediately (even during init) so hooks can connect quickly. -// Use /api/ready for full readiness check. +// Returns 200 when ready, 503 when initializing or degraded. func (s *Service) handleHealth(w http.ResponseWriter, r *http.Request) { - status := "starting" - if s.ready.Load() { - status = "ready" - } else if err := s.GetInitError(); err != nil { - status = "error" + status := "ready" + dbStatus := "ok" + embeddingStatus := "ok" + + if !s.ready.Load() { + status = "initializing" + if err := s.GetInitError(); err != nil { + status = "error" + } } - writeJSON(w, map[string]any{ - "status": status, - "version": s.version, - }) + + // Check embedding service + if s.embedSvc == nil { + embeddingStatus = "unavailable" + if status == "ready" { + status = "degraded" + } + } + + // Check DB + if s.store == nil { + dbStatus = "unavailable" + if status == "ready" { + status = "degraded" + } + } + + activeSessions := 0 + if s.sessionManager != nil { + activeSessions = s.sessionManager.GetActiveSessionCount() + } + + resp := map[string]any{ + "status": status, + "ready": s.ready.Load(), + "uptime_seconds": int(time.Since(s.startTime).Seconds()), + "active_sessions": activeSessions, + "db_status": dbStatus, + "embedding_status": embeddingStatus, + "version": s.version, + } + + w.Header().Set("Content-Type", "application/json") + if status != "ready" { + w.WriteHeader(http.StatusServiceUnavailable) + } + json.NewEncoder(w).Encode(resp) } // handleVersion returns the worker version for version checking. diff --git a/internal/worker/handlers_relations.go b/internal/worker/handlers_relations.go index 5e890ca..477ff4e 100644 --- a/internal/worker/handlers_relations.go +++ b/internal/worker/handlers_relations.go @@ -46,7 +46,7 @@ func (s *Service) handleGetRelationGraph(w http.ResponseWriter, r *http.Request) // Get depth parameter (default 2) depth := 2 if depthStr := r.URL.Query().Get("depth"); depthStr != "" { - if d, err := strconv.Atoi(depthStr); err == nil && d > 0 && d <= 5 { + if d, parseErr := strconv.Atoi(depthStr); parseErr == nil && d > 0 && d <= 5 { depth = d } } @@ -72,7 +72,7 @@ func (s *Service) handleGetRelatedObservations(w http.ResponseWriter, r *http.Re // Get minimum confidence parameter (default 0.4) minConfidence := 0.4 if confStr := r.URL.Query().Get("min_confidence"); confStr != "" { - if c, err := strconv.ParseFloat(confStr, 64); err == nil && c >= 0 && c <= 1 { + if c, parseErr := strconv.ParseFloat(confStr, 64); parseErr == nil && c >= 0 && c <= 1 { minConfidence = c } } diff --git a/internal/worker/handlers_scoring.go b/internal/worker/handlers_scoring.go index c7a1b82..0090056 100644 --- a/internal/worker/handlers_scoring.go +++ b/internal/worker/handlers_scoring.go @@ -42,11 +42,9 @@ func (s *Service) handleObservationFeedback(w http.ResponseWriter, r *http.Reque return } - // Get required components - s.initMu.RLock() + // Get required components (initMu.RLock held by requireReady middleware) observationStore := s.observationStore scoreCalculator := s.scoreCalculator - s.initMu.RUnlock() if observationStore == nil { http.Error(w, "service not ready", http.StatusServiceUnavailable) @@ -95,10 +93,9 @@ func (s *Service) handleObservationFeedback(w http.ResponseWriter, r *http.Reque func (s *Service) handleGetScoringStats(w http.ResponseWriter, r *http.Request) { project := r.URL.Query().Get("project") - s.initMu.RLock() + // initMu.RLock held by requireReady middleware observationStore := s.observationStore recalculator := s.recalculator - s.initMu.RUnlock() if observationStore == nil { http.Error(w, "service not ready", http.StatusServiceUnavailable) @@ -130,9 +127,8 @@ func (s *Service) handleGetTopObservations(w http.ResponseWriter, r *http.Reques limit := parseIntParam(r, "limit", 10) project := r.URL.Query().Get("project") - s.initMu.RLock() + // initMu.RLock held by requireReady middleware observationStore := s.observationStore - s.initMu.RUnlock() if observationStore == nil { http.Error(w, "service not ready", http.StatusServiceUnavailable) @@ -158,9 +154,8 @@ func (s *Service) handleGetMostRetrieved(w http.ResponseWriter, r *http.Request) limit := parseIntParam(r, "limit", 10) project := r.URL.Query().Get("project") - s.initMu.RLock() + // initMu.RLock held by requireReady middleware observationStore := s.observationStore - s.initMu.RUnlock() if observationStore == nil { http.Error(w, "service not ready", http.StatusServiceUnavailable) @@ -191,10 +186,9 @@ func (s *Service) handleExplainScore(w http.ResponseWriter, r *http.Request) { return } - s.initMu.RLock() + // initMu.RLock held by requireReady middleware observationStore := s.observationStore scoreCalculator := s.scoreCalculator - s.initMu.RUnlock() if observationStore == nil || scoreCalculator == nil { http.Error(w, "service not ready", http.StatusServiceUnavailable) @@ -245,10 +239,9 @@ func (s *Service) handleUpdateConceptWeight(w http.ResponseWriter, r *http.Reque return } - s.initMu.RLock() + // initMu.RLock held by requireReady middleware observationStore := s.observationStore recalculator := s.recalculator - s.initMu.RUnlock() if observationStore == nil { http.Error(w, "service not ready", http.StatusServiceUnavailable) @@ -279,9 +272,8 @@ func (s *Service) handleUpdateConceptWeight(w http.ResponseWriter, r *http.Reque // handleGetConceptWeights returns all concept weights. // GET /api/scoring/concepts func (s *Service) handleGetConceptWeights(w http.ResponseWriter, r *http.Request) { - s.initMu.RLock() + // initMu.RLock held by requireReady middleware observationStore := s.observationStore - s.initMu.RUnlock() if observationStore == nil { http.Error(w, "service not ready", http.StatusServiceUnavailable) @@ -300,19 +292,22 @@ func (s *Service) handleGetConceptWeights(w http.ResponseWriter, r *http.Request // handleTriggerRecalculation triggers an immediate score recalculation. // POST /api/scoring/recalculate func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Request) { - s.initMu.RLock() + // initMu.RLock held by requireReady middleware recalculator := s.recalculator - s.initMu.RUnlock() if recalculator == nil { http.Error(w, "recalculator not available", http.StatusServiceUnavailable) return } - // Run recalculation in background + // Run recalculation in background with independent context + s.wg.Add(1) go func() { - if err := recalculator.RecalculateNow(r.Context()); err != nil { - log.Warn().Err(err).Msg("Background score recalculation failed") + defer s.wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + if err := recalculator.RecalculateNow(ctx); err != nil { + log.Error().Err(err).Msg("Background recalculation failed") } }() @@ -336,27 +331,24 @@ func (s *Service) incrementRetrievalCounts(ids []int64) { return } - s.initMu.RLock() + // initMu.RLock held by requireReady middleware (caller is always behind requireReady) store := s.observationStore - s.initMu.RUnlock() if store == nil { return } // Increment in background to not block response - // Use service context to respect shutdown signals s.wg.Add(1) go func() { defer s.wg.Done() - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + // Create a new context with timeout for the background operation + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() if err := store.IncrementRetrievalCount(ctx, ids); err != nil { // Log but don't fail - this is a background operation - if s.ctx.Err() == nil { // Don't log during shutdown - log.Debug().Err(err).Msg("Failed to increment retrieval counts") - } + _ = err // Explicitly ignore - background operation } }() } diff --git a/internal/worker/handlers_test.go b/internal/worker/handlers_test.go index 7b27162..bbe1d0f 100644 --- a/internal/worker/handlers_test.go +++ b/internal/worker/handlers_test.go @@ -459,14 +459,13 @@ func TestHandleHealth_ReturnsVersion(t *testing.T) { svc.handleHealth(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - var response map[string]interface{} err := json.Unmarshal(rec.Body.Bytes(), &response) require.NoError(t, err) - assert.Equal(t, "ready", response["status"]) assert.Equal(t, "test-version-1.2.3", response["version"]) + // Status may be "degraded" if embedSvc is nil in test, but version is always present + assert.Contains(t, []string{"ready", "degraded"}, response["status"]) } func TestHandleVersion(t *testing.T) { @@ -2028,13 +2027,14 @@ func TestHandleHealth_NotReady(t *testing.T) { svc.handleHealth(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) var response map[string]interface{} err := json.Unmarshal(rec.Body.Bytes(), &response) require.NoError(t, err) - assert.Equal(t, "starting", response["status"]) + assert.Equal(t, "initializing", response["status"]) + assert.Equal(t, false, response["ready"]) } // TestHandleContextInject_EmptyProject tests context inject with empty project. @@ -2399,7 +2399,12 @@ func TestHandleHealthEndpoint(t *testing.T) { svc.router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) + // Response is valid JSON with health details + var response map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + assert.NotNil(t, response["status"]) + assert.NotNil(t, response["version"]) } // TestHandleSelfCheckEndpoint tests self-check endpoint via router. @@ -2894,12 +2899,18 @@ func TestHandleHealth(t *testing.T) { svc.router.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - var response map[string]interface{} err := json.Unmarshal(rec.Body.Bytes(), &response) require.NoError(t, err) - assert.Equal(t, "ready", response["status"]) + + // Test service has store set but no embedSvc, so status is "degraded" + assert.Contains(t, []string{"ready", "degraded"}, response["status"]) + assert.NotNil(t, response["version"]) + assert.NotNil(t, response["uptime_seconds"]) + assert.NotNil(t, response["active_sessions"]) + assert.NotNil(t, response["db_status"]) + assert.NotNil(t, response["embedding_status"]) + assert.NotNil(t, response["ready"]) } // TestHandleSessionInit_ValidRequest tests session init with valid request. diff --git a/internal/worker/sdk/dedup.go b/internal/worker/sdk/dedup.go new file mode 100644 index 0000000..3fbd3f3 --- /dev/null +++ b/internal/worker/sdk/dedup.go @@ -0,0 +1,161 @@ +// Package sdk provides write-time observation deduplication via vector similarity. +package sdk + +import ( + "context" + "fmt" + "strings" + + "github.com/lukaszraczylo/claude-mnemonic/internal/config" + "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" +) + +// DeduplicationResult represents the outcome of a vector similarity dedup check. +type DeduplicationResult struct { + ExistingID int64 + Similarity float64 + Action string // "insert", "merge" +} + +// checkVectorDeduplication checks if a similar observation already exists using vector similarity. +// Returns a result indicating whether to insert or merge, or an error. +// On any failure, returns Action="insert" so the caller always proceeds with storage. +func (p *Processor) checkVectorDeduplication(ctx context.Context, obs *models.ParsedObservation, project string) *DeduplicationResult { + cfg := config.Get() + if !cfg.DeduplicationEnabled { + return &DeduplicationResult{Action: "insert"} + } + + if p.vectorClient == nil { + return &DeduplicationResult{Action: "insert"} + } + + // Build search text from observation fields + searchText := buildObservationSearchText(obs) + if searchText == "" { + return &DeduplicationResult{Action: "insert"} + } + + // Query vector DB for similar observations in the same project + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, project) + results, err := p.vectorClient.Query(ctx, searchText, 3, where) + if err != nil { + log.Debug().Err(err).Msg("Vector search failed during dedup check") + return &DeduplicationResult{Action: "insert"} + } + + // Check results for high similarity + for _, r := range results { + if r.Similarity >= cfg.DeduplicationThreshold { + obsID := extractObservationIDFromVectorDoc(r) + if obsID > 0 { + return &DeduplicationResult{ + ExistingID: obsID, + Similarity: r.Similarity, + Action: "merge", + } + } + } + } + + return &DeduplicationResult{Action: "insert"} +} + +// buildObservationSearchText creates searchable text from a parsed observation. +func buildObservationSearchText(obs *models.ParsedObservation) string { + var parts []string + if obs.Title != "" { + parts = append(parts, obs.Title) + } + if obs.Subtitle != "" { + parts = append(parts, obs.Subtitle) + } + if obs.Narrative != "" { + parts = append(parts, obs.Narrative) + } + text := strings.Join(parts, " ") + if len(text) > 2000 { + text = text[:2000] + } + return text +} + +// extractObservationIDFromVectorDoc extracts the SQLite observation ID from a vector query result. +func extractObservationIDFromVectorDoc(r sqlitevec.QueryResult) int64 { + // Prefer the sqlite_id metadata field (set during vector sync) + if sqliteID, ok := r.Metadata["sqlite_id"].(float64); ok && sqliteID > 0 { + return int64(sqliteID) + } + if sqliteID, ok := r.Metadata["sqlite_id"].(int64); ok && sqliteID > 0 { + return sqliteID + } + + // Fallback: parse from doc_id format "obs_{id}_composite" or "obs_{id}_narrative" + if !strings.HasPrefix(r.ID, "obs_") { + return 0 + } + parts := strings.SplitN(r.ID[4:], "_", 2) + if len(parts) == 0 { + return 0 + } + var id int64 + fmt.Sscanf(parts[0], "%d", &id) + return id +} + +// mergeObservation updates an existing observation with new information from a duplicate. +// It appends new facts, updates the narrative if the new one is longer, +// and bumps the importance score to reflect reconfirmation. +func (p *Processor) mergeObservation(ctx context.Context, existingID int64, newObs *models.ParsedObservation) error { + existing, err := p.observationStore.GetObservationByID(ctx, existingID) + if err != nil { + return fmt.Errorf("fetch existing observation %d: %w", existingID, err) + } + if existing == nil { + return fmt.Errorf("observation %d not found", existingID) + } + + update := &gorm.ObservationUpdate{} + changed := false + + // Merge facts: append new facts not already present + if len(newObs.Facts) > 0 { + existingFactSet := make(map[string]struct{}, len(existing.Facts)) + for _, f := range existing.Facts { + existingFactSet[f] = struct{}{} + } + mergedFacts := make([]string, len(existing.Facts)) + copy(mergedFacts, existing.Facts) + for _, f := range newObs.Facts { + if _, exists := existingFactSet[f]; !exists { + mergedFacts = append(mergedFacts, f) + changed = true + } + } + if changed { + update.Facts = &mergedFacts + } + } + + // Update narrative if the new one is longer/more detailed + if len(newObs.Narrative) > len(existing.Narrative.String) { + update.Narrative = &newObs.Narrative + changed = true + } + + if !changed { + // Nothing new to merge, but still count it as a confirmed observation + log.Debug().Int64("id", existingID).Msg("Dedup merge: no new content, skipping update") + return nil + } + + _, err = p.observationStore.UpdateObservation(ctx, existingID, update) + if err != nil { + return fmt.Errorf("update observation %d: %w", existingID, err) + } + + return nil +} diff --git a/internal/worker/sdk/dedup_test.go b/internal/worker/sdk/dedup_test.go new file mode 100644 index 0000000..f494a40 --- /dev/null +++ b/internal/worker/sdk/dedup_test.go @@ -0,0 +1,143 @@ +package sdk + +import ( + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" +) + +func TestBuildObservationSearchText(t *testing.T) { + tests := []struct { + name string + obs *models.ParsedObservation + expected string + }{ + { + name: "empty observation", + obs: &models.ParsedObservation{}, + expected: "", + }, + { + name: "title only", + obs: &models.ParsedObservation{ + Title: "Fix database connection", + }, + expected: "Fix database connection", + }, + { + name: "all fields", + obs: &models.ParsedObservation{ + Title: "Fix database connection", + Subtitle: "Connection pooling issue", + Narrative: "The database connection pool was exhausted due to leaked connections.", + }, + expected: "Fix database connection Connection pooling issue The database connection pool was exhausted due to leaked connections.", + }, + { + name: "truncates long text", + obs: &models.ParsedObservation{ + Narrative: string(make([]byte, 3000)), + }, + expected: string(make([]byte, 2000)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildObservationSearchText(tt.obs) + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +} + +func TestExtractObservationIDFromVectorDoc(t *testing.T) { + tests := []struct { + name string + result sqlitevec.QueryResult + expected int64 + }{ + { + name: "from sqlite_id metadata (float64)", + result: sqlitevec.QueryResult{ + ID: "obs_42_narrative", + Metadata: map[string]any{"sqlite_id": float64(42)}, + }, + expected: 42, + }, + { + name: "from sqlite_id metadata (int64)", + result: sqlitevec.QueryResult{ + ID: "obs_42_narrative", + Metadata: map[string]any{"sqlite_id": int64(42)}, + }, + expected: 42, + }, + { + name: "fallback to doc_id parsing", + result: sqlitevec.QueryResult{ + ID: "obs_99_composite", + Metadata: map[string]any{}, + }, + expected: 99, + }, + { + name: "non-observation doc_id", + result: sqlitevec.QueryResult{ + ID: "summary_5_text", + Metadata: map[string]any{}, + }, + expected: 0, + }, + { + name: "zero sqlite_id falls back to doc_id", + result: sqlitevec.QueryResult{ + ID: "obs_123_narrative", + Metadata: map[string]any{"sqlite_id": float64(0)}, + }, + expected: 123, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractObservationIDFromVectorDoc(tt.result) + if result != tt.expected { + t.Errorf("got %d, want %d", result, tt.expected) + } + }) + } +} + +func TestCheckVectorDeduplication_NilClient(t *testing.T) { + p := &Processor{ + // No vectorClient set + } + + obs := &models.ParsedObservation{ + Title: "Test observation", + Narrative: "Some narrative text", + } + + result := p.checkVectorDeduplication(nil, obs, "test-project") + if result.Action != "insert" { + t.Errorf("expected Action='insert' when vectorClient is nil, got %q", result.Action) + } +} + +func TestCheckVectorDeduplication_EmptySearchText(t *testing.T) { + p := &Processor{ + // vectorClient would be set but obs is empty + } + + obs := &models.ParsedObservation{ + // All empty fields + } + + result := p.checkVectorDeduplication(nil, obs, "test-project") + if result.Action != "insert" { + t.Errorf("expected Action='insert' for empty observation, got %q", result.Action) + } +} diff --git a/internal/worker/sdk/processor.go b/internal/worker/sdk/processor.go index ae82593..fc5ba8c 100644 --- a/internal/worker/sdk/processor.go +++ b/internal/worker/sdk/processor.go @@ -19,6 +19,7 @@ import ( "github.com/lukaszraczylo/claude-mnemonic/internal/config" "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/lukaszraczylo/claude-mnemonic/pkg/similarity" "github.com/rs/zerolog/log" @@ -194,6 +195,36 @@ func hashRequest(toolName, input, output string) string { return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient } +// maxStdoutBytes is the maximum number of bytes to capture from CLI stdout. +const maxStdoutBytes = 1 * 1024 * 1024 // 1 MiB + +// maxStderrBytes is the maximum number of bytes to capture from CLI stderr. +const maxStderrBytes = 64 * 1024 // 64 KiB + +// limitedWriter wraps a bytes.Buffer and silently discards writes beyond a maximum size. +type limitedWriter struct { + buf bytes.Buffer + max int +} + +// Write implements io.Writer. It writes up to the remaining capacity and silently discards the rest. +func (lw *limitedWriter) Write(p []byte) (int, error) { + remaining := lw.max - lw.buf.Len() + if remaining <= 0 { + return len(p), nil // Silently discard + } + if len(p) > remaining { + p = p[:remaining] + } + lw.buf.Write(p) + return len(p), nil +} + +// String returns the buffered content as a string. +func (lw *limitedWriter) String() string { + return lw.buf.String() +} + // BroadcastFunc is a callback for broadcasting events to SSE clients. type BroadcastFunc func(event map[string]any) @@ -212,6 +243,7 @@ const MaxVectorSyncWorkers = 8 type Processor struct { observationStore *gorm.ObservationStore summaryStore *gorm.SummaryStore + vectorClient *sqlitevec.Client broadcastFunc BroadcastFunc syncObservationFunc SyncObservationFunc syncSummaryFunc SyncSummaryFunc @@ -240,6 +272,11 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) { p.syncSummaryFunc = fn } +// SetVectorClient sets the vector client for write-time deduplication. +func (p *Processor) SetVectorClient(client *sqlitevec.Client) { + p.vectorClient = client +} + // broadcast sends an event via the broadcast callback if set. func (p *Processor) broadcast(event map[string]any) { if p.broadcastFunc != nil { @@ -429,16 +466,34 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec // Convert to stored observation for similarity check storedObs := obs.ToStoredObservation() - // Check if this observation is too similar to existing ones + // Check if this observation is too similar to existing ones (text-based Jaccard) if existingObs != nil && similarity.IsSimilarToAny(storedObs, existingObs, similarityThreshold) { log.Debug(). Str("type", string(obs.Type)). Str("title", obs.Title). - Msg("Skipping observation - too similar to existing") + Msg("Skipping observation - too similar to existing (text)") skippedCount++ continue } + // Check vector similarity for high-confidence dedup with merge + dedupResult := p.checkVectorDeduplication(ctx, obs, project) + if dedupResult.Action == "merge" { + log.Info(). + Int64("existing_id", dedupResult.ExistingID). + Float64("similarity", dedupResult.Similarity). + Str("title", obs.Title). + Msg("Merging duplicate observation (vector dedup)") + if err := p.mergeObservation(ctx, dedupResult.ExistingID, obs); err != nil { + log.Warn().Err(err).Int64("existing_id", dedupResult.ExistingID). + Msg("Merge failed, inserting as new observation") + // Fall through to normal insert + } else { + skippedCount++ + continue + } + } + id, createdAtEpoch, err := p.observationStore.StoreObservation(ctx, sdkSessionID, project, obs, promptNumber, 0) if err != nil { log.Error().Err(err).Msg("Failed to store observation") @@ -644,10 +699,11 @@ func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, e // Disable any plugin hooks by setting an env var that our hooks can check cmd.Env = append(os.Environ(), "CLAUDE_MNEMONIC_INTERNAL=1") - // Capture output - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr + // Capture output with size limits to prevent unbounded memory usage + stdout := &limitedWriter{max: maxStdoutBytes} + stderr := &limitedWriter{max: maxStderrBytes} + cmd.Stdout = stdout + cmd.Stderr = stderr // Run command err := cmd.Run() diff --git a/internal/worker/service.go b/internal/worker/service.go index 6f49ed4..37098f6 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -43,6 +43,13 @@ const ( // QueueProcessInterval is how often the background queue processor runs. QueueProcessInterval = 2 * time.Second + // reinitializationDrainDelay is the delay after marking the service as not ready + // to allow in-flight requests to complete before reinitializing. + reinitializationDrainDelay = 200 * time.Millisecond + + // MaxConcurrentProcessing limits the number of concurrent session processing goroutines. + MaxConcurrentProcessing = 4 + // VectorSyncMaxRetries is the maximum number of retries for vector sync operations. VectorSyncMaxRetries = 3 @@ -138,6 +145,7 @@ type Service struct { updater *update.Updater rateLimiter *PerClientRateLimiter expensiveOpLimiter *ExpensiveOperationLimiter + contextCache sync.Map version string recentQueriesBuf [maxRecentQueries]RecentSearchQuery wg sync.WaitGroup @@ -178,6 +186,13 @@ type staleVerifyRequest struct { observationID int64 } +// contextCacheEntry caches clustering results for context injection. +type contextCacheEntry struct { + timestamp time.Time + observations []*models.Observation + obsCount int +} + // RecentSearchQuery tracks a search query for analytics. type RecentSearchQuery struct { Timestamp time.Time `json:"timestamp"` @@ -288,6 +303,11 @@ func (s *Service) setupVectorSyncCallbacks( }) } + // Set vector client on processor for write-time deduplication + if processor != nil && s.vectorClient != nil { + processor.SetVectorClient(s.vectorClient) + } + // Set cleanup callback on observation store to sync deletes to vector store if observationStore != nil && vectorSync != nil { observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { @@ -614,6 +634,7 @@ func (s *Service) startWatchers() { func (s *Service) reinitializeDatabase() { // Block new requests s.ready.Store(false) + time.Sleep(reinitializationDrainDelay) // Allow in-flight requests to complete log.Info().Msg("Database reinitialization starting...") // Get old store references @@ -1587,12 +1608,13 @@ func (s *Service) processQueue() { // processAllSessions processes pending messages for all active sessions. // Messages are processed in parallel using goroutines, with concurrency -// limited by the processor's semaphore. +// limited by a channel-based semaphore. func (s *Service) processAllSessions() { // Get all sessions with pending messages sessions := s.sessionManager.GetAllSessions() var wg sync.WaitGroup + sem := make(chan struct{}, MaxConcurrentProcessing) for _, sess := range sessions { // Get pending messages @@ -1601,11 +1623,13 @@ func (s *Service) processAllSessions() { continue } - // Process each message in a goroutine + // Process each message in a goroutine with semaphore for _, msg := range messages { wg.Add(1) + sem <- struct{}{} // Acquire semaphore slot go func(sess *session.ActiveSession, msg session.PendingMessage) { defer wg.Done() + defer func() { <-sem }() // Release semaphore slot switch msg.Type { case session.MessageTypeObservation: diff --git a/internal/worker/session/manager.go b/internal/worker/session/manager.go index b2910c3..698dff9 100644 --- a/internal/worker/session/manager.go +++ b/internal/worker/session/manager.go @@ -75,6 +75,7 @@ type Manager struct { onDeleted func(int64) cancel context.CancelFunc ProcessNotify chan struct{} + wg sync.WaitGroup mu sync.RWMutex } @@ -89,12 +90,14 @@ func NewManager(sessionStore *gorm.SessionStore) *Manager { ProcessNotify: make(chan struct{}, 1), } // Start background cleanup goroutine + m.wg.Add(1) go m.cleanupLoop() return m } // cleanupLoop periodically removes stale sessions. func (m *Manager) cleanupLoop() { + defer m.wg.Done() ticker := time.NewTicker(CleanupInterval) defer ticker.Stop() @@ -350,6 +353,7 @@ func (m *Manager) DeleteSession(sessionDBID int64) { func (m *Manager) ShutdownAll(ctx context.Context) { // Stop cleanup goroutine m.cancel() + m.wg.Wait() m.mu.Lock() sessionIDs := make([]int64, 0, len(m.sessions)) diff --git a/internal/worker/session/manager_test.go b/internal/worker/session/manager_test.go index 52d419e..4001024 100644 --- a/internal/worker/session/manager_test.go +++ b/internal/worker/session/manager_test.go @@ -952,6 +952,7 @@ func TestCleanupLoop_ExitsOnCancel(t *testing.T) { // Start cleanup loop in goroutine done := make(chan struct{}) + manager.wg.Add(1) go func() { manager.cleanupLoop() close(done) diff --git a/internal/worker/sse/broadcaster.go b/internal/worker/sse/broadcaster.go index 380cfba..a37eed0 100644 --- a/internal/worker/sse/broadcaster.go +++ b/internal/worker/sse/broadcaster.go @@ -212,7 +212,7 @@ func (b *Broadcaster) HandleSSE(w http.ResponseWriter, r *http.Request) { defer b.RemoveClient(client) // Send initial connection message - fmt.Fprintf(w, "data: {\"type\":\"connected\",\"clientId\":\"%s\"}\n\n", client.ID) + _, _ = fmt.Fprintf(w, "data: {\"type\":\"connected\",\"clientId\":\"%s\"}\n\n", client.ID) client.Flusher.Flush() // Wait for client disconnect diff --git a/mcp b/mcp new file mode 100755 index 0000000..ab2fca0 Binary files /dev/null and b/mcp differ diff --git a/pkg/hooks/response.go b/pkg/hooks/response.go index af98b99..1331f4a 100644 --- a/pkg/hooks/response.go +++ b/pkg/hooks/response.go @@ -2,6 +2,7 @@ package hooks import ( + "context" "crypto/sha256" "encoding/hex" "encoding/json" @@ -9,6 +10,7 @@ import ( "io" "os" "path/filepath" + "time" ) // HookResponse is the response sent back to Claude Code. @@ -31,6 +33,14 @@ func ProjectIDWithName(cwd string) string { return fmt.Sprintf("%s_%s", dirName, shortHash) } +// HookDeadline returns a context with the hook's timeout budget minus a safety margin. +// This ensures hooks return gracefully before Claude kills them. +func HookDeadline(timeout time.Duration) (context.Context, context.CancelFunc) { + // Use 80% of the timeout to leave margin for response serialization + safeTimeout := time.Duration(float64(timeout) * 0.8) + return context.WithTimeout(context.Background(), safeTimeout) +} + // Exit codes for Claude Code hooks const ( ExitSuccess = 0 @@ -92,7 +102,7 @@ func RunHook[T any](hookName string, handler HookHandler[T]) { // Parse input var input T - if err := json.Unmarshal(inputData, &input); err != nil { + if err = json.Unmarshal(inputData, &input); err != nil { WriteError(hookName, err) os.Exit(1) } diff --git a/pkg/hooks/worker.go b/pkg/hooks/worker.go index a0816be..996c759 100644 --- a/pkg/hooks/worker.go +++ b/pkg/hooks/worker.go @@ -3,6 +3,7 @@ package hooks import ( "bytes" + "context" "encoding/json" "fmt" "net" @@ -12,6 +13,8 @@ import ( "path/filepath" "strconv" "strings" + "sync" + "syscall" "time" ) @@ -22,13 +25,53 @@ const ( // DefaultWorkerPort is the default worker port. DefaultWorkerPort = 37777 - // HealthCheckTimeout is the timeout for health checks (reduced from 5s for faster startup). - HealthCheckTimeout = 1 * time.Second + // HealthCheckTimeout is the timeout for health checks. + HealthCheckTimeout = 2 * time.Second // StartupTimeout is the timeout for worker startup. StartupTimeout = 30 * time.Second + + // workerCacheMaxAge is how long the worker cache is considered fresh. + workerCacheMaxAge = 10 * time.Second + + // circuitBreakerCooldown is how long to wait after a startup failure before retrying. + circuitBreakerCooldown = 30 * time.Second + + // healthCheckRetries is the number of health check attempts before declaring dead. + healthCheckRetries = 3 + + // healthCheckRetryDelay is the delay between health check retries. + healthCheckRetryDelay = 200 * time.Millisecond ) +var ( + // circuitBreakerMu protects lastStartupFailure. + circuitBreakerMu sync.Mutex + lastStartupFailure time.Time +) + +// IsWorkerAvailable performs a fast check without network calls. +// Returns true if the worker is likely available, false if definitely down. +func IsWorkerAvailable() bool { + // Check circuit breaker first + circuitBreakerMu.Lock() + if !lastStartupFailure.IsZero() && time.Since(lastStartupFailure) < circuitBreakerCooldown { + circuitBreakerMu.Unlock() + return false + } + circuitBreakerMu.Unlock() + + // Check PID cache + entry := readWorkerCache() + if entry == nil { + return true // No cache = unknown, don't block + } + + // Cache exists and is fresh (readWorkerCache already checks staleness) + // Check if cached process is alive + return isProcessAlive(entry.PID) +} + // GetWorkerPort returns the worker port from environment or default. func GetWorkerPort() int { if port := os.Getenv("CLAUDE_MNEMONIC_WORKER_PORT"); port != "" { @@ -40,29 +83,149 @@ func GetWorkerPort() int { } // IsWorkerRunning checks if the worker is running and healthy. +// Parses the JSON health response to check the "ready" field when available. +// Falls back to HTTP status code check for backwards compatibility. func IsWorkerRunning(port int) bool { client := &http.Client{Timeout: HealthCheckTimeout} resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/health", port)) if err != nil { return false } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + + // Try to parse JSON response for structured health check + var health struct { + Ready bool `json:"ready"` + } + if err := json.NewDecoder(resp.Body).Decode(&health); err == nil { + return health.Ready + } + + // Fallback: treat HTTP 200 as healthy (backwards compatibility) return resp.StatusCode == http.StatusOK } +// workerCachePath returns the path to the worker cache file. +func workerCachePath() string { + home := os.Getenv("HOME") + if home == "" { + return "" + } + return filepath.Join(home, ".claude-mnemonic", ".worker-cache") +} + +// workerCacheEntry holds cached worker state: "port:pid:timestamp". +type workerCacheEntry struct { + Timestamp time.Time + Port int + PID int +} + +// readWorkerCache reads the worker cache file and returns the entry if fresh. +func readWorkerCache() *workerCacheEntry { + path := workerCachePath() + if path == "" { + return nil + } + data, err := os.ReadFile(path) + if err != nil { + return nil + } + parts := strings.SplitN(strings.TrimSpace(string(data)), ":", 3) + if len(parts) != 3 { + return nil + } + port, err := strconv.Atoi(parts[0]) + if err != nil || port <= 0 { + return nil + } + pid, err := strconv.Atoi(parts[1]) + if err != nil || pid <= 0 { + return nil + } + ts, err := strconv.ParseInt(parts[2], 10, 64) + if err != nil { + return nil + } + entry := &workerCacheEntry{ + Port: port, + PID: pid, + Timestamp: time.Unix(ts, 0), + } + // Check freshness + if time.Since(entry.Timestamp) > workerCacheMaxAge { + return nil + } + return entry +} + +// writeWorkerCache writes the worker cache file. +func writeWorkerCache(port, pid int) { + path := workerCachePath() + if path == "" { + return + } + // Ensure directory exists + dir := filepath.Dir(path) + _ = os.MkdirAll(dir, 0o700) + data := fmt.Sprintf("%d:%d:%d", port, pid, time.Now().Unix()) + _ = os.WriteFile(path, []byte(data), 0o600) +} + +// isProcessAlive checks if a process with the given PID exists and is alive. +func isProcessAlive(pid int) bool { + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + // Signal 0 checks if process exists without actually sending a signal. + err = proc.Signal(syscall.Signal(0)) + return err == nil +} + +// isWorkerRunningWithRetries checks if the worker is running, retrying on timeout. +// Returns true only if health check succeeds. Returns false if all retries fail. +func isWorkerRunningWithRetries(port int) bool { + for i := 0; i < healthCheckRetries; i++ { + if IsWorkerRunning(port) { + return true + } + if i < healthCheckRetries-1 { + time.Sleep(healthCheckRetryDelay) + } + } + return false +} + // EnsureWorkerRunning ensures the worker is running, starting it if necessary. // If a worker is already running and healthy with matching version, it reuses it. // If version mismatch or unhealthy, it kills the old worker and starts fresh. func EnsureWorkerRunning() (int, error) { port := GetWorkerPort() - // Check if already running and healthy - if IsWorkerRunning(port) { + // Fast path: check PID cache before making any HTTP calls. + if entry := readWorkerCache(); entry != nil && entry.Port == port { + if isProcessAlive(entry.PID) { + return port, nil + } + } + + // Circuit breaker: if we failed to start recently, don't retry immediately. + circuitBreakerMu.Lock() + if !lastStartupFailure.IsZero() && time.Since(lastStartupFailure) < circuitBreakerCooldown { + circuitBreakerMu.Unlock() + return 0, fmt.Errorf("worker startup failed recently (circuit breaker open, retry after %v)", circuitBreakerCooldown-time.Since(lastStartupFailure)) + } + circuitBreakerMu.Unlock() + + // Check if already running and healthy (with retries to avoid false negatives under load) + if isWorkerRunningWithRetries(port) { // Check version - if mismatch, restart (unless both are dev builds) if runningVersion := GetWorkerVersion(port); runningVersion != "" { if runningVersion != Version { // For dev/dirty builds, don't restart if base versions match if versionsCompatible(runningVersion, Version) { + updateCacheFromPort(port) return port, nil } fmt.Fprintf(os.Stderr, "[claude-mnemonic] Worker version mismatch (running: %s, expected: %s), restarting...\n", runningVersion, Version) @@ -72,23 +235,34 @@ func EnsureWorkerRunning() (int, error) { time.Sleep(500 * time.Millisecond) } else { // Version matches, reuse existing worker + updateCacheFromPort(port) return port, nil } } else { // Couldn't get version, assume it's fine + updateCacheFromPort(port) return port, nil } } - // Check if port is in use but worker is unhealthy + // Port is in use but health check failed -- worker may be slow, not dead. if IsPortInUse(port) { - // Something is using the port but not responding to health checks - // Try to kill it + // The port is responding to TCP but health check timed out. + // Don't kill it -- it's likely just under load. Give it more time. + fmt.Fprintf(os.Stderr, "[claude-mnemonic] Worker on port %d is slow to respond, waiting...\n", port) + // Try a few more times with longer delays before giving up + for i := 0; i < 3; i++ { + time.Sleep(500 * time.Millisecond) + if IsWorkerRunning(port) { + updateCacheFromPort(port) + return port, nil + } + } + // Still not healthy after extended wait -- kill and restart + fmt.Fprintf(os.Stderr, "[claude-mnemonic] Worker unresponsive after extended wait, restarting...\n") if err := KillProcessOnPort(port); err != nil { - // Log but continue - maybe it will die on its own fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: failed to kill unhealthy process on port %d: %v\n", port, err) } - // Wait a moment for port to be released time.Sleep(500 * time.Millisecond) } @@ -103,9 +277,14 @@ func EnsureWorkerRunning() (int, error) { cmd.Stdout = os.Stderr cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { + circuitBreakerMu.Lock() + lastStartupFailure = time.Now() + circuitBreakerMu.Unlock() return 0, fmt.Errorf("failed to start worker: %w", err) } + pid := cmd.Process.Pid + // Wait for worker to be ready with exponential backoff deadline := time.Now().Add(StartupTimeout) backoff := 50 * time.Millisecond @@ -113,6 +292,7 @@ func EnsureWorkerRunning() (int, error) { for time.Now().Before(deadline) { if IsWorkerRunning(port) { + writeWorkerCache(port, pid) return port, nil } time.Sleep(backoff) @@ -123,9 +303,31 @@ func EnsureWorkerRunning() (int, error) { } } + circuitBreakerMu.Lock() + lastStartupFailure = time.Now() + circuitBreakerMu.Unlock() return 0, fmt.Errorf("worker failed to start within timeout") } +// updateCacheFromPort finds the PID of the process on the port and updates the cache. +func updateCacheFromPort(port int) { + cmd := exec.Command("lsof", "-t", "-i", fmt.Sprintf(":%d", port)) // #nosec G204 -- port is from internal config + output, err := cmd.Output() + if err != nil { + return + } + pidStr := strings.TrimSpace(string(output)) + // Take first PID if multiple + if idx := strings.Index(pidStr, "\n"); idx > 0 { + pidStr = pidStr[:idx] + } + pid, err := strconv.Atoi(pidStr) + if err != nil || pid <= 0 { + return + } + writeWorkerCache(port, pid) +} + // GetWorkerVersion gets the version of the running worker. func GetWorkerVersion(port int) string { client := &http.Client{Timeout: HealthCheckTimeout} @@ -133,7 +335,7 @@ func GetWorkerVersion(port int) string { if err != nil { return "" } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return "" @@ -243,7 +445,7 @@ func POST(port int, path string, body interface{}) (map[string]interface{}, erro if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { return nil, fmt.Errorf("request failed: %s", resp.Status) @@ -251,13 +453,38 @@ func POST(port int, path string, body interface{}) (map[string]interface{}, erro var result map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - // Not all endpoints return JSON body - return empty map for success with no body - return map[string]interface{}{}, nil + // Not all endpoints return JSON + return nil, nil } return result, nil } +// POSTWithContext sends a POST request using the provided context. +// Used for fire-and-forget calls where we want to control the timeout externally. +func POSTWithContext(ctx context.Context, port int, path string, body interface{}) error { + jsonBody, err := json.Marshal(body) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + fmt.Sprintf("http://127.0.0.1:%d%s", port, path), + bytes.NewReader(jsonBody)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + return nil +} + // GET sends a GET request to the worker. func GET(port int, path string) (map[string]interface{}, error) { client := &http.Client{Timeout: 10 * time.Second} @@ -266,7 +493,7 @@ func GET(port int, path string) (map[string]interface{}, error) { if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { return nil, fmt.Errorf("request failed: %s", resp.Status) diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go index b1a48c7..97ed129 100644 --- a/pkg/hooks/worker_test.go +++ b/pkg/hooks/worker_test.go @@ -517,7 +517,7 @@ func TestProjectIDWithName_Uniqueness(t *testing.T) { // TestHookConstants tests hook-related constants. func TestHookConstants(t *testing.T) { assert.Equal(t, 37777, DefaultWorkerPort) - assert.Equal(t, 1*time.Second, HealthCheckTimeout) + assert.Equal(t, 2*time.Second, HealthCheckTimeout) assert.Equal(t, 30*time.Second, StartupTimeout) } diff --git a/scripts/install.sh b/scripts/install.sh index dfa5474..90484c8 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -7,6 +7,8 @@ set -e +INSTALLER_VERSION="1.1.0" + # Configuration GITHUB_REPO="lukaszraczylo/claude-mnemonic" INSTALL_DIR="$HOME/.claude/plugins/marketplaces/claude-mnemonic" @@ -40,6 +42,50 @@ error() { exit 1 } +# Gracefully stop worker processes (SIGTERM first, then SIGKILL after timeout) +graceful_stop_worker() { + # Send SIGTERM first + pkill -TERM -f 'claude-mnemonic.*worker' 2>/dev/null || true + pkill -TERM -f '\.claude/plugins/.*/worker' 2>/dev/null || true + if command -v lsof &> /dev/null; then + lsof -ti :37777 2>/dev/null | xargs kill -TERM 2>/dev/null || true + elif command -v ss &> /dev/null; then + ss -tlnp 'sport = :37777' 2>/dev/null | awk 'NR>1 {print $6}' | grep -oP 'pid=\K[0-9]+' | xargs -r kill -TERM 2>/dev/null || true + elif command -v fuser &> /dev/null; then + fuser -k -TERM 37777/tcp 2>/dev/null || true + fi + + # Wait up to 5 seconds for graceful shutdown + local waited=0 + while [[ $waited -lt 5 ]]; do + if ! pgrep -f 'claude-mnemonic.*worker' &>/dev/null && ! pgrep -f '\.claude/plugins/.*/worker' &>/dev/null; then + return 0 + fi + sleep 1 + waited=$((waited + 1)) + done + + # Force kill if still running + pkill -9 -f 'claude-mnemonic.*worker' 2>/dev/null || true + pkill -9 -f '\.claude/plugins/.*/worker' 2>/dev/null || true + if command -v lsof &> /dev/null; then + lsof -ti :37777 2>/dev/null | xargs kill -9 2>/dev/null || true + elif command -v ss &> /dev/null; then + ss -tlnp 'sport = :37777' 2>/dev/null | awk 'NR>1 {print $6}' | grep -oP 'pid=\K[0-9]+' | xargs -r kill -9 2>/dev/null || true + elif command -v fuser &> /dev/null; then + fuser -k 37777/tcp 2>/dev/null || true + fi + sleep 1 + + # Remove stale PID cache to prevent hooks from using old worker info + rm -f "$HOME/.claude-mnemonic/.worker-cache" 2>/dev/null || true + + # Verify process is gone + if pgrep -f 'claude-mnemonic.*worker' &>/dev/null; then + warn "Could not stop existing worker process" + fi +} + # Detect OS and architecture detect_platform() { local os arch @@ -131,7 +177,7 @@ download_release() { local tmp_dir tmp_dir=$(mktemp -d) - trap "rm -rf $tmp_dir" EXIT + trap 'rm -rf "$tmp_dir"' EXIT # Construct download URL (use .zip for Windows, .tar.gz for others) local archive_ext="tar.gz" @@ -147,8 +193,35 @@ download_release() { error "Failed to download release from: $download_url" fi + # Verify download integrity via checksum + local checksum_url="${download_url}.sha256" + info "Verifying download integrity..." + if curl -sSL -o "$tmp_dir/checksum.sha256" "$checksum_url" 2>/dev/null; then + local expected_hash actual_hash + expected_hash=$(awk '{print $1}' "$tmp_dir/checksum.sha256") + if command -v shasum &> /dev/null; then + actual_hash=$(shasum -a 256 "$tmp_dir/release.${archive_ext}" | awk '{print $1}') + elif command -v sha256sum &> /dev/null; then + actual_hash=$(sha256sum "$tmp_dir/release.${archive_ext}" | awk '{print $1}') + else + warn "No SHA256 tool found (shasum or sha256sum), skipping checksum verification" + actual_hash="" + fi + if [[ -n "$actual_hash" ]]; then + if [[ "$expected_hash" != "$actual_hash" ]]; then + error "Checksum verification failed! Expected: $expected_hash Got: $actual_hash" + fi + success "Checksum verified" + fi + else + warn "No checksum file available at $checksum_url, skipping verification" + fi + info "Extracting archive..." if [[ "$archive_ext" == "zip" ]]; then + if ! command -v unzip &> /dev/null; then + error "unzip is required for Windows archives but not installed" + fi if ! unzip -q "$tmp_dir/release.zip" -d "$tmp_dir"; then error "Failed to extract archive" fi @@ -160,17 +233,7 @@ download_release() { # Stop existing worker if running info "Stopping existing worker (if running)..." - pkill -9 -f 'claude-mnemonic.*worker' 2>/dev/null || true - pkill -9 -f '\.claude/plugins/.*/worker' 2>/dev/null || true - # Kill process on port 37777 (use lsof on macOS, ss/fuser on Linux) - if command -v lsof &> /dev/null; then - lsof -ti :37777 | xargs kill -9 2>/dev/null || true - elif command -v ss &> /dev/null; then - ss -tlnp 'sport = :37777' 2>/dev/null | awk 'NR>1 {print $6}' | grep -oP 'pid=\K[0-9]+' | xargs -r kill -9 2>/dev/null || true - elif command -v fuser &> /dev/null; then - fuser -k 37777/tcp 2>/dev/null || true - fi - sleep 1 + graceful_stop_worker # Create installation directories info "Installing to ${INSTALL_DIR}..." @@ -178,13 +241,21 @@ download_release() { mkdir -p "$INSTALL_DIR/.claude-plugin" mkdir -p "$INSTALL_DIR/commands" - # Copy binaries - cp "$tmp_dir/worker" "$INSTALL_DIR/" - cp "$tmp_dir/mcp-server" "$INSTALL_DIR/" - cp "$tmp_dir/hooks/"* "$INSTALL_DIR/hooks/" + # Copy binaries (abort on failure — could indicate disk full or permissions issue) + if ! cp "$tmp_dir/worker" "$INSTALL_DIR/"; then + error "Failed to copy worker binary to $INSTALL_DIR/" + fi + if ! cp "$tmp_dir/mcp-server" "$INSTALL_DIR/"; then + error "Failed to copy mcp-server binary to $INSTALL_DIR/" + fi + if ! cp "$tmp_dir/hooks/"* "$INSTALL_DIR/hooks/"; then + error "Failed to copy hook binaries to $INSTALL_DIR/hooks/" + fi # Copy plugin configuration - cp "$tmp_dir/.claude-plugin/"* "$INSTALL_DIR/.claude-plugin/" + if ! cp "$tmp_dir/.claude-plugin/"* "$INSTALL_DIR/.claude-plugin/"; then + error "Failed to copy plugin configuration to $INSTALL_DIR/.claude-plugin/" + fi # Copy slash commands if they exist in the release if [[ -d "$tmp_dir/commands" ]]; then @@ -338,72 +409,51 @@ start_worker() { error "Worker binary not found at $worker_path" fi + # Check for port conflict with a non-mnemonic process + if command -v lsof &> /dev/null; then + local port_pid + port_pid=$(lsof -ti :37777 2>/dev/null || true) + if [[ -n "$port_pid" ]]; then + local port_cmd + port_cmd=$(ps -p "$port_pid" -o comm= 2>/dev/null || true) + if [[ -n "$port_cmd" ]] && ! echo "$port_cmd" | grep -q "worker"; then + warn "Port 37777 is in use by another process: $port_cmd (PID $port_pid)" + warn "The worker may fail to start. Consider stopping the conflicting process." + fi + fi + fi + info "Starting worker service..." nohup "$worker_path" > /tmp/claude-mnemonic-worker.log 2>&1 & - sleep 2 + # Retry health check up to 5 times with 1s interval + local retries=0 + local max_retries=5 + while [[ $retries -lt $max_retries ]]; do + sleep 1 + if curl -sS http://localhost:37777/health > /dev/null 2>&1; then + success "Worker started successfully at http://localhost:37777" + return 0 + fi + retries=$((retries + 1)) + done - if curl -sS http://localhost:37777/health > /dev/null 2>&1; then - success "Worker started successfully at http://localhost:37777" - else - warn "Worker may not have started properly. Check /tmp/claude-mnemonic-worker.log" - fi + warn "Worker may not have started properly after ${max_retries} attempts. Check /tmp/claude-mnemonic-worker.log" } # Check optional dependencies for semantic search check_optional_deps() { - local missing_deps=() - local install_hints="" + # Semantic search uses embedded ONNX runtime - no external Python/uvx dependencies needed + success "Semantic search enabled (embedded ONNX runtime)" +} - # Check for Python 3.13+ - if command -v python3 &> /dev/null; then - local py_version=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null) - if [[ "$py_version" < "3.13" ]]; then - missing_deps+=("Python 3.13+ (found $py_version)") - fi - else - missing_deps+=("Python 3.13+") - fi - - # Check for uvx - if ! command -v uvx &> /dev/null; then - missing_deps+=("uvx") - fi - - if [[ ${#missing_deps[@]} -gt 0 ]]; then - echo "" - warn "Optional dependencies missing (needed for semantic search):" - for dep in "${missing_deps[@]}"; do - echo " - $dep" - done - echo "" - - # Detect OS and show appropriate install command - case "$(uname -s)" in - Darwin) - info "Install on macOS:" - echo " brew install python@3.13" - echo " pip3 install uv" - ;; - Linux) - info "Install on Linux:" - echo " sudo apt install python3 python3-pip" - echo " pip3 install uv" - ;; - MINGW*|MSYS*|CYGWIN*) - info "Install on Windows:" - echo " winget install Python.Python.3.13" - echo " pip install uv" - ;; - esac - echo "" - info "Note: Requires Python 3.13+. Most package managers install the latest version." - echo "" - info "Semantic search will be disabled until these are installed." - info "Core functionality (SQLite storage, full-text search) will work." - echo "" - else - success "Optional dependencies found (semantic search enabled)" +# Rollback partially installed files on failure +INSTALL_COMPLETE=false +cleanup_on_failure() { + if [[ "$INSTALL_COMPLETE" != "true" ]]; then + warn "Installation did not complete — cleaning up partial install..." + rm -rf "$INSTALL_DIR" 2>/dev/null || true + rm -rf "$CACHE_DIR" 2>/dev/null || true fi } @@ -411,6 +461,8 @@ check_optional_deps() { main() { local version="${1:-}" + trap cleanup_on_failure EXIT + echo "" echo "╔═══════════════════════════════════════════════════════════╗" echo "║ Claude Mnemonic - Installation Script ║" @@ -455,6 +507,8 @@ main() { # Check optional dependencies check_optional_deps + INSTALL_COMPLETE=true + echo "" echo "╔═══════════════════════════════════════════════════════════╗" echo "║ Installation Complete! ║" @@ -467,6 +521,12 @@ main() { echo "" } +# Handle --version flag +if [[ "${1:-}" == "--version" ]]; then + echo "claude-mnemonic installer v${INSTALLER_VERSION}" + exit 0 +fi + # Handle --register-only flag if [[ "${1:-}" == "--register-only" ]]; then version=$(cat "$INSTALL_DIR/.claude-plugin/plugin.json" 2>/dev/null | grep '"version"' | sed -E 's/.*"([^"]+)".*/\1/' || echo "1.0.0") @@ -486,17 +546,7 @@ if [[ "${1:-}" == "--uninstall" ]]; then echo "" info "Stopping worker processes..." - pkill -9 -f 'claude-mnemonic.*worker' 2>/dev/null || true - pkill -9 -f '\.claude/plugins/.*/worker' 2>/dev/null || true - # Kill process on port 37777 (use lsof on macOS, ss/fuser on Linux) - if command -v lsof &> /dev/null; then - lsof -ti :37777 | xargs kill -9 2>/dev/null || true - elif command -v ss &> /dev/null; then - ss -tlnp 'sport = :37777' 2>/dev/null | awk 'NR>1 {print $6}' | grep -oP 'pid=\K[0-9]+' | xargs -r kill -9 2>/dev/null || true - elif command -v fuser &> /dev/null; then - fuser -k 37777/tcp 2>/dev/null || true - fi - sleep 1 + graceful_stop_worker info "Removing plugin directories..." rm -rf "$INSTALL_DIR" diff --git a/scripts/register-plugin.sh b/scripts/register-plugin.sh index 83360aa..42fee38 100755 --- a/scripts/register-plugin.sh +++ b/scripts/register-plugin.sh @@ -16,6 +16,35 @@ CACHE_BASE="$HOME/.claude/plugins/cache/claude-mnemonic/claude-mnemonic" CACHE_PATH="$CACHE_BASE/$VERSION" TIMESTAMP=$(date -u +"%Y-%m-%dT%H:%M:%S.000Z") +# Helper: safely write JSON via tmp file with validation +# Usage: safe_jq_write +# The last argument is treated as the input file, output goes to input_file.tmp +safe_jq_write() { + local args=("$@") + local input_file="${args[-1]}" + local tmp_file="${input_file}.tmp" + + if jq "${args[@]}" > "$tmp_file"; then + if jq . "$tmp_file" > /dev/null 2>&1; then + mv "$tmp_file" "$input_file" + else + echo "ERROR: jq produced invalid JSON for $input_file, aborting" + rm -f "$tmp_file" + return 1 + fi + else + echo "ERROR: jq failed for $input_file, aborting" + rm -f "$tmp_file" + return 1 + fi +} + +# Check that Claude Code directory exists +if [ ! -d "$HOME/.claude" ]; then + echo "Warning: $HOME/.claude directory not found. Claude Code may not be installed." + echo "Continuing anyway, but plugin may not function until Claude Code is installed." +fi + # Ensure plugins directory exists mkdir -p "$HOME/.claude/plugins" @@ -42,6 +71,24 @@ fi # Check if jq is available if command -v jq &> /dev/null; then + # Validate jq version (1.6+ required for //= operator) + JQ_VERSION=$(jq --version 2>/dev/null | sed 's/jq-//') + JQ_MAJOR=$(echo "$JQ_VERSION" | cut -d. -f1) + JQ_MINOR=$(echo "$JQ_VERSION" | cut -d. -f2) + if [ -n "$JQ_MAJOR" ] && [ -n "$JQ_MINOR" ]; then + if [ "$JQ_MAJOR" -lt 1 ] || { [ "$JQ_MAJOR" -eq 1 ] && [ "$JQ_MINOR" -lt 6 ]; }; then + echo "ERROR: jq 1.6+ is required (found jq-$JQ_VERSION)" + echo "Please upgrade jq: brew install jq (macOS) or apt-get install jq (Linux)" + exit 1 + fi + fi + + # Validate marketplace path exists and contains expected files + if [ ! -d "$MARKETPLACE_PATH" ]; then + echo "Warning: Marketplace directory not found at $MARKETPLACE_PATH" + echo "Plugin files may not be copied to cache correctly." + fi + # Ensure cache directory exists and copy plugin files mkdir -p "$CACHE_PATH/.claude-plugin" mkdir -p "$CACHE_PATH/hooks" @@ -64,9 +111,8 @@ EOF ) # Add or update the plugin entry in installed_plugins.json - jq --arg key "$PLUGIN_KEY" --argjson entry "$PLUGIN_ENTRY" \ - '.plugins[$key] = $entry' "$PLUGINS_FILE" > "${PLUGINS_FILE}.tmp" \ - && mv "${PLUGINS_FILE}.tmp" "$PLUGINS_FILE" + safe_jq_write --arg key "$PLUGIN_KEY" --argjson entry "$PLUGIN_ENTRY" \ + '.plugins[$key] = $entry' "$PLUGINS_FILE" echo "Plugin registered in installed_plugins.json" @@ -82,9 +128,8 @@ EOF EOF ) - jq --arg key "$PLUGIN_KEY" --argjson statusline "$STATUSLINE_ENTRY" \ - '.enabledPlugins //= {} | .enabledPlugins[$key] = true | .statusLine = $statusline' "$SETTINGS_FILE" > "${SETTINGS_FILE}.tmp" \ - && mv "${SETTINGS_FILE}.tmp" "$SETTINGS_FILE" + safe_jq_write --arg key "$PLUGIN_KEY" --argjson statusline "$STATUSLINE_ENTRY" \ + '.enabledPlugins //= {} | .enabledPlugins[$key] = true | .statusLine = $statusline' "$SETTINGS_FILE" echo "Plugin enabled in settings.json" echo "Statusline configured in settings.json" @@ -102,9 +147,8 @@ EOF EOF ) - jq --arg key "$MARKETPLACE_NAME" --argjson entry "$MARKETPLACE_ENTRY" \ - '.[$key] = $entry' "$MARKETPLACES_FILE" > "${MARKETPLACES_FILE}.tmp" \ - && mv "${MARKETPLACES_FILE}.tmp" "$MARKETPLACES_FILE" + safe_jq_write --arg key "$MARKETPLACE_NAME" --argjson entry "$MARKETPLACE_ENTRY" \ + '.[$key] = $entry' "$MARKETPLACES_FILE" echo "Marketplace registered in known_marketplaces.json" @@ -126,13 +170,11 @@ EOF MCP_ENTRY=$(echo "$MCP_ENTRY" | sed "s|MCP_BINARY_PLACEHOLDER|$MCP_BINARY|g") # Add or update mcpServers field - if jq --arg key "claude-mnemonic" --argjson entry "$MCP_ENTRY" \ - '.mcpServers //= {} | .mcpServers[$key] = $entry' "$SETTINGS_FILE" > "${SETTINGS_FILE}.tmp"; then - mv "${SETTINGS_FILE}.tmp" "$SETTINGS_FILE" + if safe_jq_write --arg key "claude-mnemonic" --argjson entry "$MCP_ENTRY" \ + '.mcpServers //= {} | .mcpServers[$key] = $entry' "$SETTINGS_FILE"; then echo "MCP server registered successfully" else echo "Warning: Failed to register MCP server (jq error)" - rm -f "${SETTINGS_FILE}.tmp" fi else echo "MCP server binary not found at $MCP_BINARY, skipping MCP registration" diff --git a/scripts/unregister-plugin.sh b/scripts/unregister-plugin.sh index d709f3e..a682169 100755 --- a/scripts/unregister-plugin.sh +++ b/scripts/unregister-plugin.sh @@ -3,6 +3,18 @@ set -e +# Stop running worker processes before removing binaries +echo "Stopping worker processes..." +pkill -TERM -f 'claude-mnemonic.*worker' 2>/dev/null || true +pkill -TERM -f '\.claude/plugins/.*/worker' 2>/dev/null || true +sleep 2 +# Force kill if still running +pkill -9 -f 'claude-mnemonic.*worker' 2>/dev/null || true +pkill -9 -f '\.claude/plugins/.*/worker' 2>/dev/null || true +# Clean up port +lsof -ti :37777 | xargs kill -9 2>/dev/null || true +sleep 1 + PLUGINS_FILE="$HOME/.claude/plugins/installed_plugins.json" SETTINGS_FILE="$HOME/.claude/settings.json" MARKETPLACES_FILE="$HOME/.claude/plugins/known_marketplaces.json" @@ -30,16 +42,17 @@ else echo "No plugins file found, skipping" fi -# Remove from settings.json (enabledPlugins and statusLine if it points to our plugin) +# Remove from settings.json (enabledPlugins, statusLine, and mcpServers) if [ -f "$SETTINGS_FILE" ]; then - # Remove from enabledPlugins and clear statusLine if it references our plugin + # Remove from enabledPlugins, clear statusLine if it references our plugin, and remove MCP server jq --arg key "$PLUGIN_KEY" ' del(.enabledPlugins[$key]) | if .statusLine.command and (.statusLine.command | contains("claude-mnemonic")) then del(.statusLine) else . - end + end | + del(.mcpServers["claude-mnemonic"]) ' "$SETTINGS_FILE" > "${SETTINGS_FILE}.tmp" \ && mv "${SETTINGS_FILE}.tmp" "$SETTINGS_FILE" echo "Plugin removed from settings.json"