mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Move from chroma to sqlitevec with local embedding
This commit is contained in:
+16
-41
@@ -11,9 +11,10 @@ import (
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/mcp"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/chroma"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -40,7 +41,7 @@ func main() {
|
||||
log.Fatal().Msg("--project is required")
|
||||
}
|
||||
|
||||
// Ensure data directory, vector-db, and settings exist
|
||||
// Ensure data directory and settings exist
|
||||
if err := config.EnsureAll(); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to ensure data directories")
|
||||
}
|
||||
@@ -54,10 +55,8 @@ func main() {
|
||||
|
||||
// Override data directory if specified
|
||||
dbPath := cfg.DBPath
|
||||
vectorDBPath := cfg.VectorDBPath
|
||||
if *dataDir != "" {
|
||||
dbPath = *dataDir + "/claude-mnemonic.db"
|
||||
vectorDBPath = *dataDir + "/vector-db"
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -89,31 +88,26 @@ func main() {
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
|
||||
// Initialize ChromaDB client (optional)
|
||||
var chromaClient *chroma.Client
|
||||
chromaCfg := chroma.Config{
|
||||
Project: *project,
|
||||
DataDir: vectorDBPath,
|
||||
PythonVer: cfg.PythonVersion,
|
||||
BatchSize: 100,
|
||||
}
|
||||
chromaClient, err = chroma.NewClient(chromaCfg)
|
||||
// Initialize embedding service and vector client
|
||||
var vectorClient *sqlitevec.Client
|
||||
embedSvc, err := embedding.NewService()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB unavailable, vector search disabled")
|
||||
log.Warn().Err(err).Msg("Embedding service unavailable, vector search disabled")
|
||||
} else {
|
||||
if err := chromaClient.Connect(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to connect to ChromaDB, vector search disabled")
|
||||
chromaClient = nil
|
||||
defer embedSvc.Close()
|
||||
vectorClient, err = sqlitevec.NewClient(sqlitevec.Config{DB: store.DB()}, embedSvc)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Vector client unavailable, vector search disabled")
|
||||
} else {
|
||||
defer chromaClient.Close()
|
||||
log.Info().Msg("Vector search enabled via sqlite-vec")
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize search manager
|
||||
searchMgr := search.NewManager(observationStore, summaryStore, promptStore, chromaClient)
|
||||
searchMgr := search.NewManager(observationStore, summaryStore, promptStore, vectorClient)
|
||||
|
||||
// Start file watchers
|
||||
startWatchers(ctx, vectorDBPath, chromaClient)
|
||||
startWatchers(ctx, dbPath)
|
||||
|
||||
// Create and run MCP server
|
||||
server := mcp.NewServer(searchMgr, Version)
|
||||
@@ -124,27 +118,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// startWatchers initializes file watchers for vector DB and config.
|
||||
func startWatchers(ctx context.Context, vectorDBPath string, chromaClient *chroma.Client) {
|
||||
// Watch vector DB directory for deletion
|
||||
if chromaClient != nil {
|
||||
vectorWatcher, err := watcher.New(vectorDBPath, func() {
|
||||
log.Warn().Str("path", vectorDBPath).Msg("Vector database deleted, reconnecting ChromaDB...")
|
||||
if err := chromaClient.Reconnect(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reconnect ChromaDB after deletion")
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create vector DB watcher")
|
||||
} else {
|
||||
if err := vectorWatcher.Start(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to start vector DB watcher")
|
||||
} else {
|
||||
log.Info().Str("path", vectorDBPath).Msg("Vector DB file watcher started")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startWatchers initializes file watchers for config.
|
||||
func startWatchers(ctx context.Context, dbPath string) {
|
||||
// Watch config file for changes (triggers process exit for restart)
|
||||
configPath := config.SettingsPath()
|
||||
configWatcher, err := watcher.New(configPath, func() {
|
||||
|
||||
@@ -1,21 +1,33 @@
|
||||
module github.com/lukaszraczylo/claude-mnemonic
|
||||
|
||||
go 1.24.0
|
||||
go 1.25.1
|
||||
|
||||
replace github.com/sugarme/tokenizer => github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43
|
||||
|
||||
require (
|
||||
github.com/asg017/sqlite-vec-go-bindings v0.1.6
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/mattn/go-sqlite3 v1.14.32
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/sugarme/tokenizer v0.3.0
|
||||
github.com/yalue/onnxruntime_go v1.21.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/schollz/progressbar/v2 v2.15.0 // indirect
|
||||
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww=
|
||||
github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q=
|
||||
github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43 h1:j8YQypEqa5OjqbGciCNb9hOcYbo1oTVuEjd/iu9U2SY=
|
||||
github.com/clems4ever/tokenizer v0.0.0-20250926133620-9ddc80533c43/go.mod h1:VJ+DLK5ZEZwzvODOWwY0cw+B1dabTd3nCB5HuFCItCc=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
@@ -17,19 +24,35 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8=
|
||||
github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89ymCaN+5lPH/MwnfSVg4rzGDh4d+iy4=
|
||||
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
|
||||
github.com/yalue/onnxruntime_go v1.21.0 h1:DdtvfY7OP5gR8mwPDqAOAQckf+KcI30hPNJL8hQaYWI=
|
||||
github.com/yalue/onnxruntime_go v1.21.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -13,9 +13,6 @@ const (
|
||||
// DefaultWorkerPort is the default HTTP port for the worker service.
|
||||
DefaultWorkerPort = 37777
|
||||
|
||||
// DefaultPythonVersion for ChromaDB (avoid onnxruntime issues with 3.14+).
|
||||
DefaultPythonVersion = "3.13"
|
||||
|
||||
// DefaultModel for SDK agent (use "haiku" for cost-efficient processing).
|
||||
// Claude Code CLI accepts aliases: haiku, sonnet, opus (always latest versions)
|
||||
DefaultModel = "haiku"
|
||||
@@ -47,10 +44,6 @@ type Config struct {
|
||||
DBPath string `json:"db_path"`
|
||||
MaxConns int `json:"max_conns"`
|
||||
|
||||
// ChromaDB settings
|
||||
VectorDBPath string `json:"vector_db_path"`
|
||||
PythonVersion string `json:"python_version"`
|
||||
|
||||
// SDK Agent settings
|
||||
Model string `json:"model"`
|
||||
ClaudeCodePath string `json:"claude_code_path"`
|
||||
@@ -84,11 +77,6 @@ func DBPath() string {
|
||||
return filepath.Join(DataDir(), "claude-mnemonic.db")
|
||||
}
|
||||
|
||||
// VectorDBPath returns the vector database directory path.
|
||||
func VectorDBPath() string {
|
||||
return filepath.Join(DataDir(), "vector-db")
|
||||
}
|
||||
|
||||
// SettingsPath returns the settings file path.
|
||||
func SettingsPath() string {
|
||||
return filepath.Join(DataDir(), "settings.json")
|
||||
@@ -99,11 +87,6 @@ func EnsureDataDir() error {
|
||||
return os.MkdirAll(DataDir(), 0750)
|
||||
}
|
||||
|
||||
// EnsureVectorDBDir creates the vector database directory if it doesn't exist.
|
||||
func EnsureVectorDBDir() error {
|
||||
return os.MkdirAll(VectorDBPath(), 0750)
|
||||
}
|
||||
|
||||
// EnsureSettings creates a default settings file if it doesn't exist.
|
||||
func EnsureSettings() error {
|
||||
path := SettingsPath()
|
||||
@@ -116,7 +99,6 @@ func EnsureSettings() error {
|
||||
// Create default settings file with comments
|
||||
defaultSettings := `{
|
||||
"CLAUDE_MNEMONIC_WORKER_PORT": 37777,
|
||||
"CLAUDE_MNEMONIC_PYTHON_VERSION": "3.13",
|
||||
"CLAUDE_MNEMONIC_MODEL": "haiku",
|
||||
"CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 100,
|
||||
"CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25,
|
||||
@@ -131,9 +113,6 @@ func EnsureAll() error {
|
||||
if err := EnsureDataDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := EnsureVectorDBDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := EnsureSettings(); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -146,8 +125,6 @@ func Default() *Config {
|
||||
WorkerPort: DefaultWorkerPort,
|
||||
DBPath: DBPath(),
|
||||
MaxConns: 4,
|
||||
VectorDBPath: VectorDBPath(),
|
||||
PythonVersion: DefaultPythonVersion,
|
||||
Model: DefaultModel,
|
||||
ContextObservations: 100,
|
||||
ContextFullCount: 25,
|
||||
@@ -183,9 +160,6 @@ func Load() (*Config, error) {
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_WORKER_PORT"].(float64); ok {
|
||||
cfg.WorkerPort = int(v)
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_PYTHON_VERSION"].(string); ok {
|
||||
cfg.PythonVersion = v
|
||||
}
|
||||
if v, ok := settings["CLAUDE_MNEMONIC_MODEL"].(string); ok {
|
||||
cfg.Model = v
|
||||
}
|
||||
|
||||
@@ -254,6 +254,24 @@ var Migrations = []Migration{
|
||||
ALTER TABLE user_prompts ADD COLUMN matched_observations INTEGER DEFAULT 0;
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 17,
|
||||
Name: "sqlite_vec_vectors",
|
||||
SQL: `
|
||||
-- Vector embeddings table using sqlite-vec
|
||||
-- Each document (narrative, fact, summary field, prompt) gets one vector
|
||||
-- Uses all-MiniLM-L6-v2 embeddings (384 dimensions)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vectors USING vec0(
|
||||
doc_id TEXT PRIMARY KEY,
|
||||
embedding float[384],
|
||||
sqlite_id INTEGER,
|
||||
doc_type TEXT,
|
||||
field_type TEXT,
|
||||
project TEXT,
|
||||
scope TEXT
|
||||
);
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
// MigrationManager handles database schema migrations.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
@@ -26,6 +27,9 @@ type StoreConfig struct {
|
||||
|
||||
// NewStore creates a new database store with the given configuration.
|
||||
func NewStore(cfg StoreConfig) (*Store, error) {
|
||||
// Register sqlite-vec extension for vector operations
|
||||
sqlite_vec.Auto()
|
||||
|
||||
// Build connection string with pragmas
|
||||
connStr := cfg.Path + "?_journal_mode=WAL&_synchronous=NORMAL&_foreign_keys=ON"
|
||||
|
||||
@@ -137,3 +141,9 @@ func (s *Store) QueryRowContext(ctx context.Context, query string, args ...inter
|
||||
func (s *Store) Ping() error {
|
||||
return s.db.Ping()
|
||||
}
|
||||
|
||||
// DB returns the underlying database connection for direct access.
|
||||
// Use this sparingly - prefer the store methods for most operations.
|
||||
func (s *Store) DB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
// Model and tokenizer files - embedded for all platforms
|
||||
//
|
||||
//go:embed assets/model.onnx
|
||||
var modelData []byte
|
||||
|
||||
//go:embed assets/tokenizer.json
|
||||
var tokenizerData []byte
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,8 @@
|
||||
//go:build darwin
|
||||
|
||||
package embedding
|
||||
|
||||
// Darwin doesn't need the providers shared library
|
||||
var onnxRuntimeProvidersLib []byte
|
||||
|
||||
const onnxRuntimeProvidersLibName = ""
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build darwin && amd64
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed assets/lib/darwin-amd64/libonnxruntime.dylib
|
||||
var onnxRuntimeLib []byte
|
||||
|
||||
const onnxRuntimeLibName = "libonnxruntime.dylib"
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build darwin && arm64
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed assets/lib/darwin-arm64/libonnxruntime.dylib
|
||||
var onnxRuntimeLib []byte
|
||||
|
||||
const onnxRuntimeLibName = "libonnxruntime.dylib"
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build linux && amd64
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed assets/lib/linux-amd64/libonnxruntime.so
|
||||
var onnxRuntimeLib []byte
|
||||
|
||||
//go:embed assets/lib/linux-amd64/libonnxruntime_providers_shared.so
|
||||
var onnxRuntimeProvidersLib []byte
|
||||
|
||||
const onnxRuntimeLibName = "libonnxruntime.so"
|
||||
const onnxRuntimeProvidersLibName = "libonnxruntime_providers_shared.so"
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build linux && arm64
|
||||
|
||||
package embedding
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed assets/lib/linux-arm64/libonnxruntime.so
|
||||
var onnxRuntimeLib []byte
|
||||
|
||||
//go:embed assets/lib/linux-arm64/libonnxruntime_providers_shared.so
|
||||
var onnxRuntimeProvidersLib []byte
|
||||
|
||||
const onnxRuntimeLibName = "libonnxruntime.so"
|
||||
const onnxRuntimeProvidersLibName = "libonnxruntime_providers_shared.so"
|
||||
@@ -0,0 +1,291 @@
|
||||
// Package embedding provides text embedding generation using all-MiniLM-L6-v2.
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/sugarme/tokenizer"
|
||||
"github.com/sugarme/tokenizer/pretrained"
|
||||
ort "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
// EmbeddingDim is the dimension of embeddings produced by all-MiniLM-L6-v2.
|
||||
const EmbeddingDim = 384
|
||||
|
||||
// Service provides thread-safe text embedding generation.
|
||||
type Service struct {
|
||||
tk *tokenizer.Tokenizer
|
||||
session *ort.DynamicAdvancedSession
|
||||
mu sync.Mutex
|
||||
libDir string // temp directory containing extracted libraries
|
||||
}
|
||||
|
||||
// NewService creates a new embedding service using bundled ONNX runtime and model.
|
||||
func NewService() (*Service, error) {
|
||||
// Extract ONNX runtime library to temp directory
|
||||
libDir, err := extractONNXLibrary()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("extract ONNX library: %w", err)
|
||||
}
|
||||
|
||||
// Set the library path
|
||||
libPath := filepath.Join(libDir, onnxRuntimeLibName)
|
||||
ort.SetSharedLibraryPath(libPath)
|
||||
|
||||
// Initialize ONNX runtime
|
||||
if err := ort.InitializeEnvironment(); err != nil {
|
||||
return nil, fmt.Errorf("initialize ONNX runtime: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer from embedded data
|
||||
tk, err := pretrained.FromReader(bytes.NewReader(tokenizerData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Create ONNX session with embedded model
|
||||
inputNames := []string{"input_ids", "attention_mask", "token_type_ids"}
|
||||
outputNames := []string{"sentence_embedding"}
|
||||
|
||||
session, err := ort.NewDynamicAdvancedSessionWithONNXData(modelData, inputNames, outputNames, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ONNX session: %w", err)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
tk: tk,
|
||||
session: session,
|
||||
libDir: libDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractONNXLibrary extracts the embedded ONNX runtime library to a temp directory.
|
||||
// Uses content hash to avoid re-extracting if already present.
|
||||
func extractONNXLibrary() (string, error) {
|
||||
// Create a hash of the library content for cache key
|
||||
hash := sha256.Sum256(onnxRuntimeLib)
|
||||
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes
|
||||
|
||||
// Create cache directory
|
||||
cacheDir := filepath.Join(os.TempDir(), "claude-mnemonic-onnx", hashStr)
|
||||
libPath := filepath.Join(cacheDir, onnxRuntimeLibName)
|
||||
|
||||
// Check if already extracted
|
||||
if _, err := os.Stat(libPath); err == nil {
|
||||
return cacheDir, nil
|
||||
}
|
||||
|
||||
// Create directory
|
||||
// #nosec G301 -- Cache directory needs 0755 for user access
|
||||
if err := os.MkdirAll(cacheDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("create cache dir: %w", err)
|
||||
}
|
||||
|
||||
// Write main library
|
||||
// #nosec G306 -- Shared library needs executable permission (0755) for dynamic linker
|
||||
if err := os.WriteFile(libPath, onnxRuntimeLib, 0755); err != nil {
|
||||
return "", fmt.Errorf("write library: %w", err)
|
||||
}
|
||||
|
||||
// Write providers library if present (Linux only)
|
||||
if len(onnxRuntimeProvidersLib) > 0 && onnxRuntimeProvidersLibName != "" {
|
||||
providersPath := filepath.Join(cacheDir, onnxRuntimeProvidersLibName)
|
||||
// #nosec G306 -- Shared library needs executable permission (0755) for dynamic linker
|
||||
if err := os.WriteFile(providersPath, onnxRuntimeProvidersLib, 0755); err != nil {
|
||||
return "", fmt.Errorf("write providers library: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return cacheDir, nil
|
||||
}
|
||||
|
||||
// Embed generates an embedding for a single text.
|
||||
// Returns a 384-dimensional float32 vector.
|
||||
func (s *Service) Embed(text string) ([]float32, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if text == "" {
|
||||
return make([]float32, EmbeddingDim), nil
|
||||
}
|
||||
|
||||
results, err := s.computeBatch([]string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return make([]float32, EmbeddingDim), nil
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
// Returns slice of 384-dimensional float32 vectors.
|
||||
func (s *Service) EmbedBatch(texts []string) ([][]float32, error) {
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Filter out empty texts and track indices
|
||||
nonEmpty := make([]string, 0, len(texts))
|
||||
indices := make([]int, 0, len(texts))
|
||||
for i, t := range texts {
|
||||
if t != "" {
|
||||
nonEmpty = append(nonEmpty, t)
|
||||
indices = append(indices, i)
|
||||
}
|
||||
}
|
||||
|
||||
// If all texts are empty, return zero vectors
|
||||
if len(nonEmpty) == 0 {
|
||||
results := make([][]float32, len(texts))
|
||||
for i := range results {
|
||||
results[i] = make([]float32, EmbeddingDim)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Compute embeddings for non-empty texts
|
||||
embeddings, err := s.computeBatch(nonEmpty)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compute batch embeddings: %w", err)
|
||||
}
|
||||
|
||||
// Build result with zero vectors for empty texts
|
||||
results := make([][]float32, len(texts))
|
||||
for i := range results {
|
||||
results[i] = make([]float32, EmbeddingDim)
|
||||
}
|
||||
for i, idx := range indices {
|
||||
results[idx] = embeddings[i]
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// computeBatch runs inference on a batch of texts. Must be called with lock held.
|
||||
func (s *Service) computeBatch(sentences []string) ([][]float32, error) {
|
||||
if len(sentences) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Tokenize all sentences
|
||||
inputBatch := make([]tokenizer.EncodeInput, len(sentences))
|
||||
for i, sent := range sentences {
|
||||
inputBatch[i] = tokenizer.NewSingleEncodeInput(tokenizer.NewRawInputSequence(sent))
|
||||
}
|
||||
|
||||
encodings, err := s.tk.EncodeBatch(inputBatch, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenize: %w", err)
|
||||
}
|
||||
|
||||
batchSize := len(encodings)
|
||||
seqLength := len(encodings[0].Ids)
|
||||
hiddenSize := EmbeddingDim
|
||||
|
||||
inputShape := ort.NewShape(int64(batchSize), int64(seqLength))
|
||||
|
||||
// Create input tensors
|
||||
inputIdsData := make([]int64, batchSize*seqLength)
|
||||
attentionMaskData := make([]int64, batchSize*seqLength)
|
||||
tokenTypeIdsData := make([]int64, batchSize*seqLength)
|
||||
|
||||
for b := 0; b < batchSize; b++ {
|
||||
for i, id := range encodings[b].Ids {
|
||||
inputIdsData[b*seqLength+i] = int64(id)
|
||||
}
|
||||
for i, mask := range encodings[b].AttentionMask {
|
||||
attentionMaskData[b*seqLength+i] = int64(mask)
|
||||
}
|
||||
for i, typeId := range encodings[b].TypeIds {
|
||||
tokenTypeIdsData[b*seqLength+i] = int64(typeId)
|
||||
}
|
||||
}
|
||||
|
||||
inputIdsTensor, err := ort.NewTensor(inputShape, inputIdsData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create input_ids tensor: %w", err)
|
||||
}
|
||||
defer inputIdsTensor.Destroy()
|
||||
|
||||
attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create attention_mask tensor: %w", err)
|
||||
}
|
||||
defer attentionMaskTensor.Destroy()
|
||||
|
||||
tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create token_type_ids tensor: %w", err)
|
||||
}
|
||||
defer tokenTypeIdsTensor.Destroy()
|
||||
|
||||
sentenceOutputShape := ort.NewShape(int64(batchSize), int64(hiddenSize))
|
||||
sentenceOutputTensor, err := ort.NewEmptyTensor[float32](sentenceOutputShape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create output tensor: %w", err)
|
||||
}
|
||||
defer sentenceOutputTensor.Destroy()
|
||||
|
||||
// Run inference
|
||||
inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor}
|
||||
outputTensors := []ort.Value{sentenceOutputTensor}
|
||||
|
||||
if err := s.session.Run(inputTensors, outputTensors); err != nil {
|
||||
return nil, fmt.Errorf("run inference: %w", err)
|
||||
}
|
||||
|
||||
// Extract results
|
||||
flatOutput := sentenceOutputTensor.GetData()
|
||||
expectedSize := batchSize * hiddenSize
|
||||
if len(flatOutput) != expectedSize {
|
||||
return nil, fmt.Errorf("unexpected output size: got %d, expected %d", len(flatOutput), expectedSize)
|
||||
}
|
||||
|
||||
results := make([][]float32, batchSize)
|
||||
for i := 0; i < batchSize; i++ {
|
||||
start := i * hiddenSize
|
||||
end := start + hiddenSize
|
||||
results[i] = make([]float32, hiddenSize)
|
||||
copy(results[i], flatOutput[start:end])
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Close releases model resources.
|
||||
func (s *Service) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var errs []error
|
||||
|
||||
if s.session != nil {
|
||||
if err := s.session.Destroy(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("destroy session: %w", err))
|
||||
}
|
||||
s.session = nil
|
||||
}
|
||||
|
||||
if err := ort.DestroyEnvironment(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("destroy environment: %w", err))
|
||||
}
|
||||
|
||||
// Optionally clean up extracted library (leave for caching)
|
||||
// os.RemoveAll(s.libDir)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+16
-16
@@ -6,16 +6,16 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/chroma"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
)
|
||||
|
||||
// Manager provides unified search across SQLite and ChromaDB.
|
||||
// Manager provides unified search across SQLite and sqlite-vec.
|
||||
type Manager struct {
|
||||
observationStore *sqlite.ObservationStore
|
||||
summaryStore *sqlite.SummaryStore
|
||||
promptStore *sqlite.PromptStore
|
||||
chromaClient *chroma.Client
|
||||
vectorClient *sqlitevec.Client
|
||||
}
|
||||
|
||||
// NewManager creates a new search manager.
|
||||
@@ -23,13 +23,13 @@ func NewManager(
|
||||
observationStore *sqlite.ObservationStore,
|
||||
summaryStore *sqlite.SummaryStore,
|
||||
promptStore *sqlite.PromptStore,
|
||||
chromaClient *chroma.Client,
|
||||
vectorClient *sqlitevec.Client,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
promptStore: promptStore,
|
||||
chromaClient: chromaClient,
|
||||
vectorClient: vectorClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*Unif
|
||||
params.OrderBy = "date_desc"
|
||||
}
|
||||
|
||||
// If query is provided and Chroma is available, use vector search
|
||||
if params.Query != "" && m.chromaClient != nil {
|
||||
// If query is provided and vector client is available, use vector search
|
||||
if params.Query != "" && m.vectorClient != nil && m.vectorClient.IsConnected() {
|
||||
return m.vectorSearch(ctx, params)
|
||||
}
|
||||
|
||||
@@ -92,29 +92,29 @@ func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*Unif
|
||||
return m.filterSearch(ctx, params)
|
||||
}
|
||||
|
||||
// vectorSearch performs semantic search via ChromaDB.
|
||||
// vectorSearch performs semantic search via sqlite-vec.
|
||||
func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
|
||||
// Build where filter based on search type
|
||||
var docType chroma.DocType
|
||||
var docType sqlitevec.DocType
|
||||
switch params.Type {
|
||||
case "observations":
|
||||
docType = chroma.DocTypeObservation
|
||||
docType = sqlitevec.DocTypeObservation
|
||||
case "sessions":
|
||||
docType = chroma.DocTypeSessionSummary
|
||||
docType = sqlitevec.DocTypeSessionSummary
|
||||
case "prompts":
|
||||
docType = chroma.DocTypeUserPrompt
|
||||
docType = sqlitevec.DocTypeUserPrompt
|
||||
}
|
||||
where := chroma.BuildWhereFilter(docType, params.Project)
|
||||
where := sqlitevec.BuildWhereFilter(docType, params.Project)
|
||||
|
||||
// Query ChromaDB
|
||||
chromaResults, err := m.chromaClient.Query(ctx, params.Query, params.Limit*2, where)
|
||||
// Query sqlite-vec
|
||||
vectorResults, err := m.vectorClient.Query(ctx, params.Query, params.Limit*2, where)
|
||||
if err != nil {
|
||||
// Fall back to filter search on error
|
||||
return m.filterSearch(ctx, params)
|
||||
}
|
||||
|
||||
// Extract IDs grouped by document type using shared helper
|
||||
extracted := chroma.ExtractIDsByDocType(chromaResults)
|
||||
extracted := sqlitevec.ExtractIDsByDocType(vectorResults)
|
||||
obsIDs := extracted.ObservationIDs
|
||||
summaryIDs := extracted.SummaryIDs
|
||||
promptIDs := extracted.PromptIDs
|
||||
|
||||
@@ -1,521 +0,0 @@
|
||||
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
|
||||
package chroma
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Document represents a document to store in ChromaDB.
|
||||
type Document struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"document"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
}
|
||||
|
||||
// QueryResult represents a search result from ChromaDB.
|
||||
type QueryResult struct {
|
||||
ID string
|
||||
Distance float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// Client is a ChromaDB client that communicates via MCP protocol.
|
||||
type Client struct {
|
||||
collection string
|
||||
dataDir string
|
||||
pythonVer string
|
||||
batchSize int
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
mu sync.Mutex
|
||||
|
||||
connected bool
|
||||
requestID int
|
||||
}
|
||||
|
||||
// Config holds configuration for the ChromaDB client.
|
||||
type Config struct {
|
||||
Project string
|
||||
DataDir string
|
||||
PythonVer string
|
||||
BatchSize int
|
||||
}
|
||||
|
||||
// NewClient creates a new ChromaDB client.
|
||||
func NewClient(cfg Config) (*Client, error) {
|
||||
if cfg.DataDir == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
cfg.DataDir = filepath.Join(home, ".claude-mnemonic", "vector-db")
|
||||
}
|
||||
if cfg.PythonVer == "" {
|
||||
cfg.PythonVer = "3.13"
|
||||
}
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 100
|
||||
}
|
||||
|
||||
return &Client{
|
||||
collection: fmt.Sprintf("cm__%s", cfg.Project),
|
||||
dataDir: cfg.DataDir,
|
||||
pythonVer: cfg.PythonVer,
|
||||
batchSize: cfg.BatchSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect starts the ChromaDB MCP server and establishes connection.
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure data directory exists
|
||||
if err := os.MkdirAll(c.dataDir, 0750); err != nil {
|
||||
return fmt.Errorf("create data dir: %w", err)
|
||||
}
|
||||
|
||||
// Start chroma-mcp server via uvx
|
||||
c.cmd = exec.CommandContext(ctx, "uvx", // #nosec G204 -- config values from internal settings
|
||||
"--python", c.pythonVer,
|
||||
"chroma-mcp",
|
||||
"--client-type", "persistent",
|
||||
"--data-dir", c.dataDir,
|
||||
)
|
||||
|
||||
var err error
|
||||
c.stdin, err = c.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := c.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
c.stdout = bufio.NewReader(stdout)
|
||||
|
||||
c.cmd.Stderr = os.Stderr
|
||||
|
||||
if err := c.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start chroma-mcp: %w", err)
|
||||
}
|
||||
|
||||
// Send initialize request
|
||||
if err := c.sendInitialize(); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
}
|
||||
|
||||
c.connected = true
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Str("dataDir", c.dataDir).
|
||||
Msg("Connected to ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendInitialize sends the MCP initialize request.
|
||||
func (c *Client) sendInitialize() error {
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "initialize",
|
||||
"params": map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{
|
||||
"name": "claude-mnemonic",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read response
|
||||
_, err := c.readResponse()
|
||||
return err
|
||||
}
|
||||
|
||||
// EnsureCollection ensures the collection exists, creating it if needed.
|
||||
func (c *Client) EnsureCollection(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Try to get collection info
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_get_collection_info",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
// Collection doesn't exist, create it
|
||||
return c.createCollection()
|
||||
}
|
||||
|
||||
// Check if error in response (collection not found)
|
||||
if _, ok := resp["error"]; ok {
|
||||
return c.createCollection()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createCollection creates a new collection.
|
||||
func (c *Client) createCollection() error {
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_create_collection",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"embedding_function_name": "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := c.readResponse()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create collection: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("Created ChromaDB collection")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDocuments adds documents to the collection in batches.
|
||||
func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
for i := 0; i < len(docs); i += c.batchSize {
|
||||
end := i + c.batchSize
|
||||
if end > len(docs) {
|
||||
end = len(docs)
|
||||
}
|
||||
batch := docs[i:end]
|
||||
|
||||
// Extract fields
|
||||
documents := make([]string, len(batch))
|
||||
ids := make([]string, len(batch))
|
||||
metadatas := make([]map[string]any, len(batch))
|
||||
for j, doc := range batch {
|
||||
documents[j] = doc.Content
|
||||
ids[j] = doc.ID
|
||||
metadatas[j] = doc.Metadata
|
||||
}
|
||||
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_add_documents",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"documents": documents,
|
||||
"ids": ids,
|
||||
"metadatas": metadatas,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return fmt.Errorf("send add_documents: %w", err)
|
||||
}
|
||||
|
||||
if _, err := c.readResponse(); err != nil {
|
||||
return fmt.Errorf("add_documents response: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("batchStart", i).
|
||||
Int("batchEnd", end).
|
||||
Int("total", len(docs)).
|
||||
Msg("Added document batch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDocuments deletes documents from the collection by their IDs.
|
||||
func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_delete_documents",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"ids": ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return fmt.Errorf("send delete_documents: %w", err)
|
||||
}
|
||||
|
||||
if _, err := c.readResponse(); err != nil {
|
||||
return fmt.Errorf("delete_documents response: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("count", len(ids)).
|
||||
Msg("Deleted documents from ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query performs a semantic search on the collection.
|
||||
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
args := map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"query_texts": []string{query},
|
||||
"n_results": limit,
|
||||
"include": []string{"documents", "metadatas", "distances"},
|
||||
}
|
||||
if where != nil {
|
||||
args["where"] = where
|
||||
}
|
||||
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_query_documents",
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return nil, fmt.Errorf("send query: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query response: %w", err)
|
||||
}
|
||||
|
||||
return c.parseQueryResults(resp)
|
||||
}
|
||||
|
||||
// parseQueryResults parses the query response into QueryResult structs.
|
||||
func (c *Client) parseQueryResults(resp map[string]any) ([]QueryResult, error) {
|
||||
result, ok := resp["result"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
content, ok := result["content"].([]any)
|
||||
if !ok || len(content) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
first, ok := content[0].(map[string]any)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
text, ok := first["text"].(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
IDs [][]string `json:"ids"`
|
||||
Distances [][]float64 `json:"distances"`
|
||||
Metadatas [][]map[string]any `json:"metadatas"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(text), &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(parsed.IDs) == 0 || len(parsed.IDs[0]) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
results := make([]QueryResult, len(parsed.IDs[0]))
|
||||
for i := range parsed.IDs[0] {
|
||||
results[i] = QueryResult{
|
||||
ID: parsed.IDs[0][i],
|
||||
}
|
||||
if i < len(parsed.Distances[0]) {
|
||||
results[i].Distance = parsed.Distances[0][i]
|
||||
}
|
||||
if i < len(parsed.Metadatas[0]) {
|
||||
results[i].Metadata = parsed.Metadatas[0][i]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// send sends a JSON-RPC request to the MCP server.
|
||||
func (c *Client) send(req map[string]any) error {
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = append(data, '\n')
|
||||
_, err = c.stdin.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// readResponse reads a JSON-RPC response from the MCP server.
|
||||
func (c *Client) readResponse() (map[string]any, error) {
|
||||
line, err := c.stdout.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if errObj, ok := resp["error"]; ok {
|
||||
return nil, fmt.Errorf("MCP error: %v", errObj)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// nextID returns the next request ID.
|
||||
func (c *Client) nextID() int {
|
||||
c.requestID++
|
||||
return c.requestID
|
||||
}
|
||||
|
||||
// IsConnected returns whether the client is currently connected to ChromaDB.
|
||||
func (c *Client) IsConnected() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.connected
|
||||
}
|
||||
|
||||
// Close closes the connection to ChromaDB.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.connected = false
|
||||
|
||||
if c.stdin != nil {
|
||||
_ = c.stdin.Close()
|
||||
}
|
||||
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
_ = c.cmd.Process.Kill()
|
||||
_ = c.cmd.Wait()
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("ChromaDB connection closed")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reconnect closes the existing connection and establishes a new one.
|
||||
// This is useful when the vector database directory has been deleted and recreated.
|
||||
func (c *Client) Reconnect(ctx context.Context) error {
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("Reconnecting to ChromaDB...")
|
||||
|
||||
// Close existing connection
|
||||
if err := c.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("Error closing ChromaDB during reconnect")
|
||||
}
|
||||
|
||||
// Small delay to allow cleanup
|
||||
// (ChromaDB may need a moment to release resources)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Reconnect
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("reconnect failed: %w", err)
|
||||
}
|
||||
|
||||
// Ensure collection exists
|
||||
if err := c.EnsureCollection(ctx); err != nil {
|
||||
return fmt.Errorf("ensure collection after reconnect: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("ChromaDB reconnected successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,350 +0,0 @@
|
||||
package chroma
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// testSync creates a Sync with a nil client for testing format functions.
|
||||
func testSync() *Sync {
|
||||
return &Sync{client: nil}
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Scope: models.ScopeProject,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: sql.NullString{String: "Test Title", Valid: true},
|
||||
Subtitle: sql.NullString{String: "Test Subtitle", Valid: true},
|
||||
Narrative: sql.NullString{String: "Test narrative content", Valid: true},
|
||||
Facts: models.JSONStringArray{"Fact 1", "Fact 2", "Fact 3"},
|
||||
Concepts: models.JSONStringArray{"concept1", "concept2"},
|
||||
FilesRead: models.JSONStringArray{"file1.go", "file2.go"},
|
||||
FilesModified: models.JSONStringArray{"file3.go"},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Should have 1 narrative + 3 facts = 4 documents
|
||||
assert.Len(t, docs, 4)
|
||||
|
||||
// Check narrative document
|
||||
narrativeDoc := docs[0]
|
||||
assert.Equal(t, "obs_1_narrative", narrativeDoc.ID)
|
||||
assert.Equal(t, "Test narrative content", narrativeDoc.Content)
|
||||
assert.Equal(t, int64(1), narrativeDoc.Metadata["sqlite_id"])
|
||||
assert.Equal(t, "observation", narrativeDoc.Metadata["doc_type"])
|
||||
assert.Equal(t, "narrative", narrativeDoc.Metadata["field_type"])
|
||||
assert.Equal(t, "test-project", narrativeDoc.Metadata["project"])
|
||||
assert.Equal(t, "project", narrativeDoc.Metadata["scope"])
|
||||
assert.Equal(t, "Test Title", narrativeDoc.Metadata["title"])
|
||||
assert.Equal(t, "Test Subtitle", narrativeDoc.Metadata["subtitle"])
|
||||
|
||||
// Check fact documents
|
||||
for i := 1; i <= 3; i++ {
|
||||
factDoc := docs[i]
|
||||
assert.Equal(t, fmt.Sprintf("obs_1_fact_%d", i-1), factDoc.ID)
|
||||
assert.Equal(t, fmt.Sprintf("Fact %d", i), factDoc.Content)
|
||||
assert.Equal(t, "fact", factDoc.Metadata["field_type"])
|
||||
assert.Equal(t, i-1, factDoc.Metadata["fact_index"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs_NoNarrative(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 2,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Scope: models.ScopeGlobal,
|
||||
Type: models.ObsTypeBugfix,
|
||||
Facts: models.JSONStringArray{"Only fact"},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Should have 1 fact only (no narrative)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "obs_2_fact_0", docs[0].ID)
|
||||
assert.Equal(t, "Only fact", docs[0].Content)
|
||||
assert.Equal(t, "global", docs[0].Metadata["scope"])
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs_Empty(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 3,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Should have no documents when no content
|
||||
assert.Len(t, docs, 0)
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs_EmptyScope(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 4,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Scope: "", // Empty scope
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Narrative: sql.NullString{String: "Content", Valid: true},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Empty scope should default to "project"
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "project", docs[0].Metadata["scope"])
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 1,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{String: "Add feature", Valid: true},
|
||||
Investigated: sql.NullString{String: "Looked at code", Valid: true},
|
||||
Learned: sql.NullString{String: "Found pattern", Valid: true},
|
||||
Completed: sql.NullString{String: "Done", Valid: true},
|
||||
NextSteps: sql.NullString{String: "Test it", Valid: true},
|
||||
Notes: sql.NullString{String: "Notes here", Valid: true},
|
||||
PromptNumber: sql.NullInt64{Int64: 5, Valid: true},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Should have 6 documents (all fields present)
|
||||
assert.Len(t, docs, 6)
|
||||
|
||||
// Check first document
|
||||
assert.Equal(t, "summary_1_request", docs[0].ID)
|
||||
assert.Equal(t, "Add feature", docs[0].Content)
|
||||
assert.Equal(t, "session_summary", docs[0].Metadata["doc_type"])
|
||||
assert.Equal(t, "request", docs[0].Metadata["field_type"])
|
||||
assert.Equal(t, int64(5), docs[0].Metadata["prompt_number"])
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs_PartialFields(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 2,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{String: "Only request", Valid: true},
|
||||
Completed: sql.NullString{String: "Only completed", Valid: true},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Should have 2 documents (only valid fields)
|
||||
assert.Len(t, docs, 2)
|
||||
|
||||
// Verify field types
|
||||
fieldTypes := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
fieldTypes[i] = doc.Metadata["field_type"].(string)
|
||||
}
|
||||
assert.Contains(t, fieldTypes, "request")
|
||||
assert.Contains(t, fieldTypes, "completed")
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs_Empty(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 3,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Should have no documents when no content
|
||||
assert.Len(t, docs, 0)
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs_EmptyStrings(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 4,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{String: "", Valid: true}, // Valid but empty
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Empty strings should not produce documents
|
||||
assert.Len(t, docs, 0)
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestJoinStrings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
strs []string
|
||||
sep string
|
||||
expected string
|
||||
}{
|
||||
{"empty", []string{}, ",", ""},
|
||||
{"single", []string{"a"}, ",", "a"},
|
||||
{"multiple", []string{"a", "b", "c"}, ",", "a,b,c"},
|
||||
{"different sep", []string{"a", "b"}, "-", "a-b"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := joinStrings(tt.strs, tt.sep)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyMetadata(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
result := copyMetadata(base, "key3", "value3")
|
||||
|
||||
// Original should be unchanged
|
||||
assert.Len(t, base, 2)
|
||||
|
||||
// Result should have all keys
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, "value1", result["key1"])
|
||||
assert.Equal(t, 42, result["key2"])
|
||||
assert.Equal(t, "value3", result["key3"])
|
||||
}
|
||||
|
||||
func TestCopyMetadataMulti(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"key1": "value1",
|
||||
}
|
||||
extra := map[string]any{
|
||||
"key2": "value2",
|
||||
"key3": "value3",
|
||||
}
|
||||
|
||||
result := copyMetadataMulti(base, extra)
|
||||
|
||||
// Original should be unchanged
|
||||
assert.Len(t, base, 1)
|
||||
|
||||
// Result should have all keys
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, "value1", result["key1"])
|
||||
assert.Equal(t, "value2", result["key2"])
|
||||
assert.Equal(t, "value3", result["key3"])
|
||||
}
|
||||
|
||||
// Test ID generation patterns for delete operations
|
||||
func TestSync_DeleteObservationIDGeneration(t *testing.T) {
|
||||
// Test that we generate correct document IDs for deletion
|
||||
obsIDs := []int64{1, 2}
|
||||
maxFactsPerObs := 20
|
||||
|
||||
ids := make([]string, 0, len(obsIDs)*(maxFactsPerObs+1))
|
||||
for _, obsID := range obsIDs {
|
||||
ids = append(ids, fmt.Sprintf("obs_%d_narrative", obsID))
|
||||
for i := 0; i < maxFactsPerObs; i++ {
|
||||
ids = append(ids, fmt.Sprintf("obs_%d_fact_%d", obsID, i))
|
||||
}
|
||||
}
|
||||
|
||||
// Each observation should generate 21 IDs (1 narrative + 20 facts)
|
||||
assert.Len(t, ids, 42)
|
||||
|
||||
// Check some expected IDs
|
||||
assert.Contains(t, ids, "obs_1_narrative")
|
||||
assert.Contains(t, ids, "obs_1_fact_0")
|
||||
assert.Contains(t, ids, "obs_1_fact_19")
|
||||
assert.Contains(t, ids, "obs_2_narrative")
|
||||
assert.Contains(t, ids, "obs_2_fact_0")
|
||||
}
|
||||
|
||||
func TestSync_DeletePromptIDGeneration(t *testing.T) {
|
||||
// Test that we generate correct document IDs for prompt deletion
|
||||
promptIDs := []int64{10, 20, 30}
|
||||
|
||||
ids := make([]string, len(promptIDs))
|
||||
for i, promptID := range promptIDs {
|
||||
ids[i] = fmt.Sprintf("prompt_%d", promptID)
|
||||
}
|
||||
|
||||
assert.Len(t, ids, 3)
|
||||
assert.Contains(t, ids, "prompt_10")
|
||||
assert.Contains(t, ids, "prompt_20")
|
||||
assert.Contains(t, ids, "prompt_30")
|
||||
}
|
||||
|
||||
// Test metadata includes all expected fields
|
||||
func TestSync_ObservationMetadataFields(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
SDKSessionID: "sdk-123",
|
||||
Project: "my-project",
|
||||
Scope: models.ScopeGlobal,
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: sql.NullString{String: "Bug Fix", Valid: true},
|
||||
Subtitle: sql.NullString{String: "Memory leak", Valid: true},
|
||||
Narrative: sql.NullString{String: "Fixed the leak", Valid: true},
|
||||
Concepts: models.JSONStringArray{"memory", "performance"},
|
||||
FilesRead: models.JSONStringArray{"main.go"},
|
||||
FilesModified: models.JSONStringArray{"fix.go"},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
require := assert.New(t)
|
||||
|
||||
require.Len(docs, 1) // Only narrative, no facts
|
||||
|
||||
meta := docs[0].Metadata
|
||||
require.Equal(int64(1), meta["sqlite_id"])
|
||||
require.Equal("observation", meta["doc_type"])
|
||||
require.Equal("sdk-123", meta["sdk_session_id"])
|
||||
require.Equal("my-project", meta["project"])
|
||||
require.Equal("global", meta["scope"])
|
||||
require.Equal("bugfix", meta["type"])
|
||||
require.Equal("Bug Fix", meta["title"])
|
||||
require.Equal("Memory leak", meta["subtitle"])
|
||||
require.Equal("memory,performance", meta["concepts"])
|
||||
require.Equal("main.go", meta["files_read"])
|
||||
require.Equal("fix.go", meta["files_modified"])
|
||||
require.Equal(int64(1234567890), meta["created_at_epoch"])
|
||||
require.Equal("narrative", meta["field_type"])
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
// Package sqlitevec provides sqlite-vec based vector database integration for claude-mnemonic.
|
||||
package sqlitevec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Client provides vector operations via sqlite-vec.
|
||||
type Client struct {
|
||||
db *sql.DB
|
||||
embedSvc *embedding.Service
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Config holds configuration for the client.
|
||||
type Config struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
// NewClient creates a new sqlite-vec client.
|
||||
func NewClient(cfg Config, embedSvc *embedding.Service) (*Client, error) {
|
||||
if cfg.DB == nil {
|
||||
return nil, fmt.Errorf("database connection required")
|
||||
}
|
||||
if embedSvc == nil {
|
||||
return nil, fmt.Errorf("embedding service required")
|
||||
}
|
||||
|
||||
return &Client{
|
||||
db: cfg.DB,
|
||||
embedSvc: embedSvc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddDocuments adds documents with their embeddings to the vector store.
|
||||
func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Generate embeddings for all documents
|
||||
texts := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
texts[i] = doc.Content
|
||||
}
|
||||
|
||||
embeddings, err := c.embedSvc.EmbedBatch(texts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate embeddings: %w", err)
|
||||
}
|
||||
|
||||
// Insert into vectors table
|
||||
const insertQuery = `
|
||||
INSERT OR REPLACE INTO vectors (doc_id, embedding, sqlite_id, doc_type, field_type, project, scope)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, insertQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare statement: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for i, doc := range docs {
|
||||
// Serialize embedding to blob format
|
||||
embBlob, err := sqlite_vec.SerializeFloat32(embeddings[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize embedding for %s: %w", doc.ID, err)
|
||||
}
|
||||
|
||||
// Extract metadata
|
||||
sqliteID, _ := doc.Metadata["sqlite_id"].(int64)
|
||||
docType, _ := doc.Metadata["doc_type"].(string)
|
||||
fieldType, _ := doc.Metadata["field_type"].(string)
|
||||
project, _ := doc.Metadata["project"].(string)
|
||||
scope, _ := doc.Metadata["scope"].(string)
|
||||
|
||||
_, err = stmt.ExecContext(ctx,
|
||||
doc.ID,
|
||||
embBlob,
|
||||
sqliteID,
|
||||
docType,
|
||||
fieldType,
|
||||
project,
|
||||
scope,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert document %s: %w", doc.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("count", len(docs)).Msg("Added documents to sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDocuments removes documents by their IDs.
|
||||
func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build placeholder string
|
||||
placeholders := make([]string, len(ids))
|
||||
args := make([]interface{}, len(ids))
|
||||
for i, id := range ids {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
// #nosec G201 -- Placeholders are "?" strings, actual values are parameterized via args
|
||||
query := fmt.Sprintf("DELETE FROM vectors WHERE doc_id IN (%s)",
|
||||
strings.Join(placeholders, ","))
|
||||
|
||||
_, err := c.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete documents: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("count", len(ids)).Msg("Deleted documents from sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query performs a vector similarity search.
|
||||
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Generate query embedding
|
||||
queryEmb, err := c.embedSvc.Embed(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
|
||||
// Serialize query embedding
|
||||
queryBlob, err := sqlite_vec.SerializeFloat32(queryEmb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
// Build query with filters
|
||||
// vec0 supports WHERE clauses on metadata columns
|
||||
args := []interface{}{queryBlob}
|
||||
|
||||
sqlQuery := `
|
||||
SELECT
|
||||
doc_id,
|
||||
distance,
|
||||
sqlite_id,
|
||||
doc_type,
|
||||
field_type,
|
||||
project,
|
||||
scope
|
||||
FROM vectors
|
||||
WHERE embedding MATCH ?
|
||||
`
|
||||
|
||||
// Add filters - these work with vec0 metadata columns
|
||||
if docType, ok := where["doc_type"].(string); ok && docType != "" {
|
||||
sqlQuery += " AND doc_type = ?"
|
||||
args = append(args, docType)
|
||||
}
|
||||
if project, ok := where["project"].(string); ok && project != "" {
|
||||
// Include project-specific OR global scope
|
||||
sqlQuery += " AND (project = ? OR scope = 'global')"
|
||||
args = append(args, project)
|
||||
}
|
||||
|
||||
sqlQuery += " ORDER BY distance LIMIT ?"
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query vectors: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []QueryResult
|
||||
for rows.Next() {
|
||||
var r QueryResult
|
||||
var sqliteID int64
|
||||
var docType, fieldType, project, scope sql.NullString
|
||||
|
||||
if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docType, &fieldType, &project, &scope); err != nil {
|
||||
return nil, fmt.Errorf("scan row: %w", err)
|
||||
}
|
||||
|
||||
r.Metadata = map[string]any{
|
||||
"sqlite_id": float64(sqliteID), // Keep as float64 for compatibility
|
||||
"doc_type": docType.String,
|
||||
"field_type": fieldType.String,
|
||||
"project": project.String,
|
||||
"scope": scope.String,
|
||||
}
|
||||
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate rows: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("query", truncateString(query, 50)).
|
||||
Int("results", len(results)).
|
||||
Msg("Vector search completed")
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// IsConnected always returns true (no external process).
|
||||
func (c *Client) IsConnected() bool {
|
||||
return c.db != nil
|
||||
}
|
||||
|
||||
// Close is a no-op (db managed externally).
|
||||
func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen characters.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
|
||||
package chroma
|
||||
// Package sqlitevec provides sqlite-vec based vector database integration for claude-mnemonic.
|
||||
package sqlitevec
|
||||
|
||||
// DocType represents the type of document stored in ChromaDB.
|
||||
// DocType represents the type of document stored in the vector table.
|
||||
type DocType string
|
||||
|
||||
const (
|
||||
@@ -10,14 +10,28 @@ const (
|
||||
DocTypeUserPrompt DocType = "user_prompt"
|
||||
)
|
||||
|
||||
// ExtractedIDs contains SQLite IDs extracted from ChromaDB results, grouped by document type.
|
||||
// Document represents a document to store with vector embedding.
|
||||
type Document struct {
|
||||
ID string
|
||||
Content string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// QueryResult represents a search result from vector search.
|
||||
type QueryResult struct {
|
||||
ID string
|
||||
Distance float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// ExtractedIDs contains SQLite IDs extracted from query results, grouped by document type.
|
||||
type ExtractedIDs struct {
|
||||
ObservationIDs []int64
|
||||
SummaryIDs []int64
|
||||
PromptIDs []int64
|
||||
}
|
||||
|
||||
// BuildWhereFilter creates a where filter map for ChromaDB queries.
|
||||
// BuildWhereFilter creates a where filter map for vector queries.
|
||||
// If docType is empty, no doc_type filter is added.
|
||||
func BuildWhereFilter(docType DocType, project string) map[string]interface{} {
|
||||
where := make(map[string]interface{})
|
||||
@@ -30,7 +44,7 @@ func BuildWhereFilter(docType DocType, project string) map[string]interface{} {
|
||||
return where
|
||||
}
|
||||
|
||||
// ExtractIDsByDocType extracts SQLite IDs from ChromaDB query results,
|
||||
// ExtractIDsByDocType extracts SQLite IDs from query results,
|
||||
// grouped by document type and deduplicated.
|
||||
func ExtractIDsByDocType(results []QueryResult) *ExtractedIDs {
|
||||
ids := &ExtractedIDs{}
|
||||
@@ -41,7 +55,12 @@ func ExtractIDsByDocType(results []QueryResult) *ExtractedIDs {
|
||||
for _, result := range results {
|
||||
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
|
||||
if !ok {
|
||||
continue
|
||||
// Try int64 directly
|
||||
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
|
||||
sqliteID = float64(id)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
id := int64(sqliteID)
|
||||
|
||||
@@ -68,10 +87,8 @@ func ExtractIDsByDocType(results []QueryResult) *ExtractedIDs {
|
||||
return ids
|
||||
}
|
||||
|
||||
// ExtractObservationIDs extracts observation SQLite IDs from ChromaDB query results,
|
||||
// ExtractObservationIDs extracts observation SQLite IDs from query results,
|
||||
// optionally filtering by project or including global scope.
|
||||
// If project is empty, all observation IDs are returned.
|
||||
// If project is set, only observations matching the project or with global scope are returned.
|
||||
func ExtractObservationIDs(results []QueryResult, project string) []int64 {
|
||||
var ids []int64
|
||||
seen := make(map[int64]bool)
|
||||
@@ -79,21 +96,22 @@ func ExtractObservationIDs(results []QueryResult, project string) []int64 {
|
||||
for _, result := range results {
|
||||
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
|
||||
if !ok {
|
||||
continue
|
||||
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
|
||||
sqliteID = float64(id)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
id := int64(sqliteID)
|
||||
|
||||
// Check document type
|
||||
docType, _ := result.Metadata["doc_type"].(string)
|
||||
if docType != string(DocTypeObservation) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply project/scope filter if project is specified
|
||||
if project != "" {
|
||||
proj, _ := result.Metadata["project"].(string)
|
||||
scope, _ := result.Metadata["scope"].(string)
|
||||
// Include if project matches OR scope is global
|
||||
if proj != project && scope != "global" {
|
||||
continue
|
||||
}
|
||||
@@ -108,7 +126,7 @@ func ExtractObservationIDs(results []QueryResult, project string) []int64 {
|
||||
return ids
|
||||
}
|
||||
|
||||
// ExtractSummaryIDs extracts session summary SQLite IDs from ChromaDB query results.
|
||||
// ExtractSummaryIDs extracts session summary SQLite IDs from query results.
|
||||
func ExtractSummaryIDs(results []QueryResult, project string) []int64 {
|
||||
var ids []int64
|
||||
seen := make(map[int64]bool)
|
||||
@@ -116,7 +134,11 @@ func ExtractSummaryIDs(results []QueryResult, project string) []int64 {
|
||||
for _, result := range results {
|
||||
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
|
||||
if !ok {
|
||||
continue
|
||||
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
|
||||
sqliteID = float64(id)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
id := int64(sqliteID)
|
||||
|
||||
@@ -141,7 +163,7 @@ func ExtractSummaryIDs(results []QueryResult, project string) []int64 {
|
||||
return ids
|
||||
}
|
||||
|
||||
// ExtractPromptIDs extracts user prompt SQLite IDs from ChromaDB query results.
|
||||
// ExtractPromptIDs extracts user prompt SQLite IDs from query results.
|
||||
func ExtractPromptIDs(results []QueryResult, project string) []int64 {
|
||||
var ids []int64
|
||||
seen := make(map[int64]bool)
|
||||
@@ -149,7 +171,11 @@ func ExtractPromptIDs(results []QueryResult, project string) []int64 {
|
||||
for _, result := range results {
|
||||
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
|
||||
if !ok {
|
||||
continue
|
||||
if id, ok := result.Metadata["sqlite_id"].(int64); ok {
|
||||
sqliteID = float64(id)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
id := int64(sqliteID)
|
||||
|
||||
@@ -173,3 +199,36 @@ func ExtractPromptIDs(results []QueryResult, project string) []int64 {
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// Helper functions for metadata manipulation
|
||||
|
||||
func copyMetadata(base map[string]any, key string, value any) map[string]any {
|
||||
result := make(map[string]any, len(base)+1)
|
||||
for k, v := range base {
|
||||
result[k] = v
|
||||
}
|
||||
result[key] = value
|
||||
return result
|
||||
}
|
||||
|
||||
func copyMetadataMulti(base map[string]any, extra map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(base)+len(extra))
|
||||
for k, v := range base {
|
||||
result[k] = v
|
||||
}
|
||||
for k, v := range extra {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func joinStrings(strs []string, sep string) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := strs[0]
|
||||
for i := 1; i < len(strs); i++ {
|
||||
result += sep + strs[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
|
||||
package chroma
|
||||
// Package sqlitevec provides sqlite-vec based vector database integration for claude-mnemonic.
|
||||
package sqlitevec
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,17 +9,17 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Sync provides synchronization between SQLite and ChromaDB.
|
||||
// Sync provides synchronization between SQLite data and vector embeddings.
|
||||
type Sync struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// NewSync creates a new ChromaDB sync service.
|
||||
// NewSync creates a new sync service.
|
||||
func NewSync(client *Client) *Sync {
|
||||
return &Sync{client: client}
|
||||
}
|
||||
|
||||
// SyncObservation syncs a single observation to ChromaDB.
|
||||
// SyncObservation syncs a single observation to the vector store.
|
||||
func (s *Sync) SyncObservation(ctx context.Context, obs *models.Observation) error {
|
||||
docs := s.formatObservationDocs(obs)
|
||||
if len(docs) == 0 {
|
||||
@@ -33,12 +33,12 @@ func (s *Sync) SyncObservation(ctx context.Context, obs *models.Observation) err
|
||||
log.Debug().
|
||||
Int64("observationId", obs.ID).
|
||||
Int("docCount", len(docs)).
|
||||
Msg("Synced observation to ChromaDB")
|
||||
Msg("Synced observation to sqlite-vec")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatObservationDocs formats an observation into ChromaDB documents.
|
||||
// formatObservationDocs formats an observation into vector documents.
|
||||
// Each semantic field becomes a separate vector document (granular approach).
|
||||
func (s *Sync) formatObservationDocs(obs *models.Observation) []Document {
|
||||
docs := make([]Document, 0, len(obs.Facts)+2)
|
||||
@@ -99,7 +99,7 @@ func (s *Sync) formatObservationDocs(obs *models.Observation) []Document {
|
||||
return docs
|
||||
}
|
||||
|
||||
// SyncSummary syncs a single session summary to ChromaDB.
|
||||
// SyncSummary syncs a single session summary to the vector store.
|
||||
func (s *Sync) SyncSummary(ctx context.Context, summary *models.SessionSummary) error {
|
||||
docs := s.formatSummaryDocs(summary)
|
||||
if len(docs) == 0 {
|
||||
@@ -113,12 +113,12 @@ func (s *Sync) SyncSummary(ctx context.Context, summary *models.SessionSummary)
|
||||
log.Debug().
|
||||
Int64("summaryId", summary.ID).
|
||||
Int("docCount", len(docs)).
|
||||
Msg("Synced summary to ChromaDB")
|
||||
Msg("Synced summary to sqlite-vec")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatSummaryDocs formats a session summary into ChromaDB documents.
|
||||
// formatSummaryDocs formats a session summary into vector documents.
|
||||
func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
|
||||
docs := make([]Document, 0, 6)
|
||||
|
||||
@@ -127,6 +127,7 @@ func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
|
||||
"doc_type": "session_summary",
|
||||
"sdk_session_id": summary.SDKSessionID,
|
||||
"project": summary.Project,
|
||||
"scope": "", // Summaries don't have scope
|
||||
"created_at_epoch": summary.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
@@ -161,7 +162,7 @@ func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
|
||||
return docs
|
||||
}
|
||||
|
||||
// SyncUserPrompt syncs a single user prompt to ChromaDB.
|
||||
// SyncUserPrompt syncs a single user prompt to the vector store.
|
||||
func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWithSession) error {
|
||||
doc := Document{
|
||||
ID: fmt.Sprintf("prompt_%d", prompt.ID),
|
||||
@@ -171,8 +172,10 @@ func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWith
|
||||
"doc_type": "user_prompt",
|
||||
"sdk_session_id": prompt.SDKSessionID,
|
||||
"project": prompt.Project,
|
||||
"scope": "", // Prompts don't have scope
|
||||
"created_at_epoch": prompt.CreatedAtEpoch,
|
||||
"prompt_number": prompt.PromptNumber,
|
||||
"field_type": "prompt",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -182,14 +185,12 @@ func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWith
|
||||
|
||||
log.Debug().
|
||||
Int64("promptId", prompt.ID).
|
||||
Msg("Synced user prompt to ChromaDB")
|
||||
Msg("Synced user prompt to sqlite-vec")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteObservations removes observation documents from ChromaDB.
|
||||
// Since each observation may have multiple documents (narrative + facts),
|
||||
// we delete by the sqlite_id metadata prefix pattern.
|
||||
// DeleteObservations removes observation documents from the vector store.
|
||||
func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) error {
|
||||
if len(observationIDs) == 0 {
|
||||
return nil
|
||||
@@ -197,7 +198,6 @@ func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) e
|
||||
|
||||
// Generate all possible document IDs for these observations
|
||||
// Pattern: obs_{id}_narrative, obs_{id}_fact_{0..n}
|
||||
// Since we don't know how many facts each had, we use a reasonable upper bound
|
||||
const maxFactsPerObs = 20
|
||||
ids := make([]string, 0, len(observationIDs)*(maxFactsPerObs+1))
|
||||
|
||||
@@ -214,18 +214,17 @@ func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) e
|
||||
|
||||
log.Debug().
|
||||
Int("observationCount", len(observationIDs)).
|
||||
Msg("Deleted observations from ChromaDB")
|
||||
Msg("Deleted observations from sqlite-vec")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUserPrompts removes user prompt documents from ChromaDB.
|
||||
// DeleteUserPrompts removes user prompt documents from the vector store.
|
||||
func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
|
||||
if len(promptIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Each prompt is stored as a single document with ID pattern: prompt_{id}
|
||||
ids := make([]string, len(promptIDs))
|
||||
for i, promptID := range promptIDs {
|
||||
ids[i] = fmt.Sprintf("prompt_%d", promptID)
|
||||
@@ -237,40 +236,7 @@ func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
|
||||
|
||||
log.Debug().
|
||||
Int("promptCount", len(promptIDs)).
|
||||
Msg("Deleted user prompts from ChromaDB")
|
||||
Msg("Deleted user prompts from sqlite-vec")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func copyMetadata(base map[string]any, key string, value any) map[string]any {
|
||||
result := make(map[string]any, len(base)+1)
|
||||
for k, v := range base {
|
||||
result[k] = v
|
||||
}
|
||||
result[key] = value
|
||||
return result
|
||||
}
|
||||
|
||||
func copyMetadataMulti(base map[string]any, extra map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(base)+len(extra))
|
||||
for k, v := range base {
|
||||
result[k] = v
|
||||
}
|
||||
for k, v := range extra {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func joinStrings(strs []string, sep string) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := strs[0]
|
||||
for i := 1; i < len(strs); i++ {
|
||||
result += sep + strs[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
+56
-56
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/chroma"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
@@ -202,7 +202,7 @@ func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to save user prompt")
|
||||
// Non-fatal: continue with session initialization
|
||||
} else if s.chromaSync != nil {
|
||||
} else if s.vectorSync != nil {
|
||||
// Sync to vector DB asynchronously (non-blocking)
|
||||
now := time.Now()
|
||||
promptWithSession := &models.UserPromptWithSession{
|
||||
@@ -221,8 +221,8 @@ func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.chromaSync.SyncUserPrompt(ctx, promptWithSession); err != nil {
|
||||
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to ChromaDB")
|
||||
if err := s.vectorSync.SyncUserPrompt(ctx, promptWithSession); err != nil {
|
||||
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to sqlite-vec")
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -450,7 +450,7 @@ func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleGetObservations returns recent observations.
|
||||
// Supports optional query parameter for semantic search via ChromaDB.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||
limit := sqlite.ParseLimitParam(r, DefaultObservationsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
@@ -458,25 +458,25 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
var usedChroma bool
|
||||
var usedVector bool
|
||||
|
||||
// Use ChromaDB if query is provided and ChromaDB is available
|
||||
if query != "" && s.chromaClient != nil && s.chromaClient.IsConnected() {
|
||||
where := chroma.BuildWhereFilter(chroma.DocTypeObservation, "")
|
||||
chromaResults, chromaErr := s.chromaClient.Query(r.Context(), query, limit*2, where)
|
||||
if chromaErr == nil && len(chromaResults) > 0 {
|
||||
obsIDs := chroma.ExtractObservationIDs(chromaResults, project)
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
if len(obsIDs) > 0 {
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedChroma = true
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if ChromaDB not used
|
||||
if !usedChroma {
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
// Filter by project - includes project-scoped and global observations
|
||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
@@ -499,7 +499,7 @@ func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// handleGetSummaries returns recent summaries.
|
||||
// Supports optional query parameter for semantic search via ChromaDB.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
limit := sqlite.ParseLimitParam(r, DefaultSummariesLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
@@ -507,25 +507,25 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var summaries []*models.SessionSummary
|
||||
var err error
|
||||
var usedChroma bool
|
||||
var usedVector bool
|
||||
|
||||
// Use ChromaDB if query is provided and ChromaDB is available
|
||||
if query != "" && s.chromaClient != nil && s.chromaClient.IsConnected() {
|
||||
where := chroma.BuildWhereFilter(chroma.DocTypeSessionSummary, "")
|
||||
chromaResults, chromaErr := s.chromaClient.Query(r.Context(), query, limit*2, where)
|
||||
if chromaErr == nil && len(chromaResults) > 0 {
|
||||
summaryIDs := chroma.ExtractSummaryIDs(chromaResults, project)
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
|
||||
if len(summaryIDs) > 0 {
|
||||
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedChroma = true
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if ChromaDB not used
|
||||
if !usedChroma {
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
|
||||
} else {
|
||||
@@ -546,7 +546,7 @@ func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleGetPrompts returns recent user prompts.
|
||||
// Supports optional query parameter for semantic search via ChromaDB.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
limit := sqlite.ParseLimitParam(r, DefaultPromptsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
@@ -554,25 +554,25 @@ func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var prompts []*models.UserPromptWithSession
|
||||
var err error
|
||||
var usedChroma bool
|
||||
var usedVector bool
|
||||
|
||||
// Use ChromaDB if query is provided and ChromaDB is available
|
||||
if query != "" && s.chromaClient != nil && s.chromaClient.IsConnected() {
|
||||
where := chroma.BuildWhereFilter(chroma.DocTypeUserPrompt, "")
|
||||
chromaResults, chromaErr := s.chromaClient.Query(r.Context(), query, limit*2, where)
|
||||
if chromaErr == nil && len(chromaResults) > 0 {
|
||||
promptIDs := chroma.ExtractPromptIDs(chromaResults, project)
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
|
||||
if len(promptIDs) > 0 {
|
||||
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedChroma = true
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if ChromaDB not used
|
||||
if !usedChroma {
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
|
||||
} else {
|
||||
@@ -683,29 +683,29 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
var usedChroma bool
|
||||
var usedVector bool
|
||||
|
||||
// Try ChromaDB vector search first if available
|
||||
if s.chromaClient != nil && s.chromaClient.IsConnected() {
|
||||
where := chroma.BuildWhereFilter(chroma.DocTypeObservation, "")
|
||||
// Try vector search first if available
|
||||
if s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
|
||||
chromaResults, chromaErr := s.chromaClient.Query(r.Context(), query, limit*2, where)
|
||||
if chromaErr == nil && len(chromaResults) > 0 {
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
// Extract observation IDs with project/scope filtering using shared helper
|
||||
obsIDs := chroma.ExtractObservationIDs(chromaResults, project)
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
|
||||
if len(obsIDs) > 0 {
|
||||
// Fetch full observations from SQLite
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedChroma = true
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to FTS if ChromaDB not available or returned no results
|
||||
if !usedChroma || len(observations) == 0 {
|
||||
// Fall back to FTS if vector search not available or returned no results
|
||||
if !usedVector || len(observations) == 0 {
|
||||
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
|
||||
if err != nil {
|
||||
// FTS might fail if query has special chars, try without
|
||||
@@ -941,22 +941,22 @@ func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
components = append(components, dbStatus)
|
||||
|
||||
// Check ChromaDB
|
||||
chromaStatus := ComponentHealth{Name: "ChromaDB", Status: "healthy"}
|
||||
if s.chromaClient == nil {
|
||||
chromaStatus.Status = "degraded"
|
||||
chromaStatus.Message = "Not configured"
|
||||
// Check Vector DB (sqlite-vec)
|
||||
vectorStatus := ComponentHealth{Name: "Vector DB", Status: "healthy"}
|
||||
if s.vectorClient == nil {
|
||||
vectorStatus.Status = "degraded"
|
||||
vectorStatus.Message = "Not configured"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if !s.chromaClient.IsConnected() {
|
||||
chromaStatus.Status = "degraded"
|
||||
chromaStatus.Message = "Not connected"
|
||||
} else if !s.vectorClient.IsConnected() {
|
||||
vectorStatus.Status = "degraded"
|
||||
vectorStatus.Message = "Not connected"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
}
|
||||
components = append(components, chromaStatus)
|
||||
components = append(components, vectorStatus)
|
||||
|
||||
// Check SDK Processor
|
||||
sdkStatus := ComponentHealth{Name: "SDK Processor", Status: "healthy"}
|
||||
|
||||
+89
-76
@@ -14,8 +14,9 @@ import (
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/update"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/chroma"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
@@ -69,9 +70,10 @@ type Service struct {
|
||||
sseBroadcaster *sse.Broadcaster
|
||||
processor *sdk.Processor
|
||||
|
||||
// Vector database
|
||||
chromaClient *chroma.Client
|
||||
chromaSync *chroma.Sync
|
||||
// Vector database (sqlite-vec with local embeddings)
|
||||
embedSvc *embedding.Service
|
||||
vectorClient *sqlitevec.Client
|
||||
vectorSync *sqlitevec.Sync
|
||||
|
||||
// HTTP server
|
||||
router *chi.Mux
|
||||
@@ -151,7 +153,7 @@ func NewService(version string) (*Service, error) {
|
||||
func (s *Service) initializeAsync() {
|
||||
log.Info().Msg("Starting async initialization...")
|
||||
|
||||
// Ensure data directory, vector-db, and settings exist
|
||||
// Ensure data directory and settings exist
|
||||
if err := config.EnsureAll(); err != nil {
|
||||
s.setInitError(fmt.Errorf("ensure data dir: %w", err))
|
||||
return
|
||||
@@ -177,25 +179,26 @@ func (s *Service) initializeAsync() {
|
||||
// Create session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
|
||||
// Create ChromaDB client for vector search (optional - will be nil if unavailable)
|
||||
var chromaClient *chroma.Client
|
||||
var chromaSync *chroma.Sync
|
||||
chromaCfg := chroma.Config{
|
||||
Project: "default", // Collection prefix
|
||||
DataDir: s.config.VectorDBPath,
|
||||
BatchSize: 100,
|
||||
}
|
||||
client, err := chroma.NewClient(chromaCfg)
|
||||
// Create embedding service and sqlite-vec client for vector search (optional)
|
||||
var embedSvc *embedding.Service
|
||||
var vectorClient *sqlitevec.Client
|
||||
var vectorSync *sqlitevec.Sync
|
||||
|
||||
emb, err := embedding.NewService()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB client creation failed - vector sync disabled")
|
||||
log.Warn().Err(err).Msg("Embedding service creation failed - vector search disabled")
|
||||
} else {
|
||||
// Connect to ChromaDB (starts the MCP server)
|
||||
if err := client.Connect(s.ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB connection failed - vector sync disabled")
|
||||
embedSvc = emb
|
||||
// Create sqlite-vec client using the same DB connection
|
||||
client, err := sqlitevec.NewClient(sqlitevec.Config{
|
||||
DB: store.DB(),
|
||||
}, embedSvc)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("sqlite-vec client creation failed - vector search disabled")
|
||||
} else {
|
||||
chromaClient = client
|
||||
chromaSync = chroma.NewSync(client)
|
||||
log.Info().Msg("ChromaDB client connected - vector sync enabled")
|
||||
vectorClient = client
|
||||
vectorSync = sqlitevec.NewSync(client)
|
||||
log.Info().Msg("sqlite-vec vector search enabled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,38 +225,39 @@ func (s *Service) initializeAsync() {
|
||||
s.promptStore = promptStore
|
||||
s.sessionManager = sessionManager
|
||||
s.processor = processor
|
||||
s.chromaClient = chromaClient
|
||||
s.chromaSync = chromaSync
|
||||
s.embedSvc = embedSvc
|
||||
s.vectorClient = vectorClient
|
||||
s.vectorSync = vectorSync
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Set vector sync callbacks on processor if both are available
|
||||
if processor != nil && chromaSync != nil {
|
||||
if processor != nil && vectorSync != nil {
|
||||
processor.SetSyncObservationFunc(func(obs *models.Observation) {
|
||||
if err := chromaSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to ChromaDB")
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
||||
}
|
||||
})
|
||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||
if err := chromaSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to ChromaDB")
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on observation store to sync deletes to ChromaDB
|
||||
if observationStore != nil && chromaSync != nil {
|
||||
// Set cleanup callback on observation store to sync deletes to vector store
|
||||
if observationStore != nil && vectorSync != nil {
|
||||
observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteObservations(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from ChromaDB")
|
||||
if err := vectorSync.DeleteObservations(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on prompt store to sync deletes to ChromaDB
|
||||
if promptStore != nil && chromaSync != nil {
|
||||
// Set cleanup callback on prompt store to sync deletes to vector store
|
||||
if promptStore != nil && vectorSync != nil {
|
||||
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from ChromaDB")
|
||||
if err := vectorSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -336,13 +340,13 @@ func (s *Service) reinitializeDatabase() {
|
||||
s.initMu.Lock()
|
||||
oldStore := s.store
|
||||
oldSessionManager := s.sessionManager
|
||||
oldChromaClient := s.chromaClient
|
||||
oldEmbedSvc := s.embedSvc
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Close old stores
|
||||
if oldChromaClient != nil {
|
||||
if err := oldChromaClient.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("Error closing old ChromaDB client")
|
||||
// Close old embedding service
|
||||
if oldEmbedSvc != nil {
|
||||
if err := oldEmbedSvc.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("Error closing old embedding service")
|
||||
}
|
||||
}
|
||||
if oldStore != nil {
|
||||
@@ -356,7 +360,7 @@ func (s *Service) reinitializeDatabase() {
|
||||
oldSessionManager.ShutdownAll(s.ctx)
|
||||
}
|
||||
|
||||
// Ensure data directory, vector-db, and settings exist (may have been deleted)
|
||||
// Ensure data directory and settings exist (may have been deleted)
|
||||
if err := config.EnsureAll(); err != nil {
|
||||
s.setInitError(fmt.Errorf("ensure data dir on reinit: %w", err))
|
||||
return
|
||||
@@ -382,24 +386,25 @@ func (s *Service) reinitializeDatabase() {
|
||||
// Create new session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
|
||||
// Recreate ChromaDB client
|
||||
var chromaClient *chroma.Client
|
||||
var chromaSync *chroma.Sync
|
||||
chromaCfg := chroma.Config{
|
||||
Project: "default",
|
||||
DataDir: s.config.VectorDBPath,
|
||||
BatchSize: 100,
|
||||
}
|
||||
client, err := chroma.NewClient(chromaCfg)
|
||||
// Recreate embedding service and sqlite-vec client
|
||||
var embedSvc *embedding.Service
|
||||
var vectorClient *sqlitevec.Client
|
||||
var vectorSync *sqlitevec.Sync
|
||||
|
||||
emb, err := embedding.NewService()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB client creation failed after reinit")
|
||||
log.Warn().Err(err).Msg("Embedding service creation failed after reinit")
|
||||
} else {
|
||||
if err := client.Connect(s.ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB connection failed after reinit")
|
||||
embedSvc = emb
|
||||
client, err := sqlitevec.NewClient(sqlitevec.Config{
|
||||
DB: store.DB(),
|
||||
}, embedSvc)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("sqlite-vec client creation failed after reinit")
|
||||
} else {
|
||||
chromaClient = client
|
||||
chromaSync = chroma.NewSync(client)
|
||||
log.Info().Msg("ChromaDB client reconnected after reinit")
|
||||
vectorClient = client
|
||||
vectorSync = sqlitevec.NewSync(client)
|
||||
log.Info().Msg("sqlite-vec reconnected after reinit")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,39 +429,40 @@ func (s *Service) reinitializeDatabase() {
|
||||
s.promptStore = promptStore
|
||||
s.sessionManager = sessionManager
|
||||
s.processor = processor
|
||||
s.chromaClient = chromaClient
|
||||
s.chromaSync = chromaSync
|
||||
s.embedSvc = embedSvc
|
||||
s.vectorClient = vectorClient
|
||||
s.vectorSync = vectorSync
|
||||
s.initError = nil
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Set vector sync callbacks on processor if both are available
|
||||
if processor != nil && chromaSync != nil {
|
||||
if processor != nil && vectorSync != nil {
|
||||
processor.SetSyncObservationFunc(func(obs *models.Observation) {
|
||||
if err := chromaSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to ChromaDB")
|
||||
if err := vectorSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec")
|
||||
}
|
||||
})
|
||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||
if err := chromaSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to ChromaDB")
|
||||
if err := vectorSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on observation store to sync deletes to ChromaDB
|
||||
if observationStore != nil && chromaSync != nil {
|
||||
// Set cleanup callback on observation store to sync deletes to vector store
|
||||
if observationStore != nil && vectorSync != nil {
|
||||
observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteObservations(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from ChromaDB")
|
||||
if err := vectorSync.DeleteObservations(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on prompt store to sync deletes to ChromaDB
|
||||
if promptStore != nil && chromaSync != nil {
|
||||
// Set cleanup callback on prompt store to sync deletes to vector store
|
||||
if promptStore != nil && vectorSync != nil {
|
||||
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from ChromaDB")
|
||||
if err := vectorSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from sqlite-vec")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -862,10 +868,17 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Close ChromaDB client
|
||||
if s.chromaClient != nil {
|
||||
if err := s.chromaClient.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("ChromaDB close error")
|
||||
// Close embedding service
|
||||
if s.embedSvc != nil {
|
||||
if err := s.embedSvc.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Embedding service close error")
|
||||
}
|
||||
}
|
||||
|
||||
// Close vector client
|
||||
if s.vectorClient != nil {
|
||||
if err := s.vectorClient.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Vector client close error")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Generated
+2
-2
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.1-9-g7e49113-dirty",
|
||||
"version": "v0.6.1-10-g6c28ecb-dirty",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.1-9-g7e49113-dirty",
|
||||
"version": "v0.6.1-10-g6c28ecb-dirty",
|
||||
"dependencies": {
|
||||
"vue": "^3.5.13"
|
||||
},
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "v0.6.1-9-g7e49113-dirty",
|
||||
"version": "v0.6.1-10-g6c28ecb-dirty",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
||||
Reference in New Issue
Block a user