mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-11 00:09:28 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,514 @@
|
||||
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
|
||||
package chroma
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Document represents a document to store in ChromaDB.
|
||||
type Document struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"document"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
}
|
||||
|
||||
// QueryResult represents a search result from ChromaDB.
|
||||
type QueryResult struct {
|
||||
ID string
|
||||
Distance float64
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// Client is a ChromaDB client that communicates via MCP protocol.
|
||||
type Client struct {
|
||||
collection string
|
||||
dataDir string
|
||||
pythonVer string
|
||||
batchSize int
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
mu sync.Mutex
|
||||
|
||||
connected bool
|
||||
requestID int
|
||||
}
|
||||
|
||||
// Config holds configuration for the ChromaDB client.
|
||||
type Config struct {
|
||||
Project string
|
||||
DataDir string
|
||||
PythonVer string
|
||||
BatchSize int
|
||||
}
|
||||
|
||||
// NewClient creates a new ChromaDB client.
|
||||
func NewClient(cfg Config) (*Client, error) {
|
||||
if cfg.DataDir == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
cfg.DataDir = filepath.Join(home, ".claude-mnemonic", "vector-db")
|
||||
}
|
||||
if cfg.PythonVer == "" {
|
||||
cfg.PythonVer = "3.13"
|
||||
}
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 100
|
||||
}
|
||||
|
||||
return &Client{
|
||||
collection: fmt.Sprintf("cm__%s", cfg.Project),
|
||||
dataDir: cfg.DataDir,
|
||||
pythonVer: cfg.PythonVer,
|
||||
batchSize: cfg.BatchSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect starts the ChromaDB MCP server and establishes connection.
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure data directory exists
|
||||
if err := os.MkdirAll(c.dataDir, 0750); err != nil {
|
||||
return fmt.Errorf("create data dir: %w", err)
|
||||
}
|
||||
|
||||
// Start chroma-mcp server via uvx
|
||||
c.cmd = exec.CommandContext(ctx, "uvx", // #nosec G204 -- config values from internal settings
|
||||
"--python", c.pythonVer,
|
||||
"chroma-mcp",
|
||||
"--client-type", "persistent",
|
||||
"--data-dir", c.dataDir,
|
||||
)
|
||||
|
||||
var err error
|
||||
c.stdin, err = c.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := c.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
c.stdout = bufio.NewReader(stdout)
|
||||
|
||||
c.cmd.Stderr = os.Stderr
|
||||
|
||||
if err := c.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start chroma-mcp: %w", err)
|
||||
}
|
||||
|
||||
// Send initialize request
|
||||
if err := c.sendInitialize(); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
}
|
||||
|
||||
c.connected = true
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Str("dataDir", c.dataDir).
|
||||
Msg("Connected to ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendInitialize sends the MCP initialize request.
|
||||
func (c *Client) sendInitialize() error {
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "initialize",
|
||||
"params": map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{
|
||||
"name": "claude-mnemonic",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read response
|
||||
_, err := c.readResponse()
|
||||
return err
|
||||
}
|
||||
|
||||
// EnsureCollection ensures the collection exists, creating it if needed.
|
||||
func (c *Client) EnsureCollection(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Try to get collection info
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_get_collection_info",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
// Collection doesn't exist, create it
|
||||
return c.createCollection()
|
||||
}
|
||||
|
||||
// Check if error in response (collection not found)
|
||||
if _, ok := resp["error"]; ok {
|
||||
return c.createCollection()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createCollection creates a new collection.
|
||||
func (c *Client) createCollection() error {
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_create_collection",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"embedding_function_name": "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := c.readResponse()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create collection: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("Created ChromaDB collection")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDocuments adds documents to the collection in batches.
|
||||
func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
for i := 0; i < len(docs); i += c.batchSize {
|
||||
end := i + c.batchSize
|
||||
if end > len(docs) {
|
||||
end = len(docs)
|
||||
}
|
||||
batch := docs[i:end]
|
||||
|
||||
// Extract fields
|
||||
documents := make([]string, len(batch))
|
||||
ids := make([]string, len(batch))
|
||||
metadatas := make([]map[string]any, len(batch))
|
||||
for j, doc := range batch {
|
||||
documents[j] = doc.Content
|
||||
ids[j] = doc.ID
|
||||
metadatas[j] = doc.Metadata
|
||||
}
|
||||
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_add_documents",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"documents": documents,
|
||||
"ids": ids,
|
||||
"metadatas": metadatas,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return fmt.Errorf("send add_documents: %w", err)
|
||||
}
|
||||
|
||||
if _, err := c.readResponse(); err != nil {
|
||||
return fmt.Errorf("add_documents response: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("batchStart", i).
|
||||
Int("batchEnd", end).
|
||||
Int("total", len(docs)).
|
||||
Msg("Added document batch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDocuments deletes documents from the collection by their IDs.
|
||||
func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_delete_documents",
|
||||
"arguments": map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"ids": ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return fmt.Errorf("send delete_documents: %w", err)
|
||||
}
|
||||
|
||||
if _, err := c.readResponse(); err != nil {
|
||||
return fmt.Errorf("delete_documents response: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("count", len(ids)).
|
||||
Msg("Deleted documents from ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query performs a semantic search on the collection.
|
||||
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
args := map[string]any{
|
||||
"collection_name": c.collection,
|
||||
"query_texts": []string{query},
|
||||
"n_results": limit,
|
||||
"include": []string{"documents", "metadatas", "distances"},
|
||||
}
|
||||
if where != nil {
|
||||
args["where"] = where
|
||||
}
|
||||
|
||||
req := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": c.nextID(),
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "chroma_query_documents",
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
|
||||
if err := c.send(req); err != nil {
|
||||
return nil, fmt.Errorf("send query: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query response: %w", err)
|
||||
}
|
||||
|
||||
return c.parseQueryResults(resp)
|
||||
}
|
||||
|
||||
// parseQueryResults parses the query response into QueryResult structs.
|
||||
func (c *Client) parseQueryResults(resp map[string]any) ([]QueryResult, error) {
|
||||
result, ok := resp["result"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
content, ok := result["content"].([]any)
|
||||
if !ok || len(content) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
first, ok := content[0].(map[string]any)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
text, ok := first["text"].(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var parsed struct {
|
||||
IDs [][]string `json:"ids"`
|
||||
Distances [][]float64 `json:"distances"`
|
||||
Metadatas [][]map[string]any `json:"metadatas"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(text), &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(parsed.IDs) == 0 || len(parsed.IDs[0]) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
results := make([]QueryResult, len(parsed.IDs[0]))
|
||||
for i := range parsed.IDs[0] {
|
||||
results[i] = QueryResult{
|
||||
ID: parsed.IDs[0][i],
|
||||
}
|
||||
if i < len(parsed.Distances[0]) {
|
||||
results[i].Distance = parsed.Distances[0][i]
|
||||
}
|
||||
if i < len(parsed.Metadatas[0]) {
|
||||
results[i].Metadata = parsed.Metadatas[0][i]
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// send sends a JSON-RPC request to the MCP server.
|
||||
func (c *Client) send(req map[string]any) error {
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = append(data, '\n')
|
||||
_, err = c.stdin.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// readResponse reads a JSON-RPC response from the MCP server.
|
||||
func (c *Client) readResponse() (map[string]any, error) {
|
||||
line, err := c.stdout.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if errObj, ok := resp["error"]; ok {
|
||||
return nil, fmt.Errorf("MCP error: %v", errObj)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// nextID returns the next request ID.
|
||||
func (c *Client) nextID() int {
|
||||
c.requestID++
|
||||
return c.requestID
|
||||
}
|
||||
|
||||
// Close closes the connection to ChromaDB.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.connected = false
|
||||
|
||||
if c.stdin != nil {
|
||||
_ = c.stdin.Close()
|
||||
}
|
||||
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
_ = c.cmd.Process.Kill()
|
||||
_ = c.cmd.Wait()
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("ChromaDB connection closed")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reconnect closes the existing connection and establishes a new one.
|
||||
// This is useful when the vector database directory has been deleted and recreated.
|
||||
func (c *Client) Reconnect(ctx context.Context) error {
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("Reconnecting to ChromaDB...")
|
||||
|
||||
// Close existing connection
|
||||
if err := c.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("Error closing ChromaDB during reconnect")
|
||||
}
|
||||
|
||||
// Small delay to allow cleanup
|
||||
// (ChromaDB may need a moment to release resources)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Reconnect
|
||||
if err := c.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("reconnect failed: %w", err)
|
||||
}
|
||||
|
||||
// Ensure collection exists
|
||||
if err := c.EnsureCollection(ctx); err != nil {
|
||||
return fmt.Errorf("ensure collection after reconnect: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("collection", c.collection).
|
||||
Msg("ChromaDB reconnected successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
|
||||
package chroma
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Sync provides synchronization between SQLite and ChromaDB.
|
||||
type Sync struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// NewSync creates a new ChromaDB sync service.
|
||||
func NewSync(client *Client) *Sync {
|
||||
return &Sync{client: client}
|
||||
}
|
||||
|
||||
// SyncObservation syncs a single observation to ChromaDB.
|
||||
func (s *Sync) SyncObservation(ctx context.Context, obs *models.Observation) error {
|
||||
docs := s.formatObservationDocs(obs)
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
return fmt.Errorf("add observation docs: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int64("observationId", obs.ID).
|
||||
Int("docCount", len(docs)).
|
||||
Msg("Synced observation to ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatObservationDocs formats an observation into ChromaDB documents.
|
||||
// Each semantic field becomes a separate vector document (granular approach).
|
||||
func (s *Sync) formatObservationDocs(obs *models.Observation) []Document {
|
||||
docs := make([]Document, 0, len(obs.Facts)+2)
|
||||
|
||||
// Determine scope for metadata
|
||||
scope := string(obs.Scope)
|
||||
if scope == "" {
|
||||
scope = "project"
|
||||
}
|
||||
|
||||
baseMetadata := map[string]any{
|
||||
"sqlite_id": obs.ID,
|
||||
"doc_type": "observation",
|
||||
"sdk_session_id": obs.SDKSessionID,
|
||||
"project": obs.Project,
|
||||
"scope": scope,
|
||||
"created_at_epoch": obs.CreatedAtEpoch,
|
||||
"type": string(obs.Type),
|
||||
}
|
||||
|
||||
if obs.Title.Valid {
|
||||
baseMetadata["title"] = obs.Title.String
|
||||
}
|
||||
if obs.Subtitle.Valid {
|
||||
baseMetadata["subtitle"] = obs.Subtitle.String
|
||||
}
|
||||
if len(obs.Concepts) > 0 {
|
||||
baseMetadata["concepts"] = joinStrings(obs.Concepts, ",")
|
||||
}
|
||||
if len(obs.FilesRead) > 0 {
|
||||
baseMetadata["files_read"] = joinStrings(obs.FilesRead, ",")
|
||||
}
|
||||
if len(obs.FilesModified) > 0 {
|
||||
baseMetadata["files_modified"] = joinStrings(obs.FilesModified, ",")
|
||||
}
|
||||
|
||||
// Narrative as separate document
|
||||
if obs.Narrative.Valid && obs.Narrative.String != "" {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("obs_%d_narrative", obs.ID),
|
||||
Content: obs.Narrative.String,
|
||||
Metadata: copyMetadata(baseMetadata, "field_type", "narrative"),
|
||||
})
|
||||
}
|
||||
|
||||
// Each fact as separate document
|
||||
for i, fact := range obs.Facts {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("obs_%d_fact_%d", obs.ID, i),
|
||||
Content: fact,
|
||||
Metadata: copyMetadataMulti(baseMetadata, map[string]any{
|
||||
"field_type": "fact",
|
||||
"fact_index": i,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
return docs
|
||||
}
|
||||
|
||||
// SyncSummary syncs a single session summary to ChromaDB.
|
||||
func (s *Sync) SyncSummary(ctx context.Context, summary *models.SessionSummary) error {
|
||||
docs := s.formatSummaryDocs(summary)
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
return fmt.Errorf("add summary docs: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int64("summaryId", summary.ID).
|
||||
Int("docCount", len(docs)).
|
||||
Msg("Synced summary to ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatSummaryDocs formats a session summary into ChromaDB documents.
|
||||
func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
|
||||
docs := make([]Document, 0, 6)
|
||||
|
||||
baseMetadata := map[string]any{
|
||||
"sqlite_id": summary.ID,
|
||||
"doc_type": "session_summary",
|
||||
"sdk_session_id": summary.SDKSessionID,
|
||||
"project": summary.Project,
|
||||
"created_at_epoch": summary.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
if summary.PromptNumber.Valid {
|
||||
baseMetadata["prompt_number"] = summary.PromptNumber.Int64
|
||||
}
|
||||
|
||||
// Each field as separate document
|
||||
fields := []struct {
|
||||
name string
|
||||
value string
|
||||
valid bool
|
||||
}{
|
||||
{"request", summary.Request.String, summary.Request.Valid},
|
||||
{"investigated", summary.Investigated.String, summary.Investigated.Valid},
|
||||
{"learned", summary.Learned.String, summary.Learned.Valid},
|
||||
{"completed", summary.Completed.String, summary.Completed.Valid},
|
||||
{"next_steps", summary.NextSteps.String, summary.NextSteps.Valid},
|
||||
{"notes", summary.Notes.String, summary.Notes.Valid},
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
if field.valid && field.value != "" {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("summary_%d_%s", summary.ID, field.name),
|
||||
Content: field.value,
|
||||
Metadata: copyMetadata(baseMetadata, "field_type", field.name),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return docs
|
||||
}
|
||||
|
||||
// SyncUserPrompt syncs a single user prompt to ChromaDB.
|
||||
func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWithSession) error {
|
||||
doc := Document{
|
||||
ID: fmt.Sprintf("prompt_%d", prompt.ID),
|
||||
Content: prompt.PromptText,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": prompt.ID,
|
||||
"doc_type": "user_prompt",
|
||||
"sdk_session_id": prompt.SDKSessionID,
|
||||
"project": prompt.Project,
|
||||
"created_at_epoch": prompt.CreatedAtEpoch,
|
||||
"prompt_number": prompt.PromptNumber,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.client.AddDocuments(ctx, []Document{doc}); err != nil {
|
||||
return fmt.Errorf("add prompt doc: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int64("promptId", prompt.ID).
|
||||
Msg("Synced user prompt to ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteObservations removes observation documents from ChromaDB.
|
||||
// Since each observation may have multiple documents (narrative + facts),
|
||||
// we delete by the sqlite_id metadata prefix pattern.
|
||||
func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) error {
|
||||
if len(observationIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate all possible document IDs for these observations
|
||||
// Pattern: obs_{id}_narrative, obs_{id}_fact_{0..n}
|
||||
// Since we don't know how many facts each had, we use a reasonable upper bound
|
||||
const maxFactsPerObs = 20
|
||||
ids := make([]string, 0, len(observationIDs)*(maxFactsPerObs+1))
|
||||
|
||||
for _, obsID := range observationIDs {
|
||||
ids = append(ids, fmt.Sprintf("obs_%d_narrative", obsID))
|
||||
for i := 0; i < maxFactsPerObs; i++ {
|
||||
ids = append(ids, fmt.Sprintf("obs_%d_fact_%d", obsID, i))
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.client.DeleteDocuments(ctx, ids); err != nil {
|
||||
return fmt.Errorf("delete observation docs: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("observationCount", len(observationIDs)).
|
||||
Msg("Deleted observations from ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUserPrompts removes user prompt documents from ChromaDB.
|
||||
func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
|
||||
if len(promptIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Each prompt is stored as a single document with ID pattern: prompt_{id}
|
||||
ids := make([]string, len(promptIDs))
|
||||
for i, promptID := range promptIDs {
|
||||
ids[i] = fmt.Sprintf("prompt_%d", promptID)
|
||||
}
|
||||
|
||||
if err := s.client.DeleteDocuments(ctx, ids); err != nil {
|
||||
return fmt.Errorf("delete prompt docs: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("promptCount", len(promptIDs)).
|
||||
Msg("Deleted user prompts from ChromaDB")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func copyMetadata(base map[string]any, key string, value any) map[string]any {
|
||||
result := make(map[string]any, len(base)+1)
|
||||
for k, v := range base {
|
||||
result[k] = v
|
||||
}
|
||||
result[key] = value
|
||||
return result
|
||||
}
|
||||
|
||||
func copyMetadataMulti(base map[string]any, extra map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(base)+len(extra))
|
||||
for k, v := range base {
|
||||
result[k] = v
|
||||
}
|
||||
for k, v := range extra {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func joinStrings(strs []string, sep string) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := strs[0]
|
||||
for i := 1; i < len(strs); i++ {
|
||||
result += sep + strs[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,350 @@
|
||||
package chroma
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// testSync creates a Sync with a nil client for testing format functions.
|
||||
func testSync() *Sync {
|
||||
return &Sync{client: nil}
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Scope: models.ScopeProject,
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: sql.NullString{String: "Test Title", Valid: true},
|
||||
Subtitle: sql.NullString{String: "Test Subtitle", Valid: true},
|
||||
Narrative: sql.NullString{String: "Test narrative content", Valid: true},
|
||||
Facts: models.JSONStringArray{"Fact 1", "Fact 2", "Fact 3"},
|
||||
Concepts: models.JSONStringArray{"concept1", "concept2"},
|
||||
FilesRead: models.JSONStringArray{"file1.go", "file2.go"},
|
||||
FilesModified: models.JSONStringArray{"file3.go"},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Should have 1 narrative + 3 facts = 4 documents
|
||||
assert.Len(t, docs, 4)
|
||||
|
||||
// Check narrative document
|
||||
narrativeDoc := docs[0]
|
||||
assert.Equal(t, "obs_1_narrative", narrativeDoc.ID)
|
||||
assert.Equal(t, "Test narrative content", narrativeDoc.Content)
|
||||
assert.Equal(t, int64(1), narrativeDoc.Metadata["sqlite_id"])
|
||||
assert.Equal(t, "observation", narrativeDoc.Metadata["doc_type"])
|
||||
assert.Equal(t, "narrative", narrativeDoc.Metadata["field_type"])
|
||||
assert.Equal(t, "test-project", narrativeDoc.Metadata["project"])
|
||||
assert.Equal(t, "project", narrativeDoc.Metadata["scope"])
|
||||
assert.Equal(t, "Test Title", narrativeDoc.Metadata["title"])
|
||||
assert.Equal(t, "Test Subtitle", narrativeDoc.Metadata["subtitle"])
|
||||
|
||||
// Check fact documents
|
||||
for i := 1; i <= 3; i++ {
|
||||
factDoc := docs[i]
|
||||
assert.Equal(t, fmt.Sprintf("obs_1_fact_%d", i-1), factDoc.ID)
|
||||
assert.Equal(t, fmt.Sprintf("Fact %d", i), factDoc.Content)
|
||||
assert.Equal(t, "fact", factDoc.Metadata["field_type"])
|
||||
assert.Equal(t, i-1, factDoc.Metadata["fact_index"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs_NoNarrative(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 2,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Scope: models.ScopeGlobal,
|
||||
Type: models.ObsTypeBugfix,
|
||||
Facts: models.JSONStringArray{"Only fact"},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Should have 1 fact only (no narrative)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "obs_2_fact_0", docs[0].ID)
|
||||
assert.Equal(t, "Only fact", docs[0].Content)
|
||||
assert.Equal(t, "global", docs[0].Metadata["scope"])
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs_Empty(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 3,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Type: models.ObsTypeDiscovery,
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Should have no documents when no content
|
||||
assert.Len(t, docs, 0)
|
||||
}
|
||||
|
||||
func TestSync_FormatObservationDocs_EmptyScope(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 4,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Scope: "", // Empty scope
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Narrative: sql.NullString{String: "Content", Valid: true},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
|
||||
// Empty scope should default to "project"
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "project", docs[0].Metadata["scope"])
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 1,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{String: "Add feature", Valid: true},
|
||||
Investigated: sql.NullString{String: "Looked at code", Valid: true},
|
||||
Learned: sql.NullString{String: "Found pattern", Valid: true},
|
||||
Completed: sql.NullString{String: "Done", Valid: true},
|
||||
NextSteps: sql.NullString{String: "Test it", Valid: true},
|
||||
Notes: sql.NullString{String: "Notes here", Valid: true},
|
||||
PromptNumber: sql.NullInt64{Int64: 5, Valid: true},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Should have 6 documents (all fields present)
|
||||
assert.Len(t, docs, 6)
|
||||
|
||||
// Check first document
|
||||
assert.Equal(t, "summary_1_request", docs[0].ID)
|
||||
assert.Equal(t, "Add feature", docs[0].Content)
|
||||
assert.Equal(t, "session_summary", docs[0].Metadata["doc_type"])
|
||||
assert.Equal(t, "request", docs[0].Metadata["field_type"])
|
||||
assert.Equal(t, int64(5), docs[0].Metadata["prompt_number"])
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs_PartialFields(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 2,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{String: "Only request", Valid: true},
|
||||
Completed: sql.NullString{String: "Only completed", Valid: true},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Should have 2 documents (only valid fields)
|
||||
assert.Len(t, docs, 2)
|
||||
|
||||
// Verify field types
|
||||
fieldTypes := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
fieldTypes[i] = doc.Metadata["field_type"].(string)
|
||||
}
|
||||
assert.Contains(t, fieldTypes, "request")
|
||||
assert.Contains(t, fieldTypes, "completed")
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs_Empty(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 3,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Should have no documents when no content
|
||||
assert.Len(t, docs, 0)
|
||||
}
|
||||
|
||||
func TestSync_FormatSummaryDocs_EmptyStrings(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
summary := &models.SessionSummary{
|
||||
ID: 4,
|
||||
SDKSessionID: "test-session",
|
||||
Project: "test-project",
|
||||
Request: sql.NullString{String: "", Valid: true}, // Valid but empty
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatSummaryDocs(summary)
|
||||
|
||||
// Empty strings should not produce documents
|
||||
assert.Len(t, docs, 0)
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestJoinStrings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
strs []string
|
||||
sep string
|
||||
expected string
|
||||
}{
|
||||
{"empty", []string{}, ",", ""},
|
||||
{"single", []string{"a"}, ",", "a"},
|
||||
{"multiple", []string{"a", "b", "c"}, ",", "a,b,c"},
|
||||
{"different sep", []string{"a", "b"}, "-", "a-b"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := joinStrings(tt.strs, tt.sep)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyMetadata(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
result := copyMetadata(base, "key3", "value3")
|
||||
|
||||
// Original should be unchanged
|
||||
assert.Len(t, base, 2)
|
||||
|
||||
// Result should have all keys
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, "value1", result["key1"])
|
||||
assert.Equal(t, 42, result["key2"])
|
||||
assert.Equal(t, "value3", result["key3"])
|
||||
}
|
||||
|
||||
func TestCopyMetadataMulti(t *testing.T) {
|
||||
base := map[string]any{
|
||||
"key1": "value1",
|
||||
}
|
||||
extra := map[string]any{
|
||||
"key2": "value2",
|
||||
"key3": "value3",
|
||||
}
|
||||
|
||||
result := copyMetadataMulti(base, extra)
|
||||
|
||||
// Original should be unchanged
|
||||
assert.Len(t, base, 1)
|
||||
|
||||
// Result should have all keys
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, "value1", result["key1"])
|
||||
assert.Equal(t, "value2", result["key2"])
|
||||
assert.Equal(t, "value3", result["key3"])
|
||||
}
|
||||
|
||||
// Test ID generation patterns for delete operations
|
||||
func TestSync_DeleteObservationIDGeneration(t *testing.T) {
|
||||
// Test that we generate correct document IDs for deletion
|
||||
obsIDs := []int64{1, 2}
|
||||
maxFactsPerObs := 20
|
||||
|
||||
ids := make([]string, 0, len(obsIDs)*(maxFactsPerObs+1))
|
||||
for _, obsID := range obsIDs {
|
||||
ids = append(ids, fmt.Sprintf("obs_%d_narrative", obsID))
|
||||
for i := 0; i < maxFactsPerObs; i++ {
|
||||
ids = append(ids, fmt.Sprintf("obs_%d_fact_%d", obsID, i))
|
||||
}
|
||||
}
|
||||
|
||||
// Each observation should generate 21 IDs (1 narrative + 20 facts)
|
||||
assert.Len(t, ids, 42)
|
||||
|
||||
// Check some expected IDs
|
||||
assert.Contains(t, ids, "obs_1_narrative")
|
||||
assert.Contains(t, ids, "obs_1_fact_0")
|
||||
assert.Contains(t, ids, "obs_1_fact_19")
|
||||
assert.Contains(t, ids, "obs_2_narrative")
|
||||
assert.Contains(t, ids, "obs_2_fact_0")
|
||||
}
|
||||
|
||||
func TestSync_DeletePromptIDGeneration(t *testing.T) {
|
||||
// Test that we generate correct document IDs for prompt deletion
|
||||
promptIDs := []int64{10, 20, 30}
|
||||
|
||||
ids := make([]string, len(promptIDs))
|
||||
for i, promptID := range promptIDs {
|
||||
ids[i] = fmt.Sprintf("prompt_%d", promptID)
|
||||
}
|
||||
|
||||
assert.Len(t, ids, 3)
|
||||
assert.Contains(t, ids, "prompt_10")
|
||||
assert.Contains(t, ids, "prompt_20")
|
||||
assert.Contains(t, ids, "prompt_30")
|
||||
}
|
||||
|
||||
// Test metadata includes all expected fields
|
||||
func TestSync_ObservationMetadataFields(t *testing.T) {
|
||||
sync := testSync()
|
||||
|
||||
obs := &models.Observation{
|
||||
ID: 1,
|
||||
SDKSessionID: "sdk-123",
|
||||
Project: "my-project",
|
||||
Scope: models.ScopeGlobal,
|
||||
Type: models.ObsTypeBugfix,
|
||||
Title: sql.NullString{String: "Bug Fix", Valid: true},
|
||||
Subtitle: sql.NullString{String: "Memory leak", Valid: true},
|
||||
Narrative: sql.NullString{String: "Fixed the leak", Valid: true},
|
||||
Concepts: models.JSONStringArray{"memory", "performance"},
|
||||
FilesRead: models.JSONStringArray{"main.go"},
|
||||
FilesModified: models.JSONStringArray{"fix.go"},
|
||||
CreatedAtEpoch: 1234567890,
|
||||
}
|
||||
|
||||
docs := sync.formatObservationDocs(obs)
|
||||
require := assert.New(t)
|
||||
|
||||
require.Len(docs, 1) // Only narrative, no facts
|
||||
|
||||
meta := docs[0].Metadata
|
||||
require.Equal(int64(1), meta["sqlite_id"])
|
||||
require.Equal("observation", meta["doc_type"])
|
||||
require.Equal("sdk-123", meta["sdk_session_id"])
|
||||
require.Equal("my-project", meta["project"])
|
||||
require.Equal("global", meta["scope"])
|
||||
require.Equal("bugfix", meta["type"])
|
||||
require.Equal("Bug Fix", meta["title"])
|
||||
require.Equal("Memory leak", meta["subtitle"])
|
||||
require.Equal("memory,performance", meta["concepts"])
|
||||
require.Equal("main.go", meta["files_read"])
|
||||
require.Equal("fix.go", meta["files_modified"])
|
||||
require.Equal(int64(1234567890), meta["created_at_epoch"])
|
||||
require.Equal("narrative", meta["field_type"])
|
||||
}
|
||||
Reference in New Issue
Block a user