mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)
Root cause: synchronous MCP request processing combined with missing context propagation to the embedding layer caused indefinite hangs when ONNX inference was slow or the database was contended. Changes: - MCP server: dispatch each request in its own goroutine with semaphore (cap 10) and WaitGroup for clean shutdown drain - Embedding: add context-aware mutex acquisition (acquireMutex) so callers can bail out instead of blocking forever on a stuck ONNX model - Vector client: propagate context through getOrComputeEmbedding and replace bare RLock() calls with context-aware acquireRLockWithContext - Worker handlers: add 15s request-scoped timeouts to all search/context handlers (handleSearchByPrompt, handleContextInject, handleFileContext, handleContextCount, handleGetObservations/Summaries/Prompts) - Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends deadline per-request via http.ResponseController Fixes #45
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
package embedding
|
package embedding
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
@@ -44,6 +45,14 @@ type EmbeddingModel interface {
|
|||||||
// EmbedBatch generates embeddings for multiple texts.
|
// EmbedBatch generates embeddings for multiple texts.
|
||||||
EmbedBatch(texts []string) ([][]float32, error)
|
EmbedBatch(texts []string) ([][]float32, error)
|
||||||
|
|
||||||
|
// EmbedWithContext generates an embedding for a single text with context-aware cancellation.
|
||||||
|
// The context controls mutex acquisition timeout — if ctx is cancelled while waiting
|
||||||
|
// for the model lock, the call returns immediately with ctx.Err().
|
||||||
|
EmbedWithContext(ctx context.Context, text string) ([]float32, error)
|
||||||
|
|
||||||
|
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware cancellation.
|
||||||
|
EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error)
|
||||||
|
|
||||||
// Close releases model resources.
|
// Close releases model resources.
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package embedding
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -225,6 +226,103 @@ func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// acquireMutex attempts to acquire the model mutex, respecting context cancellation.
|
||||||
|
// On success the caller MUST call the returned unlock function.
|
||||||
|
// If ctx is cancelled while waiting, returns ctx.Err() and no unlock is needed.
|
||||||
|
func (m *bgeModel) acquireMutex(ctx context.Context) (unlock func(), err error) {
|
||||||
|
acquired := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
close(acquired)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-acquired:
|
||||||
|
// Got the lock normally.
|
||||||
|
return m.mu.Unlock, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Context cancelled while waiting. The goroutine above will eventually
|
||||||
|
// acquire the mutex — we must ensure it gets unlocked.
|
||||||
|
go func() {
|
||||||
|
<-acquired
|
||||||
|
m.mu.Unlock()
|
||||||
|
}()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedWithContext generates an embedding for a single text with context-aware mutex acquisition.
|
||||||
|
// If ctx is cancelled while waiting for the model lock, returns immediately with ctx.Err().
|
||||||
|
func (m *bgeModel) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
|
||||||
|
if text == "" {
|
||||||
|
return make([]float32, EmbeddingDim), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
unlock, err := m.acquireMutex(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("acquire embedding lock: %w", err)
|
||||||
|
}
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
results, err := m.computeBatch([]string{text})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(results) == 0 {
|
||||||
|
return make([]float32, EmbeddingDim), nil
|
||||||
|
}
|
||||||
|
return results[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware mutex acquisition.
|
||||||
|
func (m *bgeModel) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
|
||||||
|
if len(texts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out empty texts and track indices
|
||||||
|
nonEmpty := make([]string, 0, len(texts))
|
||||||
|
indices := make([]int, 0, len(texts))
|
||||||
|
for i, t := range texts {
|
||||||
|
if t != "" {
|
||||||
|
nonEmpty = append(nonEmpty, t)
|
||||||
|
indices = append(indices, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If all texts are empty, return zero vectors
|
||||||
|
if len(nonEmpty) == 0 {
|
||||||
|
results := make([][]float32, len(texts))
|
||||||
|
for i := range results {
|
||||||
|
results[i] = make([]float32, EmbeddingDim)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
unlock, err := m.acquireMutex(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("acquire embedding lock: %w", err)
|
||||||
|
}
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
// Compute embeddings for non-empty texts
|
||||||
|
embeddings, err := m.computeBatch(nonEmpty)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build result with zero vectors for empty texts
|
||||||
|
results := make([][]float32, len(texts))
|
||||||
|
for i := range results {
|
||||||
|
results[i] = make([]float32, EmbeddingDim)
|
||||||
|
}
|
||||||
|
for i, idx := range indices {
|
||||||
|
results[idx] = embeddings[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
||||||
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
||||||
if len(sentences) == 0 {
|
if len(sentences) == 0 {
|
||||||
@@ -509,6 +607,17 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
|||||||
return s.model.EmbedBatch(texts)
|
return s.model.EmbedBatch(texts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbedWithContext generates an embedding with context-aware cancellation.
|
||||||
|
// If ctx is cancelled while waiting for the model lock, returns immediately.
|
||||||
|
func (s *Service) EmbedWithContext(ctx context.Context, text string) ([]float32, error) {
|
||||||
|
return s.model.EmbedWithContext(ctx, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedBatchWithContext generates embeddings with context-aware cancellation.
|
||||||
|
func (s *Service) EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error) {
|
||||||
|
return s.model.EmbedBatchWithContext(ctx, texts)
|
||||||
|
}
|
||||||
|
|
||||||
// Close releases model resources.
|
// Close releases model resources.
|
||||||
func (s *Service) Close() error {
|
func (s *Service) Close() error {
|
||||||
return s.model.Close()
|
return s.model.Close()
|
||||||
|
|||||||
+24
-6
@@ -129,13 +129,22 @@ func (s *Server) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Semaphore limits concurrent request goroutines.
|
||||||
|
const maxConcurrent = 10
|
||||||
|
sem := make(chan struct{}, maxConcurrent)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
// Drain in-flight requests before returning.
|
||||||
|
wg.Wait()
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case line, ok := <-lines:
|
case line, ok := <-lines:
|
||||||
if !ok {
|
if !ok {
|
||||||
// Scanner finished, check for errors
|
// Scanner finished — drain in-flight requests, then check for errors.
|
||||||
|
wg.Wait()
|
||||||
err := <-scanErr
|
err := <-scanErr
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, bufio.ErrTooLong) {
|
if errors.Is(err, bufio.ErrTooLong) {
|
||||||
@@ -154,18 +163,27 @@ func (s *Server) Run(ctx context.Context) error {
|
|||||||
|
|
||||||
var req Request
|
var req Request
|
||||||
if err := json.Unmarshal([]byte(line), &req); err != nil {
|
if err := json.Unmarshal([]byte(line), &req); err != nil {
|
||||||
|
// Parse errors are cheap — send inline, no goroutine needed.
|
||||||
if werr := s.sendError(nil, -32700, "Parse error", err.Error()); werr != nil {
|
if werr := s.sendError(nil, -32700, "Parse error", err.Error()); werr != nil {
|
||||||
return fmt.Errorf("write error: %w", werr)
|
return fmt.Errorf("write error: %w", werr)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := s.handleRequest(ctx, &req)
|
// Dispatch request to its own goroutine.
|
||||||
if resp != nil {
|
wg.Add(1)
|
||||||
if werr := s.sendResponse(resp); werr != nil {
|
sem <- struct{}{} // acquire semaphore slot
|
||||||
return fmt.Errorf("write error: %w", werr)
|
go func(r Request) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer func() { <-sem }() // release semaphore slot
|
||||||
|
|
||||||
|
resp := s.handleRequest(ctx, &r)
|
||||||
|
if resp != nil {
|
||||||
|
if werr := s.sendResponse(resp); werr != nil {
|
||||||
|
log.Error().Err(werr).Msg("Failed to send response")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}(req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate query embedding OUTSIDE the lock for better concurrency
|
// Generate query embedding OUTSIDE the lock for better concurrency
|
||||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("embed query: %w", err)
|
return nil, fmt.Errorf("embed query: %w", err)
|
||||||
}
|
}
|
||||||
@@ -282,8 +282,10 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
|||||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now acquire read lock for the actual DB query
|
// Acquire read lock with context awareness to prevent indefinite blocking
|
||||||
c.readMu.RLock()
|
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||||
|
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||||
|
}
|
||||||
defer c.readMu.RUnlock()
|
defer c.readMu.RUnlock()
|
||||||
|
|
||||||
// Build query with filters
|
// Build query with filters
|
||||||
@@ -485,7 +487,7 @@ func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, wh
|
|||||||
// Combines results from different field types and deduplicates by document ID.
|
// Combines results from different field types and deduplicates by document ID.
|
||||||
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
|
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
|
||||||
// Generate embedding once
|
// Generate embedding once
|
||||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
queryEmb, err := c.getOrComputeEmbedding(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("embed query: %w", err)
|
return nil, fmt.Errorf("embed query: %w", err)
|
||||||
}
|
}
|
||||||
@@ -496,7 +498,9 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
|
|||||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.readMu.RLock()
|
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||||
|
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||||
|
}
|
||||||
defer c.readMu.RUnlock()
|
defer c.readMu.RUnlock()
|
||||||
|
|
||||||
// Query with field type aggregation - get best match per document
|
// Query with field type aggregation - get best match per document
|
||||||
@@ -555,6 +559,28 @@ func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, d
|
|||||||
return results, rows.Err()
|
return results, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// acquireRLockWithContext acquires a read lock on mu, respecting ctx cancellation.
|
||||||
|
// If ctx is cancelled while waiting for the lock, the goroutine that eventually
|
||||||
|
// acquires it will release it immediately to prevent leaks.
|
||||||
|
func acquireRLockWithContext(ctx context.Context, mu *sync.RWMutex) error {
|
||||||
|
acquired := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
mu.RLock()
|
||||||
|
close(acquired)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-acquired:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Goroutine may still acquire lock after ctx cancelled — must unlock
|
||||||
|
go func() {
|
||||||
|
<-acquired
|
||||||
|
mu.RUnlock()
|
||||||
|
}()
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// truncateString truncates a string to maxLen characters.
|
// truncateString truncates a string to maxLen characters.
|
||||||
func truncateString(s string, maxLen int) string {
|
func truncateString(s string, maxLen int) string {
|
||||||
if len(s) <= maxLen {
|
if len(s) <= maxLen {
|
||||||
@@ -565,7 +591,9 @@ func truncateString(s string, maxLen int) string {
|
|||||||
|
|
||||||
// Count returns the total number of vectors in the store.
|
// Count returns the total number of vectors in the store.
|
||||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||||
c.readMu.RLock()
|
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||||
|
return 0, fmt.Errorf("acquire read lock: %w", err)
|
||||||
|
}
|
||||||
defer c.readMu.RUnlock()
|
defer c.readMu.RUnlock()
|
||||||
|
|
||||||
var count int64
|
var count int64
|
||||||
@@ -586,7 +614,9 @@ func (c *Client) ModelVersion() string {
|
|||||||
// - The vectors table is empty
|
// - The vectors table is empty
|
||||||
// - Any vectors have a different model_version than the current model
|
// - Any vectors have a different model_version than the current model
|
||||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||||
c.readMu.RLock()
|
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
defer c.readMu.RUnlock()
|
defer c.readMu.RUnlock()
|
||||||
|
|
||||||
currentModel := c.embedSvc.Version()
|
currentModel := c.embedSvc.Version()
|
||||||
@@ -634,7 +664,9 @@ type StaleVectorInfo struct {
|
|||||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||||
// This enables granular rebuild - only re-embedding documents that need updating.
|
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||||
c.readMu.RLock()
|
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||||
|
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||||
|
}
|
||||||
defer c.readMu.RUnlock()
|
defer c.readMu.RUnlock()
|
||||||
|
|
||||||
currentModel := c.embedSvc.Version()
|
currentModel := c.embedSvc.Version()
|
||||||
@@ -692,7 +724,9 @@ type VectorHealthStats struct {
|
|||||||
|
|
||||||
// GetHealthStats returns comprehensive health statistics about the vector store.
|
// GetHealthStats returns comprehensive health statistics about the vector store.
|
||||||
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
|
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
|
||||||
c.readMu.RLock()
|
if err := acquireRLockWithContext(ctx, &c.readMu); err != nil {
|
||||||
|
return nil, fmt.Errorf("acquire read lock: %w", err)
|
||||||
|
}
|
||||||
defer c.readMu.RUnlock()
|
defer c.readMu.RUnlock()
|
||||||
|
|
||||||
stats := &VectorHealthStats{
|
stats := &VectorHealthStats{
|
||||||
@@ -856,7 +890,9 @@ func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
|
|||||||
|
|
||||||
// getOrComputeEmbedding returns a cached embedding or computes a new one.
|
// getOrComputeEmbedding returns a cached embedding or computes a new one.
|
||||||
// Uses singleflight to prevent duplicate concurrent computations for the same query.
|
// Uses singleflight to prevent duplicate concurrent computations for the same query.
|
||||||
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
// The context controls timeout on the embedding mutex acquisition -- if the ONNX model
|
||||||
|
// hangs under CGO, callers can bail out instead of blocking forever.
|
||||||
|
func (c *Client) getOrComputeEmbedding(ctx context.Context, query string) ([]float32, error) {
|
||||||
now := time.Now().UnixNano()
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
// Check cache first (read lock)
|
// Check cache first (read lock)
|
||||||
@@ -885,8 +921,8 @@ func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
|||||||
// Record cache miss
|
// Record cache miss
|
||||||
c.stats.embeddingMisses.Add(1)
|
c.stats.embeddingMisses.Add(1)
|
||||||
|
|
||||||
// Compute embedding
|
// Compute embedding with context-aware lock acquisition
|
||||||
emb, err := c.embedSvc.Embed(query)
|
emb, err := c.embedSvc.EmbedWithContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ import (
|
|||||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||||
// No synchronous verification - just filter by staleness and return.
|
// No synchronous verification - just filter by staleness and return.
|
||||||
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
project := r.URL.Query().Get("project")
|
project := r.URL.Query().Get("project")
|
||||||
query := r.URL.Query().Get("query")
|
query := r.URL.Query().Get("query")
|
||||||
cwd := r.URL.Query().Get("cwd")
|
cwd := r.URL.Query().Get("cwd")
|
||||||
@@ -54,7 +58,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
|||||||
var expandedQueries []expansion.ExpandedQuery
|
var expandedQueries []expansion.ExpandedQuery
|
||||||
var detectedIntent string
|
var detectedIntent string
|
||||||
if s.queryExpander != nil {
|
if s.queryExpander != nil {
|
||||||
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
|
expandCtx, expandCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
cfg := expansion.DefaultConfig()
|
cfg := expansion.DefaultConfig()
|
||||||
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
||||||
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
|
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
|
||||||
@@ -83,7 +87,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
|||||||
var vectorErrors int
|
var vectorErrors int
|
||||||
|
|
||||||
for _, eq := range expandedQueries {
|
for _, eq := range expandedQueries {
|
||||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
|
vectorResults, vecErr := s.vectorClient.Query(ctx, eq.Query, limit*2, where)
|
||||||
if vecErr != nil {
|
if vecErr != nil {
|
||||||
vectorErrors++
|
vectorErrors++
|
||||||
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
|
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
|
||||||
@@ -125,7 +129,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if len(obsIDs) > 0 {
|
if len(obsIDs) > 0 {
|
||||||
// Fetch full observations from SQLite
|
// Fetch full observations from SQLite
|
||||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
|
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", limit)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
usedVector = true
|
usedVector = true
|
||||||
}
|
}
|
||||||
@@ -138,11 +142,11 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
|||||||
if vectorSearchFailed {
|
if vectorSearchFailed {
|
||||||
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
|
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
|
||||||
}
|
}
|
||||||
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
|
observations, err = s.observationStore.SearchObservationsFTS(ctx, query, project, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// FTS might fail if query has special chars, try without
|
// FTS might fail if query has special chars, try without
|
||||||
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
|
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
|
||||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
observations, err = s.observationStore.GetRecentObservations(ctx, project, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@@ -354,7 +358,9 @@ func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Search for observations related to each file in parallel
|
// Search for observations related to each file in parallel
|
||||||
ctx := r.Context()
|
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||||
|
ctx, ctxCancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer ctxCancel()
|
||||||
|
|
||||||
// Check if vector search is available
|
// Check if vector search is available
|
||||||
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
|
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
|
||||||
@@ -562,6 +568,10 @@ func splitCamelCase(s string) string {
|
|||||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||||
// No synchronous verification - just filter by staleness and return.
|
// No synchronous verification - just filter by staleness and return.
|
||||||
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
project := r.URL.Query().Get("project")
|
project := r.URL.Query().Get("project")
|
||||||
if project == "" {
|
if project == "" {
|
||||||
http.Error(w, "project required", http.StatusBadRequest)
|
http.Error(w, "project required", http.StatusBadRequest)
|
||||||
@@ -592,7 +602,7 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get recent observations
|
// Get recent observations
|
||||||
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
observations, err := s.observationStore.GetRecentObservations(ctx, project, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@@ -658,13 +668,17 @@ func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// handleContextCount returns the count of observations for a project.
|
// handleContextCount returns the count of observations for a project.
|
||||||
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
project := r.URL.Query().Get("project")
|
project := r.URL.Query().Get("project")
|
||||||
if project == "" {
|
if project == "" {
|
||||||
http.Error(w, "project required", http.StatusBadRequest)
|
http.Error(w, "project required", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := s.getCachedObservationCount(r.Context(), project)
|
count, err := s.getCachedObservationCount(ctx, project)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
package worker
|
package worker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -20,6 +21,10 @@ import (
|
|||||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||||
// Supports pagination via limit and offset query parameters.
|
// Supports pagination via limit and offset query parameters.
|
||||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
|
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
|
||||||
project := r.URL.Query().Get("project")
|
project := r.URL.Query().Get("project")
|
||||||
query := r.URL.Query().Get("query")
|
query := r.URL.Query().Get("query")
|
||||||
@@ -38,11 +43,11 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
|||||||
// Use vector search if query is provided and vector client is available
|
// Use vector search if query is provided and vector client is available
|
||||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
|
vectorResults, vecErr := s.vectorClient.Query(ctx, query, pagination.Limit*2, where)
|
||||||
if vecErr == nil && len(vectorResults) > 0 {
|
if vecErr == nil && len(vectorResults) > 0 {
|
||||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||||
if len(obsIDs) > 0 {
|
if len(obsIDs) > 0 {
|
||||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
|
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", pagination.Limit)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
usedVector = true
|
usedVector = true
|
||||||
total = int64(len(observations)) // Vector search doesn't have total, use returned count
|
total = int64(len(observations)) // Vector search doesn't have total, use returned count
|
||||||
@@ -55,10 +60,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
|||||||
if !usedVector {
|
if !usedVector {
|
||||||
if project != "" {
|
if project != "" {
|
||||||
// Strict project filtering for dashboard - only observations from this project
|
// Strict project filtering for dashboard - only observations from this project
|
||||||
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
|
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, pagination.Limit, pagination.Offset)
|
||||||
} else {
|
} else {
|
||||||
// All projects
|
// All projects
|
||||||
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
|
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, pagination.Limit, pagination.Offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +95,10 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
|||||||
// handleGetSummaries returns recent summaries.
|
// handleGetSummaries returns recent summaries.
|
||||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
|
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
|
||||||
project := r.URL.Query().Get("project")
|
project := r.URL.Query().Get("project")
|
||||||
query := r.URL.Query().Get("query")
|
query := r.URL.Query().Get("query")
|
||||||
@@ -107,11 +116,11 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Use vector search if query is provided and vector client is available
|
// Use vector search if query is provided and vector client is available
|
||||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
|
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
|
||||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||||
if vecErr == nil && len(vectorResults) > 0 {
|
if vecErr == nil && len(vectorResults) > 0 {
|
||||||
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
|
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
|
||||||
if len(summaryIDs) > 0 {
|
if len(summaryIDs) > 0 {
|
||||||
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
|
summaries, err = s.summaryStore.GetSummariesByIDs(ctx, summaryIDs, "date_desc", limit)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
usedVector = true
|
usedVector = true
|
||||||
}
|
}
|
||||||
@@ -122,9 +131,9 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Fall back to SQLite if vector search not used
|
// Fall back to SQLite if vector search not used
|
||||||
if !usedVector {
|
if !usedVector {
|
||||||
if project != "" {
|
if project != "" {
|
||||||
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
|
summaries, err = s.summaryStore.GetRecentSummaries(ctx, project, limit)
|
||||||
} else {
|
} else {
|
||||||
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
|
summaries, err = s.summaryStore.GetAllRecentSummaries(ctx, limit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,6 +152,10 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
|||||||
// handleGetPrompts returns recent user prompts.
|
// handleGetPrompts returns recent user prompts.
|
||||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Add request-scoped timeout to prevent indefinite blocking on DB operations
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
|
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
|
||||||
project := r.URL.Query().Get("project")
|
project := r.URL.Query().Get("project")
|
||||||
query := r.URL.Query().Get("query")
|
query := r.URL.Query().Get("query")
|
||||||
@@ -160,11 +173,11 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Use vector search if query is provided and vector client is available
|
// Use vector search if query is provided and vector client is available
|
||||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
|
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
|
||||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||||
if vecErr == nil && len(vectorResults) > 0 {
|
if vecErr == nil && len(vectorResults) > 0 {
|
||||||
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
|
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
|
||||||
if len(promptIDs) > 0 {
|
if len(promptIDs) > 0 {
|
||||||
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
|
prompts, err = s.promptStore.GetPromptsByIDs(ctx, promptIDs, "date_desc", limit)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
usedVector = true
|
usedVector = true
|
||||||
}
|
}
|
||||||
@@ -175,9 +188,9 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Fall back to SQLite if vector search not used
|
// Fall back to SQLite if vector search not used
|
||||||
if !usedVector {
|
if !usedVector {
|
||||||
if project != "" {
|
if project != "" {
|
||||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
|
prompts, err = s.promptStore.GetRecentUserPromptsByProject(ctx, project, limit)
|
||||||
} else {
|
} else {
|
||||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
|
prompts, err = s.promptStore.GetAllRecentUserPrompts(ctx, limit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1246,7 +1246,12 @@ func (s *Service) setupRoutes() {
|
|||||||
s.router.Get("/api/selfcheck", s.handleSelfCheck)
|
s.router.Get("/api/selfcheck", s.handleSelfCheck)
|
||||||
|
|
||||||
// SSE endpoint (works before DB is ready)
|
// SSE endpoint (works before DB is ready)
|
||||||
s.router.Get("/api/events", s.sseBroadcaster.HandleSSE)
|
// Wrap with middleware that disables write deadline since SSE connections are long-lived
|
||||||
|
s.router.Get("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
_ = rc.SetWriteDeadline(time.Time{}) // no deadline for SSE
|
||||||
|
s.sseBroadcaster.HandleSSE(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
// Routes that require DB to be ready
|
// Routes that require DB to be ready
|
||||||
s.router.Group(func(r chi.Router) {
|
s.router.Group(func(r chi.Router) {
|
||||||
@@ -1546,7 +1551,7 @@ func (s *Service) Start() error {
|
|||||||
Handler: s.router,
|
Handler: s.router,
|
||||||
ReadHeaderTimeout: 10 * time.Second,
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
ReadTimeout: 30 * time.Second,
|
ReadTimeout: 30 * time.Second,
|
||||||
WriteTimeout: 0, // Disabled for SSE (long-lived connections)
|
WriteTimeout: 60 * time.Second, // Default for API routes; SSE handler extends per-request
|
||||||
IdleTimeout: 120 * time.Second,
|
IdleTimeout: 120 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user