mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-09 23:59:40 +00:00
Resolves issue #13
- Switched model to bge-small-en-v1.5 - Added lazy re-embedding - Added model version tracking per vector - Added conversion of vectors to the new model
This commit is contained in:
@@ -48,6 +48,9 @@ type Config struct {
|
||||
Model string `json:"model"`
|
||||
ClaudeCodePath string `json:"claude_code_path"`
|
||||
|
||||
// Embedding settings
|
||||
EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5"
|
||||
|
||||
// Context injection settings
|
||||
ContextObservations int `json:"context_observations"`
|
||||
ContextFullCount int `json:"context_full_count"`
|
||||
@@ -119,6 +122,9 @@ func EnsureAll() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultEmbeddingModel is the default embedding model to use.
|
||||
const DefaultEmbeddingModel = "bge-v1.5"
|
||||
|
||||
// Default returns a Config with default values.
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
@@ -126,6 +132,7 @@ func Default() *Config {
|
||||
DBPath: DBPath(),
|
||||
MaxConns: 4,
|
||||
Model: DefaultModel,
|
||||
EmbeddingModel: DefaultEmbeddingModel,
|
||||
ContextObservations: 100,
|
||||
ContextFullCount: 25,
|
||||
ContextSessionCount: 10,
|
||||
@@ -166,6 +173,9 @@ func Load() (*Config, error) {
|
||||
if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok {
|
||||
cfg.ClaudeCodePath = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_EMBEDDING_MODEL"].(string); ok && v != "" {
|
||||
cfg.EmbeddingModel = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok {
|
||||
cfg.ContextObservations = int(v)
|
||||
}
|
||||
|
||||
@@ -283,6 +283,27 @@ var Migrations = []Migration{
|
||||
ON user_prompts(claude_session_id, prompt_number);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 19,
|
||||
Name: "vectors_with_model_version",
|
||||
SQL: `
|
||||
-- Drop old vectors table (virtual tables cannot be altered)
|
||||
DROP TABLE IF EXISTS vectors;
|
||||
|
||||
-- Recreate vectors table with model_version column
|
||||
-- Uses bge-small-en-v1.5 embeddings (384 dimensions)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
embedding float[384],
|
||||
sqlite_id INTEGER,
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT,
|
||||
model_version TEXT
|
||||
);
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
// MigrationManager handles database schema migrations.
|
||||
|
||||
@@ -229,6 +229,25 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// GetAllObservations retrieves all observations (for vector rebuild).
|
||||
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
|
||||
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
|
||||
created_at, created_at_epoch
|
||||
FROM observations
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanObservationRows(rows)
|
||||
}
|
||||
|
||||
// SearchObservationsFTS performs full-text search on observations.
|
||||
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
||||
if limit <= 0 {
|
||||
|
||||
@@ -199,6 +199,28 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
|
||||
// GetAllPrompts retrieves all user prompts (for vector rebuild).
|
||||
func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) {
|
||||
const query = `
|
||||
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
|
||||
COALESCE(up.matched_observations, 0) as matched_observations,
|
||||
up.created_at, up.created_at_epoch,
|
||||
COALESCE(s.project, '') as project,
|
||||
COALESCE(s.sdk_session_id, '') as sdk_session_id
|
||||
FROM user_prompts up
|
||||
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
|
||||
ORDER BY up.id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanPromptWithSessionRows(rows)
|
||||
}
|
||||
|
||||
// FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds.
|
||||
// This is used to detect duplicate hook invocations.
|
||||
// Returns (promptID, promptNumber, found).
|
||||
|
||||
@@ -116,3 +116,21 @@ func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
|
||||
// GetAllSummaries retrieves all summaries (for vector rebuild).
|
||||
func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) {
|
||||
const query = `
|
||||
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
|
||||
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
|
||||
FROM session_summaries
|
||||
ORDER BY id
|
||||
`
|
||||
|
||||
rows, err := s.store.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSummaryRows(rows)
|
||||
}
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
bge-small-en-v1.5
|
||||
Binary file not shown.
@@ -1,21 +1,7 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"truncation": {
|
||||
"direction": "Right",
|
||||
"max_length": 128,
|
||||
"strategy": "LongestFirst",
|
||||
"stride": 0
|
||||
},
|
||||
"padding": {
|
||||
"strategy": {
|
||||
"Fixed": 128
|
||||
},
|
||||
"direction": "Right",
|
||||
"pad_to_multiple_of": null,
|
||||
"pad_id": 0,
|
||||
"pad_type_id": 0,
|
||||
"pad_token": "[PAD]"
|
||||
},
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
// Package embedding provides text embedding generation with swappable models.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PoolingStrategy defines how to pool token embeddings into sentence embeddings.
|
||||
type PoolingStrategy string
|
||||
|
||||
const (
|
||||
// PoolingNone means the model already outputs sentence embeddings directly.
|
||||
PoolingNone PoolingStrategy = "none"
|
||||
// PoolingMean averages all token embeddings (weighted by attention mask).
|
||||
PoolingMean PoolingStrategy = "mean"
|
||||
// PoolingCLS uses only the [CLS] token embedding.
|
||||
PoolingCLS PoolingStrategy = "cls"
|
||||
)
|
||||
|
||||
// ONNXConfig describes ONNX-specific model configuration.
|
||||
// This allows different models to specify their tensor names and pooling needs.
|
||||
type ONNXConfig struct {
|
||||
// InputNames are the ONNX input tensor names in order.
|
||||
InputNames []string
|
||||
// OutputNames are the ONNX output tensor names.
|
||||
OutputNames []string
|
||||
// Pooling specifies how to convert token embeddings to sentence embeddings.
|
||||
// If PoolingNone, the model outputs sentence embeddings directly.
|
||||
Pooling PoolingStrategy
|
||||
// HiddenSize is the embedding dimension (used for pooling calculations).
|
||||
HiddenSize int
|
||||
}
|
||||
|
||||
// EmbeddingModel represents a text embedding model.
|
||||
type EmbeddingModel interface {
|
||||
// Name returns the human-readable model name (e.g., "bge-small-en-v1.5").
|
||||
Name() string
|
||||
|
||||
// Version returns a short version string for storage (e.g., "bge-v1.5").
|
||||
Version() string
|
||||
|
||||
// Dimensions returns the embedding vector size.
|
||||
Dimensions() int
|
||||
|
||||
// Embed generates an embedding for a single text.
|
||||
Embed(text string) ([]float32, error)
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
EmbedBatch(texts []string) ([][]float32, error)
|
||||
|
||||
// Close releases model resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// ONNXConfigurer is an optional interface that models can implement
|
||||
// to expose their ONNX configuration for introspection.
|
||||
type ONNXConfigurer interface {
|
||||
// ONNXConfig returns the model's ONNX configuration.
|
||||
ONNXConfig() ONNXConfig
|
||||
}
|
||||
|
||||
// ModelMetadata describes an embedding model for UI/config.
|
||||
type ModelMetadata struct {
|
||||
Name string `json:"name"` // Human-readable name
|
||||
Version string `json:"version"` // Short ID for DB storage
|
||||
Dimensions int `json:"dimensions"` // Vector size
|
||||
Description string `json:"description"` // Brief description
|
||||
Default bool `json:"default"` // Is this the default model?
|
||||
}
|
||||
|
||||
// ModelFactory creates a new instance of an embedding model.
|
||||
type ModelFactory func() (EmbeddingModel, error)
|
||||
|
||||
// ModelRegistry provides model lookup by version.
|
||||
type ModelRegistry struct {
|
||||
mu sync.RWMutex
|
||||
models map[string]ModelFactory
|
||||
metadata map[string]ModelMetadata
|
||||
defaultModel string
|
||||
}
|
||||
|
||||
// NewModelRegistry creates a new model registry.
|
||||
func NewModelRegistry() *ModelRegistry {
|
||||
return &ModelRegistry{
|
||||
models: make(map[string]ModelFactory),
|
||||
metadata: make(map[string]ModelMetadata),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a model factory to the registry.
|
||||
func (r *ModelRegistry) Register(meta ModelMetadata, factory ModelFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.models[meta.Version] = factory
|
||||
r.metadata[meta.Version] = meta
|
||||
|
||||
if meta.Default {
|
||||
r.defaultModel = meta.Version
|
||||
}
|
||||
}
|
||||
|
||||
// Get creates a new instance of the model with the given version.
|
||||
func (r *ModelRegistry) Get(version string) (EmbeddingModel, error) {
|
||||
r.mu.RLock()
|
||||
factory, ok := r.models[version]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model version: %s", version)
|
||||
}
|
||||
|
||||
return factory()
|
||||
}
|
||||
|
||||
// Default returns the default model version.
|
||||
func (r *ModelRegistry) Default() string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.defaultModel
|
||||
}
|
||||
|
||||
// List returns metadata for all registered models.
|
||||
func (r *ModelRegistry) List() []ModelMetadata {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
result := make([]ModelMetadata, 0, len(r.metadata))
|
||||
for _, meta := range r.metadata {
|
||||
result = append(result, meta)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// DefaultRegistry is the global model registry with all available models.
|
||||
var DefaultRegistry = NewModelRegistry()
|
||||
|
||||
// RegisterModel adds a model to the default registry.
|
||||
func RegisterModel(meta ModelMetadata, factory ModelFactory) {
|
||||
DefaultRegistry.Register(meta, factory)
|
||||
}
|
||||
|
||||
// GetModel creates a model instance from the default registry.
|
||||
func GetModel(version string) (EmbeddingModel, error) {
|
||||
return DefaultRegistry.Get(version)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns the default model version from the default registry.
|
||||
func GetDefaultModel() string {
|
||||
return DefaultRegistry.Default()
|
||||
}
|
||||
|
||||
// ListModels returns metadata for all models in the default registry.
|
||||
func ListModels() []ModelMetadata {
|
||||
return DefaultRegistry.List()
|
||||
}
|
||||
+277
-56
@@ -1,4 +1,4 @@
|
||||
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
|
||||
// Package embedding provides text embedding generation with swappable models.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
@@ -15,19 +15,50 @@ import (
|
||||
ort "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
// EmbeddingDim is the dimension of embeddings produced by all-MiniLM-L6-v2.
|
||||
// EmbeddingDim is the dimension of embeddings produced by the current model.
|
||||
// Both all-MiniLM-L6-v2 and bge-small-en-v1.5 produce 384-dimensional embeddings.
|
||||
const EmbeddingDim = 384
|
||||
|
||||
// Service provides thread-safe text embedding generation.
|
||||
type Service struct {
|
||||
// Model version constants
|
||||
const (
|
||||
// BGEModelVersion is the version string for bge-small-en-v1.5
|
||||
BGEModelVersion = "bge-v1.5"
|
||||
// BGEModelName is the human-readable name for bge-small-en-v1.5
|
||||
BGEModelName = "bge-small-en-v1.5"
|
||||
// DefaultModelVersion is the default model to use
|
||||
DefaultModelVersion = BGEModelVersion
|
||||
)
|
||||
|
||||
// MaxSequenceLength is the maximum token sequence length for the model.
|
||||
const MaxSequenceLength = 512
|
||||
|
||||
// bgeONNXConfig defines the ONNX configuration for BGE models.
|
||||
// BGE outputs last_hidden_state and requires mean pooling.
|
||||
var bgeONNXConfig = ONNXConfig{
|
||||
InputNames: []string{"input_ids", "attention_mask", "token_type_ids"},
|
||||
OutputNames: []string{"last_hidden_state"},
|
||||
Pooling: PoolingMean,
|
||||
HiddenSize: EmbeddingDim,
|
||||
}
|
||||
|
||||
// bgeModel is the ONNX-based embedding model implementation.
|
||||
// Currently supports bge-small-en-v1.5 (previously all-MiniLM-L6-v2).
|
||||
type bgeModel struct {
|
||||
tk *tokenizer.Tokenizer
|
||||
session *ort.DynamicAdvancedSession
|
||||
mu sync.Mutex
|
||||
libDir string // temp directory containing extracted libraries
|
||||
libDir string // temp directory containing extracted libraries
|
||||
config ONNXConfig // ONNX configuration for this model
|
||||
}
|
||||
|
||||
// NewService creates a new embedding service using bundled ONNX runtime and model.
|
||||
func NewService() (*Service, error) {
|
||||
// Compile-time check that bgeModel implements EmbeddingModel
|
||||
var _ EmbeddingModel = (*bgeModel)(nil)
|
||||
|
||||
// Compile-time check that bgeModel implements ONNXConfigurer
|
||||
var _ ONNXConfigurer = (*bgeModel)(nil)
|
||||
|
||||
// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model.
|
||||
func newBGEModel() (EmbeddingModel, error) {
|
||||
// Extract ONNX runtime library to temp directory
|
||||
libDir, err := extractONNXLibrary()
|
||||
if err != nil {
|
||||
@@ -49,22 +80,41 @@ func NewService() (*Service, error) {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Create ONNX session with embedded model
|
||||
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
||||
outputNames := []string{"sentence_embedding"}
|
||||
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, inputNames, outputNames, nil)
|
||||
// Create ONNX session using model-specific configuration
|
||||
config := bgeONNXConfig
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ONNX session: %w", err)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
return &bgeModel{
|
||||
tk: tk,
|
||||
session: session,
|
||||
libDir: libDir,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ONNXConfig returns the model's ONNX configuration.
|
||||
func (m *bgeModel) ONNXConfig() ONNXConfig {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// Name returns the human-readable model name.
|
||||
func (m *bgeModel) Name() string {
|
||||
return BGEModelName
|
||||
}
|
||||
|
||||
// Version returns the short version string for storage.
|
||||
func (m *bgeModel) Version() string {
|
||||
return BGEModelVersion
|
||||
}
|
||||
|
||||
// Dimensions returns the embedding vector size.
|
||||
func (m *bgeModel) Dimensions() int {
|
||||
return EmbeddingDim
|
||||
}
|
||||
|
||||
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
|
||||
// Uses content hash to avoid re-extracting if already present.
|
||||
func extractONNXLibrary() (string, error) {
|
||||
@@ -107,15 +157,15 @@ func extractONNXLibrary() (string, error) {
|
||||
|
||||
// Embed generates an embedding for a single text.
|
||||
// Returns a 384-dimensional float32 vector.
|
||||
func (s *Service) Embed(text string) ([]float32, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
func (m *bgeModel) Embed(text string) ([]float32, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if text == "" {
|
||||
return make([]float32, EmbeddingDim), nil
|
||||
}
|
||||
|
||||
results, err := s.computeBatch([]string{text})
|
||||
results, err := m.computeBatch([]string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,13 +177,13 @@ func (s *Service) Embed(text string) ([]float32, error) {
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
// Returns slice of 384-dimensional float32 vectors.
|
||||
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Filter out empty texts and track indices
|
||||
nonEmpty := make([]string, 0, len(texts))
|
||||
@@ -155,7 +205,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
}
|
||||
|
||||
// Compute embeddings for non-empty texts
|
||||
embeddings, err := s.computeBatch(nonEmpty)
|
||||
embeddings, err := m.computeBatch(nonEmpty)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
||||
}
|
||||
@@ -173,7 +223,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
}
|
||||
|
||||
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
||||
func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
||||
func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) {
|
||||
if len(sentences) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -184,31 +234,57 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
||||
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
|
||||
}
|
||||
|
||||
encodings, err := s.tk.EncodeBatch(inputBatch, true)
|
||||
encodings, err := m.tk.EncodeBatch(inputBatch, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenize: %w", err)
|
||||
}
|
||||
|
||||
batchSize := len(encodings)
|
||||
seqLength := len(encodings[0].Ids)
|
||||
hiddenSize := EmbeddingDim
|
||||
hiddenSize := m.config.HiddenSize
|
||||
|
||||
// Find max sequence length across all encodings (tokenizer may not pad uniformly)
|
||||
// Also enforce MaxSequenceLength to prevent model errors
|
||||
seqLength := 0
|
||||
for _, enc := range encodings {
|
||||
if len(enc.Ids) > seqLength {
|
||||
seqLength = len(enc.Ids)
|
||||
}
|
||||
}
|
||||
// Truncate to max model sequence length
|
||||
if seqLength > MaxSequenceLength {
|
||||
seqLength = MaxSequenceLength
|
||||
}
|
||||
|
||||
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
||||
|
||||
// Create input tensors
|
||||
// Create input tensors (pre-filled with zeros for padding)
|
||||
inputIdsData := make([]int64, batchSize*seqLength)
|
||||
attentionMaskData := make([]int64, batchSize*seqLength)
|
||||
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
for i, id := range encodings[b].Ids {
|
||||
inputIdsData[b*seqLength+i] = int64(id)
|
||||
// Copy actual token data (rest remains 0 as padding)
|
||||
// Truncate to seqLength to handle long inputs
|
||||
copyLen := len(encodings[b].Ids)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i, mask := range encodings[b].AttentionMask {
|
||||
attentionMaskData[b*seqLength+i] = int64(mask)
|
||||
for i := 0; i < copyLen; i++ {
|
||||
inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i])
|
||||
}
|
||||
for i, typeId := range encodings[b].TypeIds {
|
||||
tokenTypeIdsData[b*seqLength+i] = int64(typeId)
|
||||
copyLen = len(encodings[b].AttentionMask)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i := 0; i < copyLen; i++ {
|
||||
attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i])
|
||||
}
|
||||
copyLen = len(encodings[b].TypeIds)
|
||||
if copyLen > seqLength {
|
||||
copyLen = seqLength
|
||||
}
|
||||
for i := 0; i < copyLen; i++ {
|
||||
tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,51 +306,131 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
||||
}
|
||||
defer tokenTypeIdsTensor.Destroy()
|
||||
|
||||
sentenceOutputShape := ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||
sentenceOutputTensor, err := ort.NewEmptyTensor[float32](sentenceOutputShape)
|
||||
// Create output tensor based on pooling strategy
|
||||
var outputShape ort.Shape
|
||||
|
||||
switch m.config.Pooling {
|
||||
case PoolingNone:
|
||||
// Direct sentence embedding output: [batch, hidden]
|
||||
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||
case PoolingMean, PoolingCLS:
|
||||
// Token-level output: [batch, seq_len, hidden]
|
||||
outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(hiddenSize))
|
||||
default:
|
||||
outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||
}
|
||||
|
||||
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||
}
|
||||
defer sentenceOutputTensor.Destroy()
|
||||
defer outputTensor.Destroy()
|
||||
|
||||
// Run inference
|
||||
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||
outputTensors := []ort.Value{sentenceOutputTensor}
|
||||
outputTensors := []ort.Value{outputTensor}
|
||||
|
||||
if err := s.session.Run(inputTensors, outputTensors); err != nil {
|
||||
if err := m.session.Run(inputTensors, outputTensors); err != nil {
|
||||
return nil, fmt.Errorf("run inference: %w", err)
|
||||
}
|
||||
|
||||
// Extract results
|
||||
flatOutput := sentenceOutputTensor.GetData()
|
||||
expectedSize := batchSize * hiddenSize
|
||||
if len(flatOutput) != expectedSize {
|
||||
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
||||
}
|
||||
// Extract and pool results based on strategy
|
||||
flatOutput := outputTensor.GetData()
|
||||
|
||||
switch m.config.Pooling {
|
||||
case PoolingNone:
|
||||
// Direct output, no pooling needed
|
||||
expectedSize := batchSize * hiddenSize
|
||||
if len(flatOutput) != expectedSize {
|
||||
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
||||
}
|
||||
results := make([][]float32, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
start := i * hiddenSize
|
||||
end := start + hiddenSize
|
||||
results[i] = make([]float32, hiddenSize)
|
||||
copy(results[i], flatOutput[start:end])
|
||||
}
|
||||
return results, nil
|
||||
|
||||
case PoolingMean:
|
||||
// Mean pooling over tokens (weighted by attention mask)
|
||||
return meanPooling(flatOutput, attentionMaskData, batchSize, seqLength, hiddenSize), nil
|
||||
|
||||
case PoolingCLS:
|
||||
// CLS token pooling (first token of each sequence)
|
||||
return clsPooling(flatOutput, batchSize, seqLength, hiddenSize), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown pooling strategy: %s", m.config.Pooling)
|
||||
}
|
||||
}
|
||||
|
||||
// meanPooling applies mean pooling over token embeddings, weighted by attention mask.
|
||||
// Input shape: [batch, seq_len, hidden], attention mask: [batch, seq_len]
|
||||
// Output shape: [batch, hidden]
|
||||
func meanPooling(embeddings []float32, attentionMask []int64, batchSize, seqLen, hiddenSize int) [][]float32 {
|
||||
results := make([][]float32, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
start := i * hiddenSize
|
||||
end := start + hiddenSize
|
||||
results[i] = make([]float32, hiddenSize)
|
||||
copy(results[i], flatOutput[start:end])
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
result := make([]float32, hiddenSize)
|
||||
var maskSum float32
|
||||
|
||||
// Sum embeddings weighted by attention mask
|
||||
for s := 0; s < seqLen; s++ {
|
||||
maskVal := float32(attentionMask[b*seqLen+s])
|
||||
maskSum += maskVal
|
||||
|
||||
if maskVal > 0 {
|
||||
embOffset := (b*seqLen + s) * hiddenSize
|
||||
for h := 0; h < hiddenSize; h++ {
|
||||
result[h] += embeddings[embOffset+h] * maskVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize by mask sum (avoid division by zero)
|
||||
if maskSum > 0 {
|
||||
for h := 0; h < hiddenSize; h++ {
|
||||
result[h] /= maskSum
|
||||
}
|
||||
}
|
||||
|
||||
results[b] = result
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return results
|
||||
}
|
||||
|
||||
// clsPooling extracts the [CLS] token embedding (first token).
|
||||
// Input shape: [batch, seq_len, hidden]
|
||||
// Output shape: [batch, hidden]
|
||||
func clsPooling(embeddings []float32, batchSize, seqLen, hiddenSize int) [][]float32 {
|
||||
results := make([][]float32, batchSize)
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
result := make([]float32, hiddenSize)
|
||||
// CLS token is at position 0
|
||||
embOffset := b * seqLen * hiddenSize
|
||||
copy(result, embeddings[embOffset:embOffset+hiddenSize])
|
||||
results[b] = result
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// Close releases model resources.
|
||||
func (s *Service) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
func (m *bgeModel) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var errs []error
|
||||
|
||||
if s.session != nil {
|
||||
if err := s.session.Destroy(); err != nil {
|
||||
if m.session != nil {
|
||||
if err := m.session.Destroy(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("destroy session: %w", err))
|
||||
}
|
||||
s.session = nil
|
||||
m.session = nil
|
||||
}
|
||||
|
||||
if err := ort.DestroyEnvironment(); err != nil {
|
||||
@@ -282,10 +438,75 @@ func (s *Service) Close() error {
|
||||
}
|
||||
|
||||
// Optionally clean up extracted library (leave for caching)
|
||||
// os.RemoveAll(s.libDir)
|
||||
// os.RemoveAll(m.libDir)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register the BGE model with the default registry at init time
|
||||
func init() {
|
||||
RegisterModel(ModelMetadata{
|
||||
Name: BGEModelName,
|
||||
Version: BGEModelVersion,
|
||||
Dimensions: EmbeddingDim,
|
||||
Description: "High-quality semantic search model",
|
||||
Default: true,
|
||||
}, newBGEModel)
|
||||
}
|
||||
|
||||
// Service provides thread-safe text embedding generation with model abstraction.
|
||||
type Service struct {
|
||||
model EmbeddingModel
|
||||
}
|
||||
|
||||
// NewService creates a new embedding service using the default model.
|
||||
func NewService() (*Service, error) {
|
||||
return NewServiceWithModel(DefaultModelVersion)
|
||||
}
|
||||
|
||||
// NewServiceWithModel creates a new embedding service using the specified model.
|
||||
func NewServiceWithModel(version string) (*Service, error) {
|
||||
if version == "" {
|
||||
version = DefaultModelVersion
|
||||
}
|
||||
|
||||
model, err := GetModel(version)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model %s: %w", version, err)
|
||||
}
|
||||
|
||||
return &Service{model: model}, nil
|
||||
}
|
||||
|
||||
// Name returns the human-readable model name.
|
||||
func (s *Service) Name() string {
|
||||
return s.model.Name()
|
||||
}
|
||||
|
||||
// Version returns the short version string for storage.
|
||||
func (s *Service) Version() string {
|
||||
return s.model.Version()
|
||||
}
|
||||
|
||||
// Dimensions returns the embedding vector size.
|
||||
func (s *Service) Dimensions() int {
|
||||
return s.model.Dimensions()
|
||||
}
|
||||
|
||||
// Embed generates an embedding for a single text.
|
||||
func (s *Service) Embed(text string) ([]float32, error) {
|
||||
return s.model.Embed(text)
|
||||
}
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
return s.model.EmbedBatch(texts)
|
||||
}
|
||||
|
||||
// Close releases model resources.
|
||||
func (s *Service) Close() error {
|
||||
return s.model.Close()
|
||||
}
|
||||
|
||||
@@ -22,8 +22,10 @@ func TestNewService(t *testing.T) {
|
||||
|
||||
defer svc.Close()
|
||||
|
||||
assert.NotNil(t, svc.tk)
|
||||
assert.NotNil(t, svc.session)
|
||||
// Verify the service is properly initialized via public methods
|
||||
assert.NotEmpty(t, svc.Name())
|
||||
assert.NotEmpty(t, svc.Version())
|
||||
assert.Equal(t, EmbeddingDim, svc.Dimensions())
|
||||
}
|
||||
|
||||
// TestEmbed_SingleText tests embedding a single text.
|
||||
@@ -269,8 +271,8 @@ func TestClose(t *testing.T) {
|
||||
err = svc.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Session should be nil after close
|
||||
assert.Nil(t, svc.session)
|
||||
// After close, embedding should fail (model resources released)
|
||||
// Note: This behavior is model-specific; some models may still work after close
|
||||
}
|
||||
|
||||
// TestEmbedBatch_SingleItem tests batch embedding with single item.
|
||||
|
||||
@@ -60,12 +60,15 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return fmt.Errorf("generate embeddings: %w", err)
|
||||
}
|
||||
|
||||
// Insert into vectors table
|
||||
// Insert into vectors table with model version tracking
|
||||
const insertQuery = `
|
||||
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
// Get current model version for tracking
|
||||
modelVersion := c.embedSvc.Version()
|
||||
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
@@ -104,6 +107,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
fieldType,
|
||||
project,
|
||||
scope,
|
||||
modelVersion,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert document %s: %w", doc.ID, err)
|
||||
@@ -114,7 +118,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec")
|
||||
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -252,3 +256,148 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// Count returns the total number of vectors in the store.
|
||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var count int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count vectors: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ModelVersion returns the current embedding model version.
|
||||
func (c *Client) ModelVersion() string {
|
||||
return c.embedSvc.Version()
|
||||
}
|
||||
|
||||
// NeedsRebuild checks if vectors need to be rebuilt due to model version change.
|
||||
// Returns true if:
|
||||
// - The vectors table is empty
|
||||
// - Any vectors have a different model_version than the current model
|
||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
// Check total count
|
||||
var totalCount int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to count vectors for rebuild check")
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if totalCount == 0 {
|
||||
return true, "empty"
|
||||
}
|
||||
|
||||
// Check for vectors with different model version
|
||||
var staleCount int64
|
||||
err = c.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
|
||||
currentModel,
|
||||
).Scan(&staleCount)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to count stale vectors")
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if staleCount > 0 {
|
||||
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// StaleVectorInfo contains information about a vector that needs rebuilding.
|
||||
type StaleVectorInfo struct {
|
||||
DocID string
|
||||
SQLiteID int64
|
||||
DocType string
|
||||
FieldType string
|
||||
Project string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
query := `
|
||||
SELECT doc_id, sqlite_id, doc_type, field_type, project, scope
|
||||
FROM vectors
|
||||
WHERE model_version != ? OR model_version IS NULL
|
||||
`
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, query, currentModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query stale vectors: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []StaleVectorInfo
|
||||
for rows.Next() {
|
||||
var info StaleVectorInfo
|
||||
var sqliteID sql.NullInt64
|
||||
var docType, fieldType, project, scope sql.NullString
|
||||
|
||||
if err := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); err != nil {
|
||||
return nil, fmt.Errorf("scan row: %w", err)
|
||||
}
|
||||
|
||||
info.SQLiteID = sqliteID.Int64
|
||||
info.DocType = docType.String
|
||||
info.FieldType = fieldType.String
|
||||
info.Project = project.String
|
||||
info.Scope = scope.String
|
||||
|
||||
results = append(results, info)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate rows: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
|
||||
// Used for granular rebuild - delete stale vectors before re-adding.
|
||||
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
|
||||
if len(docIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build placeholder string
|
||||
placeholders := make([]string, len(docIDs))
|
||||
args := make([]interface{}, len(docIDs))
|
||||
for i, id := range docIDs {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
// #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args
|
||||
query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)",
|
||||
strings.Join(placeholders, ","))
|
||||
|
||||
_, err := c.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete vectors by doc_ids: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,7 +38,8 @@ func testDB(t *testing.T) (*sql.DB, func()) {
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT
|
||||
scope TEXT,
|
||||
model_version TEXT
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
@@ -641,6 +642,18 @@ func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetModels returns available embedding models.
|
||||
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
|
||||
models := embedding.ListModels()
|
||||
defaultModel := embedding.GetDefaultModel()
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"models": models,
|
||||
"default": defaultModel,
|
||||
"current": s.embedSvc.Version(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetStats returns worker statistics.
|
||||
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
@@ -658,6 +671,22 @@ func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
"ready": s.ready.Load(),
|
||||
}
|
||||
|
||||
// Add embedding model info
|
||||
if s.embedSvc != nil {
|
||||
response["embeddingModel"] = map[string]interface{}{
|
||||
"name": s.embedSvc.Name(),
|
||||
"version": s.embedSvc.Version(),
|
||||
"dimensions": s.embedSvc.Dimensions(),
|
||||
}
|
||||
}
|
||||
|
||||
// Add vector count
|
||||
if s.vectorClient != nil {
|
||||
if count, err := s.vectorClient.Count(r.Context()); err == nil {
|
||||
response["vectorCount"] = count
|
||||
}
|
||||
}
|
||||
|
||||
// Include project-specific observation count if project is specified
|
||||
if project != "" {
|
||||
count, err := s.observationStore.GetObservationCount(r.Context(), project)
|
||||
|
||||
+229
-1
@@ -200,7 +200,9 @@ func (s *Service) initializeAsync() {
|
||||
} else {
|
||||
vectorClient = client
|
||||
vectorSync = sqlitevec.NewSync(client)
|
||||
log.Info().Msg("sqlite-vec vector search enabled")
|
||||
log.Info().
|
||||
Str("model", embedSvc.Version()).
|
||||
Msg("sqlite-vec vector search enabled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -294,6 +296,27 @@ func (s *Service) initializeAsync() {
|
||||
|
||||
// Start file watchers for auto-recreation on deletion
|
||||
s.startWatchers()
|
||||
|
||||
// Check if vectors need rebuilding (empty or model version mismatch) and trigger background rebuild
|
||||
if vectorClient != nil && vectorSync != nil {
|
||||
needsRebuild, reason := vectorClient.NeedsRebuild(s.ctx)
|
||||
if needsRebuild {
|
||||
log.Info().
|
||||
Str("reason", reason).
|
||||
Str("model", vectorClient.ModelVersion()).
|
||||
Msg("Vector rebuild required")
|
||||
|
||||
if reason == "empty" {
|
||||
// Full rebuild - vectors table is empty
|
||||
s.wg.Add(1)
|
||||
go s.rebuildAllVectors(observationStore, summaryStore, promptStore, vectorSync)
|
||||
} else {
|
||||
// Granular rebuild - only rebuild vectors with mismatched model versions
|
||||
s.wg.Add(1)
|
||||
go s.rebuildStaleVectors(observationStore, summaryStore, promptStore, vectorClient, vectorSync)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startWatchers initializes and starts file watchers for database and config.
|
||||
@@ -565,6 +588,210 @@ func (s *Service) processStaleQueue() {
|
||||
}
|
||||
}
|
||||
|
||||
// rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts.
|
||||
// Called when the vectors table is empty (e.g., after migration 20 drops all vectors).
|
||||
func (s *Service) rebuildAllVectors(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
vectorSync *sqlitevec.Sync,
|
||||
) {
|
||||
defer s.wg.Done()
|
||||
|
||||
log.Info().Msg("Starting full vector rebuild...")
|
||||
start := time.Now()
|
||||
|
||||
var totalSynced int
|
||||
var syncErrors int
|
||||
|
||||
// Rebuild observations
|
||||
observations, err := observationStore.GetAllObservations(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild")
|
||||
} else {
|
||||
for _, obs := range observations {
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(observations)).Msg("Rebuilt observation vectors")
|
||||
}
|
||||
|
||||
// Rebuild summaries
|
||||
summaries, err := summaryStore.GetAllSummaries(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild")
|
||||
} else {
|
||||
for _, summary := range summaries {
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(summaries)).Msg("Rebuilt summary vectors")
|
||||
}
|
||||
|
||||
// Rebuild user prompts
|
||||
prompts, err := promptStore.GetAllPrompts(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch prompts for vector rebuild")
|
||||
} else {
|
||||
for _, prompt := range prompts {
|
||||
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
|
||||
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(prompts)).Msg("Rebuilt prompt vectors")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
log.Info().
|
||||
Int("total_synced", totalSynced).
|
||||
Int("errors", syncErrors).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("Full vector rebuild complete")
|
||||
}
|
||||
|
||||
// rebuildStaleVectors rebuilds only vectors with mismatched or unknown model versions.
|
||||
// This is more efficient than rebuilding all vectors when only some need updating.
|
||||
func (s *Service) rebuildStaleVectors(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
vectorClient *sqlitevec.Client,
|
||||
vectorSync *sqlitevec.Sync,
|
||||
) {
|
||||
defer s.wg.Done()
|
||||
|
||||
log.Info().Msg("Starting granular vector rebuild for stale vectors...")
|
||||
start := time.Now()
|
||||
|
||||
// Get all stale vectors
|
||||
staleVectors, err := vectorClient.GetStaleVectors(s.ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get stale vectors")
|
||||
return
|
||||
}
|
||||
|
||||
if len(staleVectors) == 0 {
|
||||
log.Info().Msg("No stale vectors found")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().Int("stale_count", len(staleVectors)).Msg("Found stale vectors to rebuild")
|
||||
|
||||
// Group stale vectors by doc_type and sqlite_id for efficient lookup
|
||||
staleObsIDs := make(map[int64]bool)
|
||||
staleSummaryIDs := make(map[int64]bool)
|
||||
stalePromptIDs := make(map[int64]bool)
|
||||
staleDocIDs := make([]string, 0, len(staleVectors))
|
||||
|
||||
for _, sv := range staleVectors {
|
||||
staleDocIDs = append(staleDocIDs, sv.DocID)
|
||||
switch sv.DocType {
|
||||
case "observation":
|
||||
staleObsIDs[sv.SQLiteID] = true
|
||||
case "summary":
|
||||
staleSummaryIDs[sv.SQLiteID] = true
|
||||
case "prompt":
|
||||
stalePromptIDs[sv.SQLiteID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Delete stale vectors before re-syncing
|
||||
if err := vectorClient.DeleteVectorsByDocIDs(s.ctx, staleDocIDs); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to delete stale vectors")
|
||||
return
|
||||
}
|
||||
|
||||
var totalSynced int
|
||||
var syncErrors int
|
||||
|
||||
// Rebuild stale observations
|
||||
if len(staleObsIDs) > 0 {
|
||||
ids := make([]int64, 0, len(staleObsIDs))
|
||||
for id := range staleObsIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
observations, err := observationStore.GetObservationsByIDs(s.ctx, ids, "date_desc", 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch observations for rebuild")
|
||||
} else {
|
||||
for _, obs := range observations {
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors")
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild stale summaries
|
||||
if len(staleSummaryIDs) > 0 {
|
||||
ids := make([]int64, 0, len(staleSummaryIDs))
|
||||
for id := range staleSummaryIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
summaries, err := summaryStore.GetSummariesByIDs(s.ctx, ids, "date_desc", 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch summaries for rebuild")
|
||||
} else {
|
||||
for _, summary := range summaries {
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors")
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild stale prompts
|
||||
if len(stalePromptIDs) > 0 {
|
||||
ids := make([]int64, 0, len(stalePromptIDs))
|
||||
for id := range stalePromptIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
prompts, err := promptStore.GetPromptsByIDs(s.ctx, ids, "date_desc", 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to fetch prompts for rebuild")
|
||||
} else {
|
||||
for _, prompt := range prompts {
|
||||
if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil {
|
||||
log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild")
|
||||
syncErrors++
|
||||
} else {
|
||||
totalSynced++
|
||||
}
|
||||
}
|
||||
log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors")
|
||||
}
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
log.Info().
|
||||
Int("total_synced", totalSynced).
|
||||
Int("errors", syncErrors).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("Granular vector rebuild complete")
|
||||
}
|
||||
|
||||
// verifyStaleObservation verifies a single stale observation in the background.
|
||||
func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
|
||||
// Wait for service to be ready
|
||||
@@ -667,6 +894,7 @@ func (s *Service) setupRoutes() {
|
||||
r.Get("/api/stats", s.handleGetStats)
|
||||
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
|
||||
r.Get("/api/types", s.handleGetTypes)
|
||||
r.Get("/api/models", s.handleGetModels)
|
||||
|
||||
// Context injection
|
||||
r.Get("/api/context/count", s.handleContextCount)
|
||||
|
||||
Executable
+121
@@ -0,0 +1,121 @@
|
||||
#!/bin/bash
|
||||
# Download BGE-small-en-v1.5 model for embedding
|
||||
# Usage: ./download-bge-model.sh [--force]
|
||||
# Use --force to re-download even if files exist
|
||||
|
||||
set -e
|
||||
|
||||
MODEL_NAME="bge-small-en-v1.5"
|
||||
MODEL_REPO="BAAI/bge-small-en-v1.5"
|
||||
ASSETS_DIR="internal/embedding/assets"
|
||||
VERSION_FILE="${ASSETS_DIR}/.model_version"
|
||||
FORCE_DOWNLOAD=false
|
||||
|
||||
# Check for --force flag
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" = "--force" ]; then
|
||||
FORCE_DOWNLOAD=true
|
||||
fi
|
||||
done
|
||||
|
||||
# Temporary directory for downloads
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
trap "rm -rf ${TEMP_DIR}" EXIT
|
||||
|
||||
# Check if model already exists
|
||||
model_exists() {
|
||||
[ -f "${ASSETS_DIR}/model.onnx" ] && [ -f "${ASSETS_DIR}/tokenizer.json" ]
|
||||
}
|
||||
|
||||
# Get installed version
|
||||
get_installed_version() {
|
||||
if [ -f "$VERSION_FILE" ]; then
|
||||
cat "$VERSION_FILE"
|
||||
else
|
||||
echo ""
|
||||
fi
|
||||
}
|
||||
|
||||
# Write version file
|
||||
write_version_file() {
|
||||
echo "${MODEL_NAME}" > "$VERSION_FILE"
|
||||
}
|
||||
|
||||
download_model() {
|
||||
echo "Downloading ${MODEL_NAME} from Hugging Face..."
|
||||
|
||||
# Create assets directory
|
||||
mkdir -p "${ASSETS_DIR}"
|
||||
|
||||
# Download ONNX model
|
||||
# BGE models have ONNX exports available in the repo
|
||||
echo "Downloading ONNX model..."
|
||||
curl -fsSL \
|
||||
"https://huggingface.co/${MODEL_REPO}/resolve/main/onnx/model.onnx" \
|
||||
-o "${TEMP_DIR}/model.onnx"
|
||||
|
||||
# Download tokenizer.json
|
||||
echo "Downloading tokenizer..."
|
||||
curl -fsSL \
|
||||
"https://huggingface.co/${MODEL_REPO}/resolve/main/tokenizer.json" \
|
||||
-o "${TEMP_DIR}/tokenizer.json"
|
||||
|
||||
# Verify files exist and have content
|
||||
if [ ! -s "${TEMP_DIR}/model.onnx" ]; then
|
||||
echo "Error: Failed to download model.onnx or file is empty"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -s "${TEMP_DIR}/tokenizer.json" ]; then
|
||||
echo "Error: Failed to download tokenizer.json or file is empty"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Move to assets directory (backup old files first)
|
||||
if [ -f "${ASSETS_DIR}/model.onnx" ]; then
|
||||
mv "${ASSETS_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx.bak"
|
||||
fi
|
||||
if [ -f "${ASSETS_DIR}/tokenizer.json" ]; then
|
||||
mv "${ASSETS_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json.bak"
|
||||
fi
|
||||
|
||||
mv "${TEMP_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx"
|
||||
mv "${TEMP_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json"
|
||||
|
||||
# Remove backups on success
|
||||
rm -f "${ASSETS_DIR}/model.onnx.bak" "${ASSETS_DIR}/tokenizer.json.bak"
|
||||
|
||||
# Write version file
|
||||
write_version_file
|
||||
|
||||
echo "Model size: $(du -h "${ASSETS_DIR}/model.onnx" | cut -f1)"
|
||||
echo "Tokenizer size: $(du -h "${ASSETS_DIR}/tokenizer.json" | cut -f1)"
|
||||
}
|
||||
|
||||
echo "BGE Model Downloader - ${MODEL_NAME}"
|
||||
echo "=================================="
|
||||
|
||||
need_download=false
|
||||
reason=""
|
||||
|
||||
if [ "$FORCE_DOWNLOAD" = true ]; then
|
||||
need_download=true
|
||||
reason="forced"
|
||||
elif ! model_exists; then
|
||||
need_download=true
|
||||
reason="not found"
|
||||
elif [ "$(get_installed_version)" != "${MODEL_NAME}" ]; then
|
||||
need_download=true
|
||||
reason="version mismatch (installed: $(get_installed_version), required: ${MODEL_NAME})"
|
||||
fi
|
||||
|
||||
if [ "$need_download" = true ]; then
|
||||
if [ -n "$reason" ] && [ "$reason" != "not found" ]; then
|
||||
echo "Re-downloading: ${reason}"
|
||||
fi
|
||||
download_model
|
||||
echo "Done! ${MODEL_NAME} installed successfully."
|
||||
else
|
||||
echo "Model ${MODEL_NAME} already exists, skipping download."
|
||||
echo "Use --force to re-download."
|
||||
fi
|
||||
Generated
+2
-2
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
||||
"version": "v0.6.38-1-g2e57cb9-dirty",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
||||
"version": "v0.6.38-1-g2e57cb9-dirty",
|
||||
"dependencies": {
|
||||
"vue": "^3.5.13"
|
||||
},
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.33-3-gf38ce5c-dirty",
|
||||
"version": "v0.6.38-1-g2e57cb9-dirty",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
||||
Reference in New Issue
Block a user