From 29d57857ff81318e0c8e2ac3c93590c6f9403733 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 26 May 2026 12:34:36 +0100 Subject: [PATCH] fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45) Root cause: synchronous MCP request processing combined with missing context propagation to the embedding layer caused indefinite hangs when ONNX inference was slow or the database was contended. Changes: - MCP server: dispatch each request in its own goroutine with semaphore (cap 10) and WaitGroup for clean shutdown drain - Embedding: add context-aware mutex acquisition (acquireMutex) so callers can bail out instead of blocking forever on a stuck ONNX model - Vector client: propagate context through getOrComputeEmbedding and replace bare RLock() calls with context-aware acquireRLockWithContext - Worker handlers: add 15s request-scoped timeouts to all search/context handlers (handleSearchByPrompt, handleContextInject, handleFileContext, handleContextCount, handleGetObservations/Summaries/Prompts) - Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends deadline per-request via http.ResponseController Fixes #45 --- internal/embedding/model.go | 9 +++ internal/embedding/service.go | 109 ++++++++++++++++++++++++++++ internal/mcp/server.go | 30 ++++++-- internal/vector/sqlitevec/client.go | 60 ++++++++++++--- internal/worker/handlers_context.go | 30 ++++++-- internal/worker/handlers_data.go | 37 +++++++--- internal/worker/service.go | 9 ++- 7 files changed, 244 insertions(+), 40 deletions(-) diff --git a/internal/embedding/model.go b/internal/embedding/model.go index 2c1124d..c30caa2 100644 --- a/internal/embedding/model.go +++ b/internal/embedding/model.go @@ -2,6 +2,7 @@ package embedding import ( + "context" "fmt" "sync" ) @@ -44,6 +45,14 @@ type EmbeddingModel interface { // EmbedBatch generates embeddings for multiple texts. EmbedBatch(texts []string) ([][]float32, error) + // EmbedWithContext generates an embedding for a single text with context-aware cancellation. + // The context controls mutex acquisition timeout — if ctx is cancelled while waiting + // for the model lock, the call returns immediately with ctx.Err(). + EmbedWithContext(ctx context.Context, text string) ([]float32, error) + + // EmbedBatchWithContext generates embeddings for multiple texts with context-aware cancellation. + EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) + // Close releases model resources. Close() error } diff --git a/internal/embedding/service.go b/internal/embedding/service.go index 4d0078a..0fec133 100644 --- a/internal/embedding/service.go +++ b/internal/embedding/service.go @@ -3,6 +3,7 @@ package embedding import ( "bytes" + "context" "crypto/sha256" "encoding/hex" "fmt" @@ -225,6 +226,103 @@ func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) { return results, nil } +// acquireMutex attempts to acquire the model mutex, respecting context cancellation. +// On success the caller MUST call the returned unlock function. +// If ctx is cancelled while waiting, returns ctx.Err() and no unlock is needed. +func (m *bgeModel) acquireMutex(ctx context.Context) (unlock func(), err error) { + acquired := make(chan struct{}) + go func() { + m.mu.Lock() + close(acquired) + }() + + select { + case <-acquired: + // Got the lock normally. + return m.mu.Unlock, nil + case <-ctx.Done(): + // Context cancelled while waiting. The goroutine above will eventually + // acquire the mutex — we must ensure it gets unlocked. + go func() { + <-acquired + m.mu.Unlock() + }() + return nil, ctx.Err() + } +} + +// EmbedWithContext generates an embedding for a single text with context-aware mutex acquisition. +// If ctx is cancelled while waiting for the model lock, returns immediately with ctx.Err(). +func (m *bgeModel) EmbedWithContext(ctx context.Context, text string) ([]float32, error) { + if text == "" { + return make([]float32, EmbeddingDim), nil + } + + unlock, err := m.acquireMutex(ctx) + if err != nil { + return nil, fmt.Errorf("acquire embedding lock: %w", err) + } + defer unlock() + + results, err := m.computeBatch([]string{text}) + if err != nil { + return nil, err + } + if len(results) == 0 { + return make([]float32, EmbeddingDim), nil + } + return results[0], nil +} + +// EmbedBatchWithContext generates embeddings for multiple texts with context-aware mutex acquisition. +func (m *bgeModel) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + + // Filter out empty texts and track indices + nonEmpty := make([]string, 0, len(texts)) + indices := make([]int, 0, len(texts)) + for i, t := range texts { + if t != "" { + nonEmpty = append(nonEmpty, t) + indices = append(indices, i) + } + } + + // If all texts are empty, return zero vectors + if len(nonEmpty) == 0 { + results := make([][]float32, len(texts)) + for i := range results { + results[i] = make([]float32, EmbeddingDim) + } + return results, nil + } + + unlock, err := m.acquireMutex(ctx) + if err != nil { + return nil, fmt.Errorf("acquire embedding lock: %w", err) + } + defer unlock() + + // Compute embeddings for non-empty texts + embeddings, err := m.computeBatch(nonEmpty) + if err != nil { + return nil, fmt.Errorf("compute batch embeddings: %w", err) + } + + // Build result with zero vectors for empty texts + results := make([][]float32, len(texts)) + for i := range results { + results[i] = make([]float32, EmbeddingDim) + } + for i, idx := range indices { + results[idx] = embeddings[i] + } + + return results, nil +} + // computeBatch runs inference on a batch of texts. Must be called with lock held. func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) { if len(sentences) == 0 { @@ -509,6 +607,17 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) { return s.model.EmbedBatch(texts) } +// EmbedWithContext generates an embedding with context-aware cancellation. +// If ctx is cancelled while waiting for the model lock, returns immediately. +func (s *Service) EmbedWithContext(ctx context.Context, text string) ([]float32, error) { + return s.model.EmbedWithContext(ctx, text) +} + +// EmbedBatchWithContext generates embeddings with context-aware cancellation. +func (s *Service) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) { + return s.model.EmbedBatchWithContext(ctx, texts) +} + // Close releases model resources. func (s *Service) Close() error { return s.model.Close() diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 974e486..0e84aa5 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -129,13 +129,22 @@ func (s *Server) Run(ctx context.Context) error { } }() + // Semaphore limits concurrent request goroutines. + const maxConcurrent = 10 + sem := make(chan struct{}, maxConcurrent) + + var wg sync.WaitGroup + for { select { case <-ctx.Done(): + // Drain in-flight requests before returning. + wg.Wait() return ctx.Err() case line, ok := <-lines: if !ok { - // Scanner finished, check for errors + // Scanner finished — drain in-flight requests, then check for errors. + wg.Wait() err := <-scanErr if err != nil { if errors.Is(err, bufio.ErrTooLong) { @@ -154,18 +163,27 @@ func (s *Server) Run(ctx context.Context) error { var req Request if err := json.Unmarshal([]byte(line), &req); err != nil { + // Parse errors are cheap — send inline, no goroutine needed. if werr := s.sendError(nil, -32700, "Parse error", err.Error()); werr != nil { return fmt.Errorf("write error: %w", werr) } continue } - resp := s.handleRequest(ctx, &req) - if resp != nil { - if werr := s.sendResponse(resp); werr != nil { - return fmt.Errorf("write error: %w", werr) + // Dispatch request to its own goroutine. + wg.Add(1) + sem <- struct{}{} // acquire semaphore slot + go func(r Request) { + defer wg.Done() + defer func() { <-sem }() // release semaphore slot + + resp := s.handleRequest(ctx, &r) + if resp != nil { + if werr := s.sendResponse(resp); werr != nil { + log.Error().Err(werr).Msg("Failed to send response") + } } - } + }(req) } } } diff --git a/internal/vector/sqlitevec/client.go b/internal/vector/sqlitevec/client.go index 76f194a..d75fb28 100644 --- a/internal/vector/sqlitevec/client.go +++ b/internal/vector/sqlitevec/client.go @@ -271,7 +271,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s } // Generate query embedding OUTSIDE the lock for better concurrency - queryEmb, err := c.getOrComputeEmbedding(query) + queryEmb, err := c.getOrComputeEmbedding(ctx, query) if err != nil { return nil, fmt.Errorf("embed query: %w", err) } @@ -282,8 +282,10 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s return nil, fmt.Errorf("serialize query embedding: %w", err) } - // Now acquire read lock for the actual DB query - c.readMu.RLock() + // Acquire read lock with context awareness to prevent indefinite blocking + if err := acquireRLockWithContext(ctx, &c.readMu); err != nil { + return nil, fmt.Errorf("acquire read lock: %w", err) + } defer c.readMu.RUnlock() // Build query with filters @@ -485,7 +487,7 @@ func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, wh // Combines results from different field types and deduplicates by document ID. func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) { // Generate embedding once - queryEmb, err := c.getOrComputeEmbedding(query) + queryEmb, err := c.getOrComputeEmbedding(ctx, query) if err != nil { return nil, fmt.Errorf("embed query: %w", err) } @@ -496,7 +498,9 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d return nil, fmt.Errorf("serialize query embedding: %w", err) } - c.readMu.RLock() + if err := acquireRLockWithContext(ctx, &c.readMu); err != nil { + return nil, fmt.Errorf("acquire read lock: %w", err) + } defer c.readMu.RUnlock() // Query with field type aggregation - get best match per document @@ -555,6 +559,28 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d return results, rows.Err() } +// acquireRLockWithContext acquires a read lock on mu, respecting ctx cancellation. +// If ctx is cancelled while waiting for the lock, the goroutine that eventually +// acquires it will release it immediately to prevent leaks. +func acquireRLockWithContext(ctx context.Context, mu *sync.RWMutex) error { + acquired := make(chan struct{}) + go func() { + mu.RLock() + close(acquired) + }() + select { + case <-acquired: + return nil + case <-ctx.Done(): + // Goroutine may still acquire lock after ctx cancelled — must unlock + go func() { + <-acquired + mu.RUnlock() + }() + return ctx.Err() + } +} + // truncateString truncates a string to maxLen characters. func truncateString(s string, maxLen int) string { if len(s) <= maxLen { @@ -565,7 +591,9 @@ func truncateString(s string, maxLen int) string { // Count returns the total number of vectors in the store. func (c *Client) Count(ctx context.Context) (int64, error) { - c.readMu.RLock() + if err := acquireRLockWithContext(ctx, &c.readMu); err != nil { + return 0, fmt.Errorf("acquire read lock: %w", err) + } defer c.readMu.RUnlock() var count int64 @@ -586,7 +614,9 @@ func (c *Client) ModelVersion() string { // - The vectors table is empty // - Any vectors have a different model_version than the current model func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) { - c.readMu.RLock() + if err := acquireRLockWithContext(ctx, &c.readMu); err != nil { + return false, "" + } defer c.readMu.RUnlock() currentModel := c.embedSvc.Version() @@ -634,7 +664,9 @@ type StaleVectorInfo struct { // GetStaleVectors returns doc_ids of vectors with mismatched or null model versions. // This enables granular rebuild - only re-embedding documents that need updating. func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) { - c.readMu.RLock() + if err := acquireRLockWithContext(ctx, &c.readMu); err != nil { + return nil, fmt.Errorf("acquire read lock: %w", err) + } defer c.readMu.RUnlock() currentModel := c.embedSvc.Version() @@ -692,7 +724,9 @@ type VectorHealthStats struct { // GetHealthStats returns comprehensive health statistics about the vector store. func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) { - c.readMu.RLock() + if err := acquireRLockWithContext(ctx, &c.readMu); err != nil { + return nil, fmt.Errorf("acquire read lock: %w", err) + } defer c.readMu.RUnlock() stats := &VectorHealthStats{ @@ -856,7 +890,9 @@ func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error { // getOrComputeEmbedding returns a cached embedding or computes a new one. // Uses singleflight to prevent duplicate concurrent computations for the same query. -func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) { +// The context controls timeout on the embedding mutex acquisition -- if the ONNX model +// hangs under CGO, callers can bail out instead of blocking forever. +func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]float32, error) { now := time.Now().UnixNano() // Check cache first (read lock) @@ -885,8 +921,8 @@ func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) { // Record cache miss c.stats.embeddingMisses.Add(1) - // Compute embedding - emb, err := c.embedSvc.Embed(query) + // Compute embedding with context-aware lock acquisition + emb, err := c.embedSvc.EmbedWithContext(ctx, query) if err != nil { return nil, err } diff --git a/internal/worker/handlers_context.go b/internal/worker/handlers_context.go index c69e905..9b02ba4 100644 --- a/internal/worker/handlers_context.go +++ b/internal/worker/handlers_context.go @@ -23,6 +23,10 @@ import ( // IMPORTANT: This is on the critical startup path - must be fast! // No synchronous verification - just filter by staleness and return. func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { + // Add request-scoped timeout to prevent indefinite blocking on DB operations + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + project := r.URL.Query().Get("project") query := r.URL.Query().Get("query") cwd := r.URL.Query().Get("cwd") @@ -54,7 +58,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { var expandedQueries []expansion.ExpandedQuery var detectedIntent string if s.queryExpander != nil { - expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second) + expandCtx, expandCancel := context.WithTimeout(ctx, 5*time.Second) cfg := expansion.DefaultConfig() cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg) @@ -83,7 +87,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { var vectorErrors int for _, eq := range expandedQueries { - vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where) + vectorResults, vecErr := s.vectorClient.Query(ctx, eq.Query, limit*2, where) if vecErr != nil { vectorErrors++ log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed") @@ -125,7 +129,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { if len(obsIDs) > 0 { // Fetch full observations from SQLite - observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit) + observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", limit) if err == nil { usedVector = true } @@ -138,11 +142,11 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { if vectorSearchFailed { log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure") } - observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit) + observations, err = s.observationStore.SearchObservationsFTS(ctx, query, project, limit) if err != nil { // FTS might fail if query has special chars, try without log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent") - observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit) + observations, err = s.observationStore.GetRecentObservations(ctx, project, limit) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -354,7 +358,9 @@ func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) { } // Search for observations related to each file in parallel - ctx := r.Context() + // Add request-scoped timeout to prevent indefinite blocking on DB operations + ctx, ctxCancel := context.WithTimeout(r.Context(), 15*time.Second) + defer ctxCancel() // Check if vector search is available if s.vectorClient == nil || !s.vectorClient.IsConnected() { @@ -562,6 +568,10 @@ func splitCamelCase(s string) string { // IMPORTANT: This is on the critical startup path - must be fast! // No synchronous verification - just filter by staleness and return. func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) { + // Add request-scoped timeout to prevent indefinite blocking on DB operations + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + project := r.URL.Query().Get("project") if project == "" { http.Error(w, "project required", http.StatusBadRequest) @@ -592,7 +602,7 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) { } // Get recent observations - observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit) + observations, err := s.observationStore.GetRecentObservations(ctx, project, limit) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -658,13 +668,17 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) { // handleContextCount returns the count of observations for a project. func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) { + // Add request-scoped timeout to prevent indefinite blocking on DB operations + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + project := r.URL.Query().Get("project") if project == "" { http.Error(w, "project required", http.StatusBadRequest) return } - count, err := s.getCachedObservationCount(r.Context(), project) + count, err := s.getCachedObservationCount(ctx, project) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/internal/worker/handlers_data.go b/internal/worker/handlers_data.go index 9a1950a..5cd8189 100644 --- a/internal/worker/handlers_data.go +++ b/internal/worker/handlers_data.go @@ -2,6 +2,7 @@ package worker import ( + "context" "encoding/json" "net/http" "runtime" @@ -20,6 +21,10 @@ import ( // Supports optional query parameter for semantic search via sqlite-vec. // Supports pagination via limit and offset query parameters. func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) { + // Add request-scoped timeout to prevent indefinite blocking on DB operations + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit) project := r.URL.Query().Get("project") query := r.URL.Query().Get("query") @@ -38,11 +43,11 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) // Use vector search if query is provided and vector client is available if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") - vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where) + vectorResults, vecErr := s.vectorClient.Query(ctx, query, pagination.Limit*2, where) if vecErr == nil && len(vectorResults) > 0 { obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project) if len(obsIDs) > 0 { - observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit) + observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", pagination.Limit) if err == nil { usedVector = true total = int64(len(observations)) // Vector search doesn't have total, use returned count @@ -55,10 +60,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) if !usedVector { if project != "" { // Strict project filtering for dashboard - only observations from this project - observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset) + observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, pagination.Limit, pagination.Offset) } else { // All projects - observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset) + observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, pagination.Limit, pagination.Offset) } } @@ -90,6 +95,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) // handleGetSummaries returns recent summaries. // Supports optional query parameter for semantic search via sqlite-vec. func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) { + // Add request-scoped timeout to prevent indefinite blocking on DB operations + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + limit := gorm.ParseLimitParam(r, DefaultSummariesLimit) project := r.URL.Query().Get("project") query := r.URL.Query().Get("query") @@ -107,11 +116,11 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) { // Use vector search if query is provided and vector client is available if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "") - vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) + vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where) if vecErr == nil && len(vectorResults) > 0 { summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project) if len(summaryIDs) > 0 { - summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit) + summaries, err = s.summaryStore.GetSummariesByIDs(ctx, summaryIDs, "date_desc", limit) if err == nil { usedVector = true } @@ -122,9 +131,9 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) { // Fall back to SQLite if vector search not used if !usedVector { if project != "" { - summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit) + summaries, err = s.summaryStore.GetRecentSummaries(ctx, project, limit) } else { - summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit) + summaries, err = s.summaryStore.GetAllRecentSummaries(ctx, limit) } } @@ -143,6 +152,10 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) { // handleGetPrompts returns recent user prompts. // Supports optional query parameter for semantic search via sqlite-vec. func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) { + // Add request-scoped timeout to prevent indefinite blocking on DB operations + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second) + defer cancel() + limit := gorm.ParseLimitParam(r, DefaultPromptsLimit) project := r.URL.Query().Get("project") query := r.URL.Query().Get("query") @@ -160,11 +173,11 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) { // Use vector search if query is provided and vector client is available if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "") - vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) + vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where) if vecErr == nil && len(vectorResults) > 0 { promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project) if len(promptIDs) > 0 { - prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit) + prompts, err = s.promptStore.GetPromptsByIDs(ctx, promptIDs, "date_desc", limit) if err == nil { usedVector = true } @@ -175,9 +188,9 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) { // Fall back to SQLite if vector search not used if !usedVector { if project != "" { - prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit) + prompts, err = s.promptStore.GetRecentUserPromptsByProject(ctx, project, limit) } else { - prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit) + prompts, err = s.promptStore.GetAllRecentUserPrompts(ctx, limit) } } diff --git a/internal/worker/service.go b/internal/worker/service.go index 98e28ea..24f6f94 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -1246,7 +1246,12 @@ func (s *Service) setupRoutes() { s.router.Get("/api/selfcheck", s.handleSelfCheck) // SSE endpoint (works before DB is ready) - s.router.Get("/api/events", s.sseBroadcaster.HandleSSE) + // Wrap with middleware that disables write deadline since SSE connections are long-lived + s.router.Get("/api/events", func(w http.ResponseWriter, r *http.Request) { + rc := http.NewResponseController(w) + _ = rc.SetWriteDeadline(time.Time{}) // no deadline for SSE + s.sseBroadcaster.HandleSSE(w, r) + }) // Routes that require DB to be ready s.router.Group(func(r chi.Router) { @@ -1546,7 +1551,7 @@ func (s *Service) Start() error { Handler: s.router, ReadHeaderTimeout: 10 * time.Second, ReadTimeout: 30 * time.Second, - WriteTimeout: 0, // Disabled for SSE (long-lived connections) + WriteTimeout: 60 * time.Second, // Default for API routes; SSE handler extends per-request IdleTimeout: 120 * time.Second, }