fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)

Root cause: synchronous MCP request processing combined with missing
context propagation to the embedding layer caused indefinite hangs when
ONNX inference was slow or the database was contended.

Changes:
- MCP server: dispatch each request in its own goroutine with semaphore
  (cap 10) and WaitGroup for clean shutdown drain
- Embedding: add context-aware mutex acquisition (acquireMutex) so
  callers can bail out instead of blocking forever on a stuck ONNX model
- Vector client: propagate context through getOrComputeEmbedding and
  replace bare RLock() calls with context-aware acquireRLockWithContext
- Worker handlers: add 15s request-scoped timeouts to all search/context
  handlers (handleSearchByPrompt, handleContextInject, handleFileContext,
  handleContextCount, handleGetObservations/Summaries/Prompts)
- Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends
  deadline per-request via http.ResponseController

Fixes #45
This commit is contained in:
2026-05-26 12:34:36 +01:00
parent 56616d0616
commit 29d57857ff
7 changed files with 244 additions and 40 deletions
+9
View File
@@ -2,6 +2,7 @@
package embedding package embedding
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
) )
@@ -44,6 +45,14 @@ type EmbeddingModel interface {
// EmbedBatch generates embeddings for multiple texts. // EmbedBatch generates embeddings for multiple texts.
EmbedBatch(texts []string) ([][]float32, error) EmbedBatch(texts []string) ([][]float32, error)
// EmbedWithContext generates an embedding for a single text with context-aware cancellation.
// The context controls mutex acquisition timeout — if ctx is cancelled while waiting
// for the model lock, the call returns immediately with ctx.Err().
EmbedWithContext(ctx context.Context, text string) ([]float32, error)
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware cancellation.
EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error)
// Close releases model resources. // Close releases model resources.
Close() error Close() error
} }
+109
View File
@@ -3,6 +3,7 @@ package embedding
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
@@ -225,6 +226,103 @@ func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
return results, nil return results, nil
} }
// acquireMutex attempts to acquire the model mutex, respecting context cancellation.
// On success the caller MUST call the returned unlock function.
// If ctx is cancelled while waiting, returns ctx.Err() and no unlock is needed.
func (m *bgeModel) acquireMutex(ctx context.Context) (unlock func(), err error) {
acquired := make(chan struct{})
go func() {
m.mu.Lock()
close(acquired)
}()
select {
case <-acquired:
// Got the lock normally.
return m.mu.Unlock, nil
case <-ctx.Done():
// Context cancelled while waiting. The goroutine above will eventually
// acquire the mutex — we must ensure it gets unlocked.
go func() {
<-acquired
m.mu.Unlock()
}()
return nil, ctx.Err()
}
}
// EmbedWithContext generates an embedding for a single text with context-aware mutex acquisition.
// If ctx is cancelled while waiting for the model lock, returns immediately with ctx.Err().
func (m *bgeModel) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
if text == "" {
return make([]float32, EmbeddingDim), nil
}
unlock, err := m.acquireMutex(ctx)
if err != nil {
return nil, fmt.Errorf("acquire embedding lock: %w", err)
}
defer unlock()
results, err := m.computeBatch([]string{text})
if err != nil {
return nil, err
}
if len(results) == 0 {
return make([]float32, EmbeddingDim), nil
}
return results[0], nil
}
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware mutex acquisition.
func (m *bgeModel) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// Filter out empty texts and track indices
nonEmpty := make([]string, 0, len(texts))
indices := make([]int, 0, len(texts))
for i, t := range texts {
if t != "" {
nonEmpty = append(nonEmpty, t)
indices = append(indices, i)
}
}
// If all texts are empty, return zero vectors
if len(nonEmpty) == 0 {
results := make([][]float32, len(texts))
for i := range results {
results[i] = make([]float32, EmbeddingDim)
}
return results, nil
}
unlock, err := m.acquireMutex(ctx)
if err != nil {
return nil, fmt.Errorf("acquire embedding lock: %w", err)
}
defer unlock()
// Compute embeddings for non-empty texts
embeddings, err := m.computeBatch(nonEmpty)
if err != nil {
return nil, fmt.Errorf("compute batch embeddings: %w", err)
}
// Build result with zero vectors for empty texts
results := make([][]float32, len(texts))
for i := range results {
results[i] = make([]float32, EmbeddingDim)
}
for i, idx := range indices {
results[idx] = embeddings[i]
}
return results, nil
}
// computeBatch runs inference on a batch of texts. Must be called with lock held. // computeBatch runs inference on a batch of texts. Must be called with lock held.
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) { func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
if len(sentences) == 0 { if len(sentences) == 0 {
@@ -509,6 +607,17 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
return s.model.EmbedBatch(texts) return s.model.EmbedBatch(texts)
} }
// EmbedWithContext generates an embedding with context-aware cancellation.
// If ctx is cancelled while waiting for the model lock, returns immediately.
func (s *Service) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
return s.model.EmbedWithContext(ctx, text)
}
// EmbedBatchWithContext generates embeddings with context-aware cancellation.
func (s *Service) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
return s.model.EmbedBatchWithContext(ctx, texts)
}
// Close releases model resources. // Close releases model resources.
func (s *Service) Close() error { func (s *Service) Close() error {
return s.model.Close() return s.model.Close()
+24 -6
View File
@@ -129,13 +129,22 @@ func (s *Server) Run(ctx context.Context) error {
} }
}() }()
// Semaphore limits concurrent request goroutines.
const maxConcurrent = 10
sem := make(chan struct{}, maxConcurrent)
var wg sync.WaitGroup
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// Drain in-flight requests before returning.
wg.Wait()
return ctx.Err() return ctx.Err()
case line, ok := <-lines: case line, ok := <-lines:
if !ok { if !ok {
// Scanner finished, check for errors // Scanner finished — drain in-flight requests, then check for errors.
wg.Wait()
err := <-scanErr err := <-scanErr
if err != nil { if err != nil {
if errors.Is(err, bufio.ErrTooLong) { if errors.Is(err, bufio.ErrTooLong) {
@@ -154,18 +163,27 @@ func (s *Server) Run(ctx context.Context) error {
var req Request var req Request
if err := json.Unmarshal([]byte(line), &req); err != nil { if err := json.Unmarshal([]byte(line), &req); err != nil {
// Parse errors are cheap — send inline, no goroutine needed.
if werr := s.sendError(nil, -32700, "Parse error", err.Error()); werr != nil { if werr := s.sendError(nil, -32700, "Parse error", err.Error()); werr != nil {
return fmt.Errorf("write error: %w", werr) return fmt.Errorf("write error: %w", werr)
} }
continue continue
} }
resp := s.handleRequest(ctx, &req) // Dispatch request to its own goroutine.
if resp != nil { wg.Add(1)
if werr := s.sendResponse(resp); werr != nil { sem <- struct{}{} // acquire semaphore slot
return fmt.Errorf("write error: %w", werr) go func(r Request) {
defer wg.Done()
defer func() { <-sem }() // release semaphore slot
resp := s.handleRequest(ctx, &r)
if resp != nil {
if werr := s.sendResponse(resp); werr != nil {
log.Error().Err(werr).Msg("Failed to send response")
}
} }
} }(req)
} }
} }
} }
+48 -12
View File
@@ -271,7 +271,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
} }
// Generate query embedding OUTSIDE the lock for better concurrency // Generate query embedding OUTSIDE the lock for better concurrency
queryEmb, err := c.getOrComputeEmbedding(query) queryEmb, err := c.getOrComputeEmbedding(ctx, query)
if err != nil { if err != nil {
return nil, fmt.Errorf("embed query: %w", err) return nil, fmt.Errorf("embed query: %w", err)
} }
@@ -282,8 +282,10 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
return nil, fmt.Errorf("serialize query embedding: %w", err) return nil, fmt.Errorf("serialize query embedding: %w", err)
} }
// Now acquire read lock for the actual DB query // Acquire read lock with context awareness to prevent indefinite blocking
c.readMu.RLock() if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock() defer c.readMu.RUnlock()
// Build query with filters // Build query with filters
@@ -485,7 +487,7 @@ func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, wh
// Combines results from different field types and deduplicates by document ID. // Combines results from different field types and deduplicates by document ID.
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) { func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
// Generate embedding once // Generate embedding once
queryEmb, err := c.getOrComputeEmbedding(query) queryEmb, err := c.getOrComputeEmbedding(ctx, query)
if err != nil { if err != nil {
return nil, fmt.Errorf("embed query: %w", err) return nil, fmt.Errorf("embed query: %w", err)
} }
@@ -496,7 +498,9 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
return nil, fmt.Errorf("serialize query embedding: %w", err) return nil, fmt.Errorf("serialize query embedding: %w", err)
} }
c.readMu.RLock() if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock() defer c.readMu.RUnlock()
// Query with field type aggregation - get best match per document // Query with field type aggregation - get best match per document
@@ -555,6 +559,28 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
return results, rows.Err() return results, rows.Err()
} }
// acquireRLockWithContext acquires a read lock on mu, respecting ctx cancellation.
// If ctx is cancelled while waiting for the lock, the goroutine that eventually
// acquires it will release it immediately to prevent leaks.
func acquireRLockWithContext(ctx context.Context, mu *sync.RWMutex) error {
acquired := make(chan struct{})
go func() {
mu.RLock()
close(acquired)
}()
select {
case <-acquired:
return nil
case <-ctx.Done():
// Goroutine may still acquire lock after ctx cancelled — must unlock
go func() {
<-acquired
mu.RUnlock()
}()
return ctx.Err()
}
}
// truncateString truncates a string to maxLen characters. // truncateString truncates a string to maxLen characters.
func truncateString(s string, maxLen int) string { func truncateString(s string, maxLen int) string {
if len(s) <= maxLen { if len(s) <= maxLen {
@@ -565,7 +591,9 @@ func truncateString(s string, maxLen int) string {
// Count returns the total number of vectors in the store. // Count returns the total number of vectors in the store.
func (c *Client) Count(ctx context.Context) (int64, error) { func (c *Client) Count(ctx context.Context) (int64, error) {
c.readMu.RLock() if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return 0, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock() defer c.readMu.RUnlock()
var count int64 var count int64
@@ -586,7 +614,9 @@ func (c *Client) ModelVersion() string {
// - The vectors table is empty // - The vectors table is empty
// - Any vectors have a different model_version than the current model // - Any vectors have a different model_version than the current model
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) { func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
c.readMu.RLock() if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return false, ""
}
defer c.readMu.RUnlock() defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version() currentModel := c.embedSvc.Version()
@@ -634,7 +664,9 @@ type StaleVectorInfo struct {
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions. // GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
// This enables granular rebuild - only re-embedding documents that need updating. // This enables granular rebuild - only re-embedding documents that need updating.
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) { func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
c.readMu.RLock() if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock() defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version() currentModel := c.embedSvc.Version()
@@ -692,7 +724,9 @@ type VectorHealthStats struct {
// GetHealthStats returns comprehensive health statistics about the vector store. // GetHealthStats returns comprehensive health statistics about the vector store.
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) { func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
c.readMu.RLock() if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
return nil, fmt.Errorf("acquire read lock: %w", err)
}
defer c.readMu.RUnlock() defer c.readMu.RUnlock()
stats := &VectorHealthStats{ stats := &VectorHealthStats{
@@ -856,7 +890,9 @@ func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
// getOrComputeEmbedding returns a cached embedding or computes a new one. // getOrComputeEmbedding returns a cached embedding or computes a new one.
// Uses singleflight to prevent duplicate concurrent computations for the same query. // Uses singleflight to prevent duplicate concurrent computations for the same query.
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) { // The context controls timeout on the embedding mutex acquisition -- if the ONNX model
// hangs under CGO, callers can bail out instead of blocking forever.
func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]float32, error) {
now := time.Now().UnixNano() now := time.Now().UnixNano()
// Check cache first (read lock) // Check cache first (read lock)
@@ -885,8 +921,8 @@ func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
// Record cache miss // Record cache miss
c.stats.embeddingMisses.Add(1) c.stats.embeddingMisses.Add(1)
// Compute embedding // Compute embedding with context-aware lock acquisition
emb, err := c.embedSvc.Embed(query) emb, err := c.embedSvc.EmbedWithContext(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+22 -8
View File
@@ -23,6 +23,10 @@ import (
// IMPORTANT: This is on the critical startup path - must be fast! // IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return. // No synchronous verification - just filter by staleness and return.
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project") project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query") query := r.URL.Query().Get("query")
cwd := r.URL.Query().Get("cwd") cwd := r.URL.Query().Get("cwd")
@@ -54,7 +58,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
var expandedQueries []expansion.ExpandedQuery var expandedQueries []expansion.ExpandedQuery
var detectedIntent string var detectedIntent string
if s.queryExpander != nil { if s.queryExpander != nil {
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second) expandCtx, expandCancel := context.WithTimeout(ctx, 5*time.Second)
cfg := expansion.DefaultConfig() cfg := expansion.DefaultConfig()
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg) expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
@@ -83,7 +87,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
var vectorErrors int var vectorErrors int
for _, eq := range expandedQueries { for _, eq := range expandedQueries {
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where) vectorResults, vecErr := s.vectorClient.Query(ctx, eq.Query, limit*2, where)
if vecErr != nil { if vecErr != nil {
vectorErrors++ vectorErrors++
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed") log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
@@ -125,7 +129,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
if len(obsIDs) > 0 { if len(obsIDs) > 0 {
// Fetch full observations from SQLite // Fetch full observations from SQLite
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit) observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", limit)
if err == nil { if err == nil {
usedVector = true usedVector = true
} }
@@ -138,11 +142,11 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
if vectorSearchFailed { if vectorSearchFailed {
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure") log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
} }
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit) observations, err = s.observationStore.SearchObservationsFTS(ctx, query, project, limit)
if err != nil { if err != nil {
// FTS might fail if query has special chars, try without // FTS might fail if query has special chars, try without
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent") log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit) observations, err = s.observationStore.GetRecentObservations(ctx, project, limit)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@@ -354,7 +358,9 @@ func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
} }
// Search for observations related to each file in parallel // Search for observations related to each file in parallel
ctx := r.Context() // Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, ctxCancel := context.WithTimeout(r.Context(), 15*time.Second)
defer ctxCancel()
// Check if vector search is available // Check if vector search is available
if s.vectorClient == nil || !s.vectorClient.IsConnected() { if s.vectorClient == nil || !s.vectorClient.IsConnected() {
@@ -562,6 +568,10 @@ func splitCamelCase(s string) string {
// IMPORTANT: This is on the critical startup path - must be fast! // IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return. // No synchronous verification - just filter by staleness and return.
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) { func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project") project := r.URL.Query().Get("project")
if project == "" { if project == "" {
http.Error(w, "project required", http.StatusBadRequest) http.Error(w, "project required", http.StatusBadRequest)
@@ -592,7 +602,7 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
} }
// Get recent observations // Get recent observations
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit) observations, err := s.observationStore.GetRecentObservations(ctx, project, limit)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@@ -658,13 +668,17 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
// handleContextCount returns the count of observations for a project. // handleContextCount returns the count of observations for a project.
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) { func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project") project := r.URL.Query().Get("project")
if project == "" { if project == "" {
http.Error(w, "project required", http.StatusBadRequest) http.Error(w, "project required", http.StatusBadRequest)
return return
} }
count, err := s.getCachedObservationCount(r.Context(), project) count, err := s.getCachedObservationCount(ctx, project)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
+25 -12
View File
@@ -2,6 +2,7 @@
package worker package worker
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"runtime" "runtime"
@@ -20,6 +21,10 @@ import (
// Supports optional query parameter for semantic search via sqlite-vec. // Supports optional query parameter for semantic search via sqlite-vec.
// Supports pagination via limit and offset query parameters. // Supports pagination via limit and offset query parameters.
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) { func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit) pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
project := r.URL.Query().Get("project") project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query") query := r.URL.Query().Get("query")
@@ -38,11 +43,11 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
// Use vector search if query is provided and vector client is available // Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where) vectorResults, vecErr := s.vectorClient.Query(ctx, query, pagination.Limit*2, where)
if vecErr == nil && len(vectorResults) > 0 { if vecErr == nil && len(vectorResults) > 0 {
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project) obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
if len(obsIDs) > 0 { if len(obsIDs) > 0 {
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit) observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", pagination.Limit)
if err == nil { if err == nil {
usedVector = true usedVector = true
total = int64(len(observations)) // Vector search doesn't have total, use returned count total = int64(len(observations)) // Vector search doesn't have total, use returned count
@@ -55,10 +60,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
if !usedVector { if !usedVector {
if project != "" { if project != "" {
// Strict project filtering for dashboard - only observations from this project // Strict project filtering for dashboard - only observations from this project
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset) observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, pagination.Limit, pagination.Offset)
} else { } else {
// All projects // All projects
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset) observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, pagination.Limit, pagination.Offset)
} }
} }
@@ -90,6 +95,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
// handleGetSummaries returns recent summaries. // handleGetSummaries returns recent summaries.
// Supports optional query parameter for semantic search via sqlite-vec. // Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) { func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit) limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
project := r.URL.Query().Get("project") project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query") query := r.URL.Query().Get("query")
@@ -107,11 +116,11 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// Use vector search if query is provided and vector client is available // Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "") where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 { if vecErr == nil && len(vectorResults) > 0 {
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project) summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
if len(summaryIDs) > 0 { if len(summaryIDs) > 0 {
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit) summaries, err = s.summaryStore.GetSummariesByIDs(ctx, summaryIDs, "date_desc", limit)
if err == nil { if err == nil {
usedVector = true usedVector = true
} }
@@ -122,9 +131,9 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// Fall back to SQLite if vector search not used // Fall back to SQLite if vector search not used
if !usedVector { if !usedVector {
if project != "" { if project != "" {
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit) summaries, err = s.summaryStore.GetRecentSummaries(ctx, project, limit)
} else { } else {
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit) summaries, err = s.summaryStore.GetAllRecentSummaries(ctx, limit)
} }
} }
@@ -143,6 +152,10 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
// handleGetPrompts returns recent user prompts. // handleGetPrompts returns recent user prompts.
// Supports optional query parameter for semantic search via sqlite-vec. // Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) { func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit) limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
project := r.URL.Query().Get("project") project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query") query := r.URL.Query().Get("query")
@@ -160,11 +173,11 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
// Use vector search if query is provided and vector client is available // Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "") where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 { if vecErr == nil && len(vectorResults) > 0 {
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project) promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
if len(promptIDs) > 0 { if len(promptIDs) > 0 {
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit) prompts, err = s.promptStore.GetPromptsByIDs(ctx, promptIDs, "date_desc", limit)
if err == nil { if err == nil {
usedVector = true usedVector = true
} }
@@ -175,9 +188,9 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
// Fall back to SQLite if vector search not used // Fall back to SQLite if vector search not used
if !usedVector { if !usedVector {
if project != "" { if project != "" {
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit) prompts, err = s.promptStore.GetRecentUserPromptsByProject(ctx, project, limit)
} else { } else {
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit) prompts, err = s.promptStore.GetAllRecentUserPrompts(ctx, limit)
} }
} }
+7 -2
View File
@@ -1246,7 +1246,12 @@ func (s *Service) setupRoutes() {
s.router.Get("/api/selfcheck", s.handleSelfCheck) s.router.Get("/api/selfcheck", s.handleSelfCheck)
// SSE endpoint (works before DB is ready) // SSE endpoint (works before DB is ready)
s.router.Get("/api/events", s.sseBroadcaster.HandleSSE) // Wrap with middleware that disables write deadline since SSE connections are long-lived
s.router.Get("/api/events", func(w http.ResponseWriter, r *http.Request) {
rc := http.NewResponseController(w)
_ = rc.SetWriteDeadline(time.Time{}) // no deadline for SSE
s.sseBroadcaster.HandleSSE(w, r)
})
// Routes that require DB to be ready // Routes that require DB to be ready
s.router.Group(func(r chi.Router) { s.router.Group(func(r chi.Router) {
@@ -1546,7 +1551,7 @@ func (s *Service) Start() error {
Handler: s.router, Handler: s.router,
ReadHeaderTimeout: 10 * time.Second, ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,
WriteTimeout: 0, // Disabled for SSE (long-lived connections) WriteTimeout: 60 * time.Second, // Default for API routes; SSE handler extends per-request
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
} }