fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)

Root cause: synchronous MCP request processing combined with missing
context propagation to the embedding layer caused indefinite hangs when
ONNX inference was slow or the database was contended.

Changes:
- MCP server: dispatch each request in its own goroutine with semaphore
  (cap 10) and WaitGroup for clean shutdown drain
- Embedding: add context-aware mutex acquisition (acquireMutex) so
  callers can bail out instead of blocking forever on a stuck ONNX model
- Vector client: propagate context through getOrComputeEmbedding and
  replace bare RLock() calls with context-aware acquireRLockWithContext
- Worker handlers: add 15s request-scoped timeouts to all search/context
  handlers (handleSearchByPrompt, handleContextInject, handleFileContext,
  handleContextCount, handleGetObservations/Summaries/Prompts)
- Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends
  deadline per-request via http.ResponseController

Fixes #45
This commit is contained in:
2026-05-26 12:34:36 +01:00
parent 56616d0616
commit 29d57857ff
7 changed files with 244 additions and 40 deletions
+22 -8
View File
@@ -23,6 +23,10 @@ import (
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
cwd := r.URL.Query().Get("cwd")
@@ -54,7 +58,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
var expandedQueries []expansion.ExpandedQuery
var detectedIntent string
if s.queryExpander != nil {
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
expandCtx, expandCancel := context.WithTimeout(ctx, 5*time.Second)
cfg := expansion.DefaultConfig()
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
@@ -83,7 +87,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
var vectorErrors int
for _, eq := range expandedQueries {
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
vectorResults, vecErr := s.vectorClient.Query(ctx, eq.Query, limit*2, where)
if vecErr != nil {
vectorErrors++
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
@@ -125,7 +129,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
if len(obsIDs) > 0 {
// Fetch full observations from SQLite
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
@@ -138,11 +142,11 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
if vectorSearchFailed {
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
}
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
observations, err = s.observationStore.SearchObservationsFTS(ctx, query, project, limit)
if err != nil {
// FTS might fail if query has special chars, try without
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
observations, err = s.observationStore.GetRecentObservations(ctx, project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -354,7 +358,9 @@ func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
}
// Search for observations related to each file in parallel
ctx := r.Context()
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, ctxCancel := context.WithTimeout(r.Context(), 15*time.Second)
defer ctxCancel()
// Check if vector search is available
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
@@ -562,6 +568,10 @@ func splitCamelCase(s string) string {
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
@@ -592,7 +602,7 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
}
// Get recent observations
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
observations, err := s.observationStore.GetRecentObservations(ctx, project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -658,13 +668,17 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
// handleContextCount returns the count of observations for a project.
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
count, err := s.getCachedObservationCount(r.Context(), project)
count, err := s.getCachedObservationCount(ctx, project)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
+25 -12
View File
@@ -2,6 +2,7 @@
package worker
import (
"context"
"encoding/json"
"net/http"
"runtime"
@@ -20,6 +21,10 @@ import (
// Supports optional query parameter for semantic search via sqlite-vec.
// Supports pagination via limit and offset query parameters.
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
@@ -38,11 +43,11 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
vectorResults, vecErr := s.vectorClient.Query(ctx, query, pagination.Limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
if len(obsIDs) > 0 {
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", pagination.Limit)
if err == nil {
usedVector = true
total = int64(len(observations)) // Vector search doesn't have total, use returned count
@@ -55,10 +60,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
if !usedVector {
if project != "" {
// Strict project filtering for dashboard - only observations from this project
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, pagination.Limit, pagination.Offset)
} else {
// All projects
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, pagination.Limit, pagination.Offset)
}
}
@@ -90,6 +95,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
// handleGetSummaries returns recent summaries.
// Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
@@ -107,11 +116,11 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
if len(summaryIDs) > 0 {
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
summaries, err = s.summaryStore.GetSummariesByIDs(ctx, summaryIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
@@ -122,9 +131,9 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
summaries, err = s.summaryStore.GetRecentSummaries(ctx, project, limit)
} else {
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
summaries, err = s.summaryStore.GetAllRecentSummaries(ctx, limit)
}
}
@@ -143,6 +152,10 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// handleGetPrompts returns recent user prompts.
// Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
@@ -160,11 +173,11 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
if len(promptIDs) > 0 {
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
prompts, err = s.promptStore.GetPromptsByIDs(ctx, promptIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
@@ -175,9 +188,9 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
prompts, err = s.promptStore.GetRecentUserPromptsByProject(ctx, project, limit)
} else {
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
prompts, err = s.promptStore.GetAllRecentUserPrompts(ctx, limit)
}
}
+7 -2
View File
@@ -1246,7 +1246,12 @@ func (s *Service) setupRoutes() {
s.router.Get("/api/selfcheck", s.handleSelfCheck)
// SSE endpoint (works before DB is ready)
s.router.Get("/api/events", s.sseBroadcaster.HandleSSE)
// Wrap with middleware that disables write deadline since SSE connections are long-lived
s.router.Get("/api/events", func(w http.ResponseWriter, r *http.Request) {
rc := http.NewResponseController(w)
_ = rc.SetWriteDeadline(time.Time{}) // no deadline for SSE
s.sseBroadcaster.HandleSSE(w, r)
})
// Routes that require DB to be ready
s.router.Group(func(r chi.Router) {
@@ -1546,7 +1551,7 @@ func (s *Service) Start() error {
Handler: s.router,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // Disabled for SSE (long-lived connections)
WriteTimeout: 60 * time.Second, // Default for API routes; SSE handler extends per-request
IdleTimeout: 120 * time.Second,
}