diff --git a/.gitignore b/.gitignore index bae1712..13126d5 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,4 @@ docs/dist # Non-template plugin configs (keep only .tpl files in plugin/ dir) plugin/.claude-plugin/plugin.json plugin/.claude-plugin/marketplace.json +user-prompt diff --git a/cmd/hooks/user-prompt/main.go b/cmd/hooks/user-prompt/main.go index 2211e84..bc8736c 100644 --- a/cmd/hooks/user-prompt/main.go +++ b/cmd/hooks/user-prompt/main.go @@ -177,15 +177,27 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { if initErr != nil { return "", initErr } + if initResult == nil { + return contextToInject, nil // Non-JSON response from worker, skip session init + } // Check if skipped due to privacy if skipped, ok := initResult["skipped"].(bool); ok && skipped { fmt.Fprintf(os.Stderr, "[user-prompt] Session skipped (private)\n") - return "", nil + return contextToInject, nil } - sessionID := int64(initResult["sessionDbId"].(float64)) - promptNumber := int(initResult["promptNumber"].(float64)) + sessionDBIDVal, ok := initResult["sessionDbId"].(float64) + if !ok { + return contextToInject, nil // Missing or wrong type, skip gracefully + } + sessionID := int64(sessionDBIDVal) + + promptNumberVal, ok := initResult["promptNumber"].(float64) + if !ok { + return contextToInject, nil + } + promptNumber := int(promptNumberVal) fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber) diff --git a/cmd/hooks/user-prompt/main_test.go b/cmd/hooks/user-prompt/main_test.go new file mode 100644 index 0000000..6fe266c --- /dev/null +++ b/cmd/hooks/user-prompt/main_test.go @@ -0,0 +1,44 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestEstimateTokens tests the token estimator. +func TestEstimateTokens(t *testing.T) { + tests := []struct { + name string + input string + minToken int + maxToken int + }{ + {"empty string", "", 0, 0}, + {"single word", "hello", 1, 3}, + {"simple sentence", "Hello world this is a test", 5, 15}, + {"code-heavy", "func() { return x.y.z(); }", 5, 30}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := estimateTokens(tt.input) + assert.GreaterOrEqual(t, result, tt.minToken) + assert.LessOrEqual(t, result, tt.maxToken) + }) + } +} + +// TestHandleUserPrompt_NilInitResult_Compile verifies that the nil-safety +// fix in handleUserPrompt compiles correctly. The actual nil dereference +// was at initResult["sessionDbId"].(float64) when initResult was nil. +// This test ensures the defensive type assertions are present by exercising +// the token estimator (the handler requires a live HookContext+worker). +func TestHandleUserPrompt_NilInitResult_Compile(t *testing.T) { + // The real regression test is that `go build ./cmd/hooks/user-prompt/` + // succeeds with the nil-safe assertions. We can't easily spin up + // a full HookContext here, but we verify the package compiles and + // the helper functions are sane. + assert.Equal(t, 0, estimateTokens("")) + assert.Greater(t, estimateTokens("test input"), 0) +} diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index a8db00c..14d8926 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -59,7 +59,7 @@ func main() { }() // Start file watchers for config changes - startWatchers() + startWatchers(cancel) telemetry.Send("claude-mnemonic", Version) @@ -68,18 +68,21 @@ func main() { log.Info().Str("project", *project).Str("version", Version).Str("worker", workerURL).Msg("Starting MCP server") if err := server.Run(ctx); err != nil { + if err == context.Canceled { + log.Info().Msg("MCP server shut down (config change or signal)") + return + } log.Fatal().Err(err).Msg("MCP server error") } } // startWatchers initializes file watchers for config. -func startWatchers() { - // Watch config file for changes (triggers process exit for restart) +func startWatchers(cancel context.CancelFunc) { + // Watch config file for changes (triggers graceful shutdown via context cancellation) configPath := config.SettingsPath() configWatcher, err := watcher.New(configPath, func() { - log.Warn().Str("path", configPath).Msg("Config file changed, exiting for restart...") - time.Sleep(100 * time.Millisecond) // Give logs time to flush - os.Exit(0) + log.Warn().Str("path", configPath).Msg("Config file changed, shutting down gracefully...") + cancel() // Triggers ctx.Done() in server.Run(), which drains in-flight requests }) if err != nil { log.Warn().Err(err).Msg("Failed to create config watcher") diff --git a/internal/db/gorm/store.go b/internal/db/gorm/store.go index 480ecf7..e3511f9 100644 --- a/internal/db/gorm/store.go +++ b/internal/db/gorm/store.go @@ -99,8 +99,9 @@ func NewStore(cfg Config) (*Store, error) { "PRAGMA synchronous=NORMAL", "PRAGMA cache_size=-64000", // 64MB cache (negative = KB) "PRAGMA temp_store=MEMORY", // Store temp tables in memory - "PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O - "PRAGMA page_size=4096", // 4KB pages (optimal for most systems) + "PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O + "PRAGMA page_size=4096", // 4KB pages (optimal for most systems) + "PRAGMA wal_autocheckpoint=1000", // Explicit default; checkpoint every 1000 WAL frames } for _, pragma := range pragmas { if _, err := sqlDB.Exec(pragma); err != nil { @@ -192,6 +193,11 @@ func (s *Store) Optimize(ctx context.Context) error { log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)") } + // Passive WAL checkpoint — doesn't block readers/writers + if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA wal_checkpoint(PASSIVE)"); err != nil { + log.Warn().Err(err).Msg("WAL checkpoint failed (non-fatal)") + } + log.Info().Dur("duration", time.Since(start)).Msg("Database optimization complete") return nil } diff --git a/internal/db/gorm/store_test.go b/internal/db/gorm/store_test.go index ff3f921..9dfe80f 100644 --- a/internal/db/gorm/store_test.go +++ b/internal/db/gorm/store_test.go @@ -4,6 +4,7 @@ package gorm import ( + "context" "os" "path/filepath" "testing" @@ -150,3 +151,91 @@ func TestMigrationIdempotency(t *testing.T) { t.Logf("✅ Migrations are idempotent") } + +func TestWALAutocheckpoint(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_wal_checkpoint_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewStore(Config{ + Path: dbPath, + MaxConns: 2, + LogLevel: logger.Silent, + }) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + // Verify wal_autocheckpoint is set to 1000 + var checkpoint int + err = store.GetRawDB().QueryRow("PRAGMA wal_autocheckpoint").Scan(&checkpoint) + if err != nil { + t.Fatalf("query wal_autocheckpoint: %v", err) + } + if checkpoint != 1000 { + t.Errorf("expected wal_autocheckpoint=1000, got %d", checkpoint) + } +} + +func TestOptimize_RunsWALCheckpoint(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_optimize_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewStore(Config{ + Path: dbPath, + MaxConns: 2, + LogLevel: logger.Silent, + }) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + // Insert some data to generate WAL frames + _, err = store.GetRawDB().Exec("INSERT INTO observations (sdk_session_id, title, scope, project, type, created_at, created_at_epoch) VALUES ('test-sess', 'test data', 'project', '/tmp/test', 'decision', '2026-01-01T00:00:00Z', 1735689600)") + if err != nil { + t.Fatalf("insert test data: %v", err) + } + + // Optimize should succeed (includes PASSIVE WAL checkpoint) + err = store.Optimize(context.Background()) + if err != nil { + t.Fatalf("Optimize failed: %v", err) + } +} + +func TestOptimize_RespectsContextCancellation(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "gorm_optimize_cancel_*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + store, err := NewStore(Config{ + Path: dbPath, + MaxConns: 2, + LogLevel: logger.Silent, + }) + if err != nil { + t.Fatalf("NewStore failed: %v", err) + } + defer store.Close() + + // Already-cancelled context should cause Optimize to fail + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = store.Optimize(ctx) + if err == nil { + t.Error("expected error with cancelled context, got nil") + } +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 0e84aa5..b7c105b 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -171,11 +171,22 @@ func (s *Server) Run(ctx context.Context) error { } // Dispatch request to its own goroutine. + // Semaphore is acquired inside the goroutine so the main + // loop never blocks on a full semaphore (Fix #1). wg.Add(1) - sem <- struct{}{} // acquire semaphore slot go func(r Request) { defer wg.Done() - defer func() { <-sem }() // release semaphore slot + // Acquire semaphore inside goroutine, not blocking main loop + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-ctx.Done(): + // Server shutting down — send error so caller isn't left waiting + if r.ID != nil { + _ = s.sendError(r.ID, -32000, "Server shutting down", nil) + } + return + } resp := s.handleRequest(ctx, &r) if resp != nil { @@ -203,6 +214,10 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) *Response { case "notifications/initialized", "notifications/cancelled": return nil // Notifications don't get responses default: + if req.ID == nil { + // Notifications must not receive responses per JSON-RPC 2.0 + return nil + } return &Response{ JSONRPC: "2.0", ID: req.ID, @@ -775,13 +790,15 @@ func (s *Server) handleToolsCall(ctx context.Context, req *Request) *Response { result, err := s.callTool(ctx, params.Name, params.Arguments) if err != nil { + // MCP spec: tool failures use Result with isError, not top-level Error return &Response{ JSONRPC: "2.0", ID: req.ID, - Error: &Error{ - Code: -32000, - Message: "Tool error", - Data: err.Error(), + Result: map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "Error: " + err.Error()}, + }, + "isError": true, }, } } @@ -1902,16 +1919,38 @@ func (s *Server) sendResponse(resp *Response) error { data, err := json.Marshal(resp) if err != nil { log.Error().Err(err).Msg("Failed to marshal response") + // Send a fallback error so the caller isn't left waiting (Fix #3) + fallback, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": resp.ID, + "error": map[string]any{"code": -32603, "message": "internal marshal error"}, + }) + s.writeMu.Lock() + _, _ = fmt.Fprintln(s.stdout, string(fallback)) + s.writeMu.Unlock() + return fmt.Errorf("marshal error: %w", err) + } + + // Bound the write to prevent pipe deadlock (Fix #2) + done := make(chan error, 1) + go func() { + s.writeMu.Lock() + _, werr := fmt.Fprintln(s.stdout, string(data)) + s.writeMu.Unlock() + done <- werr + }() + + select { + case werr := <-done: + if werr != nil { + log.Error().Err(werr).Msg("Failed to write response to stdout") + return werr + } return nil + case <-time.After(10 * time.Second): + log.Error().Msg("Stdout write timed out — pipe likely full") + return fmt.Errorf("write timeout: stdout pipe full") } - s.writeMu.Lock() - _, err = fmt.Fprintln(s.stdout, string(data)) - s.writeMu.Unlock() - if err != nil { - log.Error().Err(err).Msg("Failed to write response to stdout") - return err - } - return nil } // sendError sends a JSON-RPC error response. Returns an error if writing fails. diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 5ee895e..b71ecd1 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -1665,6 +1665,7 @@ func TestToolListContainsExpectedSchemas(t *testing.T) { } // TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name. +// After Fix #5: tool errors use Result with isError, not top-level Error. func TestHandleToolsCall_UnknownTool(t *testing.T) { t.Parallel() @@ -1679,9 +1680,13 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) { } resp := server.handleToolsCall(ctx, req) - require.NotNil(t, resp.Error) - assert.Equal(t, -32000, resp.Error.Code) - assert.Contains(t, resp.Error.Data, "unknown tool") + assert.Nil(t, resp.Error, "tool errors must not use top-level Error (MCP spec)") + require.NotNil(t, resp.Result) + result, ok := resp.Result.(map[string]any) + require.True(t, ok) + assert.Equal(t, true, result["isError"]) + content := result["content"].([]map[string]any) + assert.Contains(t, content[0]["text"], "unknown tool") } // TestCallTool_ToolNameRecognition tests that valid tool names are recognized. @@ -1982,6 +1987,7 @@ func TestTimelineParamsStruct_Validation(t *testing.T) { } // TestHandleToolsCall_EmptyParams tests tools/call with empty params. +// After Fix #5: tool errors use Result with isError, not top-level Error. func TestHandleToolsCall_EmptyParams(t *testing.T) { t.Parallel() @@ -1997,8 +2003,12 @@ func TestHandleToolsCall_EmptyParams(t *testing.T) { resp := server.handleToolsCall(ctx, req) - // Should error due to missing name - require.NotNil(t, resp.Error) + // Empty name goes through callTool default branch -> isError + assert.Nil(t, resp.Error, "tool errors must use isError in Result") + require.NotNil(t, resp.Result) + result, ok := resp.Result.(map[string]any) + require.True(t, ok) + assert.Equal(t, true, result["isError"]) } // TestSendResponse_WithError tests sendResponse with an error response. @@ -2062,6 +2072,7 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) { } // TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error. +// After Fix #5: tool errors use Result with isError, not top-level Error. func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { t.Parallel() @@ -2077,12 +2088,13 @@ func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { resp := server.handleToolsCall(ctx, req) - // Should get an error response assert.Equal(t, "2.0", resp.JSONRPC) assert.Equal(t, 1, resp.ID) - require.NotNil(t, resp.Error) - // Error is "Tool error" with message containing "unknown tool" - assert.True(t, resp.Error.Code != 0) + assert.Nil(t, resp.Error, "tool errors must use isError in Result, not top-level Error") + require.NotNil(t, resp.Result) + result, ok := resp.Result.(map[string]any) + require.True(t, ok) + assert.Equal(t, true, result["isError"]) } // ============================================================================= @@ -3172,3 +3184,237 @@ func TestRun_GracefulDrainOnCancel(t *testing.T) { assert.Equal(t, "2.0", resp["jsonrpc"], "any completed response must be valid JSON-RPC 2.0") } } + +// ============================================================================= +// REGRESSION TESTS — Fix #1-#5 +// ============================================================================= + +// blockingWriter is an io.Writer that blocks forever on Write. +type blockingWriter struct { + blocked chan struct{} // closed when Write is entered +} + +func (bw *blockingWriter) Write(p []byte) (int, error) { + if bw.blocked != nil { + close(bw.blocked) + } + select {} // block forever +} + +// TestRun_SemaphoreDoesNotBlockMainLoop (Fix #1 regression) fills all semaphore +// slots with blocked requests, then sends a notification. The main loop must +// stay responsive and not hang on semaphore acquisition. +func TestRun_SemaphoreDoesNotBlockMainLoop(t *testing.T) { + t.Parallel() + + const maxConcurrent = 10 + + // Mock worker that blocks until test context is cancelled + handlerCtx, handlerCancel := context.WithCancel(context.Background()) + defer handlerCancel() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-handlerCtx.Done() // block until test cleanup + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{}`) + })) + defer func() { + handlerCancel() // unblock all handlers first + ts.CloseClientConnections() + ts.Close() + }() + + stdinR, stdinW := io.Pipe() + var stdout bytes.Buffer + + server := &Server{ + client: ts.Client(), + workerURL: ts.URL, + project: "test", + version: "1.0.0", + stdin: stdinR, + stdout: &stdout, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + runDone := make(chan error, 1) + go func() { + runDone <- server.Run(ctx) + }() + + // Send maxConcurrent+2 requests to fill all semaphore slots + for i := 0; i < maxConcurrent+2; i++ { + req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call", + Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)} + data, _ := json.Marshal(req) + _, err := io.WriteString(stdinW, string(data)+"\n") + require.NoError(t, err) + } + + // Give goroutines time to start and fill semaphore + time.Sleep(100 * time.Millisecond) + + // Now send a notification — this must not hang the main loop + notifSent := make(chan struct{}) + go func() { + notif := `{"jsonrpc":"2.0","method":"notifications/initialized"}` + "\n" + _, _ = io.WriteString(stdinW, notif) + close(notifSent) + }() + + select { + case <-notifSent: + // Main loop accepted the notification write (stdin is a pipe, so + // the write completing means the reader consumed it) + case <-time.After(3 * time.Second): + t.Fatal("main loop blocked — semaphore acquisition is blocking the main goroutine") + } + + // Clean up: cancel server context, close stdin + cancel() + _ = stdinW.Close() + + // Wait for Run to finish + select { + case <-runDone: + case <-time.After(5 * time.Second): + // Acceptable — some goroutines may still be draining + } +} + +// TestSendResponse_WriteTimeout (Fix #2 regression) verifies that sendResponse +// returns an error within a bounded time when the writer blocks forever. +func TestSendResponse_WriteTimeout(t *testing.T) { + t.Parallel() + + bw := &blockingWriter{blocked: make(chan struct{})} + server := &Server{stdout: bw} + + resp := &Response{ + JSONRPC: "2.0", + ID: 1, + Result: "ok", + } + + done := make(chan error, 1) + go func() { + done <- server.sendResponse(resp) + }() + + select { + case err := <-done: + require.Error(t, err, "sendResponse must return error on write timeout") + assert.Contains(t, err.Error(), "write timeout") + case <-time.After(15 * time.Second): + t.Fatal("sendResponse hung forever — write timeout not working") + } +} + +// TestSendResponse_MarshalError (Fix #3 regression) verifies that when +// json.Marshal fails, sendResponse sends a fallback error response and +// returns an error (instead of silently returning nil). +func TestSendResponse_MarshalError(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + server := &Server{stdout: &buf} + + // Channels are not JSON-serializable — this will cause json.Marshal to fail + resp := &Response{ + JSONRPC: "2.0", + ID: 42, + Result: make(chan int), // unserializable + } + + err := server.sendResponse(resp) + + // (a) Must return an error + require.Error(t, err, "sendResponse must return error when marshal fails") + assert.Contains(t, err.Error(), "marshal error") + + // (b) Must have written a fallback JSON-RPC error to stdout + output := buf.String() + require.NotEmpty(t, output, "fallback response must be written to stdout") + + var fallback map[string]any + require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(output)), &fallback), + "fallback must be valid JSON") + assert.Equal(t, "2.0", fallback["jsonrpc"]) + assert.Equal(t, float64(42), fallback["id"], "fallback must preserve original request ID") + + errObj, ok := fallback["error"].(map[string]any) + require.True(t, ok, "fallback must contain error object") + assert.Equal(t, float64(-32603), errObj["code"]) + assert.Equal(t, "internal marshal error", errObj["message"]) +} + +// TestHandleRequest_UnknownNotification (Fix #4 regression) verifies that +// unknown notification methods (ID == nil) get no response, while unknown +// methods with an ID still get a -32601 error. +func TestHandleRequest_UnknownNotification(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "", "", "1.0.0") + ctx := context.Background() + + // Case 1: Unknown notification (no ID) — must return nil + notifReq := &Request{ + JSONRPC: "2.0", + ID: nil, + Method: "notifications/roots/list_changed", + } + resp := server.handleRequest(ctx, notifReq) + assert.Nil(t, resp, "unknown notification must not produce a response") + + // Case 2: Unknown method WITH an ID — must return error response + methodReq := &Request{ + JSONRPC: "2.0", + ID: 99, + Method: "some/unknown/method", + } + resp = server.handleRequest(ctx, methodReq) + require.NotNil(t, resp, "unknown method with ID must produce an error response") + require.NotNil(t, resp.Error) + assert.Equal(t, -32601, resp.Error.Code) + assert.Equal(t, "Method not found", resp.Error.Message) + assert.Equal(t, 99, resp.ID) +} + +// TestHandleToolsCall_ErrorUsesIsError (Fix #5 regression) verifies that when +// callTool returns an error, the response uses Result with isError:true instead +// of top-level Error field (per MCP spec). +func TestHandleToolsCall_ErrorUsesIsError(t *testing.T) { + t.Parallel() + + server := NewServer(nil, "", "", "1.0.0") + ctx := context.Background() + + req := &Request{ + JSONRPC: "2.0", + ID: 7, + Method: "tools/call", + Params: json.RawMessage(`{"name":"nonexistent_tool","arguments":{}}`), + } + + resp := server.handleToolsCall(ctx, req) + + // (a) Response must NOT have top-level Error + assert.Nil(t, resp.Error, "tool errors must not use top-level JSON-RPC Error") + + // (b) Response must have Result with isError: true + require.NotNil(t, resp.Result, "tool error response must have Result") + result, ok := resp.Result.(map[string]any) + require.True(t, ok, "Result must be a map") + assert.Equal(t, true, result["isError"], "Result must contain isError: true") + + // (c) Result.content[0].text must contain the error message + content, ok := result["content"].([]map[string]any) + require.True(t, ok, "Result.content must be []map[string]any") + require.Len(t, content, 1) + assert.Equal(t, "text", content[0]["type"]) + errText, ok := content[0]["text"].(string) + require.True(t, ok) + assert.Contains(t, errText, "unknown tool: nonexistent_tool") +} diff --git a/internal/search/manager.go b/internal/search/manager.go index 5896556..2c5b473 100644 --- a/internal/search/manager.go +++ b/internal/search/manager.go @@ -46,7 +46,7 @@ const ( cacheWarmingInterval = 20 * time.Second // Run warming cycle every 20 seconds frequencyCleanupInterval = 5 * time.Minute // Cleanup stale entries every 5 minutes cacheCleanupInterval = time.Minute // Cleanup expired cache every minute - warmingQueryTimeout = 5 * time.Second // Timeout for warming queries + warmingQueryTimeout = 2 * time.Second // Timeout for warming queries (kept short to minimize mutex hold) warmingBatchSize = 5 // Warm top 5 queries per cycle minRecencyFactor = 0.1 // Minimum recency factor for scoring @@ -127,6 +127,7 @@ type Manager struct { queryFrequency map[string]*queryFrequencyInfo cacheTTL time.Duration cacheMaxSize int + activeQueries atomic.Int32 // tracks in-flight user queries to yield embedding mutex resultCacheMu sync.RWMutex queryFrequencyMu sync.RWMutex } @@ -256,7 +257,14 @@ func (m *Manager) cleanupStaleFrequencyEntries() { } // warmFrequentQueries pre-executes frequently used queries to warm the cache. +// Skips warming entirely when user queries are in-flight to avoid competing +// for the embedding model mutex on throttled hardware. func (m *Manager) warmFrequentQueries() { + if m.activeQueries.Load() > 0 { + log.Debug().Msg("Skipping cache warming: user queries in flight") + return + } + m.queryFrequencyMu.RLock() // Find top N most frequent queries that aren't recently cached type queryScore struct { @@ -687,30 +695,40 @@ func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*Unif params.OrderBy = defaultOrderBy } + // Track active user queries so cache warming can yield the embedding mutex + m.activeQueries.Add(1) + defer m.activeQueries.Add(-1) + // Check cache first cacheKey := m.getCacheKey(params) if cached, ok := m.getFromCache(cacheKey); ok { return cached, nil } - // Use singleflight to coalesce concurrent identical requests - result, err, _ := m.searchGroup.Do(cacheKey, func() (any, error) { + // Use singleflight DoChan to coalesce concurrent identical requests. + // DoChan + select allows per-caller context cancellation: waiting callers + // can bail out when their context expires without blocking on a slow first call. + ch := m.searchGroup.DoChan(cacheKey, func() (any, error) { return m.executeSearch(ctx, params) }) - if err != nil { - return nil, err + select { + case res := <-ch: + if res.Err != nil { + return nil, res.Err + } + searchResult := res.Val.(*UnifiedSearchResult) + + // Cache the result + m.putInCache(cacheKey, searchResult) + + // Track query frequency for cache warming + m.trackQueryFrequency(params) + + return searchResult, nil + case <-ctx.Done(): + return nil, ctx.Err() } - - searchResult := result.(*UnifiedSearchResult) - - // Cache the result - m.putInCache(cacheKey, searchResult) - - // Track query frequency for cache warming - m.trackQueryFrequency(params) - - return searchResult, nil } // executeSearch performs the actual search without caching/coalescing. diff --git a/internal/search/manager_test.go b/internal/search/manager_test.go index f531858..8fa2cc7 100644 --- a/internal/search/manager_test.go +++ b/internal/search/manager_test.go @@ -1009,6 +1009,91 @@ func TestSearchParams_FormatValues(t *testing.T) { } } +// ============================================================================= +// REGRESSION TESTS FOR Fix #3: Cache warming skipped during active queries +// ============================================================================= + +// TestCacheWarming_SkippedDuringActiveQueries verifies that warmFrequentQueries +// returns immediately without doing work when activeQueries > 0, and actually +// warms when activeQueries == 0. +func TestCacheWarming_SkippedDuringActiveQueries(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + defer m.Close() + + // Seed a frequent query so warming has something to consider + params := SearchParams{ + Query: "test warming query", + Project: "test-project", + Limit: 10, + } + m.trackQueryFrequency(params) + // Track it multiple times to boost its score + for range 5 { + m.trackQueryFrequency(params) + } + + // Case 1: activeQueries > 0 — warming should be skipped (fast return) + m.activeQueries.Store(1) + + start := time.Now() + m.warmFrequentQueries() + elapsed := time.Since(start) + + // Warming should return almost immediately when skipped + assert.Less(t, elapsed, 50*time.Millisecond, + "warmFrequentQueries should return quickly when user queries are in flight") + + // Case 2: activeQueries == 0 — warming should proceed past the guard. + // With nil stores, executeSearch would panic, so we verify the guard was + // bypassed by checking that the frequency map's lastCached is NOT updated + // (warming fails silently on nil stores). The key assertion is that Case 1 + // returns immediately (the guard works) while Case 2 enters the body. + m.activeQueries.Store(0) + + // Clear frequency data to prevent executeSearch from being called + m.queryFrequencyMu.Lock() + m.queryFrequency = make(map[string]*queryFrequencyInfo) + m.queryFrequencyMu.Unlock() + + start2 := time.Now() + m.warmFrequentQueries() + elapsed2 := time.Since(start2) + + // With empty frequency map, warming enters the function body (past guard) + // but finds no candidates — fast return without calling executeSearch. + assert.Less(t, elapsed2, 200*time.Millisecond, + "warmFrequentQueries should complete quickly with no candidates") +} + +// TestActiveQueries_IncrementDecrement verifies that the activeQueries counter +// is correctly incremented and decremented around search operations. +func TestActiveQueries_IncrementDecrement(t *testing.T) { + m := NewManager(nil, nil, nil, nil) + defer m.Close() + + // Before any search, counter should be 0 + assert.Equal(t, int32(0), m.activeQueries.Load()) + + // Directly test the atomic increment/decrement pattern used in UnifiedSearch + // (can't call UnifiedSearch with nil stores without panicking on filterSearch) + m.activeQueries.Add(1) + assert.Equal(t, int32(1), m.activeQueries.Load(), "should be 1 after increment") + + m.activeQueries.Add(1) + assert.Equal(t, int32(2), m.activeQueries.Load(), "should be 2 after second increment") + + m.activeQueries.Add(-1) + assert.Equal(t, int32(1), m.activeQueries.Load(), "should be 1 after decrement") + + m.activeQueries.Add(-1) + assert.Equal(t, int32(0), m.activeQueries.Load(), "should be 0 after final decrement") + + // Verify the field is accessible from warmFrequentQueries guard + m.activeQueries.Store(3) + assert.Equal(t, int32(3), m.activeQueries.Load()) + m.activeQueries.Store(0) +} + // TestUnifiedSearchResult_MultipleResults tests result with multiple items. func TestUnifiedSearchResult_MultipleResults(t *testing.T) { results := []SearchResult{ diff --git a/internal/vector/sqlitevec/client.go b/internal/vector/sqlitevec/client.go index d75fb28..a233c38 100644 --- a/internal/vector/sqlitevec/client.go +++ b/internal/vector/sqlitevec/client.go @@ -145,20 +145,24 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { return nil } - c.writeMu.Lock() - defer c.writeMu.Unlock() - - // Generate embeddings for all documents + // Prepare texts for embedding texts := make([]string, len(docs)) for i, doc := range docs { texts[i] = doc.Content } - embeddings, err := c.embedSvc.EmbedBatch(texts) + // Compute embeddings OUTSIDE the write lock for better concurrency. + // Embedding is ONNX inference (slow, mutex-protected internally) — holding + // writeMu during inference blocks all concurrent writes and reads. + embeddings, err := c.embedSvc.EmbedBatchWithContext(ctx, texts) if err != nil { return fmt.Errorf("generate embeddings: %w", err) } + // Acquire write lock for DB operations only + c.writeMu.Lock() + defer c.writeMu.Unlock() + // Insert into vectors table with model version tracking const insertQuery = ` INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version) @@ -906,8 +910,10 @@ func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]flo } c.queryCacheMu.RUnlock() - // Cache miss - use singleflight to deduplicate concurrent embedding requests - result, err, _ := c.embeddingGroup.Do(query, func() (any, error) { + // Cache miss - use singleflight DoChan to deduplicate concurrent embedding requests. + // DoChan + select allows per-caller context cancellation: if a caller's context + // expires it can bail out without waiting for the shared computation to finish. + ch := c.embeddingGroup.DoChan(query, func() (any, error) { // Double-check cache inside singleflight (another goroutine may have just cached it) c.queryCacheMu.RLock() if entry, ok := c.queryCache[query]; ok { @@ -921,8 +927,10 @@ func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]flo // Record cache miss c.stats.embeddingMisses.Add(1) - // Compute embedding with context-aware lock acquisition - emb, err := c.embedSvc.EmbedWithContext(ctx, query) + // Compute embedding — use non-context Embed here because the singleflight + // result is shared across callers with different contexts. Per-caller + // cancellation is handled by the select below. + emb, err := c.embedSvc.Embed(query) if err != nil { return nil, err } @@ -969,10 +977,15 @@ func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]flo return emb, nil }) - if err != nil { - return nil, err + select { + case res := <-ch: + if res.Err != nil { + return nil, res.Err + } + return res.Val.([]float32), nil + case <-ctx.Done(): + return nil, ctx.Err() } - return result.([]float32), nil } // ClearCache clears the embedding cache. diff --git a/internal/vector/sqlitevec/client_test.go b/internal/vector/sqlitevec/client_test.go index e29b374..7b1ba53 100644 --- a/internal/vector/sqlitevec/client_test.go +++ b/internal/vector/sqlitevec/client_test.go @@ -3,6 +3,7 @@ package sqlitevec import ( "context" "database/sql" + "fmt" "os" "path/filepath" "sync" @@ -2013,3 +2014,141 @@ func TestAcquireRLockWithContext_CleanupOnCancel(t *testing.T) { t.Fatal("write lock acquisition timed out: cleanup goroutine may have leaked an RLock") } } + +// ============================================================================= +// REGRESSION TESTS FOR Fix #1: Embedding outside writeMu in AddDocuments +// ============================================================================= + +// TestAddDocuments_EmbeddingOutsideWriteLock verifies that AddDocuments does NOT +// hold the write lock during embedding computation. A concurrent Query call +// should complete quickly while AddDocuments is computing embeddings. +func TestAddDocuments_EmbeddingOutsideWriteLock(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Seed the DB with one document so Query has something to search + seedDocs := []Document{ + {ID: "seed-1", Content: "Seed document for concurrency test"}, + } + err = client.AddDocuments(context.Background(), seedDocs) + require.NoError(t, err) + + // Pre-warm the embedding cache for the query text so the Query call + // itself doesn't need the embedding mutex — it only needs the DB read lock. + _, err = client.Query(context.Background(), "concurrency test", 5, nil) + require.NoError(t, err) + + // Prepare a batch of documents to trigger a slow AddDocuments call + batchDocs := make([]Document, 10) + for i := range batchDocs { + batchDocs[i] = Document{ + ID: fmt.Sprintf("batch-%d", i), + Content: fmt.Sprintf("Batch document number %d for write lock test with unique content", i), + } + } + + // Launch AddDocuments in background — embedding will take time + addDone := make(chan error, 1) + go func() { + addDone <- client.AddDocuments(context.Background(), batchDocs) + }() + + // Give AddDocuments a moment to start embedding computation + time.Sleep(10 * time.Millisecond) + + // Invalidate result cache so Query must go through to DB (tests read lock) + client.InvalidateResultCache() + + // A concurrent Query should NOT be blocked by AddDocuments. + // If the old code held writeMu during embedding, this would block until + // embedding finishes. With the fix, it should complete quickly. + queryCtx, queryCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer queryCancel() + + start := time.Now() + _, err = client.Query(queryCtx, "concurrency test", 5, nil) + queryDuration := time.Since(start) + + require.NoError(t, err, "Query should succeed while AddDocuments is embedding") + // The query should complete well within the timeout if writeMu is not held + assert.Less(t, queryDuration, 1*time.Second, + "Query should complete quickly when writeMu is not held during embedding") + + // Wait for AddDocuments to finish + err = <-addDone + require.NoError(t, err, "AddDocuments should succeed") +} + +// ============================================================================= +// REGRESSION TESTS FOR Fix #2a: DoChan + context select in getOrComputeEmbedding +// ============================================================================= + +// TestGetOrComputeEmbedding_ContextCancelDuringSingleflight verifies that when +// a singleflight embedding computation is in progress, a second caller with a +// short-lived context returns context.DeadlineExceeded promptly rather than +// waiting for the slow first call to finish. +func TestGetOrComputeEmbedding_ContextCancelDuringSingleflight(t *testing.T) { + db, dbCleanup := testDB(t) + defer dbCleanup() + + embedSvc, embedCleanup := testEmbeddingService(t) + defer embedCleanup() + + client, err := NewClient(Config{DB: db}, embedSvc) + require.NoError(t, err) + + // Add a document so we can query + docs := []Document{ + {ID: "sf-test-1", Content: "Singleflight context cancellation test document"}, + } + err = client.AddDocuments(context.Background(), docs) + require.NoError(t, err) + + // Clear cache to force embedding computation + client.ClearCache() + client.InvalidateResultCache() + + queryText := "unique singleflight context test query" + + // First call: start a normal query in background (will trigger singleflight) + firstDone := make(chan struct{}) + go func() { + defer close(firstDone) + _, _ = client.Query(context.Background(), queryText, 5, nil) + }() + + // Give the first call a moment to start the singleflight computation + time.Sleep(5 * time.Millisecond) + + // Second call: use a very short context that should expire quickly + shortCtx, shortCancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer shortCancel() + + // Clear result cache again so the second call hits getOrComputeEmbedding + client.InvalidateResultCache() + + start := time.Now() + _, err = client.Query(shortCtx, queryText, 5, nil) + elapsed := time.Since(start) + + // With DoChan + select, the second caller should return quickly with context error. + // Note: If the embedding completes fast enough, the second call may succeed + // via singleflight sharing. That's also valid — the test primarily checks + // that it doesn't BLOCK for the full embedding duration on context cancel. + if err != nil { + assert.ErrorIs(t, err, context.DeadlineExceeded, + "Should return DeadlineExceeded when context expires during singleflight") + assert.Less(t, elapsed, 500*time.Millisecond, + "Should return promptly on context cancellation, not wait for slow computation") + } + // If err == nil, the singleflight completed before the context expired — also valid. + + // Wait for first call to finish + <-firstDone +} diff --git a/internal/worker/service.go b/internal/worker/service.go index 24f6f94..c854d25 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -1651,11 +1651,15 @@ func (s *Service) processAllSessions() { defer wg.Done() defer func() { <-sem }() + // Timeout prevents a hung subprocess from blocking the queue processor forever + procCtx, procCancel := context.WithTimeout(s.ctx, 60*time.Second) + defer procCancel() + switch msg.Type { case session.MessageTypeObservation: if msg.Observation != nil { err := proc.ProcessObservation( - s.ctx, + procCtx, sess.SDKSessionID, sess.Project, msg.Observation.ToolName, @@ -1674,7 +1678,7 @@ func (s *Service) processAllSessions() { case session.MessageTypeSummarize: if msg.Summarize != nil { err := proc.ProcessSummary( - s.ctx, + procCtx, sess.SessionDBID, sess.SDKSessionID, sess.Project, diff --git a/pkg/hooks/worker.go b/pkg/hooks/worker.go index 44e8cdc..36e73c6 100644 --- a/pkg/hooks/worker.go +++ b/pkg/hooks/worker.go @@ -29,7 +29,11 @@ const ( HealthCheckTimeout = 2 * time.Second // StartupTimeout is the timeout for worker startup. - StartupTimeout = 30 * time.Second + StartupTimeout = 10 * time.Second + + // EnsureWorkerDeadline is the hard overall deadline for EnsureWorkerRunning. + // Must fit within Claude Code's hook timeout budget. + EnsureWorkerDeadline = 15 * time.Second // workerCacheMaxAge is how long the worker cache is considered fresh. workerCacheMaxAge = 10 * time.Second @@ -48,6 +52,26 @@ var ( // circuitBreakerMu protects lastStartupFailure. circuitBreakerMu sync.Mutex lastStartupFailure time.Time + + // hookClient is a shared HTTP client for hook->worker requests. + // DisableKeepAlives prevents TIME_WAIT connection leaks since each hook + // is a separate OS process that exits quickly. + hookClient = &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + MaxIdleConns: 1, + }, + } + + // healthClient is a shared HTTP client for health/version checks. + healthClient = &http.Client{ + Timeout: HealthCheckTimeout, + Transport: &http.Transport{ + DisableKeepAlives: true, + MaxIdleConns: 1, + }, + } ) // IsWorkerAvailable performs a fast check without network calls. @@ -86,8 +110,7 @@ func GetWorkerPort() int { // Parses the JSON health response to check the "ready" field when available. // Falls back to HTTP status code check for backwards compatibility. func IsWorkerRunning(port int) bool { - client := &http.Client{Timeout: HealthCheckTimeout} - resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/health", port)) + resp, err := healthClient.Get(fmt.Sprintf("http://127.0.0.1:%d/api/health", port)) if err != nil { return false } @@ -200,7 +223,25 @@ func isWorkerRunningWithRetries(port int) bool { // EnsureWorkerRunning ensures the worker is running, starting it if necessary. // If a worker is already running and healthy with matching version, it reuses it. // If version mismatch or unhealthy, it kills the old worker and starts fresh. +// A hard deadline of EnsureWorkerDeadline prevents exceeding Claude Code's hook timeout. func EnsureWorkerRunning() (int, error) { + ctx, cancel := context.WithTimeout(context.Background(), EnsureWorkerDeadline) + defer cancel() + + return ensureWorkerRunningCtx(ctx) +} + +// sleepCtx sleeps for d or returns early if ctx is cancelled. +func sleepCtx(ctx context.Context, d time.Duration) error { + select { + case <-time.After(d): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func ensureWorkerRunningCtx(ctx context.Context) (int, error) { port := GetWorkerPort() // Fast path: check PID cache before making any HTTP calls. @@ -210,6 +251,10 @@ func EnsureWorkerRunning() (int, error) { } } + if ctx.Err() != nil { + return 0, ctx.Err() + } + // Circuit breaker: if we failed to start recently, don't retry immediately. circuitBreakerMu.Lock() if !lastStartupFailure.IsZero() && time.Since(lastStartupFailure) < circuitBreakerCooldown { @@ -232,7 +277,9 @@ func EnsureWorkerRunning() (int, error) { if err := KillProcessOnPort(port); err != nil { fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: failed to kill old worker: %v\n", err) } - time.Sleep(500 * time.Millisecond) + if err := sleepCtx(ctx, 500*time.Millisecond); err != nil { + return 0, err + } } else { // Version matches, reuse existing worker updateCacheFromPort(port) @@ -245,14 +292,20 @@ func EnsureWorkerRunning() (int, error) { } } + if ctx.Err() != nil { + return 0, ctx.Err() + } + // Port is in use but health check failed -- worker may be slow, not dead. if IsPortInUse(port) { // The port is responding to TCP but health check timed out. // Don't kill it -- it's likely just under load. Give it more time. fmt.Fprintf(os.Stderr, "[claude-mnemonic] Worker on port %d is slow to respond, waiting...\n", port) - // Try a few more times with longer delays before giving up - for i := 0; i < 3; i++ { - time.Sleep(500 * time.Millisecond) + // Try a couple more times with shorter delays before giving up + for i := 0; i < 2; i++ { + if err := sleepCtx(ctx, 300*time.Millisecond); err != nil { + return 0, err + } if IsWorkerRunning(port) { updateCacheFromPort(port) return port, nil @@ -263,7 +316,13 @@ func EnsureWorkerRunning() (int, error) { if err := KillProcessOnPort(port); err != nil { fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: failed to kill unhealthy process on port %d: %v\n", port, err) } - time.Sleep(500 * time.Millisecond) + if err := sleepCtx(ctx, 500*time.Millisecond); err != nil { + return 0, err + } + } + + if ctx.Err() != nil { + return 0, ctx.Err() } // Find worker binary @@ -272,8 +331,10 @@ func EnsureWorkerRunning() (int, error) { return 0, fmt.Errorf("worker binary not found") } - // Start worker + // Start worker -- detach from hook's process group so Claude Code + // killing the hook doesn't take the worker down with it. cmd := exec.Command(workerPath) // #nosec G204 -- workerPath is from internal findWorkerBinary + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} cmd.Stdout = os.Stderr cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { @@ -286,27 +347,32 @@ func EnsureWorkerRunning() (int, error) { pid := cmd.Process.Pid // Wait for worker to be ready with exponential backoff - deadline := time.Now().Add(StartupTimeout) backoff := 50 * time.Millisecond maxBackoff := 500 * time.Millisecond - for time.Now().Before(deadline) { + for { + if ctx.Err() != nil { + circuitBreakerMu.Lock() + lastStartupFailure = time.Now() + circuitBreakerMu.Unlock() + return 0, fmt.Errorf("worker failed to start within deadline: %w", ctx.Err()) + } if IsWorkerRunning(port) { writeWorkerCache(port, pid) return port, nil } - time.Sleep(backoff) + if err := sleepCtx(ctx, backoff); err != nil { + circuitBreakerMu.Lock() + lastStartupFailure = time.Now() + circuitBreakerMu.Unlock() + return 0, fmt.Errorf("worker failed to start within deadline: %w", err) + } // Exponential backoff with cap backoff = backoff * 2 if backoff > maxBackoff { backoff = maxBackoff } } - - circuitBreakerMu.Lock() - lastStartupFailure = time.Now() - circuitBreakerMu.Unlock() - return 0, fmt.Errorf("worker failed to start within timeout") } // updateCacheFromPort finds the PID of the process on the port and updates the cache. @@ -330,8 +396,7 @@ func updateCacheFromPort(port int) { // GetWorkerVersion gets the version of the running worker. func GetWorkerVersion(port int) string { - client := &http.Client{Timeout: HealthCheckTimeout} - resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/version", port)) + resp, err := healthClient.Get(fmt.Sprintf("http://127.0.0.1:%d/api/version", port)) if err != nil { return "" } @@ -447,14 +512,12 @@ func findWorkerBinary() string { // POST sends a POST request to the worker. func POST(port int, path string, body interface{}) (map[string]interface{}, error) { - client := &http.Client{Timeout: 10 * time.Second} - jsonBody, err := json.Marshal(body) if err != nil { return nil, err } - resp, err := client.Post( + resp, err := hookClient.Post( fmt.Sprintf("http://127.0.0.1:%d%s", port, path), "application/json", bytes.NewReader(jsonBody), @@ -493,8 +556,7 @@ func POSTWithContext(ctx context.Context, port int, path string, body interface{ } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) + resp, err := hookClient.Do(req) if err != nil { return err } @@ -504,9 +566,7 @@ func POSTWithContext(ctx context.Context, port int, path string, body interface{ // GET sends a GET request to the worker. func GET(port int, path string) (map[string]interface{}, error) { - client := &http.Client{Timeout: 10 * time.Second} - - resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, path)) + resp, err := hookClient.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, path)) if err != nil { return nil, err } diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go index 97ed129..8adee07 100644 --- a/pkg/hooks/worker_test.go +++ b/pkg/hooks/worker_test.go @@ -2,6 +2,7 @@ package hooks import ( + "context" "encoding/json" "fmt" "net/http" @@ -518,7 +519,8 @@ func TestProjectIDWithName_Uniqueness(t *testing.T) { func TestHookConstants(t *testing.T) { assert.Equal(t, 37777, DefaultWorkerPort) assert.Equal(t, 2*time.Second, HealthCheckTimeout) - assert.Equal(t, 30*time.Second, StartupTimeout) + assert.Equal(t, 10*time.Second, StartupTimeout) + assert.Equal(t, 15*time.Second, EnsureWorkerDeadline) } // TestExitCodes tests exit code constants. @@ -1200,5 +1202,104 @@ func TestHealthCheckTimeout(t *testing.T) { // TestStartupTimeout tests the startup timeout is reasonable. func TestStartupTimeout(t *testing.T) { assert.Greater(t, StartupTimeout, 5*time.Second) - assert.LessOrEqual(t, StartupTimeout, time.Minute) + assert.LessOrEqual(t, StartupTimeout, 15*time.Second) +} + +// TestEnsureWorkerDeadline tests the deadline is within hook budget. +func TestEnsureWorkerDeadline(t *testing.T) { + assert.Greater(t, EnsureWorkerDeadline, StartupTimeout, "deadline must exceed startup timeout") + assert.LessOrEqual(t, EnsureWorkerDeadline, 20*time.Second, "deadline must fit in hook timeout budget") +} + +// --- Regression tests for Fixes 1, 3, 4 --- + +// TestEnsureWorkerRunning_RespectsDeadline verifies that EnsureWorkerRunning returns +// within a bounded time even in worst case (no worker, startup fails). +func TestEnsureWorkerRunning_RespectsDeadline(t *testing.T) { + // Reset circuit breaker so the function actually tries to start. + circuitBreakerMu.Lock() + lastStartupFailure = time.Time{} + circuitBreakerMu.Unlock() + + // Use a port that nothing listens on. + t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "19999") + // Point HOME to a temp dir so findWorkerBinary finds nothing. + t.Setenv("HOME", t.TempDir()) + // Clear plugin root to avoid that path too. + t.Setenv("CLAUDE_PLUGIN_ROOT", "") + + start := time.Now() + _, err := EnsureWorkerRunning() + elapsed := time.Since(start) + + // Must error (no binary found). + assert.Error(t, err) + assert.Contains(t, err.Error(), "worker binary not found") + // Must complete well within EnsureWorkerDeadline. + assert.Less(t, elapsed, EnsureWorkerDeadline, + "EnsureWorkerRunning took %v, exceeding deadline %v", elapsed, EnsureWorkerDeadline) +} + +// TestEnsureWorkerRunningCtx_CancelledContext verifies immediate return on cancelled context. +func TestEnsureWorkerRunningCtx_CancelledContext(t *testing.T) { + // Reset circuit breaker. + circuitBreakerMu.Lock() + lastStartupFailure = time.Time{} + circuitBreakerMu.Unlock() + + t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "19998") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + start := time.Now() + _, err := ensureWorkerRunningCtx(ctx) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.Less(t, elapsed, 1*time.Second, "cancelled context should return immediately") +} + +// TestSleepCtx_Normal verifies sleepCtx completes normally. +func TestSleepCtx_Normal(t *testing.T) { + ctx := context.Background() + start := time.Now() + err := sleepCtx(ctx, 50*time.Millisecond) + elapsed := time.Since(start) + + assert.NoError(t, err) + assert.GreaterOrEqual(t, elapsed, 50*time.Millisecond) +} + +// TestSleepCtx_Cancelled verifies sleepCtx returns early on cancel. +func TestSleepCtx_Cancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + err := sleepCtx(ctx, 5*time.Second) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.Less(t, elapsed, 500*time.Millisecond) +} + +// TestHookClients_DisableKeepAlives asserts shared clients disable keep-alives +// to prevent TIME_WAIT connection leaks in short-lived hook processes. +func TestHookClients_DisableKeepAlives(t *testing.T) { + hTransport, ok := hookClient.Transport.(*http.Transport) + require.True(t, ok, "hookClient.Transport should be *http.Transport") + assert.True(t, hTransport.DisableKeepAlives, "hookClient must disable keep-alives") + assert.Equal(t, 1, hTransport.MaxIdleConns) + + hcTransport, ok := healthClient.Transport.(*http.Transport) + require.True(t, ok, "healthClient.Transport should be *http.Transport") + assert.True(t, hcTransport.DisableKeepAlives, "healthClient must disable keep-alives") + assert.Equal(t, 1, hcTransport.MaxIdleConns) +} + +// TestHookClient_Timeout verifies hookClient timeout is set. +func TestHookClient_Timeout(t *testing.T) { + assert.Equal(t, 10*time.Second, hookClient.Timeout) + assert.Equal(t, HealthCheckTimeout, healthClient.Timeout) }