mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-09 23:59:40 +00:00
fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)
Root cause: synchronous MCP request processing combined with missing context propagation to the embedding layer caused indefinite hangs when ONNX inference was slow or the database was contended. Changes: - MCP server: dispatch each request in its own goroutine with semaphore (cap 10) and WaitGroup for clean shutdown drain - Embedding: add context-aware mutex acquisition (acquireMutex) so callers can bail out instead of blocking forever on a stuck ONNX model - Vector client: propagate context through getOrComputeEmbedding and replace bare RLock() calls with context-aware acquireRLockWithContext - Worker handlers: add 15s request-scoped timeouts to all search/context handlers (handleSearchByPrompt, handleContextInject, handleFileContext, handleContextCount, handleGetObservations/Summaries/Prompts) - Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends deadline per-request via http.ResponseController Fixes #45
This commit is contained in:
@@ -271,7 +271,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
}
|
||||
|
||||
// Generate query embedding OUTSIDE the lock for better concurrency
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
@@ -282,8 +282,10 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
// Now acquire read lock for the actual DB query
|
||||
c.readMu.RLock()
|
||||
// Acquire read lock with context awareness to prevent indefinite blocking
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Build query with filters
|
||||
@@ -485,7 +487,7 @@ func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, wh
|
||||
// Combines results from different field types and deduplicates by document ID.
|
||||
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
|
||||
// Generate embedding once
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
@@ -496,7 +498,9 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Query with field type aggregation - get best match per document
|
||||
@@ -555,6 +559,28 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// acquireRLockWithContext acquires a read lock on mu, respecting ctx cancellation.
|
||||
// If ctx is cancelled while waiting for the lock, the goroutine that eventually
|
||||
// acquires it will release it immediately to prevent leaks.
|
||||
func acquireRLockWithContext(ctx context.Context, mu *sync.RWMutex) error {
|
||||
acquired := make(chan struct{})
|
||||
go func() {
|
||||
mu.RLock()
|
||||
close(acquired)
|
||||
}()
|
||||
select {
|
||||
case <-acquired:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
// Goroutine may still acquire lock after ctx cancelled — must unlock
|
||||
go func() {
|
||||
<-acquired
|
||||
mu.RUnlock()
|
||||
}()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen characters.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
@@ -565,7 +591,9 @@ func truncateString(s string, maxLen int) string {
|
||||
|
||||
// Count returns the total number of vectors in the store.
|
||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return 0, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
var count int64
|
||||
@@ -586,7 +614,9 @@ func (c *Client) ModelVersion() string {
|
||||
// - The vectors table is empty
|
||||
// - Any vectors have a different model_version than the current model
|
||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return false, ""
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
@@ -634,7 +664,9 @@ type StaleVectorInfo struct {
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
@@ -692,7 +724,9 @@ type VectorHealthStats struct {
|
||||
|
||||
// GetHealthStats returns comprehensive health statistics about the vector store.
|
||||
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
stats := &VectorHealthStats{
|
||||
@@ -856,7 +890,9 @@ func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
|
||||
|
||||
// getOrComputeEmbedding returns a cached embedding or computes a new one.
|
||||
// Uses singleflight to prevent duplicate concurrent computations for the same query.
|
||||
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
||||
// The context controls timeout on the embedding mutex acquisition -- if the ONNX model
|
||||
// hangs under CGO, callers can bail out instead of blocking forever.
|
||||
func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]float32, error) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
// Check cache first (read lock)
|
||||
@@ -885,8 +921,8 @@ func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
||||
// Record cache miss
|
||||
c.stats.embeddingMisses.Add(1)
|
||||
|
||||
// Compute embedding
|
||||
emb, err := c.embedSvc.Embed(query)
|
||||
// Compute embedding with context-aware lock acquisition
|
||||
emb, err := c.embedSvc.EmbedWithContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user