mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)
Root cause: synchronous MCP request processing combined with missing context propagation to the embedding layer caused indefinite hangs when ONNX inference was slow or the database was contended. Changes: - MCP server: dispatch each request in its own goroutine with semaphore (cap 10) and WaitGroup for clean shutdown drain - Embedding: add context-aware mutex acquisition (acquireMutex) so callers can bail out instead of blocking forever on a stuck ONNX model - Vector client: propagate context through getOrComputeEmbedding and replace bare RLock() calls with context-aware acquireRLockWithContext - Worker handlers: add 15s request-scoped timeouts to all search/context handlers (handleSearchByPrompt, handleContextInject, handleFileContext, handleContextCount, handleGetObservations/Summaries/Prompts) - Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends deadline per-request via http.ResponseController Fixes #45
This commit is contained in:
@@ -23,6 +23,10 @@ import (
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
@@ -54,7 +58,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
var expandedQueries []expansion.ExpandedQuery
|
||||
var detectedIntent string
|
||||
if s.queryExpander != nil {
|
||||
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
expandCtx, expandCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
cfg := expansion.DefaultConfig()
|
||||
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
||||
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
|
||||
@@ -83,7 +87,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
var vectorErrors int
|
||||
|
||||
for _, eq := range expandedQueries {
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, eq.Query, limit*2, where)
|
||||
if vecErr != nil {
|
||||
vectorErrors++
|
||||
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
|
||||
@@ -125,7 +129,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if len(obsIDs) > 0 {
|
||||
// Fetch full observations from SQLite
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
|
||||
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
@@ -138,11 +142,11 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
if vectorSearchFailed {
|
||||
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
|
||||
}
|
||||
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
|
||||
observations, err = s.observationStore.SearchObservationsFTS(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
// FTS might fail if query has special chars, try without
|
||||
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
|
||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
observations, err = s.observationStore.GetRecentObservations(ctx, project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -354,7 +358,9 @@ func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Search for observations related to each file in parallel
|
||||
ctx := r.Context()
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, ctxCancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer ctxCancel()
|
||||
|
||||
// Check if vector search is available
|
||||
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
|
||||
@@ -562,6 +568,10 @@ func splitCamelCase(s string) string {
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
@@ -592,7 +602,7 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Get recent observations
|
||||
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
observations, err := s.observationStore.GetRecentObservations(ctx, project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -658,13 +668,17 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleContextCount returns the count of observations for a project.
|
||||
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := s.getCachedObservationCount(r.Context(), project)
|
||||
count, err := s.getCachedObservationCount(ctx, project)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"runtime"
|
||||
@@ -20,6 +21,10 @@ import (
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
// Supports pagination via limit and offset query parameters.
|
||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
@@ -38,11 +43,11 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, pagination.Limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
if len(obsIDs) > 0 {
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
|
||||
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", pagination.Limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
total = int64(len(observations)) // Vector search doesn't have total, use returned count
|
||||
@@ -55,10 +60,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
// Strict project filtering for dashboard - only observations from this project
|
||||
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
|
||||
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, pagination.Limit, pagination.Offset)
|
||||
} else {
|
||||
// All projects
|
||||
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
|
||||
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, pagination.Limit, pagination.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +95,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
// handleGetSummaries returns recent summaries.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
@@ -107,11 +116,11 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
|
||||
if len(summaryIDs) > 0 {
|
||||
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
|
||||
summaries, err = s.summaryStore.GetSummariesByIDs(ctx, summaryIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
@@ -122,9 +131,9 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(ctx, project, limit)
|
||||
} else {
|
||||
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
|
||||
summaries, err = s.summaryStore.GetAllRecentSummaries(ctx, limit)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +152,10 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// handleGetPrompts returns recent user prompts.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
@@ -160,11 +173,11 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
|
||||
if len(promptIDs) > 0 {
|
||||
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
|
||||
prompts, err = s.promptStore.GetPromptsByIDs(ctx, promptIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
@@ -175,9 +188,9 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(ctx, project, limit)
|
||||
} else {
|
||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
|
||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(ctx, limit)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1246,7 +1246,12 @@ func (s *Service) setupRoutes() {
|
||||
s.router.Get("/api/selfcheck", s.handleSelfCheck)
|
||||
|
||||
// SSE endpoint (works before DB is ready)
|
||||
s.router.Get("/api/events", s.sseBroadcaster.HandleSSE)
|
||||
// Wrap with middleware that disables write deadline since SSE connections are long-lived
|
||||
s.router.Get("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||
rc := http.NewResponseController(w)
|
||||
_ = rc.SetWriteDeadline(time.Time{}) // no deadline for SSE
|
||||
s.sseBroadcaster.HandleSSE(w, r)
|
||||
})
|
||||
|
||||
// Routes that require DB to be ready
|
||||
s.router.Group(func(r chi.Router) {
|
||||
@@ -1546,7 +1551,7 @@ func (s *Service) Start() error {
|
||||
Handler: s.router,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 0, // Disabled for SSE (long-lived connections)
|
||||
WriteTimeout: 60 * time.Second, // Default for API routes; SSE handler extends per-request
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user