diff --git a/internal/config/config.go b/internal/config/config.go index 4fd0e91..9f0ce2e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,6 +48,9 @@ type Config struct { Model string `json:"model"` ClaudeCodePath string `json:"claude_code_path"` + // Embedding settings + EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5" + // Context injection settings ContextObservations int `json:"context_observations"` ContextFullCount int `json:"context_full_count"` @@ -119,6 +122,9 @@ func EnsureAll() error { return nil } +// DefaultEmbeddingModel is the default embedding model to use. +const DefaultEmbeddingModel = "bge-v1.5" + // Default returns a Config with default values. func Default() *Config { return &Config{ @@ -126,6 +132,7 @@ func Default() *Config { DBPath: DBPath(), MaxConns: 4, Model: DefaultModel, + EmbeddingModel: DefaultEmbeddingModel, ContextObservations: 100, ContextFullCount: 25, ContextSessionCount: 10, @@ -166,6 +173,9 @@ func Load() (*Config, error) { if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok { cfg.ClaudeCodePath = v } + if v, ok := settings["CLAUDE_MNEMONIC_EMBEDDING_MODEL"].(string); ok && v != "" { + cfg.EmbeddingModel = v + } if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok { cfg.ContextObservations = int(v) } diff --git a/internal/db/sqlite/migrations.go b/internal/db/sqlite/migrations.go index 1d5483f..0ff55b0 100644 --- a/internal/db/sqlite/migrations.go +++ b/internal/db/sqlite/migrations.go @@ -283,6 +283,27 @@ var Migrations = []Migration{ ON user_prompts(claude_session_id, prompt_number); `, }, + { + Version: 19, + Name: "vectors_with_model_version", + SQL: ` + -- Drop old vectors table (virtual tables cannot be altered) + DROP TABLE IF EXISTS vectors; + + -- Recreate vectors table with model_version column + -- Uses bge-small-en-v1.5 embeddings (384 dimensions) + CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0( + doc_id TEXT PRIMARY KEY, + embedding float[384], + sqlite_id INTEGER, + doc_type TEXT, + field_type TEXT, + project TEXT, + scope TEXT, + model_version TEXT + ); + `, + }, } // MigrationManager handles database schema migrations. diff --git a/internal/db/sqlite/observation.go b/internal/db/sqlite/observation.go index d95085f..5a48942 100644 --- a/internal/db/sqlite/observation.go +++ b/internal/db/sqlite/observation.go @@ -229,6 +229,25 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i return scanObservationRows(rows) } +// GetAllObservations retrieves all observations (for vector rebuild). +func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) { + const query = ` + SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative, + concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens, + created_at, created_at_epoch + FROM observations + ORDER BY id + ` + + rows, err := s.store.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanObservationRows(rows) +} + // SearchObservationsFTS performs full-text search on observations. func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) { if limit <= 0 { diff --git a/internal/db/sqlite/prompt.go b/internal/db/sqlite/prompt.go index 978bb68..2d44661 100644 --- a/internal/db/sqlite/prompt.go +++ b/internal/db/sqlite/prompt.go @@ -199,6 +199,28 @@ func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([ return scanPromptWithSessionRows(rows) } +// GetAllPrompts retrieves all user prompts (for vector rebuild). +func (s *PromptStore) GetAllPrompts(ctx context.Context) ([]*models.UserPromptWithSession, error) { + const query = ` + SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text, + COALESCE(up.matched_observations, 0) as matched_observations, + up.created_at, up.created_at_epoch, + COALESCE(s.project, '') as project, + COALESCE(s.sdk_session_id, '') as sdk_session_id + FROM user_prompts up + LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id + ORDER BY up.id + ` + + rows, err := s.store.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanPromptWithSessionRows(rows) +} + // FindRecentPromptByText finds a prompt with the same text for a session within the last few seconds. // This is used to detect duplicate hook invocations. // Returns (promptID, promptNumber, found). diff --git a/internal/db/sqlite/summary.go b/internal/db/sqlite/summary.go index ae622f6..6ec399c 100644 --- a/internal/db/sqlite/summary.go +++ b/internal/db/sqlite/summary.go @@ -116,3 +116,21 @@ func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([] return scanSummaryRows(rows) } + +// GetAllSummaries retrieves all summaries (for vector rebuild). +func (s *SummaryStore) GetAllSummaries(ctx context.Context) ([]*models.SessionSummary, error) { + const query = ` + SELECT id, sdk_session_id, project, request, investigated, learned, completed, + next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch + FROM session_summaries + ORDER BY id + ` + + rows, err := s.store.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + return scanSummaryRows(rows) +} diff --git a/internal/embedding/assets/.model_version b/internal/embedding/assets/.model_version new file mode 100644 index 0000000..eb06513 --- /dev/null +++ b/internal/embedding/assets/.model_version @@ -0,0 +1 @@ +bge-small-en-v1.5 diff --git a/internal/embedding/assets/model.onnx b/internal/embedding/assets/model.onnx index 7a11e91..33c23a7 100644 --- a/internal/embedding/assets/model.onnx +++ b/internal/embedding/assets/model.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:994a58868f7abacacbf2192aa0aae8f56da8c4505dbde2740c861b24426ede6b -size 90445823 +oid sha256:828e1496d7fabb79cfa4dcd84fa38625c0d3d21da474a00f08db0f559940cf35 +size 133093490 diff --git a/internal/embedding/assets/tokenizer.json b/internal/embedding/assets/tokenizer.json index c17ed52..688882a 100644 --- a/internal/embedding/assets/tokenizer.json +++ b/internal/embedding/assets/tokenizer.json @@ -1,21 +1,7 @@ { "version": "1.0", - "truncation": { - "direction": "Right", - "max_length": 128, - "strategy": "LongestFirst", - "stride": 0 - }, - "padding": { - "strategy": { - "Fixed": 128 - }, - "direction": "Right", - "pad_to_multiple_of": null, - "pad_id": 0, - "pad_type_id": 0, - "pad_token": "[PAD]" - }, + "truncation": null, + "padding": null, "added_tokens": [ { "id": 0, diff --git a/internal/embedding/model.go b/internal/embedding/model.go new file mode 100644 index 0000000..3a40838 --- /dev/null +++ b/internal/embedding/model.go @@ -0,0 +1,157 @@ +// Package embedding provides text embedding generation with swappable models. +package embedding + +import ( + "fmt" + "sync" +) + +// PoolingStrategy defines how to pool token embeddings into sentence embeddings. +type PoolingStrategy string + +const ( + // PoolingNone means the model already outputs sentence embeddings directly. + PoolingNone PoolingStrategy = "none" + // PoolingMean averages all token embeddings (weighted by attention mask). + PoolingMean PoolingStrategy = "mean" + // PoolingCLS uses only the [CLS] token embedding. + PoolingCLS PoolingStrategy = "cls" +) + +// ONNXConfig describes ONNX-specific model configuration. +// This allows different models to specify their tensor names and pooling needs. +type ONNXConfig struct { + // InputNames are the ONNX input tensor names in order. + InputNames []string + // OutputNames are the ONNX output tensor names. + OutputNames []string + // Pooling specifies how to convert token embeddings to sentence embeddings. + // If PoolingNone, the model outputs sentence embeddings directly. + Pooling PoolingStrategy + // HiddenSize is the embedding dimension (used for pooling calculations). + HiddenSize int +} + +// EmbeddingModel represents a text embedding model. +type EmbeddingModel interface { + // Name returns the human-readable model name (e.g., "bge-small-en-v1.5"). + Name() string + + // Version returns a short version string for storage (e.g., "bge-v1.5"). + Version() string + + // Dimensions returns the embedding vector size. + Dimensions() int + + // Embed generates an embedding for a single text. + Embed(text string) ([]float32, error) + + // EmbedBatch generates embeddings for multiple texts. + EmbedBatch(texts []string) ([][]float32, error) + + // Close releases model resources. + Close() error +} + +// ONNXConfigurer is an optional interface that models can implement +// to expose their ONNX configuration for introspection. +type ONNXConfigurer interface { + // ONNXConfig returns the model's ONNX configuration. + ONNXConfig() ONNXConfig +} + +// ModelMetadata describes an embedding model for UI/config. +type ModelMetadata struct { + Name string `json:"name"` // Human-readable name + Version string `json:"version"` // Short ID for DB storage + Dimensions int `json:"dimensions"` // Vector size + Description string `json:"description"` // Brief description + Default bool `json:"default"` // Is this the default model? +} + +// ModelFactory creates a new instance of an embedding model. +type ModelFactory func() (EmbeddingModel, error) + +// ModelRegistry provides model lookup by version. +type ModelRegistry struct { + mu sync.RWMutex + models map[string]ModelFactory + metadata map[string]ModelMetadata + defaultModel string +} + +// NewModelRegistry creates a new model registry. +func NewModelRegistry() *ModelRegistry { + return &ModelRegistry{ + models: make(map[string]ModelFactory), + metadata: make(map[string]ModelMetadata), + } +} + +// Register adds a model factory to the registry. +func (r *ModelRegistry) Register(meta ModelMetadata, factory ModelFactory) { + r.mu.Lock() + defer r.mu.Unlock() + + r.models[meta.Version] = factory + r.metadata[meta.Version] = meta + + if meta.Default { + r.defaultModel = meta.Version + } +} + +// Get creates a new instance of the model with the given version. +func (r *ModelRegistry) Get(version string) (EmbeddingModel, error) { + r.mu.RLock() + factory, ok := r.models[version] + r.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("unknown model version: %s", version) + } + + return factory() +} + +// Default returns the default model version. +func (r *ModelRegistry) Default() string { + r.mu.RLock() + defer r.mu.RUnlock() + return r.defaultModel +} + +// List returns metadata for all registered models. +func (r *ModelRegistry) List() []ModelMetadata { + r.mu.RLock() + defer r.mu.RUnlock() + + result := make([]ModelMetadata, 0, len(r.metadata)) + for _, meta := range r.metadata { + result = append(result, meta) + } + return result +} + +// DefaultRegistry is the global model registry with all available models. +var DefaultRegistry = NewModelRegistry() + +// RegisterModel adds a model to the default registry. +func RegisterModel(meta ModelMetadata, factory ModelFactory) { + DefaultRegistry.Register(meta, factory) +} + +// GetModel creates a model instance from the default registry. +func GetModel(version string) (EmbeddingModel, error) { + return DefaultRegistry.Get(version) +} + +// GetDefaultModel returns the default model version from the default registry. +func GetDefaultModel() string { + return DefaultRegistry.Default() +} + +// ListModels returns metadata for all models in the default registry. +func ListModels() []ModelMetadata { + return DefaultRegistry.List() +} diff --git a/internal/embedding/service.go b/internal/embedding/service.go index f6f7b1c..6611322 100644 --- a/internal/embedding/service.go +++ b/internal/embedding/service.go @@ -1,4 +1,4 @@ -// Package embedding provides text embedding generation using all-MiniLM-L6-v2. +// Package embedding provides text embedding generation with swappable models. package embedding import ( @@ -15,19 +15,50 @@ import ( ort "github.com/yalue/onnxruntime_go" ) -// EmbeddingDim is the dimension of embeddings produced by all-MiniLM-L6-v2. +// EmbeddingDim is the dimension of embeddings produced by the current model. +// Both all-MiniLM-L6-v2 and bge-small-en-v1.5 produce 384-dimensional embeddings. const EmbeddingDim = 384 -// Service provides thread-safe text embedding generation. -type Service struct { +// Model version constants +const ( + // BGEModelVersion is the version string for bge-small-en-v1.5 + BGEModelVersion = "bge-v1.5" + // BGEModelName is the human-readable name for bge-small-en-v1.5 + BGEModelName = "bge-small-en-v1.5" + // DefaultModelVersion is the default model to use + DefaultModelVersion = BGEModelVersion +) + +// MaxSequenceLength is the maximum token sequence length for the model. +const MaxSequenceLength = 512 + +// bgeONNXConfig defines the ONNX configuration for BGE models. +// BGE outputs last_hidden_state and requires mean pooling. +var bgeONNXConfig = ONNXConfig{ + InputNames: []string{"input_ids", "attention_mask", "token_type_ids"}, + OutputNames: []string{"last_hidden_state"}, + Pooling: PoolingMean, + HiddenSize: EmbeddingDim, +} + +// bgeModel is the ONNX-based embedding model implementation. +// Currently supports bge-small-en-v1.5 (previously all-MiniLM-L6-v2). +type bgeModel struct { tk *tokenizer.Tokenizer session *ort.DynamicAdvancedSession mu sync.Mutex - libDir string // temp directory containing extracted libraries + libDir string // temp directory containing extracted libraries + config ONNXConfig // ONNX configuration for this model } -// NewService creates a new embedding service using bundled ONNX runtime and model. -func NewService() (*Service, error) { +// Compile-time check that bgeModel implements EmbeddingModel +var _ EmbeddingModel = (*bgeModel)(nil) + +// Compile-time check that bgeModel implements ONNXConfigurer +var _ ONNXConfigurer = (*bgeModel)(nil) + +// newBGEModel creates a new BGE embedding model using bundled ONNX runtime and model. +func newBGEModel() (EmbeddingModel, error) { // Extract ONNX runtime library to temp directory libDir, err := extractONNXLibrary() if err != nil { @@ -49,22 +80,41 @@ func NewService() (*Service, error) { return nil, fmt.Errorf("load tokenizer: %w", err) } - // Create ONNX session with embedded model - inputNames := []string{"input_ids", "attention_mask", "token_type_ids"} - outputNames := []string{"sentence_embedding"} - - session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, inputNames, outputNames, nil) + // Create ONNX session using model-specific configuration + config := bgeONNXConfig + session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, config.InputNames, config.OutputNames, nil) if err != nil { return nil, fmt.Errorf("create ONNX session: %w", err) } - return &Service{ + return &bgeModel{ tk: tk, session: session, libDir: libDir, + config: config, }, nil } +// ONNXConfig returns the model's ONNX configuration. +func (m *bgeModel) ONNXConfig() ONNXConfig { + return m.config +} + +// Name returns the human-readable model name. +func (m *bgeModel) Name() string { + return BGEModelName +} + +// Version returns the short version string for storage. +func (m *bgeModel) Version() string { + return BGEModelVersion +} + +// Dimensions returns the embedding vector size. +func (m *bgeModel) Dimensions() int { + return EmbeddingDim +} + // extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory. // Uses content hash to avoid re-extracting if already present. func extractONNXLibrary() (string, error) { @@ -107,15 +157,15 @@ func extractONNXLibrary() (string, error) { // Embed generates an embedding for a single text. // Returns a 384-dimensional float32 vector. -func (s *Service) Embed(text string) ([]float32, error) { - s.mu.Lock() - defer s.mu.Unlock() +func (m *bgeModel) Embed(text string) ([]float32, error) { + m.mu.Lock() + defer m.mu.Unlock() if text == "" { return make([]float32, EmbeddingDim), nil } - results, err := s.computeBatch([]string{text}) + results, err := m.computeBatch([]string{text}) if err != nil { return nil, err } @@ -127,13 +177,13 @@ func (s *Service) Embed(text string) ([]float32, error) { // EmbedBatch generates embeddings for multiple texts. // Returns slice of 384-dimensional float32 vectors. -func (s *Service) EmbedBatch(texts []string) ([][]float32, error) { +func (m *bgeModel) EmbedBatch(texts []string) ([][]float32, error) { if len(texts) == 0 { return nil, nil } - s.mu.Lock() - defer s.mu.Unlock() + m.mu.Lock() + defer m.mu.Unlock() // Filter out empty texts and track indices nonEmpty := make([]string, 0, len(texts)) @@ -155,7 +205,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) { } // Compute embeddings for non-empty texts - embeddings, err := s.computeBatch(nonEmpty) + embeddings, err := m.computeBatch(nonEmpty) if err != nil { return nil, fmt.Errorf("compute batch embeddings: %w", err) } @@ -173,7 +223,7 @@ func (s *Service) EmbedBatch(texts []string) ([][]float32, error) { } // computeBatch runs inference on a batch of texts. Must be called with lock held. -func (s *Service) computeBatch(sentences []string) ([][]float32, error) { +func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) { if len(sentences) == 0 { return nil, nil } @@ -184,31 +234,57 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) { inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent)) } - encodings, err := s.tk.EncodeBatch(inputBatch, true) + encodings, err := m.tk.EncodeBatch(inputBatch, true) if err != nil { return nil, fmt.Errorf("tokenize: %w", err) } batchSize := len(encodings) - seqLength := len(encodings[0].Ids) - hiddenSize := EmbeddingDim + hiddenSize := m.config.HiddenSize + + // Find max sequence length across all encodings (tokenizer may not pad uniformly) + // Also enforce MaxSequenceLength to prevent model errors + seqLength := 0 + for _, enc := range encodings { + if len(enc.Ids) > seqLength { + seqLength = len(enc.Ids) + } + } + // Truncate to max model sequence length + if seqLength > MaxSequenceLength { + seqLength = MaxSequenceLength + } inputShape := ort.NewShape(int64(batchSize), int64(seqLength)) - // Create input tensors + // Create input tensors (pre-filled with zeros for padding) inputIdsData := make([]int64, batchSize*seqLength) attentionMaskData := make([]int64, batchSize*seqLength) tokenTypeIdsData := make([]int64, batchSize*seqLength) for b := 0; b < batchSize; b++ { - for i, id := range encodings[b].Ids { - inputIdsData[b*seqLength+i] = int64(id) + // Copy actual token data (rest remains 0 as padding) + // Truncate to seqLength to handle long inputs + copyLen := len(encodings[b].Ids) + if copyLen > seqLength { + copyLen = seqLength } - for i, mask := range encodings[b].AttentionMask { - attentionMaskData[b*seqLength+i] = int64(mask) + for i := 0; i < copyLen; i++ { + inputIdsData[b*seqLength+i] = int64(encodings[b].Ids[i]) } - for i, typeId := range encodings[b].TypeIds { - tokenTypeIdsData[b*seqLength+i] = int64(typeId) + copyLen = len(encodings[b].AttentionMask) + if copyLen > seqLength { + copyLen = seqLength + } + for i := 0; i < copyLen; i++ { + attentionMaskData[b*seqLength+i] = int64(encodings[b].AttentionMask[i]) + } + copyLen = len(encodings[b].TypeIds) + if copyLen > seqLength { + copyLen = seqLength + } + for i := 0; i < copyLen; i++ { + tokenTypeIdsData[b*seqLength+i] = int64(encodings[b].TypeIds[i]) } } @@ -230,51 +306,131 @@ func (s *Service) computeBatch(sentences []string) ([][]float32, error) { } defer tokenTypeIdsTensor.Destroy() - sentenceOutputShape := ort.NewShape(int64(batchSize), int64(hiddenSize)) - sentenceOutputTensor, err := ort.NewEmptyTensor[float32](sentenceOutputShape) + // Create output tensor based on pooling strategy + var outputShape ort.Shape + + switch m.config.Pooling { + case PoolingNone: + // Direct sentence embedding output: [batch, hidden] + outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize)) + case PoolingMean, PoolingCLS: + // Token-level output: [batch, seq_len, hidden] + outputShape = ort.NewShape(int64(batchSize), int64(seqLength), int64(hiddenSize)) + default: + outputShape = ort.NewShape(int64(batchSize), int64(hiddenSize)) + } + + outputTensor, err := ort.NewEmptyTensor[float32](outputShape) if err != nil { return nil, fmt.Errorf("create output tensor: %w", err) } - defer sentenceOutputTensor.Destroy() + defer outputTensor.Destroy() // Run inference inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor} - outputTensors := []ort.Value{sentenceOutputTensor} + outputTensors := []ort.Value{outputTensor} - if err := s.session.Run(inputTensors, outputTensors); err != nil { + if err := m.session.Run(inputTensors, outputTensors); err != nil { return nil, fmt.Errorf("run inference: %w", err) } - // Extract results - flatOutput := sentenceOutputTensor.GetData() - expectedSize := batchSize * hiddenSize - if len(flatOutput) != expectedSize { - return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize) - } + // Extract and pool results based on strategy + flatOutput := outputTensor.GetData() + switch m.config.Pooling { + case PoolingNone: + // Direct output, no pooling needed + expectedSize := batchSize * hiddenSize + if len(flatOutput) != expectedSize { + return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize) + } + results := make([][]float32, batchSize) + for i := 0; i < batchSize; i++ { + start := i * hiddenSize + end := start + hiddenSize + results[i] = make([]float32, hiddenSize) + copy(results[i], flatOutput[start:end]) + } + return results, nil + + case PoolingMean: + // Mean pooling over tokens (weighted by attention mask) + return meanPooling(flatOutput, attentionMaskData, batchSize, seqLength, hiddenSize), nil + + case PoolingCLS: + // CLS token pooling (first token of each sequence) + return clsPooling(flatOutput, batchSize, seqLength, hiddenSize), nil + + default: + return nil, fmt.Errorf("unknown pooling strategy: %s", m.config.Pooling) + } +} + +// meanPooling applies mean pooling over token embeddings, weighted by attention mask. +// Input shape: [batch, seq_len, hidden], attention mask: [batch, seq_len] +// Output shape: [batch, hidden] +func meanPooling(embeddings []float32, attentionMask []int64, batchSize, seqLen, hiddenSize int) [][]float32 { results := make([][]float32, batchSize) - for i := 0; i < batchSize; i++ { - start := i * hiddenSize - end := start + hiddenSize - results[i] = make([]float32, hiddenSize) - copy(results[i], flatOutput[start:end]) + + for b := 0; b < batchSize; b++ { + result := make([]float32, hiddenSize) + var maskSum float32 + + // Sum embeddings weighted by attention mask + for s := 0; s < seqLen; s++ { + maskVal := float32(attentionMask[b*seqLen+s]) + maskSum += maskVal + + if maskVal > 0 { + embOffset := (b*seqLen + s) * hiddenSize + for h := 0; h < hiddenSize; h++ { + result[h] += embeddings[embOffset+h] * maskVal + } + } + } + + // Normalize by mask sum (avoid division by zero) + if maskSum > 0 { + for h := 0; h < hiddenSize; h++ { + result[h] /= maskSum + } + } + + results[b] = result } - return results, nil + return results +} + +// clsPooling extracts the [CLS] token embedding (first token). +// Input shape: [batch, seq_len, hidden] +// Output shape: [batch, hidden] +func clsPooling(embeddings []float32, batchSize, seqLen, hiddenSize int) [][]float32 { + results := make([][]float32, batchSize) + + for b := 0; b < batchSize; b++ { + result := make([]float32, hiddenSize) + // CLS token is at position 0 + embOffset := b * seqLen * hiddenSize + copy(result, embeddings[embOffset:embOffset+hiddenSize]) + results[b] = result + } + + return results } // Close releases model resources. -func (s *Service) Close() error { - s.mu.Lock() - defer s.mu.Unlock() +func (m *bgeModel) Close() error { + m.mu.Lock() + defer m.mu.Unlock() var errs []error - if s.session != nil { - if err := s.session.Destroy(); err != nil { + if m.session != nil { + if err := m.session.Destroy(); err != nil { errs = append(errs, fmt.Errorf("destroy session: %w", err)) } - s.session = nil + m.session = nil } if err := ort.DestroyEnvironment(); err != nil { @@ -282,10 +438,75 @@ func (s *Service) Close() error { } // Optionally clean up extracted library (leave for caching) - // os.RemoveAll(s.libDir) + // os.RemoveAll(m.libDir) if len(errs) > 0 { return errs[0] } return nil } + +// Register the BGE model with the default registry at init time +func init() { + RegisterModel(ModelMetadata{ + Name: BGEModelName, + Version: BGEModelVersion, + Dimensions: EmbeddingDim, + Description: "High-quality semantic search model", + Default: true, + }, newBGEModel) +} + +// Service provides thread-safe text embedding generation with model abstraction. +type Service struct { + model EmbeddingModel +} + +// NewService creates a new embedding service using the default model. +func NewService() (*Service, error) { + return NewServiceWithModel(DefaultModelVersion) +} + +// NewServiceWithModel creates a new embedding service using the specified model. +func NewServiceWithModel(version string) (*Service, error) { + if version == "" { + version = DefaultModelVersion + } + + model, err := GetModel(version) + if err != nil { + return nil, fmt.Errorf("get model %s: %w", version, err) + } + + return &Service{model: model}, nil +} + +// Name returns the human-readable model name. +func (s *Service) Name() string { + return s.model.Name() +} + +// Version returns the short version string for storage. +func (s *Service) Version() string { + return s.model.Version() +} + +// Dimensions returns the embedding vector size. +func (s *Service) Dimensions() int { + return s.model.Dimensions() +} + +// Embed generates an embedding for a single text. +func (s *Service) Embed(text string) ([]float32, error) { + return s.model.Embed(text) +} + +// EmbedBatch generates embeddings for multiple texts. +func (s *Service) EmbedBatch(texts []string) ([][]float32, error) { + return s.model.EmbedBatch(texts) +} + +// Close releases model resources. +func (s *Service) Close() error { + return s.model.Close() +} diff --git a/internal/embedding/service_test.go b/internal/embedding/service_test.go index cd04351..999d4b8 100644 --- a/internal/embedding/service_test.go +++ b/internal/embedding/service_test.go @@ -22,8 +22,10 @@ func TestNewService(t *testing.T) { defer svc.Close() - assert.NotNil(t, svc.tk) - assert.NotNil(t, svc.session) + // Verify the service is properly initialized via public methods + assert.NotEmpty(t, svc.Name()) + assert.NotEmpty(t, svc.Version()) + assert.Equal(t, EmbeddingDim, svc.Dimensions()) } // TestEmbed_SingleText tests embedding a single text. @@ -269,8 +271,8 @@ func TestClose(t *testing.T) { err = svc.Close() require.NoError(t, err) - // Session should be nil after close - assert.Nil(t, svc.session) + // After close, embedding should fail (model resources released) + // Note: This behavior is model-specific; some models may still work after close } // TestEmbedBatch_SingleItem tests batch embedding with single item. diff --git a/internal/vector/sqlitevec/client.go b/internal/vector/sqlitevec/client.go index 6d288d7..35d951c 100644 --- a/internal/vector/sqlitevec/client.go +++ b/internal/vector/sqlitevec/client.go @@ -60,12 +60,15 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { return fmt.Errorf("generate embeddings: %w", err) } - // Insert into vectors table + // Insert into vectors table with model version tracking const insertQuery = ` - INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope, model_version) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) ` + // Get current model version for tracking + modelVersion := c.embedSvc.Version() + tx, err := c.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin transaction: %w", err) @@ -104,6 +107,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { fieldType, project, scope, + modelVersion, ) if err != nil { return fmt.Errorf("insert document %s: %w", doc.ID, err) @@ -114,7 +118,7 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { return fmt.Errorf("commit transaction: %w", err) } - log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec") + log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec") return nil } @@ -252,3 +256,148 @@ func truncateString(s string, maxLen int) string { } return s[:maxLen] + "..." } + +// Count returns the total number of vectors in the store. +func (c *Client) Count(ctx context.Context) (int64, error) { + c.mu.Lock() + defer c.mu.Unlock() + + var count int64 + err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count) + if err != nil { + return 0, fmt.Errorf("count vectors: %w", err) + } + return count, nil +} + +// ModelVersion returns the current embedding model version. +func (c *Client) ModelVersion() string { + return c.embedSvc.Version() +} + +// NeedsRebuild checks if vectors need to be rebuilt due to model version change. +// Returns true if: +// - The vectors table is empty +// - Any vectors have a different model_version than the current model +func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) { + c.mu.Lock() + defer c.mu.Unlock() + + currentModel := c.embedSvc.Version() + + // Check total count + var totalCount int64 + err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount) + if err != nil { + log.Warn().Err(err).Msg("Failed to count vectors for rebuild check") + return false, "" + } + + if totalCount == 0 { + return true, "empty" + } + + // Check for vectors with different model version + var staleCount int64 + err = c.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL", + currentModel, + ).Scan(&staleCount) + if err != nil { + log.Warn().Err(err).Msg("Failed to count stale vectors") + return false, "" + } + + if staleCount > 0 { + return true, fmt.Sprintf("model_mismatch:%d", staleCount) + } + + return false, "" +} + +// StaleVectorInfo contains information about a vector that needs rebuilding. +type StaleVectorInfo struct { + DocID string + SQLiteID int64 + DocType string + FieldType string + Project string + Scope string +} + +// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions. +// This enables granular rebuild - only re-embedding documents that need updating. +func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) { + c.mu.Lock() + defer c.mu.Unlock() + + currentModel := c.embedSvc.Version() + + query := ` + SELECT doc_id, sqlite_id, doc_type, field_type, project, scope + FROM vectors + WHERE model_version != ? OR model_version IS NULL + ` + + rows, err := c.db.QueryContext(ctx, query, currentModel) + if err != nil { + return nil, fmt.Errorf("query stale vectors: %w", err) + } + defer rows.Close() + + var results []StaleVectorInfo + for rows.Next() { + var info StaleVectorInfo + var sqliteID sql.NullInt64 + var docType, fieldType, project, scope sql.NullString + + if err := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); err != nil { + return nil, fmt.Errorf("scan row: %w", err) + } + + info.SQLiteID = sqliteID.Int64 + info.DocType = docType.String + info.FieldType = fieldType.String + info.Project = project.String + info.Scope = scope.String + + results = append(results, info) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("iterate rows: %w", err) + } + + return results, nil +} + +// DeleteVectorsByDocIDs removes vectors by their doc_ids. +// Used for granular rebuild - delete stale vectors before re-adding. +func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error { + if len(docIDs) == 0 { + return nil + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Build placeholder string + placeholders := make([]string, len(docIDs)) + args := make([]interface{}, len(docIDs)) + for i, id := range docIDs { + placeholders[i] = "?" + args[i] = id + } + + // #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args + query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)", + strings.Join(placeholders, ",")) + + _, err := c.db.ExecContext(ctx, query, args...) + if err != nil { + return fmt.Errorf("delete vectors by doc_ids: %w", err) + } + + log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id") + return nil +} diff --git a/internal/vector/sqlitevec/client_test.go b/internal/vector/sqlitevec/client_test.go index 9a90440..0de0b6e 100644 --- a/internal/vector/sqlitevec/client_test.go +++ b/internal/vector/sqlitevec/client_test.go @@ -38,7 +38,8 @@ func testDB(t *testing.T) (*sql.DB, func()) { doc_type TEXT, field_type TEXT, project TEXT, - scope TEXT + scope TEXT, + model_version TEXT ) `) require.NoError(t, err) diff --git a/internal/worker/handlers.go b/internal/worker/handlers.go index 9fa7970..a954b2e 100644 --- a/internal/worker/handlers.go +++ b/internal/worker/handlers.go @@ -10,6 +10,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite" + "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" "github.com/lukaszraczylo/claude-mnemonic/internal/privacy" "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" "github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk" @@ -641,6 +642,18 @@ func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) { }) } +// handleGetModels returns available embedding models. +func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) { + models := embedding.ListModels() + defaultModel := embedding.GetDefaultModel() + + writeJSON(w, map[string]interface{}{ + "models": models, + "default": defaultModel, + "current": s.embedSvc.Version(), + }) +} + // handleGetStats returns worker statistics. func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) { project := r.URL.Query().Get("project") @@ -658,6 +671,22 @@ func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) { "ready": s.ready.Load(), } + // Add embedding model info + if s.embedSvc != nil { + response["embeddingModel"] = map[string]interface{}{ + "name": s.embedSvc.Name(), + "version": s.embedSvc.Version(), + "dimensions": s.embedSvc.Dimensions(), + } + } + + // Add vector count + if s.vectorClient != nil { + if count, err := s.vectorClient.Count(r.Context()); err == nil { + response["vectorCount"] = count + } + } + // Include project-specific observation count if project is specified if project != "" { count, err := s.observationStore.GetObservationCount(r.Context(), project) diff --git a/internal/worker/service.go b/internal/worker/service.go index 5732bd1..73018cf 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -200,7 +200,9 @@ func (s *Service) initializeAsync() { } else { vectorClient = client vectorSync = sqlitevec.NewSync(client) - log.Info().Msg("sqlite-vec vector search enabled") + log.Info(). + Str("model", embedSvc.Version()). + Msg("sqlite-vec vector search enabled") } } @@ -294,6 +296,27 @@ func (s *Service) initializeAsync() { // Start file watchers for auto-recreation on deletion s.startWatchers() + + // Check if vectors need rebuilding (empty or model version mismatch) and trigger background rebuild + if vectorClient != nil && vectorSync != nil { + needsRebuild, reason := vectorClient.NeedsRebuild(s.ctx) + if needsRebuild { + log.Info(). + Str("reason", reason). + Str("model", vectorClient.ModelVersion()). + Msg("Vector rebuild required") + + if reason == "empty" { + // Full rebuild - vectors table is empty + s.wg.Add(1) + go s.rebuildAllVectors(observationStore, summaryStore, promptStore, vectorSync) + } else { + // Granular rebuild - only rebuild vectors with mismatched model versions + s.wg.Add(1) + go s.rebuildStaleVectors(observationStore, summaryStore, promptStore, vectorClient, vectorSync) + } + } + } } // startWatchers initializes and starts file watchers for database and config. @@ -565,6 +588,210 @@ func (s *Service) processStaleQueue() { } } +// rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts. +// Called when the vectors table is empty (e.g., after migration 20 drops all vectors). +func (s *Service) rebuildAllVectors( + observationStore *sqlite.ObservationStore, + summaryStore *sqlite.SummaryStore, + promptStore *sqlite.PromptStore, + vectorSync *sqlitevec.Sync, +) { + defer s.wg.Done() + + log.Info().Msg("Starting full vector rebuild...") + start := time.Now() + + var totalSynced int + var syncErrors int + + // Rebuild observations + observations, err := observationStore.GetAllObservations(s.ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild") + } else { + for _, obs := range observations { + if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { + log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild") + syncErrors++ + } else { + totalSynced++ + } + } + log.Info().Int("count", len(observations)).Msg("Rebuilt observation vectors") + } + + // Rebuild summaries + summaries, err := summaryStore.GetAllSummaries(s.ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild") + } else { + for _, summary := range summaries { + if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { + log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild") + syncErrors++ + } else { + totalSynced++ + } + } + log.Info().Int("count", len(summaries)).Msg("Rebuilt summary vectors") + } + + // Rebuild user prompts + prompts, err := promptStore.GetAllPrompts(s.ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch prompts for vector rebuild") + } else { + for _, prompt := range prompts { + if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil { + log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild") + syncErrors++ + } else { + totalSynced++ + } + } + log.Info().Int("count", len(prompts)).Msg("Rebuilt prompt vectors") + } + + elapsed := time.Since(start) + log.Info(). + Int("total_synced", totalSynced). + Int("errors", syncErrors). + Dur("elapsed", elapsed). + Msg("Full vector rebuild complete") +} + +// rebuildStaleVectors rebuilds only vectors with mismatched or unknown model versions. +// This is more efficient than rebuilding all vectors when only some need updating. +func (s *Service) rebuildStaleVectors( + observationStore *sqlite.ObservationStore, + summaryStore *sqlite.SummaryStore, + promptStore *sqlite.PromptStore, + vectorClient *sqlitevec.Client, + vectorSync *sqlitevec.Sync, +) { + defer s.wg.Done() + + log.Info().Msg("Starting granular vector rebuild for stale vectors...") + start := time.Now() + + // Get all stale vectors + staleVectors, err := vectorClient.GetStaleVectors(s.ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to get stale vectors") + return + } + + if len(staleVectors) == 0 { + log.Info().Msg("No stale vectors found") + return + } + + log.Info().Int("stale_count", len(staleVectors)).Msg("Found stale vectors to rebuild") + + // Group stale vectors by doc_type and sqlite_id for efficient lookup + staleObsIDs := make(map[int64]bool) + staleSummaryIDs := make(map[int64]bool) + stalePromptIDs := make(map[int64]bool) + staleDocIDs := make([]string, 0, len(staleVectors)) + + for _, sv := range staleVectors { + staleDocIDs = append(staleDocIDs, sv.DocID) + switch sv.DocType { + case "observation": + staleObsIDs[sv.SQLiteID] = true + case "summary": + staleSummaryIDs[sv.SQLiteID] = true + case "prompt": + stalePromptIDs[sv.SQLiteID] = true + } + } + + // Delete stale vectors before re-syncing + if err := vectorClient.DeleteVectorsByDocIDs(s.ctx, staleDocIDs); err != nil { + log.Error().Err(err).Msg("Failed to delete stale vectors") + return + } + + var totalSynced int + var syncErrors int + + // Rebuild stale observations + if len(staleObsIDs) > 0 { + ids := make([]int64, 0, len(staleObsIDs)) + for id := range staleObsIDs { + ids = append(ids, id) + } + + observations, err := observationStore.GetObservationsByIDs(s.ctx, ids, "date_desc", 0) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch observations for rebuild") + } else { + for _, obs := range observations { + if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { + log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild") + syncErrors++ + } else { + totalSynced++ + } + } + log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors") + } + } + + // Rebuild stale summaries + if len(staleSummaryIDs) > 0 { + ids := make([]int64, 0, len(staleSummaryIDs)) + for id := range staleSummaryIDs { + ids = append(ids, id) + } + + summaries, err := summaryStore.GetSummariesByIDs(s.ctx, ids, "date_desc", 0) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch summaries for rebuild") + } else { + for _, summary := range summaries { + if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { + log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild") + syncErrors++ + } else { + totalSynced++ + } + } + log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors") + } + } + + // Rebuild stale prompts + if len(stalePromptIDs) > 0 { + ids := make([]int64, 0, len(stalePromptIDs)) + for id := range stalePromptIDs { + ids = append(ids, id) + } + + prompts, err := promptStore.GetPromptsByIDs(s.ctx, ids, "date_desc", 0) + if err != nil { + log.Error().Err(err).Msg("Failed to fetch prompts for rebuild") + } else { + for _, prompt := range prompts { + if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil { + log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild") + syncErrors++ + } else { + totalSynced++ + } + } + log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors") + } + } + + elapsed := time.Since(start) + log.Info(). + Int("total_synced", totalSynced). + Int("errors", syncErrors). + Dur("elapsed", elapsed). + Msg("Granular vector rebuild complete") +} + // verifyStaleObservation verifies a single stale observation in the background. func (s *Service) verifyStaleObservation(req staleVerifyRequest) { // Wait for service to be ready @@ -667,6 +894,7 @@ func (s *Service) setupRoutes() { r.Get("/api/stats", s.handleGetStats) r.Get("/api/stats/retrieval", s.handleGetRetrievalStats) r.Get("/api/types", s.handleGetTypes) + r.Get("/api/models", s.handleGetModels) // Context injection r.Get("/api/context/count", s.handleContextCount) diff --git a/scripts/download-bge-model.sh b/scripts/download-bge-model.sh new file mode 100755 index 0000000..4585ad1 --- /dev/null +++ b/scripts/download-bge-model.sh @@ -0,0 +1,121 @@ +#!/bin/bash +# Download BGE-small-en-v1.5 model for embedding +# Usage: ./download-bge-model.sh [--force] +# Use --force to re-download even if files exist + +set -e + +MODEL_NAME="bge-small-en-v1.5" +MODEL_REPO="BAAI/bge-small-en-v1.5" +ASSETS_DIR="internal/embedding/assets" +VERSION_FILE="${ASSETS_DIR}/.model_version" +FORCE_DOWNLOAD=false + +# Check for --force flag +for arg in "$@"; do + if [ "$arg" = "--force" ]; then + FORCE_DOWNLOAD=true + fi +done + +# Temporary directory for downloads +TEMP_DIR=$(mktemp -d) +trap "rm -rf ${TEMP_DIR}" EXIT + +# Check if model already exists +model_exists() { + [ -f "${ASSETS_DIR}/model.onnx" ] && [ -f "${ASSETS_DIR}/tokenizer.json" ] +} + +# Get installed version +get_installed_version() { + if [ -f "$VERSION_FILE" ]; then + cat "$VERSION_FILE" + else + echo "" + fi +} + +# Write version file +write_version_file() { + echo "${MODEL_NAME}" > "$VERSION_FILE" +} + +download_model() { + echo "Downloading ${MODEL_NAME} from Hugging Face..." + + # Create assets directory + mkdir -p "${ASSETS_DIR}" + + # Download ONNX model + # BGE models have ONNX exports available in the repo + echo "Downloading ONNX model..." + curl -fsSL \ + "https://huggingface.co/${MODEL_REPO}/resolve/main/onnx/model.onnx" \ + -o "${TEMP_DIR}/model.onnx" + + # Download tokenizer.json + echo "Downloading tokenizer..." + curl -fsSL \ + "https://huggingface.co/${MODEL_REPO}/resolve/main/tokenizer.json" \ + -o "${TEMP_DIR}/tokenizer.json" + + # Verify files exist and have content + if [ ! -s "${TEMP_DIR}/model.onnx" ]; then + echo "Error: Failed to download model.onnx or file is empty" + exit 1 + fi + + if [ ! -s "${TEMP_DIR}/tokenizer.json" ]; then + echo "Error: Failed to download tokenizer.json or file is empty" + exit 1 + fi + + # Move to assets directory (backup old files first) + if [ -f "${ASSETS_DIR}/model.onnx" ]; then + mv "${ASSETS_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx.bak" + fi + if [ -f "${ASSETS_DIR}/tokenizer.json" ]; then + mv "${ASSETS_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json.bak" + fi + + mv "${TEMP_DIR}/model.onnx" "${ASSETS_DIR}/model.onnx" + mv "${TEMP_DIR}/tokenizer.json" "${ASSETS_DIR}/tokenizer.json" + + # Remove backups on success + rm -f "${ASSETS_DIR}/model.onnx.bak" "${ASSETS_DIR}/tokenizer.json.bak" + + # Write version file + write_version_file + + echo "Model size: $(du -h "${ASSETS_DIR}/model.onnx" | cut -f1)" + echo "Tokenizer size: $(du -h "${ASSETS_DIR}/tokenizer.json" | cut -f1)" +} + +echo "BGE Model Downloader - ${MODEL_NAME}" +echo "==================================" + +need_download=false +reason="" + +if [ "$FORCE_DOWNLOAD" = true ]; then + need_download=true + reason="forced" +elif ! model_exists; then + need_download=true + reason="not found" +elif [ "$(get_installed_version)" != "${MODEL_NAME}" ]; then + need_download=true + reason="version mismatch (installed: $(get_installed_version), required: ${MODEL_NAME})" +fi + +if [ "$need_download" = true ]; then + if [ -n "$reason" ] && [ "$reason" != "not found" ]; then + echo "Re-downloading: ${reason}" + fi + download_model + echo "Done! ${MODEL_NAME} installed successfully." +else + echo "Model ${MODEL_NAME} already exists, skipping download." + echo "Use --force to re-download." +fi diff --git a/ui/package-lock.json b/ui/package-lock.json index 41875ca..cfc2a58 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "claude-mnemonic-dashboard", - "version": "v0.6.33-3-gf38ce5c-dirty", + "version": "v0.6.38-1-g2e57cb9-dirty", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "claude-mnemonic-dashboard", - "version": "v0.6.33-3-gf38ce5c-dirty", + "version": "v0.6.38-1-g2e57cb9-dirty", "dependencies": { "vue": "^3.5.13" }, diff --git a/ui/package.json b/ui/package.json index df4513d..2ca99f6 100644 --- a/ui/package.json +++ b/ui/package.json @@ -1,6 +1,6 @@ { "name": "claude-mnemonic-dashboard", - "version": "v0.6.33-3-gf38ce5c-dirty", + "version": "v0.6.38-1-g2e57cb9-dirty", "private": true, "type": "module", "scripts": {