fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)

Root cause: synchronous MCP request processing combined with missing
context propagation to the embedding layer caused indefinite hangs when
ONNX inference was slow or the database was contended.

Changes:
- MCP server: dispatch each request in its own goroutine with semaphore
  (cap 10) and WaitGroup for clean shutdown drain
- Embedding: add context-aware mutex acquisition (acquireMutex) so
  callers can bail out instead of blocking forever on a stuck ONNX model
- Vector client: propagate context through getOrComputeEmbedding and
  replace bare RLock() calls with context-aware acquireRLockWithContext
- Worker handlers: add 15s request-scoped timeouts to all search/context
  handlers (handleSearchByPrompt, handleContextInject, handleFileContext,
  handleContextCount, handleGetObservations/Summaries/Prompts)
- Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends
  deadline per-request via http.ResponseController

Fixes #45
This commit is contained in:
2026-05-26 12:34:36 +01:00
parent 56616d0616
commit 29d57857ff
7 changed files with 244 additions and 40 deletions
+48 -12
View File
@@ -271,7 +271,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
}
// Generate query embedding OUTSIDE the lock for better concurrency
queryEmb, err := c.getOrComputeEmbedding(query)
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
@@ -282,8 +282,10 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
return nil, fmt.Errorf("serialize query embedding: %w", err)
}
// Now acquire read lock for the actual DB query
c.readMu.RLock()
// Acquire read lock with context awareness to prevent indefinite blocking
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock()
// Build query with filters
@@ -485,7 +487,7 @@ func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, wh
// Combines results from different field types and deduplicates by document ID.
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
// Generate embedding once
queryEmb, err := c.getOrComputeEmbedding(query)
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
@@ -496,7 +498,9 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
return nil, fmt.Errorf("serialize query embedding: %w", err)
}
c.readMu.RLock()
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock()
// Query with field type aggregation - get best match per document
@@ -555,6 +559,28 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
return results, rows.Err()
}
// acquireRLockWithContext acquires a read lock on mu, respecting ctx cancellation.
// If ctx is cancelled while waiting for the lock, the goroutine that eventually
// acquires it will release it immediately to prevent leaks.
func acquireRLockWithContext(ctx context.Context, mu *sync.RWMutex) error {
acquired := make(chan struct{})
go func() {
mu.RLock()
close(acquired)
}()
select {
case <-acquired:
return nil
case <-ctx.Done():
// Goroutine may still acquire lock after ctx cancelled — must unlock
go func() {
<-acquired
mu.RUnlock()
}()
return ctx.Err()
}
}
// truncateString truncates a string to maxLen characters.
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
@@ -565,7 +591,9 @@ func truncateString(s string, maxLen int) string {
// Count returns the total number of vectors in the store.
func (c *Client) Count(ctx context.Context) (int64, error) {
c.readMu.RLock()
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return 0, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock()
var count int64
@@ -586,7 +614,9 @@ func (c *Client) ModelVersion() string {
// - The vectors table is empty
// - Any vectors have a different model_version than the current model
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
c.readMu.RLock()
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return false, ""
}
defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version()
@@ -634,7 +664,9 @@ type StaleVectorInfo struct {
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
// This enables granular rebuild - only re-embedding documents that need updating.
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
c.readMu.RLock()
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version()
@@ -692,7 +724,9 @@ type VectorHealthStats struct {
// GetHealthStats returns comprehensive health statistics about the vector store.
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
c.readMu.RLock()
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock()
stats := &VectorHealthStats{
@@ -856,7 +890,9 @@ func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
// getOrComputeEmbedding returns a cached embedding or computes a new one.
// Uses singleflight to prevent duplicate concurrent computations for the same query.
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
// The context controls timeout on the embedding mutex acquisition -- if the ONNX model
// hangs under CGO, callers can bail out instead of blocking forever.
func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]float32, error) {
now := time.Now().UnixNano()
// Check cache first (read lock)
@@ -885,8 +921,8 @@ func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
// Record cache miss
c.stats.embeddingMisses.Add(1)
// Compute embedding
emb, err := c.embedSvc.Embed(query)
// Compute embedding with context-aware lock acquisition
emb, err := c.embedSvc.EmbedWithContext(ctx, query)
if err != nil {
return nil, err
}