fix: address 15 additional hang vectors found during deep audit (#45)

MCP server (5 fixes):
- Move semaphore acquisition inside goroutine so main loop stays
  responsive when all slots are taken
- Add 10s write timeout to sendResponse to prevent pipe deadlock
  when Claude Code pauses reading stdout
- Send fallback JSON-RPC error when json.Marshal fails instead of
  silently swallowing the error and leaving caller waiting forever
- Silence unknown notification methods (req.ID == nil) instead of
  sending unsolicited error responses that may desync the host
- Return MCP isError content for tool failures instead of top-level
  JSON-RPC error, matching the MCP specification

Vector/embedding (3 fixes):
- Move EmbedBatchWithContext call before writeMu.Lock in AddDocuments
  so ONNX inference runs outside the write lock
- Replace singleflight.Do with DoChan + ctx select in both
  getOrComputeEmbedding and UnifiedSearch so callers can bail out
  independently when their context expires
- Add activeQueries atomic counter; skip cache warming when user
  queries are in-flight; reduce warming timeout from 5s to 2s

Hooks (4 fixes):
- Cap EnsureWorkerRunning to 15s hard deadline with context; reduce
  StartupTimeout from 30s to 10s; reduce port-in-use retries
- Fix nil dereference panic in user-prompt hook when initResult is
  nil (non-JSON worker response); use comma-ok assertions
- Use package-level hookClient/healthClient with DisableKeepAlives
  to prevent FD leaks in short-lived hook processes
- Set SysProcAttr{Setpgid: true} to detach worker from hook process
  group, preventing kill-cascade from Claude Code

Worker/DB (3 fixes):
- Replace os.Exit(0) in MCP config watcher with context cancellation
  for clean protocol shutdown
- Add 60s context.WithTimeout around ProcessObservation calls in
  processAllSessions to prevent hung CLI subprocesses from blocking
  the queue processor forever
- Set explicit PRAGMA wal_autocheckpoint=1000 and add PASSIVE WAL
  checkpoint to Optimize() to prevent checkpoint stalls

Adds 20+ regression tests across all fix areas.
This commit is contained in:
2026-05-26 13:52:09 +01:00
parent de5796bbe6
commit a81482d06a
15 changed files with 952 additions and 92 deletions
+1
View File
@@ -88,3 +88,4 @@ docs/dist
# Non-template plugin configs (keep only .tpl files in plugin/ dir)
plugin/.claude-plugin/plugin.json
plugin/.claude-plugin/marketplace.json
user-prompt
+15 -3
View File
@@ -177,15 +177,27 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) {
if initErr != nil {
return "", initErr
}
if initResult == nil {
return contextToInject, nil // Non-JSON response from worker, skip session init
}
// Check if skipped due to privacy
if skipped, ok := initResult["skipped"].(bool); ok && skipped {
fmt.Fprintf(os.Stderr, "[user-prompt] Session skipped (private)\n")
return "", nil
return contextToInject, nil
}
sessionID := int64(initResult["sessionDbId"].(float64))
promptNumber := int(initResult["promptNumber"].(float64))
sessionDBIDVal, ok := initResult["sessionDbId"].(float64)
if !ok {
return contextToInject, nil // Missing or wrong type, skip gracefully
}
sessionID := int64(sessionDBIDVal)
promptNumberVal, ok := initResult["promptNumber"].(float64)
if !ok {
return contextToInject, nil
}
promptNumber := int(promptNumberVal)
fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber)
+44
View File
@@ -0,0 +1,44 @@
package main
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestEstimateTokens tests the token estimator.
func TestEstimateTokens(t *testing.T) {
tests := []struct {
name string
input string
minToken int
maxToken int
}{
{"empty string", "", 0, 0},
{"single word", "hello", 1, 3},
{"simple sentence", "Hello world this is a test", 5, 15},
{"code-heavy", "func() { return x.y.z(); }", 5, 30},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := estimateTokens(tt.input)
assert.GreaterOrEqual(t, result, tt.minToken)
assert.LessOrEqual(t, result, tt.maxToken)
})
}
}
// TestHandleUserPrompt_NilInitResult_Compile verifies that the nil-safety
// fix in handleUserPrompt compiles correctly. The actual nil dereference
// was at initResult["sessionDbId"].(float64) when initResult was nil.
// This test ensures the defensive type assertions are present by exercising
// the token estimator (the handler requires a live HookContext+worker).
func TestHandleUserPrompt_NilInitResult_Compile(t *testing.T) {
// The real regression test is that `go build ./cmd/hooks/user-prompt/`
// succeeds with the nil-safe assertions. We can't easily spin up
// a full HookContext here, but we verify the package compiles and
// the helper functions are sane.
assert.Equal(t, 0, estimateTokens(""))
assert.Greater(t, estimateTokens("test input"), 0)
}
+9 -6
View File
@@ -59,7 +59,7 @@ func main() {
}()
// Start file watchers for config changes
startWatchers()
startWatchers(cancel)
telemetry.Send("claude-mnemonic", Version)
@@ -68,18 +68,21 @@ func main() {
log.Info().Str("project", *project).Str("version", Version).Str("worker", workerURL).Msg("Starting MCP server")
if err := server.Run(ctx); err != nil {
if err == context.Canceled {
log.Info().Msg("MCP server shut down (config change or signal)")
return
}
log.Fatal().Err(err).Msg("MCP server error")
}
}
// startWatchers initializes file watchers for config.
func startWatchers() {
// Watch config file for changes (triggers process exit for restart)
func startWatchers(cancel context.CancelFunc) {
// Watch config file for changes (triggers graceful shutdown via context cancellation)
configPath := config.SettingsPath()
configWatcher, err := watcher.New(configPath, func() {
log.Warn().Str("path", configPath).Msg("Config file changed, exiting for restart...")
time.Sleep(100 * time.Millisecond) // Give logs time to flush
os.Exit(0)
log.Warn().Str("path", configPath).Msg("Config file changed, shutting down gracefully...")
cancel() // Triggers ctx.Done() in server.Run(), which drains in-flight requests
})
if err != nil {
log.Warn().Err(err).Msg("Failed to create config watcher")
+8 -2
View File
@@ -99,8 +99,9 @@ func NewStore(cfg Config) (*Store, error) {
"PRAGMA synchronous=NORMAL",
"PRAGMA cache_size=-64000", // 64MB cache (negative = KB)
"PRAGMA temp_store=MEMORY", // Store temp tables in memory
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
"PRAGMA wal_autocheckpoint=1000", // Explicit default; checkpoint every 1000 WAL frames
}
for _, pragma := range pragmas {
if _, err := sqlDB.Exec(pragma); err != nil {
@@ -192,6 +193,11 @@ func (s *Store) Optimize(ctx context.Context) error {
log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)")
}
// Passive WAL checkpoint — doesn't block readers/writers
if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA wal_checkpoint(PASSIVE)"); err != nil {
log.Warn().Err(err).Msg("WAL checkpoint failed (non-fatal)")
}
log.Info().Dur("duration", time.Since(start)).Msg("Database optimization complete")
return nil
}
+89
View File
@@ -4,6 +4,7 @@
package gorm
import (
"context"
"os"
"path/filepath"
"testing"
@@ -150,3 +151,91 @@ func TestMigrationIdempotency(t *testing.T) {
t.Logf("✅ Migrations are idempotent")
}
func TestWALAutocheckpoint(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "gorm_wal_checkpoint_*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
store, err := NewStore(Config{
Path: dbPath,
MaxConns: 2,
LogLevel: logger.Silent,
})
if err != nil {
t.Fatalf("NewStore failed: %v", err)
}
defer store.Close()
// Verify wal_autocheckpoint is set to 1000
var checkpoint int
err = store.GetRawDB().QueryRow("PRAGMA wal_autocheckpoint").Scan(&checkpoint)
if err != nil {
t.Fatalf("query wal_autocheckpoint: %v", err)
}
if checkpoint != 1000 {
t.Errorf("expected wal_autocheckpoint=1000, got %d", checkpoint)
}
}
func TestOptimize_RunsWALCheckpoint(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "gorm_optimize_*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
store, err := NewStore(Config{
Path: dbPath,
MaxConns: 2,
LogLevel: logger.Silent,
})
if err != nil {
t.Fatalf("NewStore failed: %v", err)
}
defer store.Close()
// Insert some data to generate WAL frames
_, err = store.GetRawDB().Exec("INSERT INTO observations (sdk_session_id, title, scope, project, type, created_at, created_at_epoch) VALUES ('test-sess', 'test data', 'project', '/tmp/test', 'decision', '2026-01-01T00:00:00Z', 1735689600)")
if err != nil {
t.Fatalf("insert test data: %v", err)
}
// Optimize should succeed (includes PASSIVE WAL checkpoint)
err = store.Optimize(context.Background())
if err != nil {
t.Fatalf("Optimize failed: %v", err)
}
}
func TestOptimize_RespectsContextCancellation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "gorm_optimize_cancel_*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
dbPath := filepath.Join(tmpDir, "test.db")
store, err := NewStore(Config{
Path: dbPath,
MaxConns: 2,
LogLevel: logger.Silent,
})
if err != nil {
t.Fatalf("NewStore failed: %v", err)
}
defer store.Close()
// Already-cancelled context should cause Optimize to fail
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = store.Optimize(ctx)
if err == nil {
t.Error("expected error with cancelled context, got nil")
}
}
+53 -14
View File
@@ -171,11 +171,22 @@ func (s *Server) Run(ctx context.Context) error {
}
// Dispatch request to its own goroutine.
// Semaphore is acquired inside the goroutine so the main
// loop never blocks on a full semaphore (Fix #1).
wg.Add(1)
sem <- struct{}{} // acquire semaphore slot
go func(r Request) {
defer wg.Done()
defer func() { <-sem }() // release semaphore slot
// Acquire semaphore inside goroutine, not blocking main loop
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-ctx.Done():
// Server shutting down — send error so caller isn't left waiting
if r.ID != nil {
_ = s.sendError(r.ID, -32000, "Server shutting down", nil)
}
return
}
resp := s.handleRequest(ctx, &r)
if resp != nil {
@@ -203,6 +214,10 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) *Response {
case "notifications/initialized", "notifications/cancelled":
return nil // Notifications don't get responses
default:
if req.ID == nil {
// Notifications must not receive responses per JSON-RPC 2.0
return nil
}
return &Response{
JSONRPC: "2.0",
ID: req.ID,
@@ -775,13 +790,15 @@ func (s *Server) handleToolsCall(ctx context.Context, req *Request) *Response {
result, err := s.callTool(ctx, params.Name, params.Arguments)
if err != nil {
// MCP spec: tool failures use Result with isError, not top-level Error
return &Response{
JSONRPC: "2.0",
ID: req.ID,
Error: &Error{
Code: -32000,
Message: "Tool error",
Data: err.Error(),
Result: map[string]any{
"content": []map[string]any{
{"type": "text", "text": "Error: " + err.Error()},
},
"isError": true,
},
}
}
@@ -1902,16 +1919,38 @@ func (s *Server) sendResponse(resp *Response) error {
data, err := json.Marshal(resp)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal response")
// Send a fallback error so the caller isn't left waiting (Fix #3)
fallback, _ := json.Marshal(map[string]any{
"jsonrpc": "2.0",
"id": resp.ID,
"error": map[string]any{"code": -32603, "message": "internal marshal error"},
})
s.writeMu.Lock()
_, _ = fmt.Fprintln(s.stdout, string(fallback))
s.writeMu.Unlock()
return fmt.Errorf("marshal error: %w", err)
}
// Bound the write to prevent pipe deadlock (Fix #2)
done := make(chan error, 1)
go func() {
s.writeMu.Lock()
_, werr := fmt.Fprintln(s.stdout, string(data))
s.writeMu.Unlock()
done <- werr
}()
select {
case werr := <-done:
if werr != nil {
log.Error().Err(werr).Msg("Failed to write response to stdout")
return werr
}
return nil
case <-time.After(10 * time.Second):
log.Error().Msg("Stdout write timed out — pipe likely full")
return fmt.Errorf("write timeout: stdout pipe full")
}
s.writeMu.Lock()
_, err = fmt.Fprintln(s.stdout, string(data))
s.writeMu.Unlock()
if err != nil {
log.Error().Err(err).Msg("Failed to write response to stdout")
return err
}
return nil
}
// sendError sends a JSON-RPC error response. Returns an error if writing fails.
+255 -9
View File
@@ -1665,6 +1665,7 @@ func TestToolListContainsExpectedSchemas(t *testing.T) {
}
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
// After Fix #5: tool errors use Result with isError, not top-level Error.
func TestHandleToolsCall_UnknownTool(t *testing.T) {
t.Parallel()
@@ -1679,9 +1680,13 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) {
}
resp := server.handleToolsCall(ctx, req)
require.NotNil(t, resp.Error)
assert.Equal(t, -32000, resp.Error.Code)
assert.Contains(t, resp.Error.Data, "unknown tool")
assert.Nil(t, resp.Error, "tool errors must not use top-level Error (MCP spec)")
require.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, true, result["isError"])
content := result["content"].([]map[string]any)
assert.Contains(t, content[0]["text"], "unknown tool")
}
// TestCallTool_ToolNameRecognition tests that valid tool names are recognized.
@@ -1982,6 +1987,7 @@ func TestTimelineParamsStruct_Validation(t *testing.T) {
}
// TestHandleToolsCall_EmptyParams tests tools/call with empty params.
// After Fix #5: tool errors use Result with isError, not top-level Error.
func TestHandleToolsCall_EmptyParams(t *testing.T) {
t.Parallel()
@@ -1997,8 +2003,12 @@ func TestHandleToolsCall_EmptyParams(t *testing.T) {
resp := server.handleToolsCall(ctx, req)
// Should error due to missing name
require.NotNil(t, resp.Error)
// Empty name goes through callTool default branch -> isError
assert.Nil(t, resp.Error, "tool errors must use isError in Result")
require.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, true, result["isError"])
}
// TestSendResponse_WithError tests sendResponse with an error response.
@@ -2062,6 +2072,7 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) {
}
// TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error.
// After Fix #5: tool errors use Result with isError, not top-level Error.
func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
t.Parallel()
@@ -2077,12 +2088,13 @@ func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
resp := server.handleToolsCall(ctx, req)
// Should get an error response
assert.Equal(t, "2.0", resp.JSONRPC)
assert.Equal(t, 1, resp.ID)
require.NotNil(t, resp.Error)
// Error is "Tool error" with message containing "unknown tool"
assert.True(t, resp.Error.Code != 0)
assert.Nil(t, resp.Error, "tool errors must use isError in Result, not top-level Error")
require.NotNil(t, resp.Result)
result, ok := resp.Result.(map[string]any)
require.True(t, ok)
assert.Equal(t, true, result["isError"])
}
// =============================================================================
@@ -3172,3 +3184,237 @@ func TestRun_GracefulDrainOnCancel(t *testing.T) {
assert.Equal(t, "2.0", resp["jsonrpc"], "any completed response must be valid JSON-RPC 2.0")
}
}
// =============================================================================
// REGRESSION TESTS — Fix #1-#5
// =============================================================================
// blockingWriter is an io.Writer that blocks forever on Write.
type blockingWriter struct {
blocked chan struct{} // closed when Write is entered
}
func (bw *blockingWriter) Write(p []byte) (int, error) {
if bw.blocked != nil {
close(bw.blocked)
}
select {} // block forever
}
// TestRun_SemaphoreDoesNotBlockMainLoop (Fix #1 regression) fills all semaphore
// slots with blocked requests, then sends a notification. The main loop must
// stay responsive and not hang on semaphore acquisition.
func TestRun_SemaphoreDoesNotBlockMainLoop(t *testing.T) {
t.Parallel()
const maxConcurrent = 10
// Mock worker that blocks until test context is cancelled
handlerCtx, handlerCancel := context.WithCancel(context.Background())
defer handlerCancel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-handlerCtx.Done() // block until test cleanup
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{}`)
}))
defer func() {
handlerCancel() // unblock all handlers first
ts.CloseClientConnections()
ts.Close()
}()
stdinR, stdinW := io.Pipe()
var stdout bytes.Buffer
server := &Server{
client: ts.Client(),
workerURL: ts.URL,
project: "test",
version: "1.0.0",
stdin: stdinR,
stdout: &stdout,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runDone := make(chan error, 1)
go func() {
runDone <- server.Run(ctx)
}()
// Send maxConcurrent+2 requests to fill all semaphore slots
for i := 0; i < maxConcurrent+2; i++ {
req := Request{JSONRPC: "2.0", ID: i + 1, Method: "tools/call",
Params: json.RawMessage(`{"name":"get_memory_stats","arguments":{}}`)}
data, _ := json.Marshal(req)
_, err := io.WriteString(stdinW, string(data)+"\n")
require.NoError(t, err)
}
// Give goroutines time to start and fill semaphore
time.Sleep(100 * time.Millisecond)
// Now send a notification — this must not hang the main loop
notifSent := make(chan struct{})
go func() {
notif := `{"jsonrpc":"2.0","method":"notifications/initialized"}` + "\n"
_, _ = io.WriteString(stdinW, notif)
close(notifSent)
}()
select {
case <-notifSent:
// Main loop accepted the notification write (stdin is a pipe, so
// the write completing means the reader consumed it)
case <-time.After(3 * time.Second):
t.Fatal("main loop blocked — semaphore acquisition is blocking the main goroutine")
}
// Clean up: cancel server context, close stdin
cancel()
_ = stdinW.Close()
// Wait for Run to finish
select {
case <-runDone:
case <-time.After(5 * time.Second):
// Acceptable — some goroutines may still be draining
}
}
// TestSendResponse_WriteTimeout (Fix #2 regression) verifies that sendResponse
// returns an error within a bounded time when the writer blocks forever.
func TestSendResponse_WriteTimeout(t *testing.T) {
t.Parallel()
bw := &blockingWriter{blocked: make(chan struct{})}
server := &Server{stdout: bw}
resp := &Response{
JSONRPC: "2.0",
ID: 1,
Result: "ok",
}
done := make(chan error, 1)
go func() {
done <- server.sendResponse(resp)
}()
select {
case err := <-done:
require.Error(t, err, "sendResponse must return error on write timeout")
assert.Contains(t, err.Error(), "write timeout")
case <-time.After(15 * time.Second):
t.Fatal("sendResponse hung forever — write timeout not working")
}
}
// TestSendResponse_MarshalError (Fix #3 regression) verifies that when
// json.Marshal fails, sendResponse sends a fallback error response and
// returns an error (instead of silently returning nil).
func TestSendResponse_MarshalError(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
server := &Server{stdout: &buf}
// Channels are not JSON-serializable — this will cause json.Marshal to fail
resp := &Response{
JSONRPC: "2.0",
ID: 42,
Result: make(chan int), // unserializable
}
err := server.sendResponse(resp)
// (a) Must return an error
require.Error(t, err, "sendResponse must return error when marshal fails")
assert.Contains(t, err.Error(), "marshal error")
// (b) Must have written a fallback JSON-RPC error to stdout
output := buf.String()
require.NotEmpty(t, output, "fallback response must be written to stdout")
var fallback map[string]any
require.NoError(t, json.Unmarshal([]byte(strings.TrimSpace(output)), &fallback),
"fallback must be valid JSON")
assert.Equal(t, "2.0", fallback["jsonrpc"])
assert.Equal(t, float64(42), fallback["id"], "fallback must preserve original request ID")
errObj, ok := fallback["error"].(map[string]any)
require.True(t, ok, "fallback must contain error object")
assert.Equal(t, float64(-32603), errObj["code"])
assert.Equal(t, "internal marshal error", errObj["message"])
}
// TestHandleRequest_UnknownNotification (Fix #4 regression) verifies that
// unknown notification methods (ID == nil) get no response, while unknown
// methods with an ID still get a -32601 error.
func TestHandleRequest_UnknownNotification(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
// Case 1: Unknown notification (no ID) — must return nil
notifReq := &Request{
JSONRPC: "2.0",
ID: nil,
Method: "notifications/roots/list_changed",
}
resp := server.handleRequest(ctx, notifReq)
assert.Nil(t, resp, "unknown notification must not produce a response")
// Case 2: Unknown method WITH an ID — must return error response
methodReq := &Request{
JSONRPC: "2.0",
ID: 99,
Method: "some/unknown/method",
}
resp = server.handleRequest(ctx, methodReq)
require.NotNil(t, resp, "unknown method with ID must produce an error response")
require.NotNil(t, resp.Error)
assert.Equal(t, -32601, resp.Error.Code)
assert.Equal(t, "Method not found", resp.Error.Message)
assert.Equal(t, 99, resp.ID)
}
// TestHandleToolsCall_ErrorUsesIsError (Fix #5 regression) verifies that when
// callTool returns an error, the response uses Result with isError:true instead
// of top-level Error field (per MCP spec).
func TestHandleToolsCall_ErrorUsesIsError(t *testing.T) {
t.Parallel()
server := NewServer(nil, "", "", "1.0.0")
ctx := context.Background()
req := &Request{
JSONRPC: "2.0",
ID: 7,
Method: "tools/call",
Params: json.RawMessage(`{"name":"nonexistent_tool","arguments":{}}`),
}
resp := server.handleToolsCall(ctx, req)
// (a) Response must NOT have top-level Error
assert.Nil(t, resp.Error, "tool errors must not use top-level JSON-RPC Error")
// (b) Response must have Result with isError: true
require.NotNil(t, resp.Result, "tool error response must have Result")
result, ok := resp.Result.(map[string]any)
require.True(t, ok, "Result must be a map")
assert.Equal(t, true, result["isError"], "Result must contain isError: true")
// (c) Result.content[0].text must contain the error message
content, ok := result["content"].([]map[string]any)
require.True(t, ok, "Result.content must be []map[string]any")
require.Len(t, content, 1)
assert.Equal(t, "text", content[0]["type"])
errText, ok := content[0]["text"].(string)
require.True(t, ok)
assert.Contains(t, errText, "unknown tool: nonexistent_tool")
}
+33 -15
View File
@@ -46,7 +46,7 @@ const (
cacheWarmingInterval = 20 * time.Second // Run warming cycle every 20 seconds
frequencyCleanupInterval = 5 * time.Minute // Cleanup stale entries every 5 minutes
cacheCleanupInterval = time.Minute // Cleanup expired cache every minute
warmingQueryTimeout = 5 * time.Second // Timeout for warming queries
warmingQueryTimeout = 2 * time.Second // Timeout for warming queries (kept short to minimize mutex hold)
warmingBatchSize = 5 // Warm top 5 queries per cycle
minRecencyFactor = 0.1 // Minimum recency factor for scoring
@@ -127,6 +127,7 @@ type Manager struct {
queryFrequency map[string]*queryFrequencyInfo
cacheTTL time.Duration
cacheMaxSize int
activeQueries atomic.Int32 // tracks in-flight user queries to yield embedding mutex
resultCacheMu sync.RWMutex
queryFrequencyMu sync.RWMutex
}
@@ -256,7 +257,14 @@ func (m *Manager) cleanupStaleFrequencyEntries() {
}
// warmFrequentQueries pre-executes frequently used queries to warm the cache.
// Skips warming entirely when user queries are in-flight to avoid competing
// for the embedding model mutex on throttled hardware.
func (m *Manager) warmFrequentQueries() {
if m.activeQueries.Load() > 0 {
log.Debug().Msg("Skipping cache warming: user queries in flight")
return
}
m.queryFrequencyMu.RLock()
// Find top N most frequent queries that aren't recently cached
type queryScore struct {
@@ -687,30 +695,40 @@ func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*Unif
params.OrderBy = defaultOrderBy
}
// Track active user queries so cache warming can yield the embedding mutex
m.activeQueries.Add(1)
defer m.activeQueries.Add(-1)
// Check cache first
cacheKey := m.getCacheKey(params)
if cached, ok := m.getFromCache(cacheKey); ok {
return cached, nil
}
// Use singleflight to coalesce concurrent identical requests
result, err, _ := m.searchGroup.Do(cacheKey, func() (any, error) {
// Use singleflight DoChan to coalesce concurrent identical requests.
// DoChan + select allows per-caller context cancellation: waiting callers
// can bail out when their context expires without blocking on a slow first call.
ch := m.searchGroup.DoChan(cacheKey, func() (any, error) {
return m.executeSearch(ctx, params)
})
if err != nil {
return nil, err
select {
case res := <-ch:
if res.Err != nil {
return nil, res.Err
}
searchResult := res.Val.(*UnifiedSearchResult)
// Cache the result
m.putInCache(cacheKey, searchResult)
// Track query frequency for cache warming
m.trackQueryFrequency(params)
return searchResult, nil
case <-ctx.Done():
return nil, ctx.Err()
}
searchResult := result.(*UnifiedSearchResult)
// Cache the result
m.putInCache(cacheKey, searchResult)
// Track query frequency for cache warming
m.trackQueryFrequency(params)
return searchResult, nil
}
// executeSearch performs the actual search without caching/coalescing.
+85
View File
@@ -1009,6 +1009,91 @@ func TestSearchParams_FormatValues(t *testing.T) {
}
}
// =============================================================================
// REGRESSION TESTS FOR Fix #3: Cache warming skipped during active queries
// =============================================================================
// TestCacheWarming_SkippedDuringActiveQueries verifies that warmFrequentQueries
// returns immediately without doing work when activeQueries > 0, and actually
// warms when activeQueries == 0.
func TestCacheWarming_SkippedDuringActiveQueries(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
defer m.Close()
// Seed a frequent query so warming has something to consider
params := SearchParams{
Query: "test warming query",
Project: "test-project",
Limit: 10,
}
m.trackQueryFrequency(params)
// Track it multiple times to boost its score
for range 5 {
m.trackQueryFrequency(params)
}
// Case 1: activeQueries > 0 — warming should be skipped (fast return)
m.activeQueries.Store(1)
start := time.Now()
m.warmFrequentQueries()
elapsed := time.Since(start)
// Warming should return almost immediately when skipped
assert.Less(t, elapsed, 50*time.Millisecond,
"warmFrequentQueries should return quickly when user queries are in flight")
// Case 2: activeQueries == 0 — warming should proceed past the guard.
// With nil stores, executeSearch would panic, so we verify the guard was
// bypassed by checking that the frequency map's lastCached is NOT updated
// (warming fails silently on nil stores). The key assertion is that Case 1
// returns immediately (the guard works) while Case 2 enters the body.
m.activeQueries.Store(0)
// Clear frequency data to prevent executeSearch from being called
m.queryFrequencyMu.Lock()
m.queryFrequency = make(map[string]*queryFrequencyInfo)
m.queryFrequencyMu.Unlock()
start2 := time.Now()
m.warmFrequentQueries()
elapsed2 := time.Since(start2)
// With empty frequency map, warming enters the function body (past guard)
// but finds no candidates — fast return without calling executeSearch.
assert.Less(t, elapsed2, 200*time.Millisecond,
"warmFrequentQueries should complete quickly with no candidates")
}
// TestActiveQueries_IncrementDecrement verifies that the activeQueries counter
// is correctly incremented and decremented around search operations.
func TestActiveQueries_IncrementDecrement(t *testing.T) {
m := NewManager(nil, nil, nil, nil)
defer m.Close()
// Before any search, counter should be 0
assert.Equal(t, int32(0), m.activeQueries.Load())
// Directly test the atomic increment/decrement pattern used in UnifiedSearch
// (can't call UnifiedSearch with nil stores without panicking on filterSearch)
m.activeQueries.Add(1)
assert.Equal(t, int32(1), m.activeQueries.Load(), "should be 1 after increment")
m.activeQueries.Add(1)
assert.Equal(t, int32(2), m.activeQueries.Load(), "should be 2 after second increment")
m.activeQueries.Add(-1)
assert.Equal(t, int32(1), m.activeQueries.Load(), "should be 1 after decrement")
m.activeQueries.Add(-1)
assert.Equal(t, int32(0), m.activeQueries.Load(), "should be 0 after final decrement")
// Verify the field is accessible from warmFrequentQueries guard
m.activeQueries.Store(3)
assert.Equal(t, int32(3), m.activeQueries.Load())
m.activeQueries.Store(0)
}
// TestUnifiedSearchResult_MultipleResults tests result with multiple items.
func TestUnifiedSearchResult_MultipleResults(t *testing.T) {
results := []SearchResult{
+25 -12
View File
@@ -145,20 +145,24 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return nil
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Generate embeddings for all documents
// Prepare texts for embedding
texts := make([]string, len(docs))
for i, doc := range docs {
texts[i] = doc.Content
}
embeddings, err := c.embedSvc.EmbedBatch(texts)
// Compute embeddings OUTSIDE the write lock for better concurrency.
// Embedding is ONNX inference (slow, mutex-protected internally) — holding
// writeMu during inference blocks all concurrent writes and reads.
embeddings, err := c.embedSvc.EmbedBatchWithContext(ctx, texts)
if err != nil {
return fmt.Errorf("generate embeddings: %w", err)
}
// Acquire write lock for DB operations only
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Insert into vectors table with model version tracking
const insertQuery = `
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version)
@@ -906,8 +910,10 @@ func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]flo
}
c.queryCacheMu.RUnlock()
// Cache miss - use singleflight to deduplicate concurrent embedding requests
result, err, _ := c.embeddingGroup.Do(query, func() (any, error) {
// Cache miss - use singleflight DoChan to deduplicate concurrent embedding requests.
// DoChan + select allows per-caller context cancellation: if a caller's context
// expires it can bail out without waiting for the shared computation to finish.
ch := c.embeddingGroup.DoChan(query, func() (any, error) {
// Double-check cache inside singleflight (another goroutine may have just cached it)
c.queryCacheMu.RLock()
if entry, ok := c.queryCache[query]; ok {
@@ -921,8 +927,10 @@ func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]flo
// Record cache miss
c.stats.embeddingMisses.Add(1)
// Compute embedding with context-aware lock acquisition
emb, err := c.embedSvc.EmbedWithContext(ctx, query)
// Compute embedding — use non-context Embed here because the singleflight
// result is shared across callers with different contexts. Per-caller
// cancellation is handled by the select below.
emb, err := c.embedSvc.Embed(query)
if err != nil {
return nil, err
}
@@ -969,10 +977,15 @@ func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]flo
return emb, nil
})
if err != nil {
return nil, err
select {
case res := <-ch:
if res.Err != nil {
return nil, res.Err
}
return res.Val.([]float32), nil
case <-ctx.Done():
return nil, ctx.Err()
}
return result.([]float32), nil
}
// ClearCache clears the embedding cache.
+139
View File
@@ -3,6 +3,7 @@ package sqlitevec
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"sync"
@@ -2013,3 +2014,141 @@ func TestAcquireRLockWithContext_CleanupOnCancel(t *testing.T) {
t.Fatal("write lock acquisition timed out: cleanup goroutine may have leaked an RLock")
}
}
// =============================================================================
// REGRESSION TESTS FOR Fix #1: Embedding outside writeMu in AddDocuments
// =============================================================================
// TestAddDocuments_EmbeddingOutsideWriteLock verifies that AddDocuments does NOT
// hold the write lock during embedding computation. A concurrent Query call
// should complete quickly while AddDocuments is computing embeddings.
func TestAddDocuments_EmbeddingOutsideWriteLock(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Seed the DB with one document so Query has something to search
seedDocs := []Document{
{ID: "seed-1", Content: "Seed document for concurrency test"},
}
err = client.AddDocuments(context.Background(), seedDocs)
require.NoError(t, err)
// Pre-warm the embedding cache for the query text so the Query call
// itself doesn't need the embedding mutex — it only needs the DB read lock.
_, err = client.Query(context.Background(), "concurrency test", 5, nil)
require.NoError(t, err)
// Prepare a batch of documents to trigger a slow AddDocuments call
batchDocs := make([]Document, 10)
for i := range batchDocs {
batchDocs[i] = Document{
ID: fmt.Sprintf("batch-%d", i),
Content: fmt.Sprintf("Batch document number %d for write lock test with unique content", i),
}
}
// Launch AddDocuments in background — embedding will take time
addDone := make(chan error, 1)
go func() {
addDone <- client.AddDocuments(context.Background(), batchDocs)
}()
// Give AddDocuments a moment to start embedding computation
time.Sleep(10 * time.Millisecond)
// Invalidate result cache so Query must go through to DB (tests read lock)
client.InvalidateResultCache()
// A concurrent Query should NOT be blocked by AddDocuments.
// If the old code held writeMu during embedding, this would block until
// embedding finishes. With the fix, it should complete quickly.
queryCtx, queryCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer queryCancel()
start := time.Now()
_, err = client.Query(queryCtx, "concurrency test", 5, nil)
queryDuration := time.Since(start)
require.NoError(t, err, "Query should succeed while AddDocuments is embedding")
// The query should complete well within the timeout if writeMu is not held
assert.Less(t, queryDuration, 1*time.Second,
"Query should complete quickly when writeMu is not held during embedding")
// Wait for AddDocuments to finish
err = <-addDone
require.NoError(t, err, "AddDocuments should succeed")
}
// =============================================================================
// REGRESSION TESTS FOR Fix #2a: DoChan + context select in getOrComputeEmbedding
// =============================================================================
// TestGetOrComputeEmbedding_ContextCancelDuringSingleflight verifies that when
// a singleflight embedding computation is in progress, a second caller with a
// short-lived context returns context.DeadlineExceeded promptly rather than
// waiting for the slow first call to finish.
func TestGetOrComputeEmbedding_ContextCancelDuringSingleflight(t *testing.T) {
db, dbCleanup := testDB(t)
defer dbCleanup()
embedSvc, embedCleanup := testEmbeddingService(t)
defer embedCleanup()
client, err := NewClient(Config{DB: db}, embedSvc)
require.NoError(t, err)
// Add a document so we can query
docs := []Document{
{ID: "sf-test-1", Content: "Singleflight context cancellation test document"},
}
err = client.AddDocuments(context.Background(), docs)
require.NoError(t, err)
// Clear cache to force embedding computation
client.ClearCache()
client.InvalidateResultCache()
queryText := "unique singleflight context test query"
// First call: start a normal query in background (will trigger singleflight)
firstDone := make(chan struct{})
go func() {
defer close(firstDone)
_, _ = client.Query(context.Background(), queryText, 5, nil)
}()
// Give the first call a moment to start the singleflight computation
time.Sleep(5 * time.Millisecond)
// Second call: use a very short context that should expire quickly
shortCtx, shortCancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer shortCancel()
// Clear result cache again so the second call hits getOrComputeEmbedding
client.InvalidateResultCache()
start := time.Now()
_, err = client.Query(shortCtx, queryText, 5, nil)
elapsed := time.Since(start)
// With DoChan + select, the second caller should return quickly with context error.
// Note: If the embedding completes fast enough, the second call may succeed
// via singleflight sharing. That's also valid — the test primarily checks
// that it doesn't BLOCK for the full embedding duration on context cancel.
if err != nil {
assert.ErrorIs(t, err, context.DeadlineExceeded,
"Should return DeadlineExceeded when context expires during singleflight")
assert.Less(t, elapsed, 500*time.Millisecond,
"Should return promptly on context cancellation, not wait for slow computation")
}
// If err == nil, the singleflight completed before the context expired — also valid.
// Wait for first call to finish
<-firstDone
}
+6 -2
View File
@@ -1651,11 +1651,15 @@ func (s *Service) processAllSessions() {
defer wg.Done()
defer func() { <-sem }()
// Timeout prevents a hung subprocess from blocking the queue processor forever
procCtx, procCancel := context.WithTimeout(s.ctx, 60*time.Second)
defer procCancel()
switch msg.Type {
case session.MessageTypeObservation:
if msg.Observation != nil {
err := proc.ProcessObservation(
s.ctx,
procCtx,
sess.SDKSessionID,
sess.Project,
msg.Observation.ToolName,
@@ -1674,7 +1678,7 @@ func (s *Service) processAllSessions() {
case session.MessageTypeSummarize:
if msg.Summarize != nil {
err := proc.ProcessSummary(
s.ctx,
procCtx,
sess.SessionDBID,
sess.SDKSessionID,
sess.Project,
+87 -27
View File
@@ -29,7 +29,11 @@ const (
HealthCheckTimeout = 2 * time.Second
// StartupTimeout is the timeout for worker startup.
StartupTimeout = 30 * time.Second
StartupTimeout = 10 * time.Second
// EnsureWorkerDeadline is the hard overall deadline for EnsureWorkerRunning.
// Must fit within Claude Code's hook timeout budget.
EnsureWorkerDeadline = 15 * time.Second
// workerCacheMaxAge is how long the worker cache is considered fresh.
workerCacheMaxAge = 10 * time.Second
@@ -48,6 +52,26 @@ var (
// circuitBreakerMu protects lastStartupFailure.
circuitBreakerMu sync.Mutex
lastStartupFailure time.Time
// hookClient is a shared HTTP client for hook->worker requests.
// DisableKeepAlives prevents TIME_WAIT connection leaks since each hook
// is a separate OS process that exits quickly.
hookClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
MaxIdleConns: 1,
},
}
// healthClient is a shared HTTP client for health/version checks.
healthClient = &http.Client{
Timeout: HealthCheckTimeout,
Transport: &http.Transport{
DisableKeepAlives: true,
MaxIdleConns: 1,
},
}
)
// IsWorkerAvailable performs a fast check without network calls.
@@ -86,8 +110,7 @@ func GetWorkerPort() int {
// Parses the JSON health response to check the "ready" field when available.
// Falls back to HTTP status code check for backwards compatibility.
func IsWorkerRunning(port int) bool {
client := &http.Client{Timeout: HealthCheckTimeout}
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/health", port))
resp, err := healthClient.Get(fmt.Sprintf("http://127.0.0.1:%d/api/health", port))
if err != nil {
return false
}
@@ -200,7 +223,25 @@ func isWorkerRunningWithRetries(port int) bool {
// EnsureWorkerRunning ensures the worker is running, starting it if necessary.
// If a worker is already running and healthy with matching version, it reuses it.
// If version mismatch or unhealthy, it kills the old worker and starts fresh.
// A hard deadline of EnsureWorkerDeadline prevents exceeding Claude Code's hook timeout.
func EnsureWorkerRunning() (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), EnsureWorkerDeadline)
defer cancel()
return ensureWorkerRunningCtx(ctx)
}
// sleepCtx sleeps for d or returns early if ctx is cancelled.
func sleepCtx(ctx context.Context, d time.Duration) error {
select {
case <-time.After(d):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func ensureWorkerRunningCtx(ctx context.Context) (int, error) {
port := GetWorkerPort()
// Fast path: check PID cache before making any HTTP calls.
@@ -210,6 +251,10 @@ func EnsureWorkerRunning() (int, error) {
}
}
if ctx.Err() != nil {
return 0, ctx.Err()
}
// Circuit breaker: if we failed to start recently, don't retry immediately.
circuitBreakerMu.Lock()
if !lastStartupFailure.IsZero() && time.Since(lastStartupFailure) < circuitBreakerCooldown {
@@ -232,7 +277,9 @@ func EnsureWorkerRunning() (int, error) {
if err := KillProcessOnPort(port); err != nil {
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: failed to kill old worker: %v\n", err)
}
time.Sleep(500 * time.Millisecond)
if err := sleepCtx(ctx, 500*time.Millisecond); err != nil {
return 0, err
}
} else {
// Version matches, reuse existing worker
updateCacheFromPort(port)
@@ -245,14 +292,20 @@ func EnsureWorkerRunning() (int, error) {
}
}
if ctx.Err() != nil {
return 0, ctx.Err()
}
// Port is in use but health check failed -- worker may be slow, not dead.
if IsPortInUse(port) {
// The port is responding to TCP but health check timed out.
// Don't kill it -- it's likely just under load. Give it more time.
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Worker on port %d is slow to respond, waiting...\n", port)
// Try a few more times with longer delays before giving up
for i := 0; i < 3; i++ {
time.Sleep(500 * time.Millisecond)
// Try a couple more times with shorter delays before giving up
for i := 0; i < 2; i++ {
if err := sleepCtx(ctx, 300*time.Millisecond); err != nil {
return 0, err
}
if IsWorkerRunning(port) {
updateCacheFromPort(port)
return port, nil
@@ -263,7 +316,13 @@ func EnsureWorkerRunning() (int, error) {
if err := KillProcessOnPort(port); err != nil {
fmt.Fprintf(os.Stderr, "[claude-mnemonic] Warning: failed to kill unhealthy process on port %d: %v\n", port, err)
}
time.Sleep(500 * time.Millisecond)
if err := sleepCtx(ctx, 500*time.Millisecond); err != nil {
return 0, err
}
}
if ctx.Err() != nil {
return 0, ctx.Err()
}
// Find worker binary
@@ -272,8 +331,10 @@ func EnsureWorkerRunning() (int, error) {
return 0, fmt.Errorf("worker binary not found")
}
// Start worker
// Start worker -- detach from hook's process group so Claude Code
// killing the hook doesn't take the worker down with it.
cmd := exec.Command(workerPath) // #nosec G204 -- workerPath is from internal findWorkerBinary
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
@@ -286,27 +347,32 @@ func EnsureWorkerRunning() (int, error) {
pid := cmd.Process.Pid
// Wait for worker to be ready with exponential backoff
deadline := time.Now().Add(StartupTimeout)
backoff := 50 * time.Millisecond
maxBackoff := 500 * time.Millisecond
for time.Now().Before(deadline) {
for {
if ctx.Err() != nil {
circuitBreakerMu.Lock()
lastStartupFailure = time.Now()
circuitBreakerMu.Unlock()
return 0, fmt.Errorf("worker failed to start within deadline: %w", ctx.Err())
}
if IsWorkerRunning(port) {
writeWorkerCache(port, pid)
return port, nil
}
time.Sleep(backoff)
if err := sleepCtx(ctx, backoff); err != nil {
circuitBreakerMu.Lock()
lastStartupFailure = time.Now()
circuitBreakerMu.Unlock()
return 0, fmt.Errorf("worker failed to start within deadline: %w", err)
}
// Exponential backoff with cap
backoff = backoff * 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
circuitBreakerMu.Lock()
lastStartupFailure = time.Now()
circuitBreakerMu.Unlock()
return 0, fmt.Errorf("worker failed to start within timeout")
}
// updateCacheFromPort finds the PID of the process on the port and updates the cache.
@@ -330,8 +396,7 @@ func updateCacheFromPort(port int) {
// GetWorkerVersion gets the version of the running worker.
func GetWorkerVersion(port int) string {
client := &http.Client{Timeout: HealthCheckTimeout}
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/version", port))
resp, err := healthClient.Get(fmt.Sprintf("http://127.0.0.1:%d/api/version", port))
if err != nil {
return ""
}
@@ -447,14 +512,12 @@ func findWorkerBinary() string {
// POST sends a POST request to the worker.
func POST(port int, path string, body interface{}) (map[string]interface{}, error) {
client := &http.Client{Timeout: 10 * time.Second}
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
resp, err := client.Post(
resp, err := hookClient.Post(
fmt.Sprintf("http://127.0.0.1:%d%s", port, path),
"application/json",
bytes.NewReader(jsonBody),
@@ -493,8 +556,7 @@ func POSTWithContext(ctx context.Context, port int, path string, body interface{
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
resp, err := hookClient.Do(req)
if err != nil {
return err
}
@@ -504,9 +566,7 @@ func POSTWithContext(ctx context.Context, port int, path string, body interface{
// GET sends a GET request to the worker.
func GET(port int, path string) (map[string]interface{}, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, path))
resp, err := hookClient.Get(fmt.Sprintf("http://127.0.0.1:%d%s", port, path))
if err != nil {
return nil, err
}
+103 -2
View File
@@ -2,6 +2,7 @@
package hooks
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -518,7 +519,8 @@ func TestProjectIDWithName_Uniqueness(t *testing.T) {
func TestHookConstants(t *testing.T) {
assert.Equal(t, 37777, DefaultWorkerPort)
assert.Equal(t, 2*time.Second, HealthCheckTimeout)
assert.Equal(t, 30*time.Second, StartupTimeout)
assert.Equal(t, 10*time.Second, StartupTimeout)
assert.Equal(t, 15*time.Second, EnsureWorkerDeadline)
}
// TestExitCodes tests exit code constants.
@@ -1200,5 +1202,104 @@ func TestHealthCheckTimeout(t *testing.T) {
// TestStartupTimeout tests the startup timeout is reasonable.
func TestStartupTimeout(t *testing.T) {
assert.Greater(t, StartupTimeout, 5*time.Second)
assert.LessOrEqual(t, StartupTimeout, time.Minute)
assert.LessOrEqual(t, StartupTimeout, 15*time.Second)
}
// TestEnsureWorkerDeadline tests the deadline is within hook budget.
func TestEnsureWorkerDeadline(t *testing.T) {
assert.Greater(t, EnsureWorkerDeadline, StartupTimeout, "deadline must exceed startup timeout")
assert.LessOrEqual(t, EnsureWorkerDeadline, 20*time.Second, "deadline must fit in hook timeout budget")
}
// --- Regression tests for Fixes 1, 3, 4 ---
// TestEnsureWorkerRunning_RespectsDeadline verifies that EnsureWorkerRunning returns
// within a bounded time even in worst case (no worker, startup fails).
func TestEnsureWorkerRunning_RespectsDeadline(t *testing.T) {
// Reset circuit breaker so the function actually tries to start.
circuitBreakerMu.Lock()
lastStartupFailure = time.Time{}
circuitBreakerMu.Unlock()
// Use a port that nothing listens on.
t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "19999")
// Point HOME to a temp dir so findWorkerBinary finds nothing.
t.Setenv("HOME", t.TempDir())
// Clear plugin root to avoid that path too.
t.Setenv("CLAUDE_PLUGIN_ROOT", "")
start := time.Now()
_, err := EnsureWorkerRunning()
elapsed := time.Since(start)
// Must error (no binary found).
assert.Error(t, err)
assert.Contains(t, err.Error(), "worker binary not found")
// Must complete well within EnsureWorkerDeadline.
assert.Less(t, elapsed, EnsureWorkerDeadline,
"EnsureWorkerRunning took %v, exceeding deadline %v", elapsed, EnsureWorkerDeadline)
}
// TestEnsureWorkerRunningCtx_CancelledContext verifies immediate return on cancelled context.
func TestEnsureWorkerRunningCtx_CancelledContext(t *testing.T) {
// Reset circuit breaker.
circuitBreakerMu.Lock()
lastStartupFailure = time.Time{}
circuitBreakerMu.Unlock()
t.Setenv("CLAUDE_MNEMONIC_WORKER_PORT", "19998")
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
start := time.Now()
_, err := ensureWorkerRunningCtx(ctx)
elapsed := time.Since(start)
assert.Error(t, err)
assert.Less(t, elapsed, 1*time.Second, "cancelled context should return immediately")
}
// TestSleepCtx_Normal verifies sleepCtx completes normally.
func TestSleepCtx_Normal(t *testing.T) {
ctx := context.Background()
start := time.Now()
err := sleepCtx(ctx, 50*time.Millisecond)
elapsed := time.Since(start)
assert.NoError(t, err)
assert.GreaterOrEqual(t, elapsed, 50*time.Millisecond)
}
// TestSleepCtx_Cancelled verifies sleepCtx returns early on cancel.
func TestSleepCtx_Cancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
err := sleepCtx(ctx, 5*time.Second)
elapsed := time.Since(start)
assert.Error(t, err)
assert.Less(t, elapsed, 500*time.Millisecond)
}
// TestHookClients_DisableKeepAlives asserts shared clients disable keep-alives
// to prevent TIME_WAIT connection leaks in short-lived hook processes.
func TestHookClients_DisableKeepAlives(t *testing.T) {
hTransport, ok := hookClient.Transport.(*http.Transport)
require.True(t, ok, "hookClient.Transport should be *http.Transport")
assert.True(t, hTransport.DisableKeepAlives, "hookClient must disable keep-alives")
assert.Equal(t, 1, hTransport.MaxIdleConns)
hcTransport, ok := healthClient.Transport.(*http.Transport)
require.True(t, ok, "healthClient.Transport should be *http.Transport")
assert.True(t, hcTransport.DisableKeepAlives, "healthClient must disable keep-alives")
assert.Equal(t, 1, hcTransport.MaxIdleConns)
}
// TestHookClient_Timeout verifies hookClient timeout is set.
func TestHookClient_Timeout(t *testing.T) {
assert.Equal(t, 10*time.Second, hookClient.Timeout)
assert.Equal(t, HealthCheckTimeout, healthClient.Timeout)
}