mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)
Root cause: synchronous MCP request processing combined with missing context propagation to the embedding layer caused indefinite hangs when ONNX inference was slow or the database was contended. Changes: - MCP server: dispatch each request in its own goroutine with semaphore (cap 10) and WaitGroup for clean shutdown drain - Embedding: add context-aware mutex acquisition (acquireMutex) so callers can bail out instead of blocking forever on a stuck ONNX model - Vector client: propagate context through getOrComputeEmbedding and replace bare RLock() calls with context-aware acquireRLockWithContext - Worker handlers: add 15s request-scoped timeouts to all search/context handlers (handleSearchByPrompt, handleContextInject, handleFileContext, handleContextCount, handleGetObservations/Summaries/Prompts) - Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends deadline per-request via http.ResponseController Fixes #45
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
@@ -44,6 +45,14 @@ type EmbeddingModel interface {
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
EmbedBatch(texts []string) ([][]float32, error)
|
||||
|
||||
// EmbedWithContext generates an embedding for a single text with context-aware cancellation.
|
||||
// The context controls mutex acquisition timeout — if ctx is cancelled while waiting
|
||||
// for the model lock, the call returns immediately with ctx.Err().
|
||||
EmbedWithContext(ctx context.Context, text string) ([]float32, error)
|
||||
|
||||
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware cancellation.
|
||||
EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error)
|
||||
|
||||
// Close releases model resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package embedding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
@@ -225,6 +226,103 @@ func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// acquireMutex attempts to acquire the model mutex, respecting context cancellation.
|
||||
// On success the caller MUST call the returned unlock function.
|
||||
// If ctx is cancelled while waiting, returns ctx.Err() and no unlock is needed.
|
||||
func (m *bgeModel) acquireMutex(ctx context.Context) (unlock func(), err error) {
|
||||
acquired := make(chan struct{})
|
||||
go func() {
|
||||
m.mu.Lock()
|
||||
close(acquired)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-acquired:
|
||||
// Got the lock normally.
|
||||
return m.mu.Unlock, nil
|
||||
case <-ctx.Done():
|
||||
// Context cancelled while waiting. The goroutine above will eventually
|
||||
// acquire the mutex — we must ensure it gets unlocked.
|
||||
go func() {
|
||||
<-acquired
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// EmbedWithContext generates an embedding for a single text with context-aware mutex acquisition.
|
||||
// If ctx is cancelled while waiting for the model lock, returns immediately with ctx.Err().
|
||||
func (m *bgeModel) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
|
||||
if text == "" {
|
||||
return make([]float32, EmbeddingDim), nil
|
||||
}
|
||||
|
||||
unlock, err := m.acquireMutex(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acquire embedding lock: %w", err)
|
||||
}
|
||||
defer unlock()
|
||||
|
||||
results, err := m.computeBatch([]string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return make([]float32, EmbeddingDim), nil
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware mutex acquisition.
|
||||
func (m *bgeModel) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Filter out empty texts and track indices
|
||||
nonEmpty := make([]string, 0, len(texts))
|
||||
indices := make([]int, 0, len(texts))
|
||||
for i, t := range texts {
|
||||
if t != "" {
|
||||
nonEmpty = append(nonEmpty, t)
|
||||
indices = append(indices, i)
|
||||
}
|
||||
}
|
||||
|
||||
// If all texts are empty, return zero vectors
|
||||
if len(nonEmpty) == 0 {
|
||||
results := make([][]float32, len(texts))
|
||||
for i := range results {
|
||||
results[i] = make([]float32, EmbeddingDim)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
unlock, err := m.acquireMutex(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acquire embedding lock: %w", err)
|
||||
}
|
||||
defer unlock()
|
||||
|
||||
// Compute embeddings for non-empty texts
|
||||
embeddings, err := m.computeBatch(nonEmpty)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
||||
}
|
||||
|
||||
// Build result with zero vectors for empty texts
|
||||
results := make([][]float32, len(texts))
|
||||
for i := range results {
|
||||
results[i] = make([]float32, EmbeddingDim)
|
||||
}
|
||||
for i, idx := range indices {
|
||||
results[idx] = embeddings[i]
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
||||
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
||||
if len(sentences) == 0 {
|
||||
@@ -509,6 +607,17 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
return s.model.EmbedBatch(texts)
|
||||
}
|
||||
|
||||
// EmbedWithContext generates an embedding with context-aware cancellation.
|
||||
// If ctx is cancelled while waiting for the model lock, returns immediately.
|
||||
func (s *Service) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
|
||||
return s.model.EmbedWithContext(ctx, text)
|
||||
}
|
||||
|
||||
// EmbedBatchWithContext generates embeddings with context-aware cancellation.
|
||||
func (s *Service) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
return s.model.EmbedBatchWithContext(ctx, texts)
|
||||
}
|
||||
|
||||
// Close releases model resources.
|
||||
func (s *Service) Close() error {
|
||||
return s.model.Close()
|
||||
|
||||
+24
-6
@@ -129,13 +129,22 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
// Semaphore limits concurrent request goroutines.
|
||||
const maxConcurrent = 10
|
||||
sem := make(chan struct{}, maxConcurrent)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Drain in-flight requests before returning.
|
||||
wg.Wait()
|
||||
return ctx.Err()
|
||||
case line, ok := <-lines:
|
||||
if !ok {
|
||||
// Scanner finished, check for errors
|
||||
// Scanner finished — drain in-flight requests, then check for errors.
|
||||
wg.Wait()
|
||||
err := <-scanErr
|
||||
if err != nil {
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
@@ -154,18 +163,27 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
|
||||
var req Request
|
||||
if err := json.Unmarshal([]byte(line), &req); err != nil {
|
||||
// Parse errors are cheap — send inline, no goroutine needed.
|
||||
if werr := s.sendError(nil, -32700, "Parse error", err.Error()); werr != nil {
|
||||
return fmt.Errorf("write error: %w", werr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
resp := s.handleRequest(ctx, &req)
|
||||
if resp != nil {
|
||||
if werr := s.sendResponse(resp); werr != nil {
|
||||
return fmt.Errorf("write error: %w", werr)
|
||||
// Dispatch request to its own goroutine.
|
||||
wg.Add(1)
|
||||
sem <- struct{}{} // acquire semaphore slot
|
||||
go func(r Request) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }() // release semaphore slot
|
||||
|
||||
resp := s.handleRequest(ctx, &r)
|
||||
if resp != nil {
|
||||
if werr := s.sendResponse(resp); werr != nil {
|
||||
log.Error().Err(werr).Msg("Failed to send response")
|
||||
}
|
||||
}
|
||||
}
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
}
|
||||
|
||||
// Generate query embedding OUTSIDE the lock for better concurrency
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
@@ -282,8 +282,10 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
// Now acquire read lock for the actual DB query
|
||||
c.readMu.RLock()
|
||||
// Acquire read lock with context awareness to prevent indefinite blocking
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Build query with filters
|
||||
@@ -485,7 +487,7 @@ func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, wh
|
||||
// Combines results from different field types and deduplicates by document ID.
|
||||
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
|
||||
// Generate embedding once
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
@@ -496,7 +498,9 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Query with field type aggregation - get best match per document
|
||||
@@ -555,6 +559,28 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// acquireRLockWithContext acquires a read lock on mu, respecting ctx cancellation.
|
||||
// If ctx is cancelled while waiting for the lock, the goroutine that eventually
|
||||
// acquires it will release it immediately to prevent leaks.
|
||||
func acquireRLockWithContext(ctx context.Context, mu *sync.RWMutex) error {
|
||||
acquired := make(chan struct{})
|
||||
go func() {
|
||||
mu.RLock()
|
||||
close(acquired)
|
||||
}()
|
||||
select {
|
||||
case <-acquired:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
// Goroutine may still acquire lock after ctx cancelled — must unlock
|
||||
go func() {
|
||||
<-acquired
|
||||
mu.RUnlock()
|
||||
}()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen characters.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
@@ -565,7 +591,9 @@ func truncateString(s string, maxLen int) string {
|
||||
|
||||
// Count returns the total number of vectors in the store.
|
||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return 0, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
var count int64
|
||||
@@ -586,7 +614,9 @@ func (c *Client) ModelVersion() string {
|
||||
// - The vectors table is empty
|
||||
// - Any vectors have a different model_version than the current model
|
||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return false, ""
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
@@ -634,7 +664,9 @@ type StaleVectorInfo struct {
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
@@ -692,7 +724,9 @@ type VectorHealthStats struct {
|
||||
|
||||
// GetHealthStats returns comprehensive health statistics about the vector store.
|
||||
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
|
||||
c.readMu.RLock()
|
||||
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||
}
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
stats := &VectorHealthStats{
|
||||
@@ -856,7 +890,9 @@ func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
|
||||
|
||||
// getOrComputeEmbedding returns a cached embedding or computes a new one.
|
||||
// Uses singleflight to prevent duplicate concurrent computations for the same query.
|
||||
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
||||
// The context controls timeout on the embedding mutex acquisition -- if the ONNX model
|
||||
// hangs under CGO, callers can bail out instead of blocking forever.
|
||||
func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]float32, error) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
// Check cache first (read lock)
|
||||
@@ -885,8 +921,8 @@ func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
||||
// Record cache miss
|
||||
c.stats.embeddingMisses.Add(1)
|
||||
|
||||
// Compute embedding
|
||||
emb, err := c.embedSvc.Embed(query)
|
||||
// Compute embedding with context-aware lock acquisition
|
||||
emb, err := c.embedSvc.EmbedWithContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -23,6 +23,10 @@ import (
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
@@ -54,7 +58,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
var expandedQueries []expansion.ExpandedQuery
|
||||
var detectedIntent string
|
||||
if s.queryExpander != nil {
|
||||
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
expandCtx, expandCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
cfg := expansion.DefaultConfig()
|
||||
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
||||
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
|
||||
@@ -83,7 +87,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
var vectorErrors int
|
||||
|
||||
for _, eq := range expandedQueries {
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, eq.Query, limit*2, where)
|
||||
if vecErr != nil {
|
||||
vectorErrors++
|
||||
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
|
||||
@@ -125,7 +129,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if len(obsIDs) > 0 {
|
||||
// Fetch full observations from SQLite
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
|
||||
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
@@ -138,11 +142,11 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
if vectorSearchFailed {
|
||||
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
|
||||
}
|
||||
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
|
||||
observations, err = s.observationStore.SearchObservationsFTS(ctx, query, project, limit)
|
||||
if err != nil {
|
||||
// FTS might fail if query has special chars, try without
|
||||
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
|
||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
observations, err = s.observationStore.GetRecentObservations(ctx, project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -354,7 +358,9 @@ func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Search for observations related to each file in parallel
|
||||
ctx := r.Context()
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, ctxCancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer ctxCancel()
|
||||
|
||||
// Check if vector search is available
|
||||
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
|
||||
@@ -562,6 +568,10 @@ func splitCamelCase(s string) string {
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
@@ -592,7 +602,7 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Get recent observations
|
||||
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
observations, err := s.observationStore.GetRecentObservations(ctx, project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -658,13 +668,17 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleContextCount returns the count of observations for a project.
|
||||
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := s.getCachedObservationCount(r.Context(), project)
|
||||
count, err := s.getCachedObservationCount(ctx, project)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"runtime"
|
||||
@@ -20,6 +21,10 @@ import (
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
// Supports pagination via limit and offset query parameters.
|
||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
@@ -38,11 +43,11 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, pagination.Limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
if len(obsIDs) > 0 {
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
|
||||
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", pagination.Limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
total = int64(len(observations)) // Vector search doesn't have total, use returned count
|
||||
@@ -55,10 +60,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
// Strict project filtering for dashboard - only observations from this project
|
||||
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
|
||||
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, pagination.Limit, pagination.Offset)
|
||||
} else {
|
||||
// All projects
|
||||
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
|
||||
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, pagination.Limit, pagination.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +95,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
// handleGetSummaries returns recent summaries.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
@@ -107,11 +116,11 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
|
||||
if len(summaryIDs) > 0 {
|
||||
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
|
||||
summaries, err = s.summaryStore.GetSummariesByIDs(ctx, summaryIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
@@ -122,9 +131,9 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(ctx, project, limit)
|
||||
} else {
|
||||
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
|
||||
summaries, err = s.summaryStore.GetAllRecentSummaries(ctx, limit)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +152,10 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
// handleGetPrompts returns recent user prompts.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
@@ -160,11 +173,11 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
|
||||
if len(promptIDs) > 0 {
|
||||
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
|
||||
prompts, err = s.promptStore.GetPromptsByIDs(ctx, promptIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
@@ -175,9 +188,9 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(ctx, project, limit)
|
||||
} else {
|
||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
|
||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(ctx, limit)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1246,7 +1246,12 @@ func (s *Service) setupRoutes() {
|
||||
s.router.Get("/api/selfcheck", s.handleSelfCheck)
|
||||
|
||||
// SSE endpoint (works before DB is ready)
|
||||
s.router.Get("/api/events", s.sseBroadcaster.HandleSSE)
|
||||
// Wrap with middleware that disables write deadline since SSE connections are long-lived
|
||||
s.router.Get("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||
rc := http.NewResponseController(w)
|
||||
_ = rc.SetWriteDeadline(time.Time{}) // no deadline for SSE
|
||||
s.sseBroadcaster.HandleSSE(w, r)
|
||||
})
|
||||
|
||||
// Routes that require DB to be ready
|
||||
s.router.Group(func(r chi.Router) {
|
||||
@@ -1546,7 +1551,7 @@ func (s *Service) Start() error {
|
||||
Handler: s.router,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 0, // Disabled for SSE (long-lived connections)
|
||||
WriteTimeout: 60 * time.Second, // Default for API routes; SSE handler extends per-request
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user