Move from chroma to sqlitevec with local embedding

This commit is contained in:
2025-12-16 11:28:26 +00:00
parent 6c28ecb22a
commit 7fe679f83b
25 changed files with 31649 additions and 1161 deletions
-521
View File
@@ -1,521 +0,0 @@
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
package chroma
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"sync"
"github.com/rs/zerolog/log"
)
// Document represents a document to store in ChromaDB.
type Document struct {
ID string `json:"id"`
Content string `json:"document"`
Metadata map[string]any `json:"metadata"`
}
// QueryResult represents a search result from ChromaDB.
type QueryResult struct {
ID string
Distance float64
Metadata map[string]any
}
// Client is a ChromaDB client that communicates via MCP protocol.
type Client struct {
collection string
dataDir string
pythonVer string
batchSize int
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
mu sync.Mutex
connected bool
requestID int
}
// Config holds configuration for the ChromaDB client.
type Config struct {
Project string
DataDir string
PythonVer string
BatchSize int
}
// NewClient creates a new ChromaDB client.
func NewClient(cfg Config) (*Client, error) {
if cfg.DataDir == "" {
home, _ := os.UserHomeDir()
cfg.DataDir = filepath.Join(home, ".claude-mnemonic", "vector-db")
}
if cfg.PythonVer == "" {
cfg.PythonVer = "3.13"
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100
}
return &Client{
collection: fmt.Sprintf("cm__%s", cfg.Project),
dataDir: cfg.DataDir,
pythonVer: cfg.PythonVer,
batchSize: cfg.BatchSize,
}, nil
}
// Connect starts the ChromaDB MCP server and establishes connection.
func (c *Client) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return nil
}
// Ensure data directory exists
if err := os.MkdirAll(c.dataDir, 0750); err != nil {
return fmt.Errorf("create data dir: %w", err)
}
// Start chroma-mcp server via uvx
c.cmd = exec.CommandContext(ctx, "uvx", // #nosec G204 -- config values from internal settings
"--python", c.pythonVer,
"chroma-mcp",
"--client-type", "persistent",
"--data-dir", c.dataDir,
)
var err error
c.stdin, err = c.cmd.StdinPipe()
if err != nil {
return fmt.Errorf("stdin pipe: %w", err)
}
stdout, err := c.cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("stdout pipe: %w", err)
}
c.stdout = bufio.NewReader(stdout)
c.cmd.Stderr = os.Stderr
if err := c.cmd.Start(); err != nil {
return fmt.Errorf("start chroma-mcp: %w", err)
}
// Send initialize request
if err := c.sendInitialize(); err != nil {
_ = c.Close()
return fmt.Errorf("initialize: %w", err)
}
c.connected = true
log.Info().
Str("collection", c.collection).
Str("dataDir", c.dataDir).
Msg("Connected to ChromaDB")
return nil
}
// sendInitialize sends the MCP initialize request.
func (c *Client) sendInitialize() error {
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{
"name": "claude-mnemonic",
"version": "1.0.0",
},
},
}
if err := c.send(req); err != nil {
return err
}
// Read response
_, err := c.readResponse()
return err
}
// EnsureCollection ensures the collection exists, creating it if needed.
func (c *Client) EnsureCollection(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
// Try to get collection info
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_get_collection_info",
"arguments": map[string]any{
"collection_name": c.collection,
},
},
}
if err := c.send(req); err != nil {
return err
}
resp, err := c.readResponse()
if err != nil {
// Collection doesn't exist, create it
return c.createCollection()
}
// Check if error in response (collection not found)
if _, ok := resp["error"]; ok {
return c.createCollection()
}
return nil
}
// createCollection creates a new collection.
func (c *Client) createCollection() error {
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_create_collection",
"arguments": map[string]any{
"collection_name": c.collection,
"embedding_function_name": "default",
},
},
}
if err := c.send(req); err != nil {
return err
}
_, err := c.readResponse()
if err != nil {
return fmt.Errorf("create collection: %w", err)
}
log.Info().
Str("collection", c.collection).
Msg("Created ChromaDB collection")
return nil
}
// AddDocuments adds documents to the collection in batches.
func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return fmt.Errorf("not connected")
}
for i := 0; i < len(docs); i += c.batchSize {
end := i + c.batchSize
if end > len(docs) {
end = len(docs)
}
batch := docs[i:end]
// Extract fields
documents := make([]string, len(batch))
ids := make([]string, len(batch))
metadatas := make([]map[string]any, len(batch))
for j, doc := range batch {
documents[j] = doc.Content
ids[j] = doc.ID
metadatas[j] = doc.Metadata
}
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_add_documents",
"arguments": map[string]any{
"collection_name": c.collection,
"documents": documents,
"ids": ids,
"metadatas": metadatas,
},
},
}
if err := c.send(req); err != nil {
return fmt.Errorf("send add_documents: %w", err)
}
if _, err := c.readResponse(); err != nil {
return fmt.Errorf("add_documents response: %w", err)
}
log.Debug().
Int("batchStart", i).
Int("batchEnd", end).
Int("total", len(docs)).
Msg("Added document batch")
}
return nil
}
// DeleteDocuments deletes documents from the collection by their IDs.
func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return fmt.Errorf("not connected")
}
if len(ids) == 0 {
return nil
}
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_delete_documents",
"arguments": map[string]any{
"collection_name": c.collection,
"ids": ids,
},
},
}
if err := c.send(req); err != nil {
return fmt.Errorf("send delete_documents: %w", err)
}
if _, err := c.readResponse(); err != nil {
return fmt.Errorf("delete_documents response: %w", err)
}
log.Debug().
Int("count", len(ids)).
Msg("Deleted documents from ChromaDB")
return nil
}
// Query performs a semantic search on the collection.
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil, fmt.Errorf("not connected")
}
args := map[string]any{
"collection_name": c.collection,
"query_texts": []string{query},
"n_results": limit,
"include": []string{"documents", "metadatas", "distances"},
}
if where != nil {
args["where"] = where
}
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_query_documents",
"arguments": args,
},
}
if err := c.send(req); err != nil {
return nil, fmt.Errorf("send query: %w", err)
}
resp, err := c.readResponse()
if err != nil {
return nil, fmt.Errorf("query response: %w", err)
}
return c.parseQueryResults(resp)
}
// parseQueryResults parses the query response into QueryResult structs.
func (c *Client) parseQueryResults(resp map[string]any) ([]QueryResult, error) {
result, ok := resp["result"].(map[string]any)
if !ok {
return nil, nil
}
content, ok := result["content"].([]any)
if !ok || len(content) == 0 {
return nil, nil
}
first, ok := content[0].(map[string]any)
if !ok {
return nil, nil
}
text, ok := first["text"].(string)
if !ok {
return nil, nil
}
var parsed struct {
IDs [][]string `json:"ids"`
Distances [][]float64 `json:"distances"`
Metadatas [][]map[string]any `json:"metadatas"`
}
if err := json.Unmarshal([]byte(text), &parsed); err != nil {
return nil, err
}
if len(parsed.IDs) == 0 || len(parsed.IDs[0]) == 0 {
return nil, nil
}
results := make([]QueryResult, len(parsed.IDs[0]))
for i := range parsed.IDs[0] {
results[i] = QueryResult{
ID: parsed.IDs[0][i],
}
if i < len(parsed.Distances[0]) {
results[i].Distance = parsed.Distances[0][i]
}
if i < len(parsed.Metadatas[0]) {
results[i].Metadata = parsed.Metadatas[0][i]
}
}
return results, nil
}
// send sends a JSON-RPC request to the MCP server.
func (c *Client) send(req map[string]any) error {
data, err := json.Marshal(req)
if err != nil {
return err
}
data = append(data, '\n')
_, err = c.stdin.Write(data)
return err
}
// readResponse reads a JSON-RPC response from the MCP server.
func (c *Client) readResponse() (map[string]any, error) {
line, err := c.stdout.ReadString('\n')
if err != nil {
return nil, err
}
var resp map[string]any
if err := json.Unmarshal([]byte(line), &resp); err != nil {
return nil, err
}
if errObj, ok := resp["error"]; ok {
return nil, fmt.Errorf("MCP error: %v", errObj)
}
return resp, nil
}
// nextID returns the next request ID.
func (c *Client) nextID() int {
c.requestID++
return c.requestID
}
// IsConnected returns whether the client is currently connected to ChromaDB.
func (c *Client) IsConnected() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.connected
}
// Close closes the connection to ChromaDB.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil
}
c.connected = false
if c.stdin != nil {
_ = c.stdin.Close()
}
if c.cmd != nil && c.cmd.Process != nil {
_ = c.cmd.Process.Kill()
_ = c.cmd.Wait()
}
log.Info().
Str("collection", c.collection).
Msg("ChromaDB connection closed")
return nil
}
// Reconnect closes the existing connection and establishes a new one.
// This is useful when the vector database directory has been deleted and recreated.
func (c *Client) Reconnect(ctx context.Context) error {
log.Info().
Str("collection", c.collection).
Msg("Reconnecting to ChromaDB...")
// Close existing connection
if err := c.Close(); err != nil {
log.Warn().Err(err).Msg("Error closing ChromaDB during reconnect")
}
// Small delay to allow cleanup
// (ChromaDB may need a moment to release resources)
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Reconnect
if err := c.Connect(ctx); err != nil {
return fmt.Errorf("reconnect failed: %w", err)
}
// Ensure collection exists
if err := c.EnsureCollection(ctx); err != nil {
return fmt.Errorf("ensure collection after reconnect: %w", err)
}
log.Info().
Str("collection", c.collection).
Msg("ChromaDB reconnected successfully")
return nil
}
-350
View File
@@ -1,350 +0,0 @@
package chroma
import (
"database/sql"
"fmt"
"testing"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
)
// testSync creates a Sync with a nil client for testing format functions.
func testSync() *Sync {
return &Sync{client: nil}
}
func TestSync_FormatObservationDocs(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 1,
SDKSessionID: "test-session",
Project: "test-project",
Scope: models.ScopeProject,
Type: models.ObsTypeDiscovery,
Title: sql.NullString{String: "Test Title", Valid: true},
Subtitle: sql.NullString{String: "Test Subtitle", Valid: true},
Narrative: sql.NullString{String: "Test narrative content", Valid: true},
Facts: models.JSONStringArray{"Fact 1", "Fact 2", "Fact 3"},
Concepts: models.JSONStringArray{"concept1", "concept2"},
FilesRead: models.JSONStringArray{"file1.go", "file2.go"},
FilesModified: models.JSONStringArray{"file3.go"},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Should have 1 narrative + 3 facts = 4 documents
assert.Len(t, docs, 4)
// Check narrative document
narrativeDoc := docs[0]
assert.Equal(t, "obs_1_narrative", narrativeDoc.ID)
assert.Equal(t, "Test narrative content", narrativeDoc.Content)
assert.Equal(t, int64(1), narrativeDoc.Metadata["sqlite_id"])
assert.Equal(t, "observation", narrativeDoc.Metadata["doc_type"])
assert.Equal(t, "narrative", narrativeDoc.Metadata["field_type"])
assert.Equal(t, "test-project", narrativeDoc.Metadata["project"])
assert.Equal(t, "project", narrativeDoc.Metadata["scope"])
assert.Equal(t, "Test Title", narrativeDoc.Metadata["title"])
assert.Equal(t, "Test Subtitle", narrativeDoc.Metadata["subtitle"])
// Check fact documents
for i := 1; i <= 3; i++ {
factDoc := docs[i]
assert.Equal(t, fmt.Sprintf("obs_1_fact_%d", i-1), factDoc.ID)
assert.Equal(t, fmt.Sprintf("Fact %d", i), factDoc.Content)
assert.Equal(t, "fact", factDoc.Metadata["field_type"])
assert.Equal(t, i-1, factDoc.Metadata["fact_index"])
}
}
func TestSync_FormatObservationDocs_NoNarrative(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 2,
SDKSessionID: "test-session",
Project: "test-project",
Scope: models.ScopeGlobal,
Type: models.ObsTypeBugfix,
Facts: models.JSONStringArray{"Only fact"},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Should have 1 fact only (no narrative)
assert.Len(t, docs, 1)
assert.Equal(t, "obs_2_fact_0", docs[0].ID)
assert.Equal(t, "Only fact", docs[0].Content)
assert.Equal(t, "global", docs[0].Metadata["scope"])
}
func TestSync_FormatObservationDocs_Empty(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 3,
SDKSessionID: "test-session",
Project: "test-project",
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Should have no documents when no content
assert.Len(t, docs, 0)
}
func TestSync_FormatObservationDocs_EmptyScope(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 4,
SDKSessionID: "test-session",
Project: "test-project",
Scope: "", // Empty scope
Type: models.ObsTypeDiscovery,
Narrative: sql.NullString{String: "Content", Valid: true},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Empty scope should default to "project"
assert.Len(t, docs, 1)
assert.Equal(t, "project", docs[0].Metadata["scope"])
}
func TestSync_FormatSummaryDocs(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 1,
SDKSessionID: "test-session",
Project: "test-project",
Request: sql.NullString{String: "Add feature", Valid: true},
Investigated: sql.NullString{String: "Looked at code", Valid: true},
Learned: sql.NullString{String: "Found pattern", Valid: true},
Completed: sql.NullString{String: "Done", Valid: true},
NextSteps: sql.NullString{String: "Test it", Valid: true},
Notes: sql.NullString{String: "Notes here", Valid: true},
PromptNumber: sql.NullInt64{Int64: 5, Valid: true},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Should have 6 documents (all fields present)
assert.Len(t, docs, 6)
// Check first document
assert.Equal(t, "summary_1_request", docs[0].ID)
assert.Equal(t, "Add feature", docs[0].Content)
assert.Equal(t, "session_summary", docs[0].Metadata["doc_type"])
assert.Equal(t, "request", docs[0].Metadata["field_type"])
assert.Equal(t, int64(5), docs[0].Metadata["prompt_number"])
}
func TestSync_FormatSummaryDocs_PartialFields(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 2,
SDKSessionID: "test-session",
Project: "test-project",
Request: sql.NullString{String: "Only request", Valid: true},
Completed: sql.NullString{String: "Only completed", Valid: true},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Should have 2 documents (only valid fields)
assert.Len(t, docs, 2)
// Verify field types
fieldTypes := make([]string, len(docs))
for i, doc := range docs {
fieldTypes[i] = doc.Metadata["field_type"].(string)
}
assert.Contains(t, fieldTypes, "request")
assert.Contains(t, fieldTypes, "completed")
}
func TestSync_FormatSummaryDocs_Empty(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 3,
SDKSessionID: "test-session",
Project: "test-project",
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Should have no documents when no content
assert.Len(t, docs, 0)
}
func TestSync_FormatSummaryDocs_EmptyStrings(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 4,
SDKSessionID: "test-session",
Project: "test-project",
Request: sql.NullString{String: "", Valid: true}, // Valid but empty
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Empty strings should not produce documents
assert.Len(t, docs, 0)
}
// Test helper functions
func TestJoinStrings(t *testing.T) {
tests := []struct {
name string
strs []string
sep string
expected string
}{
{"empty", []string{}, ",", ""},
{"single", []string{"a"}, ",", "a"},
{"multiple", []string{"a", "b", "c"}, ",", "a,b,c"},
{"different sep", []string{"a", "b"}, "-", "a-b"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := joinStrings(tt.strs, tt.sep)
assert.Equal(t, tt.expected, result)
})
}
}
func TestCopyMetadata(t *testing.T) {
base := map[string]any{
"key1": "value1",
"key2": 42,
}
result := copyMetadata(base, "key3", "value3")
// Original should be unchanged
assert.Len(t, base, 2)
// Result should have all keys
assert.Len(t, result, 3)
assert.Equal(t, "value1", result["key1"])
assert.Equal(t, 42, result["key2"])
assert.Equal(t, "value3", result["key3"])
}
func TestCopyMetadataMulti(t *testing.T) {
base := map[string]any{
"key1": "value1",
}
extra := map[string]any{
"key2": "value2",
"key3": "value3",
}
result := copyMetadataMulti(base, extra)
// Original should be unchanged
assert.Len(t, base, 1)
// Result should have all keys
assert.Len(t, result, 3)
assert.Equal(t, "value1", result["key1"])
assert.Equal(t, "value2", result["key2"])
assert.Equal(t, "value3", result["key3"])
}
// Test ID generation patterns for delete operations
func TestSync_DeleteObservationIDGeneration(t *testing.T) {
// Test that we generate correct document IDs for deletion
obsIDs := []int64{1, 2}
maxFactsPerObs := 20
ids := make([]string, 0, len(obsIDs)*(maxFactsPerObs+1))
for _, obsID := range obsIDs {
ids = append(ids, fmt.Sprintf("obs_%d_narrative", obsID))
for i := 0; i < maxFactsPerObs; i++ {
ids = append(ids, fmt.Sprintf("obs_%d_fact_%d", obsID, i))
}
}
// Each observation should generate 21 IDs (1 narrative + 20 facts)
assert.Len(t, ids, 42)
// Check some expected IDs
assert.Contains(t, ids, "obs_1_narrative")
assert.Contains(t, ids, "obs_1_fact_0")
assert.Contains(t, ids, "obs_1_fact_19")
assert.Contains(t, ids, "obs_2_narrative")
assert.Contains(t, ids, "obs_2_fact_0")
}
func TestSync_DeletePromptIDGeneration(t *testing.T) {
// Test that we generate correct document IDs for prompt deletion
promptIDs := []int64{10, 20, 30}
ids := make([]string, len(promptIDs))
for i, promptID := range promptIDs {
ids[i] = fmt.Sprintf("prompt_%d", promptID)
}
assert.Len(t, ids, 3)
assert.Contains(t, ids, "prompt_10")
assert.Contains(t, ids, "prompt_20")
assert.Contains(t, ids, "prompt_30")
}
// Test metadata includes all expected fields
func TestSync_ObservationMetadataFields(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 1,
SDKSessionID: "sdk-123",
Project: "my-project",
Scope: models.ScopeGlobal,
Type: models.ObsTypeBugfix,
Title: sql.NullString{String: "Bug Fix", Valid: true},
Subtitle: sql.NullString{String: "Memory leak", Valid: true},
Narrative: sql.NullString{String: "Fixed the leak", Valid: true},
Concepts: models.JSONStringArray{"memory", "performance"},
FilesRead: models.JSONStringArray{"main.go"},
FilesModified: models.JSONStringArray{"fix.go"},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
require := assert.New(t)
require.Len(docs, 1) // Only narrative, no facts
meta := docs[0].Metadata
require.Equal(int64(1), meta["sqlite_id"])
require.Equal("observation", meta["doc_type"])
require.Equal("sdk-123", meta["sdk_session_id"])
require.Equal("my-project", meta["project"])
require.Equal("global", meta["scope"])
require.Equal("bugfix", meta["type"])
require.Equal("Bug Fix", meta["title"])
require.Equal("Memory leak", meta["subtitle"])
require.Equal("memory,performance", meta["concepts"])
require.Equal("main.go", meta["files_read"])
require.Equal("fix.go", meta["files_modified"])
require.Equal(int64(1234567890), meta["created_at_epoch"])
require.Equal("narrative", meta["field_type"])
}
+254
View File
@@ -0,0 +1,254 @@
// Package sqlitevec provides sqlite-vec based vector database integration for claude-mnemonic.
package sqlitevec
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/rs/zerolog/log"
)
// Client provides vector operations via sqlite-vec.
type Client struct {
db *sql.DB
embedSvc *embedding.Service
mu sync.Mutex
}
// Config holds configuration for the client.
type Config struct {
DB *sql.DB
}
// NewClient creates a new sqlite-vec client.
func NewClient(cfg Config, embedSvc *embedding.Service) (*Client, error) {
if cfg.DB == nil {
return nil, fmt.Errorf("database connection required")
}
if embedSvc == nil {
return nil, fmt.Errorf("embedding service required")
}
return &Client{
db: cfg.DB,
embedSvc: embedSvc,
}, nil
}
// AddDocuments adds documents with their embeddings to the vector store.
func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
if len(docs) == 0 {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
// Generate embeddings for all documents
texts := make([]string, len(docs))
for i, doc := range docs {
texts[i] = doc.Content
}
embeddings, err := c.embedSvc.EmbedBatch(texts)
if err != nil {
return fmt.Errorf("generate embeddings: %w", err)
}
// Insert into vectors table
const insertQuery = `
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
stmt, err := tx.PrepareContext(ctx, insertQuery)
if err != nil {
return fmt.Errorf("prepare statement: %w", err)
}
defer stmt.Close()
for i, doc := range docs {
// Serialize embedding to blob format
embBlob, err := sqlite_vec.SerializeFloat32(embeddings[i])
if err != nil {
return fmt.Errorf("serialize embedding for %s: %w", doc.ID, err)
}
// Extract metadata
sqliteID, _ := doc.Metadata["sqlite_id"].(int64)
docType, _ := doc.Metadata["doc_type"].(string)
fieldType, _ := doc.Metadata["field_type"].(string)
project, _ := doc.Metadata["project"].(string)
scope, _ := doc.Metadata["scope"].(string)
_, err = stmt.ExecContext(ctx,
doc.ID,
embBlob,
sqliteID,
docType,
fieldType,
project,
scope,
)
if err != nil {
return fmt.Errorf("insert document %s: %w", doc.ID, err)
}
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec")
return nil
}
// DeleteDocuments removes documents by their IDs.
func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
// Build placeholder string
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
// #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args
query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)",
strings.Join(placeholders, ","))
_, err := c.db.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("delete documents: %w", err)
}
log.Debug().Int("count", len(ids)).Msg("Deleted documents from sqlite-vec")
return nil
}
// Query performs a vector similarity search.
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
c.mu.Lock()
defer c.mu.Unlock()
// Generate query embedding
queryEmb, err := c.embedSvc.Embed(query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
// Serialize query embedding
queryBlob, err := sqlite_vec.SerializeFloat32(queryEmb)
if err != nil {
return nil, fmt.Errorf("serialize query embedding: %w", err)
}
// Build query with filters
// vec0 supports WHERE clauses on metadata columns
args := []interface{}{queryBlob}
sqlQuery := `
SELECT
doc_id,
distance,
sqlite_id,
doc_type,
field_type,
project,
scope
FROM vectors
WHERE embedding MATCH ?
`
// Add filters - these work with vec0 metadata columns
if docType, ok := where["doc_type"].(string); ok && docType != "" {
sqlQuery += " AND doc_type = ?"
args = append(args, docType)
}
if project, ok := where["project"].(string); ok && project != "" {
// Include project-specific OR global scope
sqlQuery += " AND (project = ? OR scope = 'global')"
args = append(args, project)
}
sqlQuery += " ORDER BY distance LIMIT ?"
args = append(args, limit)
rows, err := c.db.QueryContext(ctx, sqlQuery, args...)
if err != nil {
return nil, fmt.Errorf("query vectors: %w", err)
}
defer rows.Close()
var results []QueryResult
for rows.Next() {
var r QueryResult
var sqliteID int64
var docType, fieldType, project, scope sql.NullString
if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docType, &fieldType, &project, &scope); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
r.Metadata = map[string]any{
"sqlite_id": float64(sqliteID), // Keep as float64 for compatibility
"doc_type": docType.String,
"field_type": fieldType.String,
"project": project.String,
"scope": scope.String,
}
results = append(results, r)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("iterate rows: %w", err)
}
log.Debug().
Str("query", truncateString(query, 50)).
Int("results", len(results)).
Msg("Vector search completed")
return results, nil
}
// IsConnected always returns true (no external process).
func (c *Client) IsConnected() bool {
return c.db != nil
}
// Close is a no-op (db managed externally).
func (c *Client) Close() error {
return nil
}
// truncateString truncates a string to maxLen characters.
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
@@ -1,7 +1,7 @@
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
package chroma
// Package sqlitevec provides sqlite-vec based vector database integration for claude-mnemonic.
package sqlitevec
// DocType represents the type of document stored in ChromaDB.
// DocType represents the type of document stored in the vector table.
type DocType string
const (
@@ -10,14 +10,28 @@ const (
DocTypeUserPrompt DocType = "user_prompt"
)
// ExtractedIDs contains SQLite IDs extracted from ChromaDB results, grouped by document type.
// Document represents a document to store with vector embedding.
type Document struct {
ID string
Content string
Metadata map[string]any
}
// QueryResult represents a search result from vector search.
type QueryResult struct {
ID string
Distance float64
Metadata map[string]any
}
// ExtractedIDs contains SQLite IDs extracted from query results, grouped by document type.
type ExtractedIDs struct {
ObservationIDs []int64
SummaryIDs []int64
PromptIDs []int64
}
// BuildWhereFilter creates a where filter map for ChromaDB queries.
// BuildWhereFilter creates a where filter map for vector queries.
// If docType is empty, no doc_type filter is added.
func BuildWhereFilter(docType DocType, project string) map[string]interface{} {
where := make(map[string]interface{})
@@ -30,7 +44,7 @@ func BuildWhereFilter(docType DocType, project string) map[string]interface{} {
return where
}
// ExtractIDsByDocType extracts SQLite IDs from ChromaDB query results,
// ExtractIDsByDocType extracts SQLite IDs from query results,
// grouped by document type and deduplicated.
func ExtractIDsByDocType(results []QueryResult) *ExtractedIDs {
ids := &ExtractedIDs{}
@@ -41,7 +55,12 @@ func ExtractIDsByDocType(results []QueryResult) *ExtractedIDs {
for _, result := range results {
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
if !ok {
continue
// Try int64 directly
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
sqliteID = float64(id)
} else {
continue
}
}
id := int64(sqliteID)
@@ -68,10 +87,8 @@ func ExtractIDsByDocType(results []QueryResult) *ExtractedIDs {
return ids
}
// ExtractObservationIDs extracts observation SQLite IDs from ChromaDB query results,
// ExtractObservationIDs extracts observation SQLite IDs from query results,
// optionally filtering by project or including global scope.
// If project is empty, all observation IDs are returned.
// If project is set, only observations matching the project or with global scope are returned.
func ExtractObservationIDs(results []QueryResult, project string) []int64 {
var ids []int64
seen := make(map[int64]bool)
@@ -79,21 +96,22 @@ func ExtractObservationIDs(results []QueryResult, project string) []int64 {
for _, result := range results {
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
if !ok {
continue
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
sqliteID = float64(id)
} else {
continue
}
}
id := int64(sqliteID)
// Check document type
docType, _ := result.Metadata["doc_type"].(string)
if docType != string(DocTypeObservation) {
continue
}
// Apply project/scope filter if project is specified
if project != "" {
proj, _ := result.Metadata["project"].(string)
scope, _ := result.Metadata["scope"].(string)
// Include if project matches OR scope is global
if proj != project && scope != "global" {
continue
}
@@ -108,7 +126,7 @@ func ExtractObservationIDs(results []QueryResult, project string) []int64 {
return ids
}
// ExtractSummaryIDs extracts session summary SQLite IDs from ChromaDB query results.
// ExtractSummaryIDs extracts session summary SQLite IDs from query results.
func ExtractSummaryIDs(results []QueryResult, project string) []int64 {
var ids []int64
seen := make(map[int64]bool)
@@ -116,7 +134,11 @@ func ExtractSummaryIDs(results []QueryResult, project string) []int64 {
for _, result := range results {
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
if !ok {
continue
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
sqliteID = float64(id)
} else {
continue
}
}
id := int64(sqliteID)
@@ -141,7 +163,7 @@ func ExtractSummaryIDs(results []QueryResult, project string) []int64 {
return ids
}
// ExtractPromptIDs extracts user prompt SQLite IDs from ChromaDB query results.
// ExtractPromptIDs extracts user prompt SQLite IDs from query results.
func ExtractPromptIDs(results []QueryResult, project string) []int64 {
var ids []int64
seen := make(map[int64]bool)
@@ -149,7 +171,11 @@ func ExtractPromptIDs(results []QueryResult, project string) []int64 {
for _, result := range results {
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
if !ok {
continue
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
sqliteID = float64(id)
} else {
continue
}
}
id := int64(sqliteID)
@@ -173,3 +199,36 @@ func ExtractPromptIDs(results []QueryResult, project string) []int64 {
return ids
}
// Helper functions for metadata manipulation
func copyMetadata(base map[string]any, key string, value any) map[string]any {
result := make(map[string]any, len(base)+1)
for k, v := range base {
result[k] = v
}
result[key] = value
return result
}
func copyMetadataMulti(base map[string]any, extra map[string]any) map[string]any {
result := make(map[string]any, len(base)+len(extra))
for k, v := range base {
result[k] = v
}
for k, v := range extra {
result[k] = v
}
return result
}
func joinStrings(strs []string, sep string) string {
if len(strs) == 0 {
return ""
}
result := strs[0]
for i := 1; i < len(strs); i++ {
result += sep + strs[i]
}
return result
}
@@ -1,5 +1,5 @@
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
package chroma
// Package sqlitevec provides sqlite-vec based vector database integration for claude-mnemonic.
package sqlitevec
import (
"context"
@@ -9,17 +9,17 @@ import (
"github.com/rs/zerolog/log"
)
// Sync provides synchronization between SQLite and ChromaDB.
// Sync provides synchronization between SQLite data and vector embeddings.
type Sync struct {
client *Client
}
// NewSync creates a new ChromaDB sync service.
// NewSync creates a new sync service.
func NewSync(client *Client) *Sync {
return &Sync{client: client}
}
// SyncObservation syncs a single observation to ChromaDB.
// SyncObservation syncs a single observation to the vector store.
func (s *Sync) SyncObservation(ctx context.Context, obs *models.Observation) error {
docs := s.formatObservationDocs(obs)
if len(docs) == 0 {
@@ -33,12 +33,12 @@ func (s *Sync) SyncObservation(ctx context.Context, obs *models.Observation) err
log.Debug().
Int64("observationId", obs.ID).
Int("docCount", len(docs)).
Msg("Synced observation to ChromaDB")
Msg("Synced observation to sqlite-vec")
return nil
}
// formatObservationDocs formats an observation into ChromaDB documents.
// formatObservationDocs formats an observation into vector documents.
// Each semantic field becomes a separate vector document (granular approach).
func (s *Sync) formatObservationDocs(obs *models.Observation) []Document {
docs := make([]Document, 0, len(obs.Facts)+2)
@@ -99,7 +99,7 @@ func (s *Sync) formatObservationDocs(obs *models.Observation) []Document {
return docs
}
// SyncSummary syncs a single session summary to ChromaDB.
// SyncSummary syncs a single session summary to the vector store.
func (s *Sync) SyncSummary(ctx context.Context, summary *models.SessionSummary) error {
docs := s.formatSummaryDocs(summary)
if len(docs) == 0 {
@@ -113,12 +113,12 @@ func (s *Sync) SyncSummary(ctx context.Context, summary *models.SessionSummary)
log.Debug().
Int64("summaryId", summary.ID).
Int("docCount", len(docs)).
Msg("Synced summary to ChromaDB")
Msg("Synced summary to sqlite-vec")
return nil
}
// formatSummaryDocs formats a session summary into ChromaDB documents.
// formatSummaryDocs formats a session summary into vector documents.
func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
docs := make([]Document, 0, 6)
@@ -127,6 +127,7 @@ func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
"doc_type": "session_summary",
"sdk_session_id": summary.SDKSessionID,
"project": summary.Project,
"scope": "", // Summaries don't have scope
"created_at_epoch": summary.CreatedAtEpoch,
}
@@ -161,7 +162,7 @@ func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
return docs
}
// SyncUserPrompt syncs a single user prompt to ChromaDB.
// SyncUserPrompt syncs a single user prompt to the vector store.
func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWithSession) error {
doc := Document{
ID: fmt.Sprintf("prompt_%d", prompt.ID),
@@ -171,8 +172,10 @@ func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWith
"doc_type": "user_prompt",
"sdk_session_id": prompt.SDKSessionID,
"project": prompt.Project,
"scope": "", // Prompts don't have scope
"created_at_epoch": prompt.CreatedAtEpoch,
"prompt_number": prompt.PromptNumber,
"field_type": "prompt",
},
}
@@ -182,14 +185,12 @@ func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWith
log.Debug().
Int64("promptId", prompt.ID).
Msg("Synced user prompt to ChromaDB")
Msg("Synced user prompt to sqlite-vec")
return nil
}
// DeleteObservations removes observation documents from ChromaDB.
// Since each observation may have multiple documents (narrative + facts),
// we delete by the sqlite_id metadata prefix pattern.
// DeleteObservations removes observation documents from the vector store.
func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) error {
if len(observationIDs) == 0 {
return nil
@@ -197,7 +198,6 @@ func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) e
// Generate all possible document IDs for these observations
// Pattern: obs_{id}_narrative, obs_{id}_fact_{0..n}
// Since we don't know how many facts each had, we use a reasonable upper bound
const maxFactsPerObs = 20
ids := make([]string, 0, len(observationIDs)*(maxFactsPerObs+1))
@@ -214,18 +214,17 @@ func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) e
log.Debug().
Int("observationCount", len(observationIDs)).
Msg("Deleted observations from ChromaDB")
Msg("Deleted observations from sqlite-vec")
return nil
}
// DeleteUserPrompts removes user prompt documents from ChromaDB.
// DeleteUserPrompts removes user prompt documents from the vector store.
func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
if len(promptIDs) == 0 {
return nil
}
// Each prompt is stored as a single document with ID pattern: prompt_{id}
ids := make([]string, len(promptIDs))
for i, promptID := range promptIDs {
ids[i] = fmt.Sprintf("prompt_%d", promptID)
@@ -237,40 +236,7 @@ func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
log.Debug().
Int("promptCount", len(promptIDs)).
Msg("Deleted user prompts from ChromaDB")
Msg("Deleted user prompts from sqlite-vec")
return nil
}
// Helper functions
func copyMetadata(base map[string]any, key string, value any) map[string]any {
result := make(map[string]any, len(base)+1)
for k, v := range base {
result[k] = v
}
result[key] = value
return result
}
func copyMetadataMulti(base map[string]any, extra map[string]any) map[string]any {
result := make(map[string]any, len(base)+len(extra))
for k, v := range base {
result[k] = v
}
for k, v := range extra {
result[k] = v
}
return result
}
func joinStrings(strs []string, sep string) string {
if len(strs) == 0 {
return ""
}
result := strs[0]
for i := 1; i < len(strs); i++ {
result += sep + strs[i]
}
return result
}