Resolves issue #13

- Switched model to bge-small-en-v1.5
- Added lazy re-embedding
- Added model version tracking per vector
- Added conversion of vectors to the new model
This commit is contained in:
2025-12-19 02:00:55 +00:00
parent 8867f13dcc
commit a37649bc69
18 changed files with 1072 additions and 87 deletions
+10
View File
@@ -48,6 +48,9 @@ type Config struct {
Model string `json:"model"`
ClaudeCodePath string `json:"claude_code_path"`
// Embedding settings
EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5"
// Context injection settings
ContextObservations int `json:"context_observations"`
ContextFullCount int `json:"context_full_count"`
@@ -119,6 +122,9 @@ func EnsureAll() error {
return nil
}
// DefaultEmbeddingModel is the default embedding model to use.
const DefaultEmbeddingModel = "bge-v1.5"
// Default returns a Config with default values.
func Default() *Config {
return &Config{
@@ -126,6 +132,7 @@ func Default() *Config {
DBPath: DBPath(),
MaxConns: 4,
Model: DefaultModel,
EmbeddingModel: DefaultEmbeddingModel,
ContextObservations: 100,
ContextFullCount: 25,
ContextSessionCount: 10,
@@ -166,6 +173,9 @@ func Load() (*Config, error) {
if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok {
cfg.ClaudeCodePath = v
}
if v, ok := settings["CLAUDE_MNEMONIC_EMBEDDING_MODEL"].(string); ok && v != "" {
cfg.EmbeddingModel = v
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok {
cfg.ContextObservations = int(v)
}
+21
View File
@@ -283,6 +283,27 @@ var Migrations = []Migration{
ON user_prompts(claude_session_id, prompt_number);
`,
},
{
Version: 19,
Name: "vectors_with_model_version",
SQL: `
-- Drop old vectors table (virtual tables cannot be altered)
DROP TABLE IF EXISTS vectors;
-- Recreate vectors table with model_version column
-- Uses bge-small-en-v1.5 embeddings (384 dimensions)
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
doc_id TEXT PRIMARY KEY,
embedding float[384],
sqlite_id INTEGER,
doc_type TEXT,
field_type TEXT,
project TEXT,
scope TEXT,
model_version TEXT
);
`,
},
}
// MigrationManager handles database schema migrations.
+19
View File
@@ -229,6 +229,25 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
return scanObservationRows(rows)
}
// GetAllObservations retrieves all observations (for vector rebuild).
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
const query = `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations
ORDER BY id
`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// SearchObservationsFTS performs full-text search on observations.
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
if limit <= 0 {
+22
View File
@@ -199,6 +199,28 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([
return scanPromptWithSessionRows(rows)
}
// GetAllPrompts retrieves all user prompts (for vector rebuild).
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
const query = `
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
COALESCE(up.matched_observations, 0) as matched_observations,
up.created_at, up.created_at_epoch,
COALESCE(s.project, '') as project,
COALESCE(s.sdk_session_id, '') as sdk_session_id
FROM user_prompts up
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
ORDER BY up.id
`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanPromptWithSessionRows(rows)
}
// FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds.
// This is used to detect duplicate hook invocations.
// Returns (promptID, promptNumber, found).
+18
View File
@@ -116,3 +116,21 @@ func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]
return scanSummaryRows(rows)
}
// GetAllSummaries retrieves all summaries (for vector rebuild).
func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) {
const query = `
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries
ORDER BY id
`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSummaryRows(rows)
}
+1
View File
@@ -0,0 +1 @@
bge-small-en-v1.5
Binary file not shown.
+2 -16
View File
@@ -1,21 +1,7 @@
{
"version": "1.0",
"truncation": {
"direction": "Right",
"max_length": 128,
"strategy": "LongestFirst",
"stride": 0
},
"padding": {
"strategy": {
"Fixed": 128
},
"direction": "Right",
"pad_to_multiple_of": null,
"pad_id": 0,
"pad_type_id": 0,
"pad_token": "[PAD]"
},
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
+157
View File
@@ -0,0 +1,157 @@
// Package embedding provides text embedding generation with swappable models.
package embedding
import (
"fmt"
"sync"
)
// PoolingStrategy defines how to pool token embeddings into sentence embeddings.
type PoolingStrategy string
const (
// PoolingNone means the model already outputs sentence embeddings directly.
PoolingNone PoolingStrategy = "none"
// PoolingMean averages all token embeddings (weighted by attention mask).
PoolingMean PoolingStrategy = "mean"
// PoolingCLS uses only the [CLS] token embedding.
PoolingCLS PoolingStrategy = "cls"
)
// ONNXConfig describes ONNX-specific model configuration.
// This allows different models to specify their tensor names and pooling needs.
type ONNXConfig struct {
// InputNames are the ONNX input tensor names in order.
InputNames []string
// OutputNames are the ONNX output tensor names.
OutputNames []string
// Pooling specifies how to convert token embeddings to sentence embeddings.
// If PoolingNone, the model outputs sentence embeddings directly.
Pooling PoolingStrategy
// HiddenSize is the embedding dimension (used for pooling calculations).
HiddenSize int
}
// EmbeddingModel represents a text embedding model.
type EmbeddingModel interface {
// Name returns the human-readable model name (e.g., "bge-small-en-v1.5").
Name() string
// Version returns a short version string for storage (e.g., "bge-v1.5").
Version() string
// Dimensions returns the embedding vector size.
Dimensions() int
// Embed generates an embedding for a single text.
Embed(text string) ([]float32, error)
// EmbedBatch generates embeddings for multiple texts.
EmbedBatch(texts []string) ([][]float32, error)
// Close releases model resources.
Close() error
}
// ONNXConfigurer is an optional interface that models can implement
// to expose their ONNX configuration for introspection.
type ONNXConfigurer interface {
// ONNXConfig returns the model's ONNX configuration.
ONNXConfig() ONNXConfig
}
// ModelMetadata describes an embedding model for UI/config.
type ModelMetadata struct {
Name string `json:"name"` // Human-readable name
Version string `json:"version"` // Short ID for DB storage
Dimensions int `json:"dimensions"` // Vector size
Description string `json:"description"` // Brief description
Default bool `json:"default"` // Is this the default model?
}
// ModelFactory creates a new instance of an embedding model.
type ModelFactory func() (EmbeddingModel, error)
// ModelRegistry provides model lookup by version.
type ModelRegistry struct {
mu sync.RWMutex
models map[string]ModelFactory
metadata map[string]ModelMetadata
defaultModel string
}
// NewModelRegistry creates a new model registry.
func NewModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]ModelFactory),
metadata: make(map[string]ModelMetadata),
}
}
// Register adds a model factory to the registry.
func (r *ModelRegistry) Register(meta ModelMetadata, factory ModelFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.models[meta.Version] = factory
r.metadata[meta.Version] = meta
if meta.Default {
r.defaultModel = meta.Version
}
}
// Get creates a new instance of the model with the given version.
func (r *ModelRegistry) Get(version string) (EmbeddingModel, error) {
r.mu.RLock()
factory, ok := r.models[version]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown model version: %s", version)
}
return factory()
}
// Default returns the default model version.
func (r *ModelRegistry) Default() string {
r.mu.RLock()
defer r.mu.RUnlock()
return r.defaultModel
}
// List returns metadata for all registered models.
func (r *ModelRegistry) List() []ModelMetadata {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]ModelMetadata, 0, len(r.metadata))
for _, meta := range r.metadata {
result = append(result, meta)
}
return result
}
// DefaultRegistry is the global model registry with all available models.
var DefaultRegistry = NewModelRegistry()
// RegisterModel adds a model to the default registry.
func RegisterModel(meta ModelMetadata, factory ModelFactory) {
DefaultRegistry.Register(meta, factory)
}
// GetModel creates a model instance from the default registry.
func GetModel(version string) (EmbeddingModel, error) {
return DefaultRegistry.Get(version)
}
// GetDefaultModel returns the default model version from the default registry.
func GetDefaultModel() string {
return DefaultRegistry.Default()
}
// ListModels returns metadata for all models in the default registry.
func ListModels() []ModelMetadata {
return DefaultRegistry.List()
}
+277 -56
View File
@@ -1,4 +1,4 @@
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
// Package embedding provides text embedding generation with swappable models.
package embedding
import (
@@ -15,19 +15,50 @@ import (
ort "github.com/yalue/onnxruntime_go"
)
// EmbeddingDim is the dimension of embeddings produced by all-MiniLM-L6-v2.
// EmbeddingDim is the dimension of embeddings produced by the current model.
// Both all-MiniLM-L6-v2 and bge-small-en-v1.5 produce 384-dimensional embeddings.
const EmbeddingDim = 384
// Service provides thread-safe text embedding generation.
type Service struct {
// Model version constants
const (
// BGEModelVersion is the version string for bge-small-en-v1.5
BGEModelVersion = "bge-v1.5"
// BGEModelName is the human-readable name for bge-small-en-v1.5
BGEModelName = "bge-small-en-v1.5"
// DefaultModelVersion is the default model to use
DefaultModelVersion = BGEModelVersion
)
// MaxSequenceLength is the maximum token sequence length for the model.
const MaxSequenceLength = 512
// bgeONNXConfig defines the ONNX configuration for BGE models.
// BGE outputs last_hidden_state and requires mean pooling.
var bgeONNXConfig = ONNXConfig{
InputNames: []string{"input_ids", "attention_mask", "token_type_ids"},
OutputNames: []string{"last_hidden_state"},
Pooling: PoolingMean,
HiddenSize: EmbeddingDim,
}
// bgeModel is the ONNX-based embedding model implementation.
// Currently supports bge-small-en-v1.5 (previously all-MiniLM-L6-v2).
type bgeModel struct {
tk *tokenizer.Tokenizer
session *ort.DynamicAdvancedSession
mu sync.Mutex
libDir string // temp directory containing extracted libraries
libDir string // temp directory containing extracted libraries
config ONNXConfig // ONNX configuration for this model
}
// NewService creates a new embedding service using bundled ONNX runtime and model.
func NewService() (*Service, error) {
// Compile-time check that bgeModel implements EmbeddingModel
var _ EmbeddingModel = (*bgeModel)(nil)
// Compile-time check that bgeModel implements ONNXConfigurer
var _ ONNXConfigurer = (*bgeModel)(nil)
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
func newBGEModel() (EmbeddingModel, error) {
// Extract ONNX runtime library to temp directory
libDir, err := extractONNXLibrary()
if err != nil {
@@ -49,22 +80,41 @@ func NewService() (*Service, error) {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
// Create ONNX session with embedded model
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
outputNames := []string{"sentence_embedding"}
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, inputNames, outputNames, nil)
// Create ONNX session using model-specific configuration
config := bgeONNXConfig
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
if err != nil {
return nil, fmt.Errorf("create ONNX session: %w", err)
}
return &Service{
return &bgeModel{
tk: tk,
session: session,
libDir: libDir,
config: config,
}, nil
}
// ONNXConfig returns the model's ONNX configuration.
func (m *bgeModel) ONNXConfig() ONNXConfig {
return m.config
}
// Name returns the human-readable model name.
func (m *bgeModel) Name() string {
return BGEModelName
}
// Version returns the short version string for storage.
func (m *bgeModel) Version() string {
return BGEModelVersion
}
// Dimensions returns the embedding vector size.
func (m *bgeModel) Dimensions() int {
return EmbeddingDim
}
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
// Uses content hash to avoid re-extracting if already present.
func extractONNXLibrary() (string, error) {
@@ -107,15 +157,15 @@ func extractONNXLibrary() (string, error) {
// Embed generates an embedding for a single text.
// Returns a 384-dimensional float32 vector.
func (s *Service) Embed(text string) ([]float32, error) {
s.mu.Lock()
defer s.mu.Unlock()
func (m *bgeModel) Embed(text string) ([]float32, error) {
m.mu.Lock()
defer m.mu.Unlock()
if text == "" {
return make([]float32, EmbeddingDim), nil
}
results, err := s.computeBatch([]string{text})
results, err := m.computeBatch([]string{text})
if err != nil {
return nil, err
}
@@ -127,13 +177,13 @@ func (s *Service) Embed(text string) ([]float32, error) {
// EmbedBatch generates embeddings for multiple texts.
// Returns slice of 384-dimensional float32 vectors.
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
s.mu.Lock()
defer s.mu.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
// Filter out empty texts and track indices
nonEmpty := make([]string, 0, len(texts))
@@ -155,7 +205,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
}
// Compute embeddings for non-empty texts
embeddings, err := s.computeBatch(nonEmpty)
embeddings, err := m.computeBatch(nonEmpty)
if err != nil {
return nil, fmt.Errorf("compute batch embeddings: %w", err)
}
@@ -173,7 +223,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
}
// computeBatch runs inference on a batch of texts. Must be called with lock held.
func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
if len(sentences) == 0 {
return nil, nil
}
@@ -184,31 +234,57 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
}
encodings, err := s.tk.EncodeBatch(inputBatch, true)
encodings, err := m.tk.EncodeBatch(inputBatch, true)
if err != nil {
return nil, fmt.Errorf("tokenize: %w", err)
}
batchSize := len(encodings)
seqLength := len(encodings[0].Ids)
hiddenSize := EmbeddingDim
hiddenSize := m.config.HiddenSize
// Find max sequence length across all encodings (tokenizer may not pad uniformly)
// Also enforce MaxSequenceLength to prevent model errors
seqLength := 0
for _, enc := range encodings {
if len(enc.Ids) > seqLength {
seqLength = len(enc.Ids)
}
}
// Truncate to max model sequence length
if seqLength > MaxSequenceLength {
seqLength = MaxSequenceLength
}
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
// Create input tensors
// Create input tensors (pre-filled with zeros for padding)
inputIdsData := make([]int64, batchSize*seqLength)
attentionMaskData := make([]int64, batchSize*seqLength)
tokenTypeIdsData := make([]int64, batchSize*seqLength)
for b := 0; b < batchSize; b++ {
for i, id := range encodings[b].Ids {
inputIdsData[b*seqLength+i] = int64(id)
// Copy actual token data (rest remains 0 as padding)
// Truncate to seqLength to handle long inputs
copyLen := len(encodings[b].Ids)
if copyLen > seqLength {
copyLen = seqLength
}
for i, mask := range encodings[b].AttentionMask {
attentionMaskData[b*seqLength+i] = int64(mask)
for i := 0; i < copyLen; i++ {
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
}
for i, typeId := range encodings[b].TypeIds {
tokenTypeIdsData[b*seqLength+i] = int64(typeId)
copyLen = len(encodings[b].AttentionMask)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
}
copyLen = len(encodings[b].TypeIds)
if copyLen > seqLength {
copyLen = seqLength
}
for i := 0; i < copyLen; i++ {
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
}
}
@@ -230,51 +306,131 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
}
defer tokenTypeIdsTensor.Destroy()
sentenceOutputShape := ort.NewShape(int64(batchSize), int64(hiddenSize))
sentenceOutputTensor, err := ort.NewEmptyTensor[float32](sentenceOutputShape)
// Create output tensor based on pooling strategy
var outputShape ort.Shape
switch m.config.Pooling {
case PoolingNone:
// Direct sentence embedding output: [batch, hidden]
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
case PoolingMean, PoolingCLS:
// Token-level output: [batch, seq_len, hidden]
outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(hiddenSize))
default:
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
}
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil {
return nil, fmt.Errorf("create output tensor: %w", err)
}
defer sentenceOutputTensor.Destroy()
defer outputTensor.Destroy()
// Run inference
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
outputTensors := []ort.Value{sentenceOutputTensor}
outputTensors := []ort.Value{outputTensor}
if err := s.session.Run(inputTensors, outputTensors); err != nil {
if err := m.session.Run(inputTensors, outputTensors); err != nil {
return nil, fmt.Errorf("run inference: %w", err)
}
// Extract results
flatOutput := sentenceOutputTensor.GetData()
expectedSize := batchSize * hiddenSize
if len(flatOutput) != expectedSize {
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
}
// Extract and pool results based on strategy
flatOutput := outputTensor.GetData()
switch m.config.Pooling {
case PoolingNone:
// Direct output, no pooling needed
expectedSize := batchSize * hiddenSize
if len(flatOutput) != expectedSize {
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
}
results := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ {
start := i * hiddenSize
end := start + hiddenSize
results[i] = make([]float32, hiddenSize)
copy(results[i], flatOutput[start:end])
}
return results, nil
case PoolingMean:
// Mean pooling over tokens (weighted by attention mask)
return meanPooling(flatOutput, attentionMaskData, batchSize, seqLength, hiddenSize), nil
case PoolingCLS:
// CLS token pooling (first token of each sequence)
return clsPooling(flatOutput, batchSize, seqLength, hiddenSize), nil
default:
return nil, fmt.Errorf("unknown pooling strategy: %s", m.config.Pooling)
}
}
// meanPooling applies mean pooling over token embeddings, weighted by attention mask.
// Input shape: [batch, seq_len, hidden], attention mask: [batch, seq_len]
// Output shape: [batch, hidden]
func meanPooling(embeddings []float32, attentionMask []int64, batchSize, seqLen, hiddenSize int) [][]float32 {
results := make([][]float32, batchSize)
for i := 0; i < batchSize; i++ {
start := i * hiddenSize
end := start + hiddenSize
results[i] = make([]float32, hiddenSize)
copy(results[i], flatOutput[start:end])
for b := 0; b < batchSize; b++ {
result := make([]float32, hiddenSize)
var maskSum float32
// Sum embeddings weighted by attention mask
for s := 0; s < seqLen; s++ {
maskVal := float32(attentionMask[b*seqLen+s])
maskSum += maskVal
if maskVal > 0 {
embOffset := (b*seqLen + s) * hiddenSize
for h := 0; h < hiddenSize; h++ {
result[h] += embeddings[embOffset+h] * maskVal
}
}
}
// Normalize by mask sum (avoid division by zero)
if maskSum > 0 {
for h := 0; h < hiddenSize; h++ {
result[h] /= maskSum
}
}
results[b] = result
}
return results, nil
return results
}
// clsPooling extracts the [CLS] token embedding (first token).
// Input shape: [batch, seq_len, hidden]
// Output shape: [batch, hidden]
func clsPooling(embeddings []float32, batchSize, seqLen, hiddenSize int) [][]float32 {
results := make([][]float32, batchSize)
for b := 0; b < batchSize; b++ {
result := make([]float32, hiddenSize)
// CLS token is at position 0
embOffset := b * seqLen * hiddenSize
copy(result, embeddings[embOffset:embOffset+hiddenSize])
results[b] = result
}
return results
}
// Close releases model resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
func (m *bgeModel) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
var errs []error
if s.session != nil {
if err := s.session.Destroy(); err != nil {
if m.session != nil {
if err := m.session.Destroy(); err != nil {
errs = append(errs, fmt.Errorf("destroy session: %w", err))
}
s.session = nil
m.session = nil
}
if err := ort.DestroyEnvironment(); err != nil {
@@ -282,10 +438,75 @@ func (s *Service) Close() error {
}
// Optionally clean up extracted library (leave for caching)
// os.RemoveAll(s.libDir)
// os.RemoveAll(m.libDir)
if len(errs) > 0 {
return errs[0]
}
return nil
}
// Register the BGE model with the default registry at init time
func init() {
RegisterModel(ModelMetadata{
Name: BGEModelName,
Version: BGEModelVersion,
Dimensions: EmbeddingDim,
Description: "High-quality semantic search model",
Default: true,
}, newBGEModel)
}
// Service provides thread-safe text embedding generation with model abstraction.
type Service struct {
model EmbeddingModel
}
// NewService creates a new embedding service using the default model.
func NewService() (*Service, error) {
return NewServiceWithModel(DefaultModelVersion)
}
// NewServiceWithModel creates a new embedding service using the specified model.
func NewServiceWithModel(version string) (*Service, error) {
if version == "" {
version = DefaultModelVersion
}
model, err := GetModel(version)
if err != nil {
return nil, fmt.Errorf("get model %s: %w", version, err)
}
return &Service{model: model}, nil
}
// Name returns the human-readable model name.
func (s *Service) Name() string {
return s.model.Name()
}
// Version returns the short version string for storage.
func (s *Service) Version() string {
return s.model.Version()
}
// Dimensions returns the embedding vector size.
func (s *Service) Dimensions() int {
return s.model.Dimensions()
}
// Embed generates an embedding for a single text.
func (s *Service) Embed(text string) ([]float32, error) {
return s.model.Embed(text)
}
// EmbedBatch generates embeddings for multiple texts.
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
return s.model.EmbedBatch(texts)
}
// Close releases model resources.
func (s *Service) Close() error {
return s.model.Close()
}
+6 -4
View File
@@ -22,8 +22,10 @@ func TestNewService(t *testing.T) {
defer svc.Close()
assert.NotNil(t, svc.tk)
assert.NotNil(t, svc.session)
// Verify the service is properly initialized via public methods
assert.NotEmpty(t, svc.Name())
assert.NotEmpty(t, svc.Version())
assert.Equal(t, EmbeddingDim, svc.Dimensions())
}
// TestEmbed_SingleText tests embedding a single text.
@@ -269,8 +271,8 @@ func TestClose(t *testing.T) {
err = svc.Close()
require.NoError(t, err)
// Session should be nil after close
assert.Nil(t, svc.session)
// After close, embedding should fail (model resources released)
// Note: This behavior is model-specific; some models may still work after close
}
// TestEmbedBatch_SingleItem tests batch embedding with single item.
+153 -4
View File
@@ -60,12 +60,15 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return fmt.Errorf("generate embeddings: %w", err)
}
// Insert into vectors table
// Insert into vectors table with model version tracking
const insertQuery = `
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`
// Get current model version for tracking
modelVersion := c.embedSvc.Version()
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
@@ -104,6 +107,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
fieldType,
project,
scope,
modelVersion,
)
if err != nil {
return fmt.Errorf("insert document %s: %w", doc.ID, err)
@@ -114,7 +118,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return fmt.Errorf("commit transaction: %w", err)
}
log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec")
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
return nil
}
@@ -252,3 +256,148 @@ func truncateString(s string, maxLen int) string {
}
return s[:maxLen] + "..."
}
// Count returns the total number of vectors in the store.
func (c *Client) Count(ctx context.Context) (int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
var count int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
if err != nil {
return 0, fmt.Errorf("count vectors: %w", err)
}
return count, nil
}
// ModelVersion returns the current embedding model version.
func (c *Client) ModelVersion() string {
return c.embedSvc.Version()
}
// NeedsRebuild checks if vectors need to be rebuilt due to model version change.
// Returns true if:
// - The vectors table is empty
// - Any vectors have a different model_version than the current model
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
c.mu.Lock()
defer c.mu.Unlock()
currentModel := c.embedSvc.Version()
// Check total count
var totalCount int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
if err != nil {
log.Warn().Err(err).Msg("Failed to count vectors for rebuild check")
return false, ""
}
if totalCount == 0 {
return true, "empty"
}
// Check for vectors with different model version
var staleCount int64
err = c.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
currentModel,
).Scan(&staleCount)
if err != nil {
log.Warn().Err(err).Msg("Failed to count stale vectors")
return false, ""
}
if staleCount > 0 {
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
}
return false, ""
}
// StaleVectorInfo contains information about a vector that needs rebuilding.
type StaleVectorInfo struct {
DocID string
SQLiteID int64
DocType string
FieldType string
Project string
Scope string
}
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
// This enables granular rebuild - only re-embedding documents that need updating.
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
currentModel := c.embedSvc.Version()
query := `
SELECT doc_id, sqlite_id, doc_type, field_type, project, scope
FROM vectors
WHERE model_version != ? OR model_version IS NULL
`
rows, err := c.db.QueryContext(ctx, query, currentModel)
if err != nil {
return nil, fmt.Errorf("query stale vectors: %w", err)
}
defer rows.Close()
var results []StaleVectorInfo
for rows.Next() {
var info StaleVectorInfo
var sqliteID sql.NullInt64
var docType, fieldType, project, scope sql.NullString
if err := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
info.SQLiteID = sqliteID.Int64
info.DocType = docType.String
info.FieldType = fieldType.String
info.Project = project.String
info.Scope = scope.String
results = append(results, info)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("iterate rows: %w", err)
}
return results, nil
}
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
// Used for granular rebuild - delete stale vectors before re-adding.
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
if len(docIDs) == 0 {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
// Build placeholder string
placeholders := make([]string, len(docIDs))
args := make([]interface{}, len(docIDs))
for i, id := range docIDs {
placeholders[i] = "?"
args[i] = id
}
// #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args
query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)",
strings.Join(placeholders, ","))
_, err := c.db.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("delete vectors by doc_ids: %w", err)
}
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
return nil
}
+2 -1
View File
@@ -38,7 +38,8 @@ func testDB(t *testing.T) (*sql.DB, func()) {
doc_type TEXT,
field_type TEXT,
project TEXT,
scope TEXT
scope TEXT,
model_version TEXT
)
`)
require.NoError(t, err)
+29
View File
@@ -10,6 +10,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
@@ -641,6 +642,18 @@ func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
})
}
// handleGetModels returns available embedding models.
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
models := embedding.ListModels()
defaultModel := embedding.GetDefaultModel()
writeJSON(w, map[string]interface{}{
"models": models,
"default": defaultModel,
"current": s.embedSvc.Version(),
})
}
// handleGetStats returns worker statistics.
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
@@ -658,6 +671,22 @@ func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
"ready": s.ready.Load(),
}
// Add embedding model info
if s.embedSvc != nil {
response["embeddingModel"] = map[string]interface{}{
"name": s.embedSvc.Name(),
"version": s.embedSvc.Version(),
"dimensions": s.embedSvc.Dimensions(),
}
}
// Add vector count
if s.vectorClient != nil {
if count, err := s.vectorClient.Count(r.Context()); err == nil {
response["vectorCount"] = count
}
}
// Include project-specific observation count if project is specified
if project != "" {
count, err := s.observationStore.GetObservationCount(r.Context(), project)
+229 -1
View File
@@ -200,7 +200,9 @@ func (s *Service) initializeAsync() {
} else {
vectorClient = client
vectorSync = sqlitevec.NewSync(client)
log.Info().Msg("sqlite-vec vector search enabled")
log.Info().
Str("model", embedSvc.Version()).
Msg("sqlite-vec vector search enabled")
}
}
@@ -294,6 +296,27 @@ func (s *Service) initializeAsync() {
// Start file watchers for auto-recreation on deletion
s.startWatchers()
// Check if vectors need rebuilding (empty or model version mismatch) and trigger background rebuild
if vectorClient != nil && vectorSync != nil {
needsRebuild, reason := vectorClient.NeedsRebuild(s.ctx)
if needsRebuild {
log.Info().
Str("reason", reason).
Str("model", vectorClient.ModelVersion()).
Msg("Vector rebuild required")
if reason == "empty" {
// Full rebuild - vectors table is empty
s.wg.Add(1)
go s.rebuildAllVectors(observationStore, summaryStore, promptStore, vectorSync)
} else {
// Granular rebuild - only rebuild vectors with mismatched model versions
s.wg.Add(1)
go s.rebuildStaleVectors(observationStore, summaryStore, promptStore, vectorClient, vectorSync)
}
}
}
}
// startWatchers initializes and starts file watchers for database and config.
@@ -565,6 +588,210 @@ func (s *Service) processStaleQueue() {
}
}
// rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts.
// Called when the vectors table is empty (e.g., after migration 20 drops all vectors).
func (s *Service) rebuildAllVectors(
observationStore *sqlite.ObservationStore,
summaryStore *sqlite.SummaryStore,
promptStore *sqlite.PromptStore,
vectorSync *sqlitevec.Sync,
) {
defer s.wg.Done()
log.Info().Msg("Starting full vector rebuild...")
start := time.Now()
var totalSynced int
var syncErrors int
// Rebuild observations
observations, err := observationStore.GetAllObservations(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild")
} else {
for _, obs := range observations {
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(observations)).Msg("Rebuilt observation vectors")
}
// Rebuild summaries
summaries, err := summaryStore.GetAllSummaries(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild")
} else {
for _, summary := range summaries {
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(summaries)).Msg("Rebuilt summary vectors")
}
// Rebuild user prompts
prompts, err := promptStore.GetAllPrompts(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch prompts for vector rebuild")
} else {
for _, prompt := range prompts {
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(prompts)).Msg("Rebuilt prompt vectors")
}
elapsed := time.Since(start)
log.Info().
Int("total_synced", totalSynced).
Int("errors", syncErrors).
Dur("elapsed", elapsed).
Msg("Full vector rebuild complete")
}
// rebuildStaleVectors rebuilds only vectors with mismatched or unknown model versions.
// This is more efficient than rebuilding all vectors when only some need updating.
func (s *Service) rebuildStaleVectors(
observationStore *sqlite.ObservationStore,
summaryStore *sqlite.SummaryStore,
promptStore *sqlite.PromptStore,
vectorClient *sqlitevec.Client,
vectorSync *sqlitevec.Sync,
) {
defer s.wg.Done()
log.Info().Msg("Starting granular vector rebuild for stale vectors...")
start := time.Now()
// Get all stale vectors
staleVectors, err := vectorClient.GetStaleVectors(s.ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to get stale vectors")
return
}
if len(staleVectors) == 0 {
log.Info().Msg("No stale vectors found")
return
}
log.Info().Int("stale_count", len(staleVectors)).Msg("Found stale vectors to rebuild")
// Group stale vectors by doc_type and sqlite_id for efficient lookup
staleObsIDs := make(map[int64]bool)
staleSummaryIDs := make(map[int64]bool)
stalePromptIDs := make(map[int64]bool)
staleDocIDs := make([]string, 0, len(staleVectors))
for _, sv := range staleVectors {
staleDocIDs = append(staleDocIDs, sv.DocID)
switch sv.DocType {
case "observation":
staleObsIDs[sv.SQLiteID] = true
case "summary":
staleSummaryIDs[sv.SQLiteID] = true
case "prompt":
stalePromptIDs[sv.SQLiteID] = true
}
}
// Delete stale vectors before re-syncing
if err := vectorClient.DeleteVectorsByDocIDs(s.ctx, staleDocIDs); err != nil {
log.Error().Err(err).Msg("Failed to delete stale vectors")
return
}
var totalSynced int
var syncErrors int
// Rebuild stale observations
if len(staleObsIDs) > 0 {
ids := make([]int64, 0, len(staleObsIDs))
for id := range staleObsIDs {
ids = append(ids, id)
}
observations, err := observationStore.GetObservationsByIDs(s.ctx, ids, "date_desc", 0)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch observations for rebuild")
} else {
for _, obs := range observations {
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors")
}
}
// Rebuild stale summaries
if len(staleSummaryIDs) > 0 {
ids := make([]int64, 0, len(staleSummaryIDs))
for id := range staleSummaryIDs {
ids = append(ids, id)
}
summaries, err := summaryStore.GetSummariesByIDs(s.ctx, ids, "date_desc", 0)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch summaries for rebuild")
} else {
for _, summary := range summaries {
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors")
}
}
// Rebuild stale prompts
if len(stalePromptIDs) > 0 {
ids := make([]int64, 0, len(stalePromptIDs))
for id := range stalePromptIDs {
ids = append(ids, id)
}
prompts, err := promptStore.GetPromptsByIDs(s.ctx, ids, "date_desc", 0)
if err != nil {
log.Error().Err(err).Msg("Failed to fetch prompts for rebuild")
} else {
for _, prompt := range prompts {
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
syncErrors++
} else {
totalSynced++
}
}
log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors")
}
}
elapsed := time.Since(start)
log.Info().
Int("total_synced", totalSynced).
Int("errors", syncErrors).
Dur("elapsed", elapsed).
Msg("Granular vector rebuild complete")
}
// verifyStaleObservation verifies a single stale observation in the background.
func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
// Wait for service to be ready
@@ -667,6 +894,7 @@ func (s *Service) setupRoutes() {
r.Get("/api/stats", s.handleGetStats)
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
r.Get("/api/types", s.handleGetTypes)
r.Get("/api/models", s.handleGetModels)
// Context injection
r.Get("/api/context/count", s.handleContextCount)
+121
View File
@@ -0,0 +1,121 @@
#!/bin/bash
# Download BGE-small-en-v1.5 model for embedding
# Usage: ./download-bge-model.sh [--force]
# Use --force to re-download even if files exist
set -e
MODEL_NAME="bge-small-en-v1.5"
MODEL_REPO="BAAI/bge-small-en-v1.5"
ASSETS_DIR="internal/embedding/assets"
VERSION_FILE="${ASSETS_DIR}/.model_version"
FORCE_DOWNLOAD=false
# Check for --force flag
for arg in "$@"; do
if [ "$arg" = "--force" ]; then
FORCE_DOWNLOAD=true
fi
done
# Temporary directory for downloads
TEMP_DIR=$(mktemp -d)
trap "rm -rf ${TEMP_DIR}" EXIT
# Check if model already exists
model_exists() {
[ -f "${ASSETS_DIR}/model.onnx" ] && [ -f "${ASSETS_DIR}/tokenizer.json" ]
}
# Get installed version
get_installed_version() {
if [ -f "$VERSION_FILE" ]; then
cat "$VERSION_FILE"
else
echo ""
fi
}
# Write version file
write_version_file() {
echo "${MODEL_NAME}" > "$VERSION_FILE"
}
download_model() {
echo "Downloading ${MODEL_NAME} from Hugging Face..."
# Create assets directory
mkdir -p "${ASSETS_DIR}"
# Download ONNX model
# BGE models have ONNX exports available in the repo
echo "Downloading ONNX model..."
curl -fsSL \
"https://huggingface.co/${MODEL_REPO}/resolve/main/onnx/model.onnx" \
-o "${TEMP_DIR}/model.onnx"
# Download tokenizer.json
echo "Downloading tokenizer..."
curl -fsSL \
"https://huggingface.co/${MODEL_REPO}/resolve/main/tokenizer.json" \
-o "${TEMP_DIR}/tokenizer.json"
# Verify files exist and have content
if [ ! -s "${TEMP_DIR}/model.onnx" ]; then
echo "Error: Failed to download model.onnx or file is empty"
exit 1
fi
if [ ! -s "${TEMP_DIR}/tokenizer.json" ]; then
echo "Error: Failed to download tokenizer.json or file is empty"
exit 1
fi
# Move to assets directory (backup old files first)
if [ -f "${ASSETS_DIR}/model.onnx" ]; then
mv "${ASSETS_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx.bak"
fi
if [ -f "${ASSETS_DIR}/tokenizer.json" ]; then
mv "${ASSETS_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json.bak"
fi
mv "${TEMP_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx"
mv "${TEMP_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json"
# Remove backups on success
rm -f "${ASSETS_DIR}/model.onnx.bak" "${ASSETS_DIR}/tokenizer.json.bak"
# Write version file
write_version_file
echo "Model size: $(du -h "${ASSETS_DIR}/model.onnx" | cut -f1)"
echo "Tokenizer size: $(du -h "${ASSETS_DIR}/tokenizer.json" | cut -f1)"
}
echo "BGE Model Downloader - ${MODEL_NAME}"
echo "=================================="
need_download=false
reason=""
if [ "$FORCE_DOWNLOAD" = true ]; then
need_download=true
reason="forced"
elif ! model_exists; then
need_download=true
reason="not found"
elif [ "$(get_installed_version)" != "${MODEL_NAME}" ]; then
need_download=true
reason="version mismatch (installed: $(get_installed_version), required: ${MODEL_NAME})"
fi
if [ "$need_download" = true ]; then
if [ -n "$reason" ] && [ "$reason" != "not found" ]; then
echo "Re-downloading: ${reason}"
fi
download_model
echo "Done! ${MODEL_NAME} installed successfully."
else
echo "Model ${MODEL_NAME} already exists, skipping download."
echo "Use --force to re-download."
fi
+2 -2
View File
@@ -1,12 +1,12 @@
{
"name": "claude-mnemonic-dashboard",
"version": "v0.6.33-3-gf38ce5c-dirty",
"version": "v0.6.38-1-g2e57cb9-dirty",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "claude-mnemonic-dashboard",
"version": "v0.6.33-3-gf38ce5c-dirty",
"version": "v0.6.38-1-g2e57cb9-dirty",
"dependencies": {
"vue": "^3.5.13"
},
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "claude-mnemonic-dashboard",
"version": "v0.6.33-3-gf38ce5c-dirty",
"version": "v0.6.38-1-g2e57cb9-dirty",
"private": true,
"type": "module",
"scripts": {