Files
claude-mnemonic/internal/embedding/model.go
T
lukaszraczylo 29d57857ff fix: prevent MCP server hanging by adding concurrency, timeouts, and context propagation (#45)
Root cause: synchronous MCP request processing combined with missing
context propagation to the embedding layer caused indefinite hangs when
ONNX inference was slow or the database was contended.

Changes:
- MCP server: dispatch each request in its own goroutine with semaphore
  (cap 10) and WaitGroup for clean shutdown drain
- Embedding: add context-aware mutex acquisition (acquireMutex) so
  callers can bail out instead of blocking forever on a stuck ONNX model
- Vector client: propagate context through getOrComputeEmbedding and
  replace bare RLock() calls with context-aware acquireRLockWithContext
- Worker handlers: add 15s request-scoped timeouts to all search/context
  handlers (handleSearchByPrompt, handleContextInject, handleFileContext,
  handleContextCount, handleGetObservations/Summaries/Prompts)
- Worker HTTP server: set WriteTimeout=60s (was 0); SSE endpoint extends
  deadline per-request via http.ResponseController

Fixes #45
2026-05-26 14:29:34 +01:00

162 lines
4.6 KiB
Go

// Package embedding provides text embedding generation with swappable models.
package embedding
import (
"context"
"fmt"
"sync"
)
// PoolingStrategy defines how to pool token embeddings into sentence embeddings.
type PoolingStrategy string
const (
// PoolingNone means the model already outputs sentence embeddings directly.
PoolingNone PoolingStrategy = "none"
// PoolingMean averages all token embeddings (weighted by attention mask).
PoolingMean PoolingStrategy = "mean"
// PoolingCLS uses only the [CLS] token embedding.
PoolingCLS PoolingStrategy = "cls"
)
// ONNXConfig describes ONNX-specific model configuration.
// This allows different models to specify their tensor names and pooling needs.
type ONNXConfig struct {
Pooling PoolingStrategy
InputNames []string
OutputNames []string
HiddenSize int
}
// EmbeddingModel represents a text embedding model.
type EmbeddingModel interface {
// Name returns the human-readable model name (e.g., "bge-small-en-v1.5").
Name() string
// Version returns a short version string for storage (e.g., "bge-v1.5").
Version() string
// Dimensions returns the embedding vector size.
Dimensions() int
// Embed generates an embedding for a single text.
Embed(text string) ([]float32, error)
// EmbedBatch generates embeddings for multiple texts.
EmbedBatch(texts []string) ([][]float32, error)
// EmbedWithContext generates an embedding for a single text with context-aware cancellation.
// The context controls mutex acquisition timeout — if ctx is cancelled while waiting
// for the model lock, the call returns immediately with ctx.Err().
EmbedWithContext(ctx context.Context, text string) ([]float32, error)
// EmbedBatchWithContext generates embeddings for multiple texts with context-aware cancellation.
EmbedBatchWithContext(ctx context.Context, texts []string) ([][]float32, error)
// Close releases model resources.
Close() error
}
// ONNXConfigurer is an optional interface that models can implement
// to expose their ONNX configuration for introspection.
type ONNXConfigurer interface {
// ONNXConfig returns the model's ONNX configuration.
ONNXConfig() ONNXConfig
}
// ModelMetadata describes an embedding model for UI/config.
type ModelMetadata struct {
Name string `json:"name"`
Version string `json:"version"`
Description string `json:"description"`
Dimensions int `json:"dimensions"`
Default bool `json:"default"`
}
// ModelFactory creates a new instance of an embedding model.
type ModelFactory func() (EmbeddingModel, error)
// ModelRegistry provides model lookup by version.
type ModelRegistry struct {
models map[string]ModelFactory
metadata map[string]ModelMetadata
defaultModel string
mu sync.RWMutex
}
// NewModelRegistry creates a new model registry.
func NewModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]ModelFactory),
metadata: make(map[string]ModelMetadata),
}
}
// Register adds a model factory to the registry.
func (r *ModelRegistry) Register(meta ModelMetadata, factory ModelFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.models[meta.Version] = factory
r.metadata[meta.Version] = meta
if meta.Default {
r.defaultModel = meta.Version
}
}
// Get creates a new instance of the model with the given version.
func (r *ModelRegistry) Get(version string) (EmbeddingModel, error) {
r.mu.RLock()
factory, ok := r.models[version]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown model version: %s", version)
}
return factory()
}
// Default returns the default model version.
func (r *ModelRegistry) Default() string {
r.mu.RLock()
defer r.mu.RUnlock()
return r.defaultModel
}
// List returns metadata for all registered models.
func (r *ModelRegistry) List() []ModelMetadata {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]ModelMetadata, 0, len(r.metadata))
for _, meta := range r.metadata {
result = append(result, meta)
}
return result
}
// DefaultRegistry is the global model registry with all available models.
var DefaultRegistry = NewModelRegistry()
// RegisterModel adds a model to the default registry.
func RegisterModel(meta ModelMetadata, factory ModelFactory) {
DefaultRegistry.Register(meta, factory)
}
// GetModel creates a model instance from the default registry.
func GetModel(version string) (EmbeddingModel, error) {
return DefaultRegistry.Get(version)
}
// GetDefaultModel returns the default model version from the default registry.
func GetDefaultModel() string {
return DefaultRegistry.Default()
}
// ListModels returns metadata for all models in the default registry.
func ListModels() []ModelMetadata {
return DefaultRegistry.List()
}