Files
claude-mnemonic/internal/worker/handlers_context.go
T
lukaszraczylo 29d57857ff fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)
Root cause: synchronous MCP request processing combined with missing
context propagation to the embedding layer caused indefinite hangs when
ONNX inference was slow or the database was contended.

Changes:
- MCP server: dispatch each request in its own goroutine with semaphore
  (cap 10) and WaitGroup for clean shutdown drain
- Embedding: add context-aware mutex acquisition (acquireMutex) so
  callers can bail out instead of blocking forever on a stuck ONNX model
- Vector client: propagate context through getOrComputeEmbedding and
  replace bare RLock() calls with context-aware acquireRLockWithContext
- Worker handlers: add 15s request-scoped timeouts to all search/context
  handlers (handleSearchByPrompt, handleContextInject, handleFileContext,
  handleContextCount, handleGetObservations/Summaries/Prompts)
- Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends
  deadline per-request via http.ResponseController

Fixes #45
2026-05-26 14:29:34 +01:00

692 lines
22 KiB
Go

// Package worker provides context and search-related HTTP handlers.
package worker
import (
"context"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// handleSearchByPrompt searches observations relevant to a user prompt.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
cwd := r.URL.Query().Get("cwd")
if project == "" || query == "" {
http.Error(w, "project and query required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
limit := gorm.ParseLimitParamWithMax(r, DefaultSearchLimit, 200)
var observations []*models.Observation
var err error
var usedVector bool
similarityScores := make(map[int64]float64) // Track similarity per observation
// Get threshold settings from config
threshold := s.config.ContextRelevanceThreshold
maxResults := s.config.ContextMaxPromptResults
// Generate expanded queries if query expander is available
// Use timeout context to prevent query expansion from blocking
var expandedQueries []expansion.ExpandedQuery
var detectedIntent string
if s.queryExpander != nil {
expandCtx, expandCancel := context.WithTimeout(ctx, 5*time.Second)
cfg := expansion.DefaultConfig()
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
expandCancel() // Cancel immediately after use (defer not needed - no panic possible between creation and here)
if len(expandedQueries) > 0 {
detectedIntent = string(expandedQueries[0].Intent)
}
}
if len(expandedQueries) == 0 {
// Fallback to just the original query
expandedQueries = []expansion.ExpandedQuery{
{Query: query, Weight: 1.0, Source: "original"},
}
}
// Try vector search first if available
var vectorSearchFailed bool
if s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
// Search with each expanded query and merge results
// Pre-allocate with estimated capacity to avoid repeated reallocation
estimatedCapacity := len(expandedQueries) * limit * 2
allVectorResults := make([]sqlitevec.QueryResult, 0, estimatedCapacity)
queryWeights := make(map[string]float64, len(expandedQueries))
var vectorErrors int
for _, eq := range expandedQueries {
vectorResults, vecErr := s.vectorClient.Query(ctx, eq.Query, limit*2, where)
if vecErr != nil {
vectorErrors++
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
} else if len(vectorResults) > 0 {
// Apply weight to similarity scores before merging
for i := range vectorResults {
vectorResults[i].Similarity *= eq.Weight
}
allVectorResults = append(allVectorResults, vectorResults...)
queryWeights[eq.Query] = eq.Weight
}
}
// Track if vector search had issues
if vectorErrors > 0 && vectorErrors == len(expandedQueries) {
vectorSearchFailed = true
log.Warn().Int("errors", vectorErrors).Str("project", project).Msg("All vector queries failed, falling back to FTS")
}
if len(allVectorResults) > 0 {
// Filter by relevance threshold before extracting IDs
// Use a slightly lower threshold for expanded queries
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
// Build similarity map for filtered results (keeping highest weighted score per observation)
for _, vr := range filteredResults {
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
id := int64(sqliteID)
// Keep the highest score for each observation
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
similarityScores[id] = vr.Similarity
}
}
}
// Extract observation IDs with project/scope filtering using shared helper
obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
if len(obsIDs) > 0 {
// Fetch full observations from SQLite
observations, err = s.observationStore.GetObservationsByIDs(ctx, obsIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
}
}
}
// Fall back to FTS if vector search not available, failed, or returned no results
if !usedVector || len(observations) == 0 {
if vectorSearchFailed {
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
}
observations, err = s.observationStore.SearchObservationsFTS(ctx, query, project, limit)
if err != nil {
// FTS might fail if query has special chars, try without
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
observations, err = s.observationStore.GetRecentObservations(ctx, project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
}
// Fast staleness filter - NO verification (that's too slow for interactive use)
// Just check mtimes and exclude obviously stale observations
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 && cwd != "" {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Apply cross-encoder reranking if available
var reranked bool
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
// Build candidates from observations with their bi-encoder scores
candidates := make([]reranking.Candidate, len(freshObservations))
for i, obs := range freshObservations {
// Use strings.Builder for efficient concatenation
var content string
if obs.Narrative.Valid && obs.Narrative.String != "" {
var sb strings.Builder
sb.Grow(len(obs.Title.String) + 1 + len(obs.Narrative.String))
sb.WriteString(obs.Title.String)
sb.WriteByte(' ')
sb.WriteString(obs.Narrative.String)
content = sb.String()
} else {
content = obs.Title.String
}
candidates[i] = reranking.Candidate{
ID: strconv.FormatInt(obs.ID, 10), // Faster than fmt.Sprintf
Content: content,
Score: similarityScores[obs.ID],
Metadata: map[string]any{"obs_idx": i},
}
}
// Rerank using cross-encoder - use pure mode or combined scores
var rerankResults []reranking.RerankResult
var rerankErr error
if s.config.RerankingPureMode {
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
} else {
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
}
if rerankErr != nil {
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
} else if len(rerankResults) > 0 {
// Update similarity scores with reranked scores
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
similarityScores[id] = rr.CombinedScore
}
}
// Reorder observations based on rerank results
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
obsMap := make(map[int64]*models.Observation)
for _, obs := range freshObservations {
obsMap[obs.ID] = obs
}
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
if obs, ok := obsMap[id]; ok {
reorderedObs = append(reorderedObs, obs)
}
}
}
freshObservations = reorderedObs
reranked = true
log.Debug().
Int("candidates", len(candidates)).
Int("returned", len(rerankResults)).
Msg("Cross-encoder reranking complete")
}
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
// Sort by similarity score (highest first) if we have scores and didn't rerank
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
sort.Slice(clusteredObservations, func(i, j int) bool {
scoreI := similarityScores[clusteredObservations[i].ID]
scoreJ := similarityScores[clusteredObservations[j].ID]
return scoreI > scoreJ
})
}
// Apply max results cap if configured
if maxResults > 0 && len(clusteredObservations) > maxResults {
clusteredObservations = clusteredObservations[:maxResults]
}
// Record retrieval stats with staleness metrics
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), true)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info().
Str("project", project).
Str("query", query).
Str("intent", detectedIntent).
Int("expansions", len(expandedQueries)).
Int("found", len(clusteredObservations)).
Int("stale_excluded", staleCount).
Float64("threshold", threshold).
Msg("Prompt-based observation search")
// Build response with similarity scores
obsWithScores := make([]map[string]any, len(clusteredObservations))
for i, obs := range clusteredObservations {
obsMap := obs.ToMap()
if score, ok := similarityScores[obs.ID]; ok {
obsMap["similarity"] = score
}
obsWithScores[i] = obsMap
}
// Build expansion info for response
expansionInfo := make([]map[string]any, len(expandedQueries))
for i, eq := range expandedQueries {
expansionInfo[i] = map[string]any{
"query": eq.Query,
"weight": eq.Weight,
"source": eq.Source,
}
}
// Track this search for analytics
s.trackSearchQuery(query, project, "observations", len(clusteredObservations), usedVector)
writeJSON(w, map[string]any{
"project": project,
"query": query,
"intent": detectedIntent,
"expansions": expansionInfo,
"observations": obsWithScores,
"threshold": threshold,
"max_results": maxResults,
})
}
// handleFileContext returns observations relevant to specific files being worked on.
// Uses vector similarity search to find observations that mention or relate to the given files.
func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
filesParam := r.URL.Query().Get("files")
if filesParam == "" {
http.Error(w, "files required", http.StatusBadRequest)
return
}
// Parse comma-separated file paths
files := strings.Split(filesParam, ",")
if len(files) == 0 {
http.Error(w, "at least one file required", http.StatusBadRequest)
return
}
// Limit to reasonable number of files
maxFiles := 20
if len(files) > maxFiles {
files = files[:maxFiles]
}
// Get limit parameter (default 10 per file)
limitStr := r.URL.Query().Get("limit")
limit := 10
if limitStr != "" {
if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 && parsed <= 50 {
limit = parsed
}
}
// Search for observations related to each file in parallel
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, ctxCancel := context.WithTimeout(r.Context(), 15*time.Second)
defer ctxCancel()
// Check if vector search is available
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
writeJSON(w, map[string]any{
"files": files,
"results": map[string]any{},
"count": 0,
"error": "vector search not available",
})
return
}
// Prepare for parallel execution
type fileResult struct {
file string
results []map[string]any
obsIDs []int64 // Track observation IDs for deduplication
}
resultsChan := make(chan fileResult, len(files))
sem := make(chan struct{}, 5) // Limit concurrency to 5 parallel searches
var wg sync.WaitGroup
for _, file := range files {
file = strings.TrimSpace(file)
if file == "" {
continue
}
wg.Add(1)
go func(file string) {
defer wg.Done()
sem <- struct{}{} // Acquire semaphore
defer func() { <-sem }() // Release semaphore
// Build search query from file path
query := buildFileQuery(file)
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
if vecErr != nil {
log.Warn().Err(vecErr).Str("file", file).Msg("Vector search failed for file context")
return
}
// Extract observation IDs from vector results
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
if len(obsIDs) == 0 {
return
}
// Fetch observations
observations, err := s.observationStore.GetObservationsByIDs(ctx, obsIDs, "score_desc", limit*2)
if err != nil {
log.Warn().Err(err).Str("file", file).Msg("Failed to fetch observations for file context")
return
}
// Pre-build score map from vector results (O(n) instead of O(n²))
scoreMap := make(map[int64]float64, len(vectorResults))
var avgScore float64
for _, vr := range vectorResults {
avgScore += vr.Similarity
// Parse observation ID from vector result ID (format: obs_{id}_{field})
// Use index-based parsing to avoid slice allocation from strings.Split
if len(vr.ID) > 4 && vr.ID[:4] == "obs_" {
rest := vr.ID[4:] // Skip "obs_"
underscoreIdx := strings.IndexByte(rest, '_')
var idStr string
if underscoreIdx >= 0 {
idStr = rest[:underscoreIdx]
} else {
idStr = rest
}
if id, parseErr := strconv.ParseInt(idStr, 10, 64); parseErr == nil {
// Keep highest score for each observation
if existing, exists := scoreMap[id]; !exists || vr.Similarity > existing {
scoreMap[id] = vr.Similarity
}
}
}
}
if len(vectorResults) > 0 {
avgScore /= float64(len(vectorResults))
}
fileResults := make([]map[string]any, 0, limit)
var usedIDs []int64
for _, obs := range observations {
// Check project scope
if obs.Scope == "project" && obs.Project != project {
continue
}
// O(1) score lookup instead of O(n) nested loop
score, found := scoreMap[obs.ID]
if !found {
// Use average score as fallback
score = avgScore
}
// Only include if score is above threshold
if score < 0.3 {
continue
}
fileResults = append(fileResults, map[string]any{
"id": obs.ID,
"title": obs.Title.String,
"type": obs.Type,
"narrative": obs.Narrative.String,
"facts": obs.Facts,
"score": score,
})
usedIDs = append(usedIDs, obs.ID)
if len(fileResults) >= limit {
break
}
}
if len(fileResults) > 0 {
resultsChan <- fileResult{file: file, results: fileResults, obsIDs: usedIDs}
}
}(file)
}
// Close channel when all goroutines complete
go func() {
wg.Wait()
close(resultsChan)
}()
// Collect results and deduplicate
allResults := make(map[string]any)
seenObservationIDs := make(map[int64]bool)
for res := range resultsChan {
// Filter out duplicates that were already seen in other files
dedupedResults := make([]map[string]any, 0, len(res.results))
for i, r := range res.results {
obsID := res.obsIDs[i]
if !seenObservationIDs[obsID] {
seenObservationIDs[obsID] = true
dedupedResults = append(dedupedResults, r)
}
}
if len(dedupedResults) > 0 {
allResults[res.file] = dedupedResults
}
}
writeJSON(w, map[string]any{
"files": files,
"results": allResults,
"count": len(allResults),
})
}
// buildFileQuery extracts meaningful search terms from a file path.
func buildFileQuery(filePath string) string {
// Remove common prefixes and extensions
path := strings.TrimPrefix(filePath, "/")
// Extract the filename and directory
parts := strings.Split(path, "/")
meaningful := make([]string, 0, len(parts))
for _, part := range parts {
// Skip common directory names that aren't meaningful
switch strings.ToLower(part) {
case "src", "lib", "internal", "pkg", "cmd", "api", "app", "test", "tests", "spec", "specs":
continue
default:
// Remove file extension
if idx := strings.LastIndex(part, "."); idx > 0 {
part = part[:idx]
}
// Convert camelCase/PascalCase to spaces
part = splitCamelCase(part)
// Convert snake_case to spaces
part = strings.ReplaceAll(part, "_", " ")
// Convert kebab-case to spaces
part = strings.ReplaceAll(part, "-", " ")
meaningful = append(meaningful, part)
}
}
return strings.Join(meaningful, " ")
}
// splitCamelCase splits camelCase or PascalCase into separate words.
func splitCamelCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune(' ')
}
result.WriteRune(r)
}
return result.String()
}
// handleContextInject returns context for injection at session start.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
cwd := r.URL.Query().Get("cwd")
if cwd == "" {
cwd = "/"
}
// Limit observations for fast startup (configurable, default 100)
limit := s.config.ContextObservations
if limit <= 0 {
limit = DefaultContextLimit
}
// Full count determines how many observations get full detail (configurable, default 25)
fullCount := s.config.ContextFullCount
if fullCount <= 0 {
fullCount = 25
}
// Get recent observations
observations, err := s.observationStore.GetRecentObservations(ctx, project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Fast staleness filter - NO verification (that's too slow for startup)
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
// Record retrieval stats with staleness metrics
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), false)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info().
Str("project", project).
Int("total", len(observations)).
Int("fresh", len(freshObservations)).
Int("clustered", len(clusteredObservations)).
Int("duplicates", duplicatesRemoved).
Int("stale_excluded", staleCount).
Msg("Context injection with clustering")
writeJSON(w, map[string]any{
"project": project,
"observations": clusteredObservations,
"full_count": fullCount,
"stale_excluded": staleCount,
"duplicates_removed": duplicatesRemoved,
})
}
// handleContextCount returns the count of observations for a project.
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
// Add request-scoped timeout to prevent indefinite blocking on DB operations
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
defer cancel()
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
count, err := s.getCachedObservationCount(ctx, project)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]any{
"project": project,
"count": count,
})
}