Initial commit

This commit is contained in:
2025-12-14 21:59:59 +00:00
commit d7c20cea54
126 changed files with 21728 additions and 0 deletions
+251
View File
@@ -0,0 +1,251 @@
// Package config provides configuration management for claude-mnemonic.
package config
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"sync"
)
const (
// DefaultWorkerPort is the default HTTP port for the worker service.
DefaultWorkerPort = 37777
// DefaultPythonVersion for ChromaDB (avoid onnxruntime issues with 3.14+).
DefaultPythonVersion = "3.13"
// DefaultModel for SDK agent (use "haiku" for cost-efficient processing).
// Claude Code CLI accepts aliases: haiku, sonnet, opus (always latest versions)
DefaultModel = "haiku"
)
// DefaultObservationTypes are the observation types to include in context.
var DefaultObservationTypes = []string{
"bugfix", "feature", "refactor", "change", "discovery", "decision",
}
// DefaultObservationConcepts are the concept tags to include in context.
var DefaultObservationConcepts = []string{
"how-it-works", "why-it-exists", "what-changed",
"problem-solution", "gotcha", "pattern", "trade-off",
}
// CriticalConcepts are concepts that indicate "must know" information.
// Observations with these concepts are prioritized in context injection.
var CriticalConcepts = []string{
"gotcha", "pattern", "problem-solution", "trade-off",
}
// Config holds the application configuration.
type Config struct {
// Worker settings
WorkerPort int `json:"worker_port"`
// Database settings
DBPath string `json:"db_path"`
MaxConns int `json:"max_conns"`
// ChromaDB settings
VectorDBPath string `json:"vector_db_path"`
PythonVersion string `json:"python_version"`
// SDK Agent settings
Model string `json:"model"`
ClaudeCodePath string `json:"claude_code_path"`
// Context injection settings
ContextObservations int `json:"context_observations"`
ContextFullCount int `json:"context_full_count"`
ContextSessionCount int `json:"context_session_count"`
ContextShowReadTokens bool `json:"context_show_read_tokens"`
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
ContextFullField string `json:"context_full_field"`
ContextShowLastSummary bool `json:"context_show_last_summary"`
ContextObsTypes []string `json:"context_obs_types"`
ContextObsConcepts []string `json:"context_obs_concepts"`
}
var (
globalConfig *Config
configOnce sync.Once
configMu sync.RWMutex
)
// DataDir returns the data directory path (~/.claude-mnemonic).
func DataDir() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".claude-mnemonic")
}
// DBPath returns the database file path.
func DBPath() string {
return filepath.Join(DataDir(), "claude-mnemonic.db")
}
// VectorDBPath returns the vector database directory path.
func VectorDBPath() string {
return filepath.Join(DataDir(), "vector-db")
}
// SettingsPath returns the settings file path.
func SettingsPath() string {
return filepath.Join(DataDir(), "settings.json")
}
// EnsureDataDir creates the data directory if it doesn't exist.
func EnsureDataDir() error {
return os.MkdirAll(DataDir(), 0750)
}
// EnsureVectorDBDir creates the vector database directory if it doesn't exist.
func EnsureVectorDBDir() error {
return os.MkdirAll(VectorDBPath(), 0750)
}
// EnsureSettings creates a default settings file if it doesn't exist.
func EnsureSettings() error {
path := SettingsPath()
// Check if file exists
if _, err := os.Stat(path); err == nil {
return nil // File exists
}
// Create default settings file with comments
defaultSettings := `{
"CLAUDE_MNEMONIC_WORKER_PORT": 37777,
"CLAUDE_MNEMONIC_PYTHON_VERSION": "3.13",
"CLAUDE_MNEMONIC_MODEL": "haiku",
"CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS": 100,
"CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT": 25,
"CLAUDE_MNEMONIC_CONTEXT_SESSION_COUNT": 10
}
`
return os.WriteFile(path, []byte(defaultSettings), 0600)
}
// EnsureAll ensures all required directories and files exist.
func EnsureAll() error {
if err := EnsureDataDir(); err != nil {
return err
}
if err := EnsureVectorDBDir(); err != nil {
return err
}
if err := EnsureSettings(); err != nil {
return err
}
return nil
}
// Default returns a Config with default values.
func Default() *Config {
return &Config{
WorkerPort: DefaultWorkerPort,
DBPath: DBPath(),
MaxConns: 4,
VectorDBPath: VectorDBPath(),
PythonVersion: DefaultPythonVersion,
Model: DefaultModel,
ContextObservations: 100,
ContextFullCount: 25,
ContextSessionCount: 10,
ContextShowReadTokens: true,
ContextShowWorkTokens: true,
ContextFullField: "narrative",
ContextShowLastSummary: true,
ContextObsTypes: DefaultObservationTypes,
ContextObsConcepts: DefaultObservationConcepts,
}
}
// Load loads configuration from the settings file, merging with defaults.
func Load() (*Config, error) {
cfg := Default()
data, err := os.ReadFile(SettingsPath())
if err != nil {
if os.IsNotExist(err) {
return cfg, nil
}
return nil, err
}
// Load settings into a map to preserve unknown fields
var settings map[string]interface{}
if err := json.Unmarshal(data, &settings); err != nil {
return cfg, nil // Return defaults on parse error
}
// Map settings to config
if v, ok := settings["CLAUDE_MNEMONIC_WORKER_PORT"].(float64); ok {
cfg.WorkerPort = int(v)
}
if v, ok := settings["CLAUDE_MNEMONIC_PYTHON_VERSION"].(string); ok {
cfg.PythonVersion = v
}
if v, ok := settings["CLAUDE_MNEMONIC_MODEL"].(string); ok {
cfg.Model = v
}
if v, ok := settings["CLAUDE_CODE_PATH"].(string); ok {
cfg.ClaudeCodePath = v
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBSERVATIONS"].(float64); ok {
cfg.ContextObservations = int(v)
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_FULL_COUNT"].(float64); ok {
cfg.ContextFullCount = int(v)
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_SESSION_COUNT"].(float64); ok {
cfg.ContextSessionCount = int(v)
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBS_TYPES"].(string); ok && v != "" {
cfg.ContextObsTypes = splitTrim(v)
}
if v, ok := settings["CLAUDE_MNEMONIC_CONTEXT_OBS_CONCEPTS"].(string); ok && v != "" {
cfg.ContextObsConcepts = splitTrim(v)
}
return cfg, nil
}
// splitTrim splits a comma-separated string and trims whitespace.
func splitTrim(s string) []string {
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}
// Get returns the global configuration, loading it if necessary.
func Get() *Config {
configOnce.Do(func() {
var err error
globalConfig, err = Load()
if err != nil {
globalConfig = Default()
}
})
configMu.RLock()
defer configMu.RUnlock()
return globalConfig
}
// GetWorkerPort returns the worker port from environment or config.
func GetWorkerPort() int {
if port := os.Getenv("CLAUDE_MNEMONIC_WORKER_PORT"); port != "" {
var p int
if err := json.Unmarshal([]byte(port), &p); err == nil && p > 0 {
return p
}
}
return Get().WorkerPort
}
+347
View File
@@ -0,0 +1,347 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"database/sql"
"fmt"
"time"
)
// Migration represents a database schema migration.
type Migration struct {
Version int
Name string
SQL string
}
// Migrations is the list of all database migrations in order.
var Migrations = []Migration{
{
Version: 4,
Name: "sdk_agent_architecture",
SQL: `
-- SDK Sessions (main session tracking)
CREATE TABLE IF NOT EXISTS sdk_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
claude_session_id TEXT UNIQUE NOT NULL,
sdk_session_id TEXT UNIQUE,
project TEXT NOT NULL,
user_prompt TEXT,
started_at TEXT NOT NULL,
started_at_epoch INTEGER NOT NULL,
completed_at TEXT,
completed_at_epoch INTEGER,
status TEXT CHECK(status IN ('active', 'completed', 'failed')) NOT NULL DEFAULT 'active'
);
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id);
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id);
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_project ON sdk_sessions(project);
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_status ON sdk_sessions(status);
CREATE INDEX IF NOT EXISTS idx_sdk_sessions_started ON sdk_sessions(started_at_epoch DESC);
-- Observations (extracted learnings)
CREATE TABLE IF NOT EXISTS observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sdk_session_id TEXT NOT NULL,
project TEXT NOT NULL,
text TEXT,
type TEXT NOT NULL CHECK(type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change')),
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_observations_sdk_session ON observations(sdk_session_id);
CREATE INDEX IF NOT EXISTS idx_observations_project ON observations(project);
CREATE INDEX IF NOT EXISTS idx_observations_type ON observations(type);
CREATE INDEX IF NOT EXISTS idx_observations_created ON observations(created_at_epoch DESC);
-- Session Summaries
CREATE TABLE IF NOT EXISTS session_summaries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sdk_session_id TEXT NOT NULL,
project TEXT NOT NULL,
request TEXT,
investigated TEXT,
learned TEXT,
completed TEXT,
next_steps TEXT,
files_read TEXT,
files_edited TEXT,
notes TEXT,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_session_summaries_sdk_session ON session_summaries(sdk_session_id);
CREATE INDEX IF NOT EXISTS idx_session_summaries_project ON session_summaries(project);
CREATE INDEX IF NOT EXISTS idx_session_summaries_created ON session_summaries(created_at_epoch DESC);
`,
},
{
Version: 5,
Name: "worker_port_column",
SQL: `ALTER TABLE sdk_sessions ADD COLUMN worker_port INTEGER;`,
},
{
Version: 6,
Name: "prompt_tracking_columns",
SQL: `
ALTER TABLE sdk_sessions ADD COLUMN prompt_counter INTEGER DEFAULT 0;
ALTER TABLE observations ADD COLUMN prompt_number INTEGER;
ALTER TABLE session_summaries ADD COLUMN prompt_number INTEGER;
`,
},
{
Version: 8,
Name: "observation_hierarchical_fields",
SQL: `
ALTER TABLE observations ADD COLUMN title TEXT;
ALTER TABLE observations ADD COLUMN subtitle TEXT;
ALTER TABLE observations ADD COLUMN facts TEXT;
ALTER TABLE observations ADD COLUMN narrative TEXT;
ALTER TABLE observations ADD COLUMN concepts TEXT;
ALTER TABLE observations ADD COLUMN files_read TEXT;
ALTER TABLE observations ADD COLUMN files_modified TEXT;
`,
},
{
Version: 10,
Name: "user_prompts_table",
SQL: `
-- User prompts table
CREATE TABLE IF NOT EXISTS user_prompts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
claude_session_id TEXT NOT NULL,
prompt_number INTEGER NOT NULL,
prompt_text TEXT NOT NULL,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(claude_session_id) REFERENCES sdk_sessions(claude_session_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_user_prompts_claude_session ON user_prompts(claude_session_id);
CREATE INDEX IF NOT EXISTS idx_user_prompts_created ON user_prompts(created_at_epoch DESC);
CREATE INDEX IF NOT EXISTS idx_user_prompts_prompt_number ON user_prompts(prompt_number);
CREATE INDEX IF NOT EXISTS idx_user_prompts_lookup ON user_prompts(claude_session_id, prompt_number);
-- FTS5 virtual table for user prompts
CREATE VIRTUAL TABLE IF NOT EXISTS user_prompts_fts USING fts5(
prompt_text,
content='user_prompts',
content_rowid='id'
);
-- Triggers for FTS5 sync
CREATE TRIGGER IF NOT EXISTS user_prompts_ai AFTER INSERT ON user_prompts BEGIN
INSERT INTO user_prompts_fts(rowid, prompt_text)
VALUES (new.id, new.prompt_text);
END;
CREATE TRIGGER IF NOT EXISTS user_prompts_ad AFTER DELETE ON user_prompts BEGIN
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
VALUES('delete', old.id, old.prompt_text);
END;
CREATE TRIGGER IF NOT EXISTS user_prompts_au AFTER UPDATE ON user_prompts BEGIN
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
VALUES('delete', old.id, old.prompt_text);
INSERT INTO user_prompts_fts(rowid, prompt_text)
VALUES (new.id, new.prompt_text);
END;
`,
},
{
Version: 11,
Name: "discovery_tokens_column",
SQL: `
ALTER TABLE observations ADD COLUMN discovery_tokens INTEGER DEFAULT 0;
ALTER TABLE session_summaries ADD COLUMN discovery_tokens INTEGER DEFAULT 0;
`,
},
{
Version: 12,
Name: "observations_fts",
SQL: `
-- FTS5 virtual table for observations
CREATE VIRTUAL TABLE IF NOT EXISTS observations_fts USING fts5(
title, subtitle, narrative,
content='observations',
content_rowid='id'
);
-- Triggers for FTS5 sync
CREATE TRIGGER IF NOT EXISTS observations_ai AFTER INSERT ON observations BEGIN
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
VALUES (new.id, new.title, new.subtitle, new.narrative);
END;
CREATE TRIGGER IF NOT EXISTS observations_ad AFTER DELETE ON observations BEGIN
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
VALUES('delete', old.id, old.title, old.subtitle, old.narrative);
END;
CREATE TRIGGER IF NOT EXISTS observations_au AFTER UPDATE ON observations BEGIN
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
VALUES('delete', old.id, old.title, old.subtitle, old.narrative);
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
VALUES (new.id, new.title, new.subtitle, new.narrative);
END;
`,
},
{
Version: 13,
Name: "session_summaries_fts",
SQL: `
-- FTS5 virtual table for session summaries
CREATE VIRTUAL TABLE IF NOT EXISTS session_summaries_fts USING fts5(
request, investigated, learned, completed, next_steps, notes,
content='session_summaries',
content_rowid='id'
);
-- Triggers for FTS5 sync
CREATE TRIGGER IF NOT EXISTS session_summaries_ai AFTER INSERT ON session_summaries BEGIN
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
END;
CREATE TRIGGER IF NOT EXISTS session_summaries_ad AFTER DELETE ON session_summaries BEGIN
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
VALUES('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
END;
CREATE TRIGGER IF NOT EXISTS session_summaries_au AFTER UPDATE ON session_summaries BEGIN
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
VALUES('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
END;
`,
},
{
Version: 14,
Name: "observation_scope_column",
SQL: `
-- Add scope column for project isolation
-- 'project' = only visible within same project (default)
-- 'global' = visible across all projects (best practices, patterns)
ALTER TABLE observations ADD COLUMN scope TEXT DEFAULT 'project' CHECK(scope IN ('project', 'global'));
-- Index for efficient scope-based queries
CREATE INDEX IF NOT EXISTS idx_observations_scope ON observations(scope);
CREATE INDEX IF NOT EXISTS idx_observations_project_scope ON observations(project, scope);
`,
},
{
Version: 15,
Name: "observation_file_mtimes",
SQL: `
-- Store file modification times at observation creation
-- JSON object: {"path": mtime_epoch_ms, ...}
-- Used to detect staleness when files change
ALTER TABLE observations ADD COLUMN file_mtimes TEXT;
`,
},
{
Version: 16,
Name: "prompt_matched_observations",
SQL: `
-- Track how many observations were found relevant for each prompt
-- Displayed in dashboard timeline
ALTER TABLE user_prompts ADD COLUMN matched_observations INTEGER DEFAULT 0;
`,
},
}
// MigrationManager handles database schema migrations.
type MigrationManager struct {
db *sql.DB
}
// NewMigrationManager creates a new migration manager.
func NewMigrationManager(db *sql.DB) *MigrationManager {
return &MigrationManager{db: db}
}
// EnsureSchemaVersionsTable creates the schema_versions table if it doesn't exist.
func (m *MigrationManager) EnsureSchemaVersionsTable() error {
_, err := m.db.Exec(`
CREATE TABLE IF NOT EXISTS schema_versions (
id INTEGER PRIMARY KEY,
version INTEGER UNIQUE NOT NULL,
applied_at TEXT NOT NULL
)
`)
return err
}
// GetAppliedVersions returns all applied migration versions.
func (m *MigrationManager) GetAppliedVersions() (map[int]bool, error) {
rows, err := m.db.Query("SELECT version FROM schema_versions ORDER BY version")
if err != nil {
return nil, err
}
defer rows.Close()
versions := make(map[int]bool)
for rows.Next() {
var version int
if err := rows.Scan(&version); err != nil {
return nil, err
}
versions[version] = true
}
return versions, rows.Err()
}
// ApplyMigration applies a single migration.
func (m *MigrationManager) ApplyMigration(migration Migration) error {
tx, err := m.db.Begin()
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback()
// Execute migration SQL
if _, err := tx.Exec(migration.SQL); err != nil {
return fmt.Errorf("execute migration %d (%s): %w", migration.Version, migration.Name, err)
}
// Record migration
_, err = tx.Exec(
"INSERT INTO schema_versions (version, applied_at) VALUES (?, ?)",
migration.Version, time.Now().Format(time.RFC3339),
)
if err != nil {
return fmt.Errorf("record migration %d: %w", migration.Version, err)
}
return tx.Commit()
}
// RunMigrations applies all pending migrations.
func (m *MigrationManager) RunMigrations() error {
if err := m.EnsureSchemaVersionsTable(); err != nil {
return fmt.Errorf("ensure schema_versions table: %w", err)
}
applied, err := m.GetAppliedVersions()
if err != nil {
return fmt.Errorf("get applied versions: %w", err)
}
for _, migration := range Migrations {
if applied[migration.Version] {
continue
}
if err := m.ApplyMigration(migration); err != nil {
return err
}
}
return nil
}
+513
View File
@@ -0,0 +1,513 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"strings"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// CleanupFunc is a callback for when observations are cleaned up.
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
// ObservationStore provides observation-related database operations.
type ObservationStore struct {
store *Store
cleanupFunc CleanupFunc
}
// NewObservationStore creates a new observation store.
func NewObservationStore(store *Store) *ObservationStore {
return &ObservationStore{store: store}
}
// SetCleanupFunc sets the callback for when observations are deleted during cleanup.
func (s *ObservationStore) SetCleanupFunc(fn CleanupFunc) {
s.cleanupFunc = fn
}
// StoreObservation stores a new observation.
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
now := time.Now()
nowEpoch := now.UnixMilli()
// Ensure session exists (auto-create if missing)
if err := s.ensureSessionExists(ctx, sdkSessionID, project); err != nil {
return 0, 0, err
}
// Determine scope: use parsed scope if set, otherwise auto-determine from concepts
scope := obs.Scope
if scope == "" {
scope = models.DetermineScope(obs.Concepts)
}
factsJSON, _ := json.Marshal(obs.Facts)
conceptsJSON, _ := json.Marshal(obs.Concepts)
filesReadJSON, _ := json.Marshal(obs.FilesRead)
filesModifiedJSON, _ := json.Marshal(obs.FilesModified)
fileMtimesJSON, _ := json.Marshal(obs.FileMtimes)
const query = `
INSERT INTO observations
(sdk_session_id, project, scope, type, title, subtitle, facts, narrative, concepts,
files_read, files_modified, file_mtimes, prompt_number, discovery_tokens, created_at, created_at_epoch)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
result, err := s.store.ExecContext(ctx, query,
sdkSessionID, project, string(scope), string(obs.Type),
nullString(obs.Title), nullString(obs.Subtitle),
string(factsJSON), nullString(obs.Narrative), string(conceptsJSON),
string(filesReadJSON), string(filesModifiedJSON), string(fileMtimesJSON),
nullInt(promptNumber), discoveryTokens,
now.Format(time.RFC3339), nowEpoch,
)
if err != nil {
return 0, 0, err
}
id, _ := result.LastInsertId()
// Cleanup old observations beyond the limit for this project
if project != "" {
deletedIDs, _ := s.CleanupOldObservations(ctx, project)
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
s.cleanupFunc(ctx, deletedIDs)
}
}
return id, nowEpoch, nil
}
// ensureSessionExists creates a session if it doesn't exist.
func (s *ObservationStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
const checkQuery = `SELECT id FROM sdk_sessions WHERE sdk_session_id = ?`
var id int64
err := s.store.QueryRowContext(ctx, checkQuery, sdkSessionID).Scan(&id)
if err == nil {
return nil // Session exists
}
if err != sql.ErrNoRows {
return err
}
// Auto-create session
now := time.Now()
const insertQuery = `
INSERT INTO sdk_sessions
(claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status)
VALUES (?, ?, ?, ?, ?, 'active')
`
_, err = s.store.ExecContext(ctx, insertQuery,
sdkSessionID, sdkSessionID, project,
now.Format(time.RFC3339), now.UnixMilli(),
)
return err
}
// GetObservationByID retrieves an observation by ID.
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
const query = `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations
WHERE id = ?
`
obs, err := scanObservation(s.store.QueryRowContext(ctx, query, id))
if err == sql.ErrNoRows {
return nil, nil
}
return obs, err
}
// GetObservationsByIDs retrieves observations by a list of IDs.
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
if len(ids) == 0 {
return nil, nil
}
// Build query with placeholders
// #nosec G202 -- query uses parameterized placeholders, not user input
query := `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
ORDER BY created_at_epoch `
if orderBy == "date_asc" {
query += "ASC"
} else {
query += "DESC"
}
if limit > 0 {
query += " LIMIT ?"
}
// Convert []int64 to []interface{}
args := make([]interface{}, len(ids))
for i, id := range ids {
args[i] = id
}
if limit > 0 {
args = append(args, limit)
}
rows, err := s.store.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// GetRecentObservations retrieves recent observations for a project.
// This includes project-scoped observations for the specified project AND global observations.
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
const query = `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations
WHERE (project = ? AND (scope IS NULL OR scope = 'project'))
OR scope = 'global'
ORDER BY created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, project, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// GetObservationCount returns the count of observations for a project (including global).
func (s *ObservationStore) GetObservationCount(ctx context.Context, project string) (int, error) {
const query = `
SELECT COUNT(*) FROM observations
WHERE project = ? OR scope = 'global'
`
var count int
err := s.store.QueryRowContext(ctx, query, project).Scan(&count)
return count, err
}
// GetAllRecentObservations retrieves recent observations across all projects.
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
const query = `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type, title, subtitle, facts, narrative,
concepts, files_read, files_modified, file_mtimes, prompt_number, discovery_tokens,
created_at, created_at_epoch
FROM observations
ORDER BY created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// SearchObservationsFTS performs full-text search on observations.
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
if limit <= 0 {
limit = 10
}
// Extract keywords from the query (words > 3 chars, not common)
keywords := extractKeywords(query)
if len(keywords) == 0 {
return nil, nil
}
// Build FTS5 query: keyword1 OR keyword2 OR keyword3
ftsTerms := strings.Join(keywords, " OR ")
// Use FTS5 to search title, subtitle, and narrative
const ftsQuery = `
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch
FROM observations o
JOIN observations_fts fts ON o.id = fts.rowid
WHERE observations_fts MATCH ?
AND (o.project = ? OR o.scope = 'global')
ORDER BY rank
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, ftsQuery, ftsTerms, project, limit)
if err != nil {
// FTS failed, try LIKE fallback
return s.searchObservationsLike(ctx, keywords, project, limit)
}
defer rows.Close()
observations, err := scanObservationRows(rows)
if err != nil {
return nil, err
}
// If FTS returned nothing, try LIKE search
if len(observations) == 0 {
return s.searchObservationsLike(ctx, keywords, project, limit)
}
return observations, nil
}
// searchObservationsLike performs fallback LIKE search on observations.
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
if len(keywords) == 0 {
return nil, nil
}
// Build LIKE conditions for each keyword
var conditions []string
var args []interface{}
for _, kw := range keywords {
pattern := "%" + kw + "%"
conditions = append(conditions, "(title LIKE ? OR subtitle LIKE ? OR narrative LIKE ?)")
args = append(args, pattern, pattern, pattern)
}
// #nosec G202 -- query uses parameterized placeholders, not user input
query := `
SELECT id, sdk_session_id, project, COALESCE(scope, 'project') as scope, type,
title, subtitle, facts, narrative, concepts, files_read, files_modified,
file_mtimes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM observations
WHERE (` + strings.Join(conditions, " OR ") + `)
AND (project = ? OR scope = 'global')
ORDER BY created_at_epoch DESC
LIMIT ?
`
args = append(args, project, limit)
rows, err := s.store.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanObservationRows(rows)
}
// extractKeywords extracts significant words from a query.
func extractKeywords(query string) []string {
// Common words to skip
stopWords := map[string]bool{
"what": true, "is": true, "the": true, "a": true, "an": true,
"how": true, "does": true, "do": true, "can": true, "could": true,
"would": true, "should": true, "where": true, "when": true, "why": true,
"which": true, "who": true, "this": true, "that": true, "these": true,
"those": true, "it": true, "its": true, "for": true, "from": true,
"with": true, "about": true, "into": true, "through": true, "during": true,
"before": true, "after": true, "above": true, "below": true, "to": true,
"of": true, "in": true, "on": true, "at": true, "by": true, "and": true,
"or": true, "but": true, "if": true, "then": true, "else": true,
"function": true, "method": true, "class": true, "file": true,
"code": true, "work": true, "works": true, "working": true,
"please": true, "help": true, "me": true, "my": true, "i": true,
"tell": true, "show": true, "explain": true, "describe": true,
}
// Split and filter
words := strings.FieldsFunc(strings.ToLower(query), func(r rune) bool {
return !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_')
})
var keywords []string
seen := make(map[string]bool)
for _, word := range words {
// Skip short words, stop words, and duplicates
if len(word) < 4 || stopWords[word] || seen[word] {
continue
}
seen[word] = true
keywords = append(keywords, word)
}
return keywords
}
// ExistsSimilarObservation checks if an observation about the same files exists for a project.
// Used to prevent duplicate observations when re-reading the same files.
func (s *ObservationStore) ExistsSimilarObservation(ctx context.Context, project string, filesRead, filesModified []string) (bool, error) {
// If no files tracked, can't deduplicate
if len(filesRead) == 0 && len(filesModified) == 0 {
return false, nil
}
// Check if any observation exists with the same primary file
// Use the first file as the key identifier
var primaryFile string
if len(filesRead) > 0 {
primaryFile = filesRead[0]
} else if len(filesModified) > 0 {
primaryFile = filesModified[0]
}
const query = `
SELECT COUNT(*) FROM observations
WHERE project = ? AND (files_read LIKE ? OR files_modified LIKE ?)
`
pattern := "%" + primaryFile + "%"
var count int
err := s.store.QueryRowContext(ctx, query, project, pattern, pattern).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// DeleteObservations deletes multiple observations by ID.
func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
query := `DELETE FROM observations WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)` // #nosec G202 -- uses parameterized placeholders
args := make([]interface{}, len(ids))
for i, id := range ids {
args[i] = id
}
result, err := s.store.db.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// MaxObservationsPerProject is the hard limit of observations per project.
const MaxObservationsPerProject = 100
// CleanupOldObservations deletes observations beyond the limit for a project.
// Keeps the most recent MaxObservationsPerProject observations per project.
// Returns the IDs of deleted observations for downstream cleanup (e.g., vector DB).
func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) {
// First, find IDs that will be deleted
const selectQuery = `
SELECT id FROM observations
WHERE project = ? AND id NOT IN (
SELECT id FROM observations
WHERE project = ?
ORDER BY created_at_epoch DESC
LIMIT ?
)
`
rows, err := s.store.QueryContext(ctx, selectQuery, project, project, MaxObservationsPerProject)
if err != nil {
return nil, err
}
defer rows.Close()
var toDelete []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
toDelete = append(toDelete, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(toDelete) == 0 {
return nil, nil
}
// Delete the observations
const deleteQuery = `
DELETE FROM observations
WHERE project = ? AND id NOT IN (
SELECT id FROM observations
WHERE project = ?
ORDER BY created_at_epoch DESC
LIMIT ?
)
`
_, err = s.store.ExecContext(ctx, deleteQuery, project, project, MaxObservationsPerProject)
if err != nil {
return nil, err
}
return toDelete, nil
}
// Helper functions
// scanObservation scans a single observation from a row scanner.
// This reduces code duplication across all observation query methods.
func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.Observation, error) {
var obs models.Observation
if err := scanner.Scan(
&obs.ID, &obs.SDKSessionID, &obs.Project, &obs.Scope, &obs.Type,
&obs.Title, &obs.Subtitle, &obs.Facts, &obs.Narrative,
&obs.Concepts, &obs.FilesRead, &obs.FilesModified, &obs.FileMtimes,
&obs.PromptNumber, &obs.DiscoveryTokens,
&obs.CreatedAt, &obs.CreatedAtEpoch,
); err != nil {
return nil, err
}
return &obs, nil
}
// scanObservationRows scans multiple observations from rows.
// Caller must close rows after calling this function.
func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
var observations []*models.Observation
for rows.Next() {
obs, err := scanObservation(rows)
if err != nil {
return nil, err
}
observations = append(observations, obs)
}
return observations, rows.Err()
}
func nullString(s string) sql.NullString {
return sql.NullString{String: s, Valid: s != ""}
}
func nullInt(i int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(i), Valid: i > 0}
}
func repeatPlaceholders(n int) string {
if n <= 0 {
return ""
}
result := ""
for i := 0; i < n; i++ {
result += ", ?"
}
return result
}
+374
View File
@@ -0,0 +1,374 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testObservationStore creates an ObservationStore with a test database including FTS5.
func testObservationStore(t *testing.T) (*ObservationStore, *Store, func()) {
t.Helper()
db, _, cleanup := testDB(t)
createAllTables(t, db)
store := newStoreFromDB(db)
obsStore := NewObservationStore(store)
return obsStore, store, cleanup
}
func TestObservationStore_StoreAndRetrieve(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Test Observation",
Subtitle: "A subtitle",
Narrative: "This is a test observation about testing",
Facts: []string{"Fact 1", "Fact 2"},
Concepts: []string{"testing", "golang"},
FilesRead: []string{"test.go"},
FilesModified: []string{},
}
id, epoch, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
assert.Greater(t, epoch, int64(0))
// Retrieve by ID
retrieved, err := obsStore.GetObservationByID(ctx, id)
require.NoError(t, err)
require.NotNil(t, retrieved)
assert.Equal(t, id, retrieved.ID)
assert.Equal(t, "session-1", retrieved.SDKSessionID)
assert.Equal(t, "project-a", retrieved.Project)
assert.Equal(t, models.ObsTypeDiscovery, retrieved.Type)
assert.Equal(t, "Test Observation", retrieved.Title.String)
assert.Equal(t, "A subtitle", retrieved.Subtitle.String)
assert.Equal(t, "This is a test observation about testing", retrieved.Narrative.String)
assert.Equal(t, []string{"Fact 1", "Fact 2"}, []string(retrieved.Facts))
assert.Equal(t, []string{"testing", "golang"}, []string(retrieved.Concepts))
}
func TestObservationStore_GetRecentObservations(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create multiple observations
for i := 0; i < 10; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Observation " + string(rune('A'+i)),
Narrative: "Content " + string(rune('A'+i)),
Concepts: []string{"test"},
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond) // Ensure different timestamps
}
// Get recent with limit 5
recent, err := obsStore.GetRecentObservations(ctx, "project-a", 5)
require.NoError(t, err)
assert.Len(t, recent, 5)
// Get recent with limit 20 (more than exists)
recent, err = obsStore.GetRecentObservations(ctx, "project-a", 20)
require.NoError(t, err)
assert.Len(t, recent, 10)
}
func TestObservationStore_SearchObservationsFTS(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
// FTS5 tables are created by testObservationStore via testutil.CreateAllTables
ctx := context.Background()
// Create observations with different content
observations := []struct {
title string
narrative string
}{
{"Authentication implementation", "JWT based authentication flow"},
{"Database setup", "PostgreSQL configuration and migrations"},
{"Caching layer", "Redis caching implementation"},
{"User authentication fix", "Fixed authentication bug in login"},
{"API endpoints", "REST API implementation details"},
}
for _, o := range observations {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: o.title,
Narrative: o.narrative,
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
// Search for authentication - should find 2 observations
results, err := obsStore.SearchObservationsFTS(ctx, "authentication", "project-a", 50)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(results), 2, "should find at least 2 authentication-related observations")
// Search for database - should find 1 observation
results, err = obsStore.SearchObservationsFTS(ctx, "database PostgreSQL", "project-a", 50)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(results), 1, "should find at least 1 database-related observation")
}
func TestObservationStore_SearchObservationsFTS_LimitRespected(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
// FTS5 tables are created by testObservationStore via testutil.CreateAllTables
ctx := context.Background()
// Create 20 observations with similar content
for i := 0; i < 20; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Testing observation " + string(rune('A'+i)),
Narrative: "This is about testing and quality assurance " + string(rune('A'+i)),
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
// Search with limit 5
results, err := obsStore.SearchObservationsFTS(ctx, "testing quality", "project-a", 5)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 5, "should respect limit of 5")
// Search with limit 15
results, err = obsStore.SearchObservationsFTS(ctx, "testing quality", "project-a", 15)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 15, "should respect limit of 15")
// Search with limit 50 (our new default)
results, err = obsStore.SearchObservationsFTS(ctx, "testing quality", "project-a", 50)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 50, "should respect limit of 50")
assert.Equal(t, 20, len(results), "should return all 20 matching observations")
}
func TestObservationStore_GlobalScope(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create a project-scoped observation
projectObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project specific code",
Narrative: "This is specific to project-a",
Concepts: []string{"project-specific"},
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", projectObs, 1, 100)
require.NoError(t, err)
// Create a global-scoped observation (has a globalizable concept)
globalObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Security best practice",
Narrative: "Always validate user input",
Concepts: []string{"security", "best-practice"}, // "security" is in GlobalizableConcepts
}
_, _, err = obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 1, 100)
require.NoError(t, err)
// Get recent for project-a - should see both
results, err := obsStore.GetRecentObservations(ctx, "project-a", 10)
require.NoError(t, err)
assert.Len(t, results, 2)
// Get recent for project-b - should only see global observation
results, err = obsStore.GetRecentObservations(ctx, "project-b", 10)
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "Security best practice", results[0].Title.String)
assert.Equal(t, models.ScopeGlobal, results[0].Scope)
}
func TestObservationStore_DeleteObservations(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create observations
var ids []int64
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Observation " + string(rune('A'+i)),
}
id, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
ids = append(ids, id)
}
// Verify all exist
all, err := obsStore.GetRecentObservations(ctx, "project-a", 10)
require.NoError(t, err)
assert.Len(t, all, 5)
// Delete first 3
deleted, err := obsStore.DeleteObservations(ctx, ids[:3])
require.NoError(t, err)
assert.Equal(t, int64(3), deleted)
// Verify only 2 remain
remaining, err := obsStore.GetRecentObservations(ctx, "project-a", 10)
require.NoError(t, err)
assert.Len(t, remaining, 2)
}
func TestObservationStore_GetObservationCount(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create observations for different projects
for i := 0; i < 5; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project A observation " + string(rune('0'+i)),
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
}
for i := 0; i < 3; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Project B observation " + string(rune('0'+i)),
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-b", obs, 1, 100)
require.NoError(t, err)
}
// Create a global observation
globalObs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Global observation",
Concepts: []string{"best-practice"}, // Makes it global
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", globalObs, 1, 100)
require.NoError(t, err)
// Count for project-a includes its own + global
count, err := obsStore.GetObservationCount(ctx, "project-a")
require.NoError(t, err)
assert.Equal(t, 6, count) // 5 project-a + 1 global
// Count for project-b includes its own + global
count, err = obsStore.GetObservationCount(ctx, "project-b")
require.NoError(t, err)
assert.Equal(t, 4, count) // 3 project-b + 1 global
}
func TestObservationStore_CleanupOldObservations(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Create more observations than the limit (MaxObservationsPerProject = 100)
// We'll create a smaller number and verify the logic works
for i := 0; i < 10; i++ {
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Observation " + string(rune('A'+i)),
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, i+1, 100)
require.NoError(t, err)
time.Sleep(time.Millisecond)
}
// Cleanup should return empty since we're under the limit
deletedIDs, err := obsStore.CleanupOldObservations(ctx, "project-a")
require.NoError(t, err)
assert.Empty(t, deletedIDs)
// All 10 should still exist
count, err := obsStore.GetObservationCount(ctx, "project-a")
require.NoError(t, err)
assert.Equal(t, 10, count)
}
func TestObservationStore_SetCleanupFunc(t *testing.T) {
obsStore, _, cleanup := testObservationStore(t)
defer cleanup()
ctx := context.Background()
// Track cleanup calls
var cleanupCalledWith []int64
obsStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
cleanupCalledWith = deletedIDs
})
// Store an observation (should trigger cleanup, but won't delete anything under limit)
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: "Test observation",
}
_, _, err := obsStore.StoreObservation(ctx, "session-1", "project-a", obs, 1, 100)
require.NoError(t, err)
// Cleanup func should not have been called since nothing was deleted
assert.Empty(t, cleanupCalledWith)
}
func TestExtractKeywords(t *testing.T) {
tests := []struct {
query string
expected []string
}{
{
query: "What is the authentication flow?",
expected: []string{"authentication", "flow"},
},
{
query: "How does the database connection work?",
expected: []string{"database", "connection"},
},
{
query: "JWT token validation",
expected: []string{"token", "validation"},
},
{
query: "the a an is are", // All stop words
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.query, func(t *testing.T) {
keywords := extractKeywords(tt.query)
for _, exp := range tt.expected {
assert.Contains(t, keywords, exp, "should contain keyword: "+exp)
}
})
}
}
+241
View File
@@ -0,0 +1,241 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// PromptCleanupFunc is a callback for when prompts are cleaned up.
// Receives the IDs of deleted prompts for downstream cleanup (e.g., vector DB).
type PromptCleanupFunc func(ctx context.Context, deletedIDs []int64)
// MaxPromptsGlobal is the hard limit of prompts across all projects.
const MaxPromptsGlobal = 500
// PromptStore provides user prompt-related database operations.
type PromptStore struct {
store *Store
cleanupFunc PromptCleanupFunc
}
// NewPromptStore creates a new prompt store.
func NewPromptStore(store *Store) *PromptStore {
return &PromptStore{store: store}
}
// SetCleanupFunc sets the callback for when prompts are deleted during cleanup.
func (s *PromptStore) SetCleanupFunc(fn PromptCleanupFunc) {
s.cleanupFunc = fn
}
// SaveUserPromptWithMatches saves a user prompt with matched observation count.
func (s *PromptStore) SaveUserPromptWithMatches(ctx context.Context, claudeSessionID string, promptNumber int, promptText string, matchedObservations int) (int64, error) {
now := time.Now()
const query = `
INSERT INTO user_prompts
(claude_session_id, prompt_number, prompt_text, matched_observations, created_at, created_at_epoch)
VALUES (?, ?, ?, ?, ?, ?)
`
result, err := s.store.ExecContext(ctx, query,
claudeSessionID, promptNumber, promptText, matchedObservations,
now.Format(time.RFC3339), now.UnixMilli(),
)
if err != nil {
return 0, err
}
id, _ := result.LastInsertId()
// Cleanup old prompts beyond the global limit
deletedIDs, _ := s.CleanupOldPrompts(ctx)
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
s.cleanupFunc(ctx, deletedIDs)
}
return id, nil
}
// CleanupOldPrompts deletes prompts beyond the global limit.
// Keeps the most recent MaxPromptsGlobal prompts.
// Returns the IDs of deleted prompts for downstream cleanup (e.g., vector DB).
func (s *PromptStore) CleanupOldPrompts(ctx context.Context) ([]int64, error) {
// First, find IDs that will be deleted
const selectQuery = `
SELECT id FROM user_prompts
WHERE id NOT IN (
SELECT id FROM user_prompts
ORDER BY created_at_epoch DESC
LIMIT ?
)
`
rows, err := s.store.QueryContext(ctx, selectQuery, MaxPromptsGlobal)
if err != nil {
return nil, err
}
defer rows.Close()
var toDelete []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
toDelete = append(toDelete, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(toDelete) == 0 {
return nil, nil
}
// Delete the prompts
const deleteQuery = `
DELETE FROM user_prompts
WHERE id NOT IN (
SELECT id FROM user_prompts
ORDER BY created_at_epoch DESC
LIMIT ?
)
`
_, err = s.store.ExecContext(ctx, deleteQuery, MaxPromptsGlobal)
if err != nil {
return nil, err
}
return toDelete, nil
}
// GetPromptsByIDs retrieves user prompts by a list of IDs.
func (s *PromptStore) GetPromptsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.UserPromptWithSession, error) {
if len(ids) == 0 {
return nil, nil
}
// Build query with placeholders
// #nosec G202 -- query uses parameterized placeholders, not user input
query := `
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
up.created_at, up.created_at_epoch, s.project, s.sdk_session_id
FROM user_prompts up
JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
WHERE up.id IN (?` + repeatPlaceholders(len(ids)-1) + `)
ORDER BY up.created_at_epoch `
if orderBy == "date_asc" {
query += "ASC"
} else {
query += "DESC"
}
if limit > 0 {
query += " LIMIT ?"
}
// Convert []int64 to []interface{}
args := make([]interface{}, len(ids))
for i, id := range ids {
args[i] = id
}
if limit > 0 {
args = append(args, limit)
}
rows, err := s.store.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var prompts []*models.UserPromptWithSession
for rows.Next() {
var prompt models.UserPromptWithSession
if err := rows.Scan(
&prompt.ID, &prompt.ClaudeSessionID, &prompt.PromptNumber, &prompt.PromptText,
&prompt.CreatedAt, &prompt.CreatedAtEpoch, &prompt.Project, &prompt.SDKSessionID,
); err != nil {
return nil, err
}
prompts = append(prompts, &prompt)
}
return prompts, rows.Err()
}
// GetAllRecentUserPrompts retrieves recent user prompts across all sessions.
func (s *PromptStore) GetAllRecentUserPrompts(ctx context.Context, limit int) ([]*models.UserPromptWithSession, error) {
const query = `
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
COALESCE(up.matched_observations, 0) as matched_observations,
up.created_at, up.created_at_epoch,
COALESCE(s.project, '') as project,
COALESCE(s.sdk_session_id, '') as sdk_session_id
FROM user_prompts up
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
ORDER BY up.created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var prompts []*models.UserPromptWithSession
for rows.Next() {
var prompt models.UserPromptWithSession
if err := rows.Scan(
&prompt.ID, &prompt.ClaudeSessionID, &prompt.PromptNumber, &prompt.PromptText,
&prompt.MatchedObservations, &prompt.CreatedAt, &prompt.CreatedAtEpoch,
&prompt.Project, &prompt.SDKSessionID,
); err != nil {
return nil, err
}
prompts = append(prompts, &prompt)
}
return prompts, rows.Err()
}
// GetRecentUserPromptsByProject retrieves recent user prompts for a specific project.
func (s *PromptStore) GetRecentUserPromptsByProject(ctx context.Context, project string, limit int) ([]*models.UserPromptWithSession, error) {
const query = `
SELECT up.id, up.claude_session_id, up.prompt_number, up.prompt_text,
COALESCE(up.matched_observations, 0) as matched_observations,
up.created_at, up.created_at_epoch,
COALESCE(s.project, '') as project,
COALESCE(s.sdk_session_id, '') as sdk_session_id
FROM user_prompts up
LEFT JOIN sdk_sessions s ON up.claude_session_id = s.claude_session_id
WHERE s.project = ?
ORDER BY up.created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, project, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var prompts []*models.UserPromptWithSession
for rows.Next() {
var prompt models.UserPromptWithSession
if err := rows.Scan(
&prompt.ID, &prompt.ClaudeSessionID, &prompt.PromptNumber, &prompt.PromptText,
&prompt.MatchedObservations, &prompt.CreatedAt, &prompt.CreatedAtEpoch,
&prompt.Project, &prompt.SDKSessionID,
); err != nil {
return nil, err
}
prompts = append(prompts, &prompt)
}
return prompts, rows.Err()
}
+196
View File
@@ -0,0 +1,196 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testPromptStore(t *testing.T) (*PromptStore, *Store, func()) {
t.Helper()
db, _, cleanup := testDB(t)
createAllTables(t, db)
store := newStoreFromDB(db)
promptStore := NewPromptStore(store)
return promptStore, store, cleanup
}
func TestPromptStore_SaveUserPromptWithMatches(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create a session first
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Save a prompt
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Help me fix this bug", 5)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
// Verify it was saved
var count int
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts WHERE id = ?", id).Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestPromptStore_GetAllRecentUserPrompts(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Save multiple prompts
for i := 1; i <= 5; i++ {
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt "+string(rune('A'+i-1)), i)
require.NoError(t, err)
time.Sleep(time.Millisecond) // Ensure different timestamps
}
// Get recent prompts
prompts, err := promptStore.GetAllRecentUserPrompts(ctx, 3)
require.NoError(t, err)
assert.Len(t, prompts, 3)
// Should be in descending order (most recent first)
assert.Equal(t, 5, prompts[0].PromptNumber)
}
func TestPromptStore_GetRecentUserPromptsByProject(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create sessions for different projects
seedSession(t, storeDB(store), "claude-1", "sdk-1", "project-a")
seedSession(t, storeDB(store), "claude-2", "sdk-2", "project-b")
// Save prompts for both projects
for i := 1; i <= 3; i++ {
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Project A prompt", 0)
require.NoError(t, err)
}
for i := 1; i <= 2; i++ {
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-2", i, "Project B prompt", 0)
require.NoError(t, err)
}
// Get prompts for project-a
prompts, err := promptStore.GetRecentUserPromptsByProject(ctx, "project-a", 10)
require.NoError(t, err)
assert.Len(t, prompts, 3)
// Get prompts for project-b
prompts, err = promptStore.GetRecentUserPromptsByProject(ctx, "project-b", 10)
require.NoError(t, err)
assert.Len(t, prompts, 2)
}
func TestPromptStore_CleanupOldPrompts(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Save more prompts than the limit
// Note: MaxPromptsGlobal is 500, but we'll test with a smaller number
// by directly calling CleanupOldPrompts
for i := 1; i <= 10; i++ {
_, err := storeDB(store).Exec(`
INSERT INTO user_prompts (claude_session_id, prompt_number, prompt_text, created_at, created_at_epoch)
VALUES (?, ?, ?, datetime('now'), ?)
`, "claude-1", i, "Prompt "+string(rune('A'+i-1)), time.Now().UnixMilli()+int64(i))
require.NoError(t, err)
}
// Verify we have 10 prompts
var count int
err := storeDB(store).QueryRow("SELECT COUNT(*) FROM user_prompts").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 10, count)
// Cleanup should return empty since we're under the limit
deletedIDs, err := promptStore.CleanupOldPrompts(ctx)
require.NoError(t, err)
assert.Empty(t, deletedIDs)
}
func TestPromptStore_SetCleanupFunc(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Track cleanup calls
var cleanupCalledWith []int64
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
cleanupCalledWith = deletedIDs
})
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Save a prompt (should trigger cleanup, but won't delete anything under limit)
_, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", 1, "Test prompt", 0)
require.NoError(t, err)
// Cleanup func should not have been called since nothing was deleted
assert.Empty(t, cleanupCalledWith)
}
func TestPromptStore_GetPromptsByIDs(t *testing.T) {
promptStore, store, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Save some prompts and collect their IDs
var ids []int64
for i := 1; i <= 5; i++ {
id, err := promptStore.SaveUserPromptWithMatches(ctx, "claude-1", i, "Prompt "+string(rune('A'+i-1)), 0)
require.NoError(t, err)
ids = append(ids, id)
time.Sleep(time.Millisecond)
}
// Get specific prompts by ID
prompts, err := promptStore.GetPromptsByIDs(ctx, ids[:3], "date_desc", 10)
require.NoError(t, err)
assert.Len(t, prompts, 3)
// Test with ascending order
prompts, err = promptStore.GetPromptsByIDs(ctx, ids, "date_asc", 2)
require.NoError(t, err)
assert.Len(t, prompts, 2)
assert.Equal(t, 1, prompts[0].PromptNumber)
}
func TestPromptStore_GetPromptsByIDs_EmptyInput(t *testing.T) {
promptStore, _, cleanup := testPromptStore(t)
defer cleanup()
ctx := context.Background()
// Empty IDs should return nil
prompts, err := promptStore.GetPromptsByIDs(ctx, []int64{}, "date_desc", 10)
require.NoError(t, err)
assert.Nil(t, prompts)
}
+184
View File
@@ -0,0 +1,184 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// SessionStore provides session-related database operations.
type SessionStore struct {
store *Store
}
// NewSessionStore creates a new session store.
func NewSessionStore(store *Store) *SessionStore {
return &SessionStore{store: store}
}
// CreateSDKSession creates a new SDK session (idempotent - returns existing ID if exists).
// This is the KEY to how claude-mnemonic stays unified across hooks.
func (s *SessionStore) CreateSDKSession(ctx context.Context, claudeSessionID, project, userPrompt string) (int64, error) {
now := time.Now()
// CRITICAL: INSERT OR IGNORE makes this idempotent
const query = `
INSERT OR IGNORE INTO sdk_sessions
(claude_session_id, sdk_session_id, project, user_prompt, started_at, started_at_epoch, status)
VALUES (?, ?, ?, ?, ?, ?, 'active')
`
result, err := s.store.ExecContext(ctx, query,
claudeSessionID, claudeSessionID, project, userPrompt,
now.Format(time.RFC3339), now.UnixMilli(),
)
if err != nil {
return 0, err
}
// Check if insert happened
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
// Session exists - UPDATE project and user_prompt if we have non-empty values
if project != "" {
const updateQuery = `
UPDATE sdk_sessions
SET project = ?, user_prompt = ?
WHERE claude_session_id = ?
`
_, _ = s.store.ExecContext(ctx, updateQuery, project, userPrompt, claudeSessionID)
}
// Fetch existing ID
var id int64
const selectQuery = `SELECT id FROM sdk_sessions WHERE claude_session_id = ? LIMIT 1`
err := s.store.QueryRowContext(ctx, selectQuery, claudeSessionID).Scan(&id)
return id, err
}
id, _ := result.LastInsertId()
return id, nil
}
// GetSessionByID retrieves a session by its database ID.
func (s *SessionStore) GetSessionByID(ctx context.Context, id int64) (*models.SDKSession, error) {
const query = `
SELECT id, claude_session_id, sdk_session_id, project, user_prompt,
worker_port, prompt_counter, status, started_at, started_at_epoch,
completed_at, completed_at_epoch
FROM sdk_sessions
WHERE id = ?
LIMIT 1
`
var sess models.SDKSession
err := s.store.QueryRowContext(ctx, query, id).Scan(
&sess.ID, &sess.ClaudeSessionID, &sess.SDKSessionID, &sess.Project, &sess.UserPrompt,
&sess.WorkerPort, &sess.PromptCounter, &sess.Status, &sess.StartedAt, &sess.StartedAtEpoch,
&sess.CompletedAt, &sess.CompletedAtEpoch,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &sess, nil
}
// FindAnySDKSession finds any session by Claude session ID (any status).
func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID string) (*models.SDKSession, error) {
const query = `
SELECT id, claude_session_id, sdk_session_id, project, user_prompt,
worker_port, prompt_counter, status, started_at, started_at_epoch,
completed_at, completed_at_epoch
FROM sdk_sessions
WHERE claude_session_id = ?
LIMIT 1
`
var sess models.SDKSession
err := s.store.QueryRowContext(ctx, query, claudeSessionID).Scan(
&sess.ID, &sess.ClaudeSessionID, &sess.SDKSessionID, &sess.Project, &sess.UserPrompt,
&sess.WorkerPort, &sess.PromptCounter, &sess.Status, &sess.StartedAt, &sess.StartedAtEpoch,
&sess.CompletedAt, &sess.CompletedAtEpoch,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &sess, nil
}
// IncrementPromptCounter increments the prompt counter and returns the new value.
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
const updateQuery = `
UPDATE sdk_sessions
SET prompt_counter = COALESCE(prompt_counter, 0) + 1
WHERE id = ?
`
if _, err := s.store.ExecContext(ctx, updateQuery, id); err != nil {
return 0, err
}
const selectQuery = `SELECT prompt_counter FROM sdk_sessions WHERE id = ?`
var counter int
err := s.store.QueryRowContext(ctx, selectQuery, id).Scan(&counter)
return counter, err
}
// GetPromptCounter returns the current prompt counter for a session.
func (s *SessionStore) GetPromptCounter(ctx context.Context, id int64) (int, error) {
const query = `SELECT COALESCE(prompt_counter, 0) FROM sdk_sessions WHERE id = ?`
var counter int
err := s.store.QueryRowContext(ctx, query, id).Scan(&counter)
return counter, err
}
// GetSessionsToday returns the count of sessions started today.
func (s *SessionStore) GetSessionsToday(ctx context.Context) (int, error) {
// Get start of today in milliseconds
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
startEpoch := startOfDay.UnixMilli()
const query = `SELECT COUNT(*) FROM sdk_sessions WHERE started_at_epoch >= ?`
var count int
err := s.store.QueryRowContext(ctx, query, startEpoch).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// GetAllProjects returns all unique project names.
func (s *SessionStore) GetAllProjects(ctx context.Context) ([]string, error) {
const query = `
SELECT DISTINCT project
FROM sdk_sessions
WHERE project IS NOT NULL AND project != ''
ORDER BY project ASC
`
rows, err := s.store.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var projects []string
for rows.Next() {
var project string
if err := rows.Scan(&project); err != nil {
return nil, err
}
projects = append(projects, project)
}
return projects, rows.Err()
}
+218
View File
@@ -0,0 +1,218 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testSessionStore(t *testing.T) (*SessionStore, *Store, func()) {
t.Helper()
db, _, cleanup := testDB(t)
createAllTables(t, db)
store := newStoreFromDB(db)
sessionStore := NewSessionStore(store)
return sessionStore, store, cleanup
}
func TestSessionStore_CreateSDKSession(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Create a new session
id, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "initial prompt")
require.NoError(t, err)
assert.Greater(t, id, int64(0))
// Retrieve and verify
sess, err := sessionStore.GetSessionByID(ctx, id)
require.NoError(t, err)
require.NotNil(t, sess)
assert.Equal(t, "claude-1", sess.ClaudeSessionID)
assert.Equal(t, "test-project", sess.Project)
assert.Equal(t, models.SessionStatusActive, sess.Status)
}
func TestSessionStore_CreateSDKSession_Idempotent(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Create first session
id1, err := sessionStore.CreateSDKSession(ctx, "claude-1", "project-a", "prompt 1")
require.NoError(t, err)
// Create again with same claude_session_id but different project
id2, err := sessionStore.CreateSDKSession(ctx, "claude-1", "project-b", "prompt 2")
require.NoError(t, err)
// Should return same ID (idempotent)
assert.Equal(t, id1, id2)
// Should have updated project to project-b
sess, err := sessionStore.GetSessionByID(ctx, id1)
require.NoError(t, err)
assert.Equal(t, "project-b", sess.Project)
}
func TestSessionStore_FindAnySDKSession(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
require.NoError(t, err)
// Find it
sess, err := sessionStore.FindAnySDKSession(ctx, "claude-1")
require.NoError(t, err)
require.NotNil(t, sess)
assert.Equal(t, "claude-1", sess.ClaudeSessionID)
// Try to find non-existent
sess, err = sessionStore.FindAnySDKSession(ctx, "claude-nonexistent")
require.NoError(t, err)
assert.Nil(t, sess)
}
func TestSessionStore_IncrementPromptCounter(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
id, err := sessionStore.CreateSDKSession(ctx, "claude-1", "test-project", "")
require.NoError(t, err)
// Initial counter should be 0
counter, err := sessionStore.GetPromptCounter(ctx, id)
require.NoError(t, err)
assert.Equal(t, 0, counter)
// Increment
counter, err = sessionStore.IncrementPromptCounter(ctx, id)
require.NoError(t, err)
assert.Equal(t, 1, counter)
// Increment again
counter, err = sessionStore.IncrementPromptCounter(ctx, id)
require.NoError(t, err)
assert.Equal(t, 2, counter)
// Verify via GetPromptCounter
counter, err = sessionStore.GetPromptCounter(ctx, id)
require.NoError(t, err)
assert.Equal(t, 2, counter)
}
func TestSessionStore_GetSessionsToday(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Initially no sessions today
count, err := sessionStore.GetSessionsToday(ctx)
require.NoError(t, err)
assert.Equal(t, 0, count)
// Create some sessions
_, err = sessionStore.CreateSDKSession(ctx, "claude-1", "project-a", "")
require.NoError(t, err)
_, err = sessionStore.CreateSDKSession(ctx, "claude-2", "project-b", "")
require.NoError(t, err)
_, err = sessionStore.CreateSDKSession(ctx, "claude-3", "project-c", "")
require.NoError(t, err)
// Should have 3 sessions today
count, err = sessionStore.GetSessionsToday(ctx)
require.NoError(t, err)
assert.Equal(t, 3, count)
}
func TestSessionStore_GetAllProjects(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Create sessions for different projects
_, err := sessionStore.CreateSDKSession(ctx, "claude-1", "alpha-project", "")
require.NoError(t, err)
_, err = sessionStore.CreateSDKSession(ctx, "claude-2", "beta-project", "")
require.NoError(t, err)
_, err = sessionStore.CreateSDKSession(ctx, "claude-3", "alpha-project", "") // duplicate
require.NoError(t, err)
_, err = sessionStore.CreateSDKSession(ctx, "claude-4", "gamma-project", "")
require.NoError(t, err)
// Get all projects
projects, err := sessionStore.GetAllProjects(ctx)
require.NoError(t, err)
assert.Len(t, projects, 3)
assert.Contains(t, projects, "alpha-project")
assert.Contains(t, projects, "beta-project")
assert.Contains(t, projects, "gamma-project")
// Should be sorted alphabetically
assert.Equal(t, "alpha-project", projects[0])
}
func TestSessionStore_GetSessionByID_NotFound(t *testing.T) {
sessionStore, _, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Non-existent ID should return nil, nil (not an error)
sess, err := sessionStore.GetSessionByID(ctx, 999)
require.NoError(t, err)
assert.Nil(t, sess)
}
func TestSessionStore_SessionFields(t *testing.T) {
sessionStore, store, cleanup := testSessionStore(t)
defer cleanup()
ctx := context.Background()
// Create a session with full details
id, err := sessionStore.CreateSDKSession(ctx, "claude-full", "full-project", "full user prompt")
require.NoError(t, err)
// Manually update additional fields for testing
now := time.Now()
_, err = storeDB(store).Exec(`
UPDATE sdk_sessions
SET worker_port = ?, completed_at = ?, completed_at_epoch = ?, status = 'completed'
WHERE id = ?
`, 37777, now.Format(time.RFC3339), now.UnixMilli(), id)
require.NoError(t, err)
// Retrieve and verify all fields
sess, err := sessionStore.GetSessionByID(ctx, id)
require.NoError(t, err)
require.NotNil(t, sess)
assert.Equal(t, id, sess.ID)
assert.Equal(t, "claude-full", sess.ClaudeSessionID)
assert.Equal(t, "full-project", sess.Project)
assert.Equal(t, models.SessionStatusCompleted, sess.Status)
assert.True(t, sess.WorkerPort.Valid)
assert.Equal(t, int64(37777), sess.WorkerPort.Int64)
assert.True(t, sess.CompletedAt.Valid)
assert.True(t, sess.CompletedAtEpoch.Valid)
}
+134
View File
@@ -0,0 +1,134 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"fmt"
"sync"
_ "github.com/mattn/go-sqlite3"
)
// Store provides database operations with connection pooling and prepared statements.
type Store struct {
db *sql.DB
stmtCache map[string]*sql.Stmt
stmtMu sync.RWMutex
}
// StoreConfig holds configuration for the database store.
type StoreConfig struct {
Path string
MaxConns int
WALMode bool
}
// NewStore creates a new database store with the given configuration.
func NewStore(cfg StoreConfig) (*Store, error) {
// Build connection string with pragmas
connStr := cfg.Path + "?_journal_mode=WAL&_synchronous=NORMAL&_foreign_keys=ON"
db, err := sql.Open("sqlite3", connStr)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
// Configure connection pool
maxConns := cfg.MaxConns
if maxConns <= 0 {
maxConns = 4
}
db.SetMaxOpenConns(maxConns)
db.SetMaxIdleConns(maxConns)
db.SetConnMaxLifetime(0) // Never expire - SQLite connections are cheap
// Verify connection
if err := db.Ping(); err != nil {
_ = db.Close()
return nil, fmt.Errorf("ping database: %w", err)
}
store := &Store{
db: db,
stmtCache: make(map[string]*sql.Stmt),
}
// Run migrations
mgr := NewMigrationManager(db)
if err := mgr.RunMigrations(); err != nil {
_ = db.Close()
return nil, fmt.Errorf("run migrations: %w", err)
}
return store, nil
}
// Close closes the database connection and all cached statements.
func (s *Store) Close() error {
s.stmtMu.Lock()
defer s.stmtMu.Unlock()
for _, stmt := range s.stmtCache {
_ = stmt.Close()
}
s.stmtCache = nil
return s.db.Close()
}
// GetStmt returns a cached prepared statement, creating it if necessary.
func (s *Store) GetStmt(query string) (*sql.Stmt, error) {
s.stmtMu.RLock()
stmt, ok := s.stmtCache[query]
s.stmtMu.RUnlock()
if ok {
return stmt, nil
}
s.stmtMu.Lock()
defer s.stmtMu.Unlock()
// Double-check after acquiring write lock
if stmt, ok := s.stmtCache[query]; ok {
return stmt, nil
}
stmt, err := s.db.Prepare(query)
if err != nil {
return nil, err
}
s.stmtCache[query] = stmt
return stmt, nil
}
// ExecContext executes a query that doesn't return rows.
func (s *Store) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := s.GetStmt(query)
if err != nil {
// Fall back to direct execution
return s.db.ExecContext(ctx, query, args...)
}
return stmt.ExecContext(ctx, args...)
}
// QueryContext executes a query that returns rows.
func (s *Store) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := s.GetStmt(query)
if err != nil {
// Fall back to direct execution
return s.db.QueryContext(ctx, query, args...)
}
return stmt.QueryContext(ctx, args...)
}
// QueryRowContext executes a query that returns a single row.
func (s *Store) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := s.GetStmt(query)
if err != nil {
// Fall back to direct execution
return s.db.QueryRowContext(ctx, query, args...)
}
return stmt.QueryRowContext(ctx, args...)
}
+200
View File
@@ -0,0 +1,200 @@
// Package sqlite provides SQLite database operations for claude-mnemonic.
package sqlite
import (
"context"
"database/sql"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// SummaryStore provides summary-related database operations.
type SummaryStore struct {
store *Store
}
// NewSummaryStore creates a new summary store.
func NewSummaryStore(store *Store) *SummaryStore {
return &SummaryStore{store: store}
}
// StoreSummary stores a new session summary.
func (s *SummaryStore) StoreSummary(ctx context.Context, sdkSessionID, project string, summary *models.ParsedSummary, promptNumber int, discoveryTokens int64) (int64, int64, error) {
now := time.Now()
nowEpoch := now.UnixMilli()
// Ensure session exists (auto-create if missing)
if err := s.ensureSessionExists(ctx, sdkSessionID, project); err != nil {
return 0, 0, err
}
const query = `
INSERT INTO session_summaries
(sdk_session_id, project, request, investigated, learned, completed,
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
result, err := s.store.ExecContext(ctx, query,
sdkSessionID, project,
nullString(summary.Request), nullString(summary.Investigated),
nullString(summary.Learned), nullString(summary.Completed),
nullString(summary.NextSteps), nullString(summary.Notes),
nullInt(promptNumber), discoveryTokens,
now.Format(time.RFC3339), nowEpoch,
)
if err != nil {
return 0, 0, err
}
id, _ := result.LastInsertId()
return id, nowEpoch, nil
}
// ensureSessionExists creates a session if it doesn't exist.
func (s *SummaryStore) ensureSessionExists(ctx context.Context, sdkSessionID, project string) error {
const checkQuery = `SELECT id FROM sdk_sessions WHERE sdk_session_id = ?`
var id int64
err := s.store.QueryRowContext(ctx, checkQuery, sdkSessionID).Scan(&id)
if err == nil {
return nil // Session exists
}
if err != sql.ErrNoRows {
return err
}
// Auto-create session
now := time.Now()
const insertQuery = `
INSERT INTO sdk_sessions
(claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status)
VALUES (?, ?, ?, ?, ?, 'active')
`
_, err = s.store.ExecContext(ctx, insertQuery,
sdkSessionID, sdkSessionID, project,
now.Format(time.RFC3339), now.UnixMilli(),
)
return err
}
// GetSummariesByIDs retrieves summaries by a list of IDs.
func (s *SummaryStore) GetSummariesByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.SessionSummary, error) {
if len(ids) == 0 {
return nil, nil
}
// Build query with placeholders
// #nosec G202 -- query uses parameterized placeholders, not user input
query := `
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries
WHERE id IN (?` + repeatPlaceholders(len(ids)-1) + `)
ORDER BY created_at_epoch `
if orderBy == "date_asc" {
query += "ASC"
} else {
query += "DESC"
}
if limit > 0 {
query += " LIMIT ?"
}
// Convert []int64 to []interface{}
args := make([]interface{}, len(ids))
for i, id := range ids {
args[i] = id
}
if limit > 0 {
args = append(args, limit)
}
rows, err := s.store.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var summaries []*models.SessionSummary
for rows.Next() {
var summary models.SessionSummary
if err := rows.Scan(
&summary.ID, &summary.SDKSessionID, &summary.Project,
&summary.Request, &summary.Investigated, &summary.Learned, &summary.Completed,
&summary.NextSteps, &summary.Notes, &summary.PromptNumber, &summary.DiscoveryTokens,
&summary.CreatedAt, &summary.CreatedAtEpoch,
); err != nil {
return nil, err
}
summaries = append(summaries, &summary)
}
return summaries, rows.Err()
}
// GetRecentSummaries retrieves recent summaries for a project.
func (s *SummaryStore) GetRecentSummaries(ctx context.Context, project string, limit int) ([]*models.SessionSummary, error) {
const query = `
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries
WHERE project = ?
ORDER BY created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, project, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var summaries []*models.SessionSummary
for rows.Next() {
var summary models.SessionSummary
if err := rows.Scan(
&summary.ID, &summary.SDKSessionID, &summary.Project,
&summary.Request, &summary.Investigated, &summary.Learned, &summary.Completed,
&summary.NextSteps, &summary.Notes, &summary.PromptNumber, &summary.DiscoveryTokens,
&summary.CreatedAt, &summary.CreatedAtEpoch,
); err != nil {
return nil, err
}
summaries = append(summaries, &summary)
}
return summaries, rows.Err()
}
// GetAllRecentSummaries retrieves recent summaries across all projects.
func (s *SummaryStore) GetAllRecentSummaries(ctx context.Context, limit int) ([]*models.SessionSummary, error) {
const query = `
SELECT id, sdk_session_id, project, request, investigated, learned, completed,
next_steps, notes, prompt_number, discovery_tokens, created_at, created_at_epoch
FROM session_summaries
ORDER BY created_at_epoch DESC
LIMIT ?
`
rows, err := s.store.QueryContext(ctx, query, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var summaries []*models.SessionSummary
for rows.Next() {
var summary models.SessionSummary
if err := rows.Scan(
&summary.ID, &summary.SDKSessionID, &summary.Project,
&summary.Request, &summary.Investigated, &summary.Learned, &summary.Completed,
&summary.NextSteps, &summary.Notes, &summary.PromptNumber, &summary.DiscoveryTokens,
&summary.CreatedAt, &summary.CreatedAtEpoch,
); err != nil {
return nil, err
}
summaries = append(summaries, &summary)
}
return summaries, rows.Err()
}
+242
View File
@@ -0,0 +1,242 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testSummaryStore(t *testing.T) (*SummaryStore, *Store, func()) {
t.Helper()
db, _, cleanup := testDB(t)
createAllTables(t, db)
store := newStoreFromDB(db)
summaryStore := NewSummaryStore(store)
return summaryStore, store, cleanup
}
func TestSummaryStore_StoreSummary(t *testing.T) {
summaryStore, store, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Create a session first
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
summary := &models.ParsedSummary{
Request: "Add new feature",
Investigated: "Looked at existing code",
Learned: "Found the pattern to follow",
Completed: "Implemented the feature",
NextSteps: "Add tests",
Notes: "Some additional notes",
}
id, epoch, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, 1, 100)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
assert.Greater(t, epoch, int64(0))
// Verify it was saved
var count int
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM session_summaries WHERE id = ?", id).Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestSummaryStore_StoreSummary_AutoCreateSession(t *testing.T) {
summaryStore, store, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Don't create session beforehand - should be auto-created
summary := &models.ParsedSummary{
Request: "Test request",
}
id, _, err := summaryStore.StoreSummary(ctx, "auto-session", "test-project", summary, 1, 0)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
// Verify session was auto-created
var sessionCount int
err = storeDB(store).QueryRow("SELECT COUNT(*) FROM sdk_sessions WHERE sdk_session_id = ?", "auto-session").Scan(&sessionCount)
require.NoError(t, err)
assert.Equal(t, 1, sessionCount)
}
func TestSummaryStore_GetRecentSummaries(t *testing.T) {
summaryStore, store, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Store multiple summaries
for i := 0; i < 5; i++ {
summary := &models.ParsedSummary{
Request: "Request " + string(rune('A'+i)),
}
_, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, i+1, 0)
require.NoError(t, err)
time.Sleep(time.Millisecond) // Ensure different timestamps
}
// Get recent summaries with limit
summaries, err := summaryStore.GetRecentSummaries(ctx, "test-project", 3)
require.NoError(t, err)
assert.Len(t, summaries, 3)
// Should be in descending order
assert.Equal(t, int64(5), summaries[0].PromptNumber.Int64)
}
func TestSummaryStore_GetAllRecentSummaries(t *testing.T) {
summaryStore, store, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Create sessions for different projects
seedSession(t, storeDB(store), "claude-1", "sdk-1", "project-a")
seedSession(t, storeDB(store), "claude-2", "sdk-2", "project-b")
// Store summaries for both projects
for i := 0; i < 3; i++ {
summary := &models.ParsedSummary{Request: "Project A request"}
_, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "project-a", summary, i+1, 0)
require.NoError(t, err)
}
for i := 0; i < 2; i++ {
summary := &models.ParsedSummary{Request: "Project B request"}
_, _, err := summaryStore.StoreSummary(ctx, "sdk-2", "project-b", summary, i+1, 0)
require.NoError(t, err)
}
// Get all summaries (should include both projects)
summaries, err := summaryStore.GetAllRecentSummaries(ctx, 10)
require.NoError(t, err)
assert.Len(t, summaries, 5)
}
func TestSummaryStore_GetSummariesByIDs(t *testing.T) {
summaryStore, store, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Store summaries and collect IDs
var ids []int64
for i := 0; i < 5; i++ {
summary := &models.ParsedSummary{Request: "Request " + string(rune('A'+i))}
id, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, i+1, 0)
require.NoError(t, err)
ids = append(ids, id)
time.Sleep(time.Millisecond)
}
// Get specific summaries by ID
summaries, err := summaryStore.GetSummariesByIDs(ctx, ids[:3], "date_desc", 10)
require.NoError(t, err)
assert.Len(t, summaries, 3)
// Test with ascending order
summaries, err = summaryStore.GetSummariesByIDs(ctx, ids, "date_asc", 2)
require.NoError(t, err)
assert.Len(t, summaries, 2)
assert.Equal(t, int64(1), summaries[0].PromptNumber.Int64)
}
func TestSummaryStore_GetSummariesByIDs_EmptyInput(t *testing.T) {
summaryStore, _, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Empty IDs should return nil
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{}, "date_desc", 10)
require.NoError(t, err)
assert.Nil(t, summaries)
}
func TestSummaryStore_SummaryFields(t *testing.T) {
summaryStore, store, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Store a summary with all fields
summary := &models.ParsedSummary{
Request: "Add authentication",
Investigated: "Reviewed existing auth code",
Learned: "OAuth is preferred",
Completed: "Implemented OAuth flow",
NextSteps: "Add refresh token support",
Notes: "Consider rate limiting",
}
id, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, 5, 1500)
require.NoError(t, err)
// Retrieve and verify all fields
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{id}, "date_desc", 1)
require.NoError(t, err)
require.Len(t, summaries, 1)
s := summaries[0]
assert.Equal(t, id, s.ID)
assert.Equal(t, "sdk-1", s.SDKSessionID)
assert.Equal(t, "test-project", s.Project)
assert.Equal(t, "Add authentication", s.Request.String)
assert.Equal(t, "Reviewed existing auth code", s.Investigated.String)
assert.Equal(t, "OAuth is preferred", s.Learned.String)
assert.Equal(t, "Implemented OAuth flow", s.Completed.String)
assert.Equal(t, "Add refresh token support", s.NextSteps.String)
assert.Equal(t, "Consider rate limiting", s.Notes.String)
assert.Equal(t, int64(5), s.PromptNumber.Int64)
assert.Equal(t, int64(1500), s.DiscoveryTokens)
}
func TestSummaryStore_EmptySummary(t *testing.T) {
summaryStore, store, cleanup := testSummaryStore(t)
defer cleanup()
ctx := context.Background()
// Create a session
seedSession(t, storeDB(store), "claude-1", "sdk-1", "test-project")
// Store an empty summary
summary := &models.ParsedSummary{}
id, _, err := summaryStore.StoreSummary(ctx, "sdk-1", "test-project", summary, 0, 0)
require.NoError(t, err)
assert.Greater(t, id, int64(0))
// Retrieve and verify null fields
summaries, err := summaryStore.GetSummariesByIDs(ctx, []int64{id}, "date_desc", 1)
require.NoError(t, err)
require.Len(t, summaries, 1)
s := summaries[0]
assert.False(t, s.Request.Valid || s.Request.String != "")
assert.False(t, s.Investigated.Valid || s.Investigated.String != "")
assert.False(t, s.Learned.Valid || s.Learned.String != "")
}
+315
View File
@@ -0,0 +1,315 @@
package sqlite
import (
"database/sql"
"os"
"testing"
_ "github.com/mattn/go-sqlite3"
)
// newStoreFromDB creates a Store from an existing database connection for testing.
func newStoreFromDB(db *sql.DB) *Store {
return &Store{
db: db,
stmtCache: make(map[string]*sql.Stmt),
}
}
// storeDB returns the underlying database connection from a store for testing.
func storeDB(s *Store) *sql.DB {
return s.db
}
// testDB creates a temporary SQLite database for testing.
// Returns the database, path, and a cleanup function.
func testDB(t *testing.T) (*sql.DB, string, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "claude-mnemonic-test-*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
dbPath := tmpDir + "/test.db"
connStr := dbPath + "?_journal_mode=WAL&_synchronous=NORMAL&_foreign_keys=ON"
db, err := sql.Open("sqlite3", connStr)
if err != nil {
_ = os.RemoveAll(tmpDir)
t.Fatalf("open database: %v", err)
}
cleanup := func() {
_ = db.Close()
_ = os.RemoveAll(tmpDir)
}
return db, dbPath, cleanup
}
// createBaseTables creates the base tables without FTS5 for unit testing.
func createBaseTables(t *testing.T, db *sql.DB) {
t.Helper()
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS schema_versions (
id INTEGER PRIMARY KEY,
version INTEGER UNIQUE NOT NULL,
applied_at TEXT NOT NULL
)
`)
if err != nil {
t.Fatalf("create schema_versions: %v", err)
}
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS sdk_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
claude_session_id TEXT UNIQUE NOT NULL,
sdk_session_id TEXT UNIQUE,
project TEXT NOT NULL,
user_prompt TEXT,
started_at TEXT NOT NULL,
started_at_epoch INTEGER NOT NULL,
completed_at TEXT,
completed_at_epoch INTEGER,
status TEXT CHECK(status IN ('active', 'completed', 'failed')) NOT NULL DEFAULT 'active',
worker_port INTEGER,
prompt_counter INTEGER DEFAULT 0
)
`)
if err != nil {
t.Fatalf("create sdk_sessions: %v", err)
}
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS observations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sdk_session_id TEXT NOT NULL,
project TEXT NOT NULL,
text TEXT,
type TEXT NOT NULL CHECK(type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change')),
title TEXT,
subtitle TEXT,
facts TEXT,
narrative TEXT,
concepts TEXT,
files_read TEXT,
files_modified TEXT,
file_mtimes TEXT,
scope TEXT DEFAULT 'project' CHECK(scope IN ('project', 'global')),
prompt_number INTEGER,
discovery_tokens INTEGER DEFAULT 0,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
)
`)
if err != nil {
t.Fatalf("create observations: %v", err)
}
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS session_summaries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sdk_session_id TEXT NOT NULL,
project TEXT NOT NULL,
request TEXT,
investigated TEXT,
learned TEXT,
completed TEXT,
next_steps TEXT,
files_read TEXT,
files_edited TEXT,
notes TEXT,
prompt_number INTEGER,
discovery_tokens INTEGER DEFAULT 0,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(sdk_session_id) REFERENCES sdk_sessions(sdk_session_id) ON DELETE CASCADE
)
`)
if err != nil {
t.Fatalf("create session_summaries: %v", err)
}
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_prompts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
claude_session_id TEXT NOT NULL,
prompt_number INTEGER NOT NULL,
prompt_text TEXT NOT NULL,
matched_observations INTEGER DEFAULT 0,
created_at TEXT NOT NULL,
created_at_epoch INTEGER NOT NULL,
FOREIGN KEY(claude_session_id) REFERENCES sdk_sessions(claude_session_id) ON DELETE CASCADE
)
`)
if err != nil {
t.Fatalf("create user_prompts: %v", err)
}
indexes := []string{
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_claude_id ON sdk_sessions(claude_session_id)`,
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_sdk_id ON sdk_sessions(sdk_session_id)`,
`CREATE INDEX IF NOT EXISTS idx_sdk_sessions_project ON sdk_sessions(project)`,
`CREATE INDEX IF NOT EXISTS idx_observations_sdk_session ON observations(sdk_session_id)`,
`CREATE INDEX IF NOT EXISTS idx_observations_project ON observations(project)`,
`CREATE INDEX IF NOT EXISTS idx_observations_scope ON observations(scope)`,
`CREATE INDEX IF NOT EXISTS idx_observations_created ON observations(created_at_epoch DESC)`,
`CREATE INDEX IF NOT EXISTS idx_session_summaries_sdk_session ON session_summaries(sdk_session_id)`,
`CREATE INDEX IF NOT EXISTS idx_session_summaries_project ON session_summaries(project)`,
`CREATE INDEX IF NOT EXISTS idx_user_prompts_claude_session ON user_prompts(claude_session_id)`,
`CREATE INDEX IF NOT EXISTS idx_user_prompts_created ON user_prompts(created_at_epoch DESC)`,
}
for _, idx := range indexes {
if _, err := db.Exec(idx); err != nil {
t.Fatalf("create index: %v", err)
}
}
}
// seedSession creates a test session in the database.
func seedSession(t *testing.T, db *sql.DB, claudeSessionID, sdkSessionID, project string) {
t.Helper()
_, err := db.Exec(`
INSERT INTO sdk_sessions (claude_session_id, sdk_session_id, project, started_at, started_at_epoch, status)
VALUES (?, ?, ?, datetime('now'), strftime('%s', 'now') * 1000, 'active')
`, claudeSessionID, sdkSessionID, project)
if err != nil {
t.Fatalf("seed session: %v", err)
}
}
// hasFTS5 checks if FTS5 is available in the SQLite build.
func hasFTS5(db *sql.DB) bool {
_, err := db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)")
if err != nil {
return false
}
_, _ = db.Exec("DROP TABLE IF EXISTS fts5_test")
return true
}
// createFTSTables creates FTS5 virtual tables and triggers for full-text search.
func createFTSTables(t *testing.T, db *sql.DB) {
t.Helper()
if !hasFTS5(db) {
t.Skip("FTS5 not available in this SQLite build")
}
_, err := db.Exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS observations_fts USING fts5(
title, subtitle, narrative,
content='observations',
content_rowid='id'
)
`)
if err != nil {
t.Fatalf("create observations_fts: %v", err)
}
_, err = db.Exec(`
CREATE TRIGGER IF NOT EXISTS observations_ai AFTER INSERT ON observations BEGIN
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
VALUES (new.id, new.title, new.subtitle, new.narrative);
END
`)
if err != nil {
t.Fatalf("create observations_ai trigger: %v", err)
}
_, err = db.Exec(`
CREATE TRIGGER IF NOT EXISTS observations_ad AFTER DELETE ON observations BEGIN
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
VALUES ('delete', old.id, old.title, old.subtitle, old.narrative);
END
`)
if err != nil {
t.Fatalf("create observations_ad trigger: %v", err)
}
_, err = db.Exec(`
CREATE TRIGGER IF NOT EXISTS observations_au AFTER UPDATE ON observations BEGIN
INSERT INTO observations_fts(observations_fts, rowid, title, subtitle, narrative)
VALUES ('delete', old.id, old.title, old.subtitle, old.narrative);
INSERT INTO observations_fts(rowid, title, subtitle, narrative)
VALUES (new.id, new.title, new.subtitle, new.narrative);
END
`)
if err != nil {
t.Fatalf("create observations_au trigger: %v", err)
}
_, err = db.Exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS session_summaries_fts USING fts5(
request, investigated, learned, completed, next_steps, notes,
content='session_summaries',
content_rowid='id'
)
`)
if err != nil {
t.Fatalf("create session_summaries_fts: %v", err)
}
_, err = db.Exec(`
CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON session_summaries BEGIN
INSERT INTO session_summaries_fts(rowid, request, investigated, learned, completed, next_steps, notes)
VALUES (new.id, new.request, new.investigated, new.learned, new.completed, new.next_steps, new.notes);
END
`)
if err != nil {
t.Fatalf("create summaries_ai trigger: %v", err)
}
_, err = db.Exec(`
CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON session_summaries BEGIN
INSERT INTO session_summaries_fts(session_summaries_fts, rowid, request, investigated, learned, completed, next_steps, notes)
VALUES ('delete', old.id, old.request, old.investigated, old.learned, old.completed, old.next_steps, old.notes);
END
`)
if err != nil {
t.Fatalf("create summaries_ad trigger: %v", err)
}
_, err = db.Exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS user_prompts_fts USING fts5(
prompt_text,
content='user_prompts',
content_rowid='id'
)
`)
if err != nil {
t.Fatalf("create user_prompts_fts: %v", err)
}
_, err = db.Exec(`
CREATE TRIGGER IF NOT EXISTS prompts_ai AFTER INSERT ON user_prompts BEGIN
INSERT INTO user_prompts_fts(rowid, prompt_text)
VALUES (new.id, new.prompt_text);
END
`)
if err != nil {
t.Fatalf("create prompts_ai trigger: %v", err)
}
_, err = db.Exec(`
CREATE TRIGGER IF NOT EXISTS prompts_ad AFTER DELETE ON user_prompts BEGIN
INSERT INTO user_prompts_fts(user_prompts_fts, rowid, prompt_text)
VALUES ('delete', old.id, old.prompt_text);
END
`)
if err != nil {
t.Fatalf("create prompts_ad trigger: %v", err)
}
}
// createAllTables creates all tables including FTS5 for comprehensive testing.
func createAllTables(t *testing.T, db *sql.DB) {
t.Helper()
createBaseTables(t, db)
createFTSTables(t, db)
}
+562
View File
@@ -0,0 +1,562 @@
// Package mcp provides the MCP (Model Context Protocol) server for claude-mnemonic.
package mcp
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"github.com/lukaszraczylo/claude-mnemonic/internal/search"
"github.com/rs/zerolog/log"
)
// Server is the MCP server that exposes search tools.
type Server struct {
searchMgr *search.Manager
version string
stdin io.Reader
stdout io.Writer
}
// NewServer creates a new MCP server.
func NewServer(searchMgr *search.Manager, version string) *Server {
return &Server{
searchMgr: searchMgr,
version: version,
stdin: os.Stdin,
stdout: os.Stdout,
}
}
// Request represents a JSON-RPC request.
type Request struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
// Response represents a JSON-RPC response.
type Response struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Result any `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
}
// Error represents a JSON-RPC error.
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
// ToolCallParams represents parameters for tools/call method.
type ToolCallParams struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
// Tool represents an MCP tool definition.
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]any `json:"inputSchema"`
}
// Run starts the MCP server loop.
func (s *Server) Run(ctx context.Context) error {
scanner := bufio.NewScanner(s.stdin)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
var req Request
if err := json.Unmarshal([]byte(line), &req); err != nil {
s.sendError(nil, -32700, "Parse error", err)
continue
}
resp := s.handleRequest(ctx, &req)
s.sendResponse(resp)
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error: %w", err)
}
return nil
}
// handleRequest dispatches the request to the appropriate handler.
func (s *Server) handleRequest(ctx context.Context, req *Request) *Response {
switch req.Method {
case "initialize":
return s.handleInitialize(req)
case "tools/list":
return s.handleToolsList(req)
case "tools/call":
return s.handleToolsCall(ctx, req)
default:
return &Response{
JSONRPC: "2.0",
ID: req.ID,
Error: &Error{
Code: -32601,
Message: "Method not found",
},
}
}
}
// handleInitialize handles the initialize request.
func (s *Server) handleInitialize(req *Request) *Response {
return &Response{
JSONRPC: "2.0",
ID: req.ID,
Result: map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{
"tools": map[string]any{},
},
"serverInfo": map[string]any{
"name": "claude-mnemonic",
"version": s.version,
},
},
}
}
// handleToolsList returns the list of available tools.
func (s *Server) handleToolsList(req *Request) *Response {
tools := []Tool{
{
Name: "search",
Description: "Unified search across all memory types (observations, sessions, and user prompts) using vector-first semantic search (ChromaDB).",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string", "description": "Natural language search query for semantic ranking"},
"type": map[string]any{"type": "string", "enum": []string{"observations", "sessions", "prompts"}, "description": "Filter by document type"},
"project": map[string]any{"type": "string", "description": "Filter by project name"},
"obs_type": map[string]any{"type": "string", "description": "Filter observations by type"},
"concepts": map[string]any{"type": "string", "description": "Filter by concept tags"},
"files": map[string]any{"type": "string", "description": "Filter by file paths"},
"dateStart": map[string]any{"type": []string{"string", "number"}, "description": "Start date for filtering"},
"dateEnd": map[string]any{"type": []string{"string", "number"}, "description": "End date for filtering"},
"orderBy": map[string]any{"type": "string", "enum": []string{"relevance", "date_desc", "date_asc"}, "default": "date_desc"},
"limit": map[string]any{"type": "number", "default": 20, "minimum": 1, "maximum": 100},
"offset": map[string]any{"type": "number", "default": 0, "minimum": 0},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "timeline",
Description: "Fetch timeline of observations around a specific point in time.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"anchor_id": map[string]any{"type": "number", "description": "Observation ID to use as anchor"},
"query": map[string]any{"type": "string", "description": "Natural language query to find anchor observation"},
"before": map[string]any{"type": "number", "default": 10, "minimum": 0, "maximum": 100},
"after": map[string]any{"type": "number", "default": 10, "minimum": 0, "maximum": 100},
"project": map[string]any{"type": "string"},
"concepts": map[string]any{"type": "string"},
"files": map[string]any{"type": "string"},
"obs_type": map[string]any{"type": "string"},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "decisions",
Description: "Semantic shortcut for finding architectural, design, and implementation decisions.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"query"},
"properties": map[string]any{
"query": map[string]any{"type": "string", "description": "Natural language query for finding decisions"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"limit": map[string]any{"type": "number", "default": 20, "minimum": 1, "maximum": 100},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "changes",
Description: "Semantic shortcut for finding code changes, refactorings, and modifications.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"query"},
"properties": map[string]any{
"query": map[string]any{"type": "string", "description": "Natural language query for finding changes"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"limit": map[string]any{"type": "number", "default": 20, "minimum": 1, "maximum": 100},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "how_it_works",
Description: "Semantic shortcut for understanding system architecture, design patterns, and implementation details.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"query"},
"properties": map[string]any{
"query": map[string]any{"type": "string", "description": "Natural language query for understanding how something works"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"limit": map[string]any{"type": "number", "default": 20, "minimum": 1, "maximum": 100},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "find_by_concept",
Description: "Find observations tagged with specific concepts.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"concepts"},
"properties": map[string]any{
"concepts": map[string]any{"type": "string", "description": "Concept tag(s) to filter by"},
"type": map[string]any{"type": "string"},
"files": map[string]any{"type": "string"},
"project": map[string]any{"type": "string"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"orderBy": map[string]any{"type": "string", "enum": []string{"date_desc", "date_asc"}, "default": "date_desc"},
"limit": map[string]any{"type": "number", "default": 20},
"offset": map[string]any{"type": "number", "default": 0},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "find_by_file",
Description: "Find observations related to specific file paths.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"files"},
"properties": map[string]any{
"files": map[string]any{"type": "string", "description": "File path(s) to filter by"},
"type": map[string]any{"type": "string"},
"concepts": map[string]any{"type": "string"},
"project": map[string]any{"type": "string"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"orderBy": map[string]any{"type": "string", "enum": []string{"date_desc", "date_asc"}, "default": "date_desc"},
"limit": map[string]any{"type": "number", "default": 20},
"offset": map[string]any{"type": "number", "default": 0},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "find_by_type",
Description: "Find observations of specific types.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"type"},
"properties": map[string]any{
"type": map[string]any{"type": "string", "description": "Observation type(s) to filter by"},
"concepts": map[string]any{"type": "string"},
"files": map[string]any{"type": "string"},
"project": map[string]any{"type": "string"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"orderBy": map[string]any{"type": "string", "enum": []string{"date_desc", "date_asc"}, "default": "date_desc"},
"limit": map[string]any{"type": "number", "default": 20},
"offset": map[string]any{"type": "number", "default": 0},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "get_recent_context",
Description: "Get recent session context for timeline display.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"project": map[string]any{"type": "string"},
"type": map[string]any{"type": "string"},
"concepts": map[string]any{"type": "string"},
"files": map[string]any{"type": "string"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"limit": map[string]any{"type": "number", "default": 30, "minimum": 1, "maximum": 100},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "get_context_timeline",
Description: "Get timeline of observations around a specific observation ID.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"anchor_id"},
"properties": map[string]any{
"anchor_id": map[string]any{"type": "number", "description": "Observation ID to use as anchor point"},
"before": map[string]any{"type": "number", "default": 10, "minimum": 0, "maximum": 100},
"after": map[string]any{"type": "number", "default": 10, "minimum": 0, "maximum": 100},
"project": map[string]any{"type": "string"},
"type": map[string]any{"type": "string"},
"concepts": map[string]any{"type": "string"},
"files": map[string]any{"type": "string"},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
{
Name: "get_timeline_by_query",
Description: "Combined search + timeline tool. First searches for observations matching the query, then returns timeline around the best match.",
InputSchema: map[string]any{
"type": "object",
"required": []string{"query"},
"properties": map[string]any{
"query": map[string]any{"type": "string", "description": "Natural language query to find anchor observation"},
"before": map[string]any{"type": "number", "default": 10, "minimum": 0, "maximum": 100},
"after": map[string]any{"type": "number", "default": 10, "minimum": 0, "maximum": 100},
"project": map[string]any{"type": "string"},
"type": map[string]any{"type": "string"},
"concepts": map[string]any{"type": "string"},
"files": map[string]any{"type": "string"},
"dateStart": map[string]any{"type": []string{"string", "number"}},
"dateEnd": map[string]any{"type": []string{"string", "number"}},
"format": map[string]any{"type": "string", "enum": []string{"index", "full"}, "default": "index"},
},
},
},
}
return &Response{
JSONRPC: "2.0",
ID: req.ID,
Result: map[string]any{
"tools": tools,
},
}
}
// handleToolsCall handles tool invocations.
func (s *Server) handleToolsCall(ctx context.Context, req *Request) *Response {
var params ToolCallParams
if err := json.Unmarshal(req.Params, &params); err != nil {
return &Response{
JSONRPC: "2.0",
ID: req.ID,
Error: &Error{
Code: -32602,
Message: "Invalid params",
Data: err.Error(),
},
}
}
result, err := s.callTool(ctx, params.Name, params.Arguments)
if err != nil {
return &Response{
JSONRPC: "2.0",
ID: req.ID,
Error: &Error{
Code: -32000,
Message: "Tool error",
Data: err.Error(),
},
}
}
return &Response{
JSONRPC: "2.0",
ID: req.ID,
Result: map[string]any{
"content": []map[string]any{
{
"type": "text",
"text": result,
},
},
},
}
}
// callTool dispatches to the appropriate tool handler.
func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage) (string, error) {
var params search.SearchParams
if err := json.Unmarshal(args, &params); err != nil {
return "", fmt.Errorf("invalid arguments: %w", err)
}
var result *search.UnifiedSearchResult
var err error
switch name {
case "search":
result, err = s.searchMgr.UnifiedSearch(ctx, params)
case "timeline":
result, err = s.handleTimeline(ctx, args)
case "decisions":
result, err = s.searchMgr.Decisions(ctx, params)
case "changes":
result, err = s.searchMgr.Changes(ctx, params)
case "how_it_works":
result, err = s.searchMgr.HowItWorks(ctx, params)
case "find_by_concept":
params.Type = "observations"
result, err = s.searchMgr.UnifiedSearch(ctx, params)
case "find_by_file":
params.Type = "observations"
result, err = s.searchMgr.UnifiedSearch(ctx, params)
case "find_by_type":
params.Type = "observations"
result, err = s.searchMgr.UnifiedSearch(ctx, params)
case "get_recent_context":
result, err = s.searchMgr.UnifiedSearch(ctx, params)
case "get_context_timeline":
result, err = s.handleTimeline(ctx, args)
case "get_timeline_by_query":
result, err = s.handleTimelineByQuery(ctx, args)
default:
return "", fmt.Errorf("unknown tool: %s", name)
}
if err != nil {
return "", err
}
output, err := json.Marshal(result)
if err != nil {
return "", fmt.Errorf("marshal result: %w", err)
}
return string(output), nil
}
// TimelineParams represents parameters for timeline operations.
type TimelineParams struct {
AnchorID int64 `json:"anchor_id"`
Query string `json:"query"`
Before int `json:"before"`
After int `json:"after"`
Project string `json:"project"`
ObsType string `json:"obs_type"`
Concepts string `json:"concepts"`
Files string `json:"files"`
DateStart int64 `json:"dateStart"`
DateEnd int64 `json:"dateEnd"`
Format string `json:"format"`
}
// handleTimeline handles timeline requests.
func (s *Server) handleTimeline(ctx context.Context, args json.RawMessage) (*search.UnifiedSearchResult, error) {
var params TimelineParams
if err := json.Unmarshal(args, &params); err != nil {
return nil, fmt.Errorf("invalid timeline params: %w", err)
}
if params.Before <= 0 {
params.Before = 10
}
if params.After <= 0 {
params.After = 10
}
// If query provided, first find anchor
if params.Query != "" && params.AnchorID == 0 {
searchParams := search.SearchParams{
Query: params.Query,
Type: "observations",
Project: params.Project,
Limit: 1,
}
result, err := s.searchMgr.UnifiedSearch(ctx, searchParams)
if err != nil {
return nil, err
}
if len(result.Results) > 0 {
params.AnchorID = result.Results[0].ID
}
}
if params.AnchorID == 0 {
return &search.UnifiedSearchResult{Results: []search.SearchResult{}}, nil
}
// Fetch observations around anchor
searchParams := search.SearchParams{
Type: "observations",
Project: params.Project,
ObsType: params.ObsType,
Concepts: params.Concepts,
Files: params.Files,
Limit: params.Before + params.After + 1,
Format: params.Format,
}
return s.searchMgr.UnifiedSearch(ctx, searchParams)
}
// handleTimelineByQuery handles combined search + timeline requests.
func (s *Server) handleTimelineByQuery(ctx context.Context, args json.RawMessage) (*search.UnifiedSearchResult, error) {
var params TimelineParams
if err := json.Unmarshal(args, &params); err != nil {
return nil, fmt.Errorf("invalid timeline params: %w", err)
}
if params.Query == "" {
return nil, fmt.Errorf("query is required")
}
// First search
searchParams := search.SearchParams{
Query: params.Query,
Type: "observations",
Project: params.Project,
DateStart: params.DateStart,
DateEnd: params.DateEnd,
Limit: 1,
}
result, err := s.searchMgr.UnifiedSearch(ctx, searchParams)
if err != nil {
return nil, err
}
if len(result.Results) == 0 {
return result, nil
}
// Now get timeline around that result
params.AnchorID = result.Results[0].ID
return s.handleTimeline(ctx, args)
}
// sendResponse sends a JSON-RPC response.
func (s *Server) sendResponse(resp *Response) {
data, err := json.Marshal(resp)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal response")
return
}
fmt.Fprintln(s.stdout, string(data))
}
// sendError sends a JSON-RPC error response.
func (s *Server) sendError(id any, code int, message string, data any) {
resp := &Response{
JSONRPC: "2.0",
ID: id,
Error: &Error{
Code: code,
Message: message,
Data: data,
},
}
s.sendResponse(resp)
}
+47
View File
@@ -0,0 +1,47 @@
// Package privacy provides privacy tag handling for claude-mnemonic.
package privacy
import (
"regexp"
"strings"
)
var (
// privateTagRegex matches <private>...</private> tags
privateTagRegex = regexp.MustCompile(`(?s)<private>.*?</private>`)
// memoryTagRegex matches <claude-mnemonic-context>...</claude-mnemonic-context> tags
memoryTagRegex = regexp.MustCompile(`(?s)<claude-mnemonic-context>.*?</claude-mnemonic-context>`)
)
// StripPrivateTags removes all <private>...</private> content from text.
func StripPrivateTags(text string) string {
return privateTagRegex.ReplaceAllString(text, "")
}
// StripMemoryTags removes all <claude-mnemonic-context>...</claude-mnemonic-context> content from text.
func StripMemoryTags(text string) string {
return memoryTagRegex.ReplaceAllString(text, "")
}
// StripAllTags removes both private and memory context tags.
func StripAllTags(text string) string {
text = StripPrivateTags(text)
text = StripMemoryTags(text)
return text
}
// IsEntirelyPrivate checks if the text is entirely within <private> tags.
func IsEntirelyPrivate(text string) bool {
stripped := StripPrivateTags(text)
return strings.TrimSpace(stripped) == ""
}
// Clean performs full privacy cleaning on text.
// This is the main function to use before storing any user content.
func Clean(text string) string {
// Strip both types of tags
text = StripAllTags(text)
// Trim whitespace
return strings.TrimSpace(text)
}
+273
View File
@@ -0,0 +1,273 @@
package privacy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStripPrivateTags(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no tags",
input: "Hello world",
expected: "Hello world",
},
{
name: "single private tag",
input: "Hello <private>secret</private> world",
expected: "Hello world",
},
{
name: "multiple private tags",
input: "Hello <private>secret1</private> and <private>secret2</private> world",
expected: "Hello and world",
},
{
name: "nested content in private tag",
input: "Hello <private>secret with\nnewline</private> world",
expected: "Hello world",
},
{
name: "multiline private tag",
input: "Hello <private>\nmultiline\nsecret\n</private> world",
expected: "Hello world",
},
{
name: "empty private tag",
input: "Hello <private></private> world",
expected: "Hello world",
},
{
name: "entirely private",
input: "<private>everything is secret</private>",
expected: "",
},
{
name: "unmatched opening tag",
input: "Hello <private>unclosed",
expected: "Hello <private>unclosed",
},
{
name: "unmatched closing tag",
input: "Hello </private> world",
expected: "Hello </private> world",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StripPrivateTags(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestStripMemoryTags(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no tags",
input: "Hello world",
expected: "Hello world",
},
{
name: "single memory tag",
input: "Hello <claude-mnemonic-context>memory</claude-mnemonic-context> world",
expected: "Hello world",
},
{
name: "multiline memory tag",
input: "Hello <claude-mnemonic-context>\nmemory\ncontent\n</claude-mnemonic-context> world",
expected: "Hello world",
},
{
name: "entirely memory context",
input: "<claude-mnemonic-context>all memory</claude-mnemonic-context>",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StripMemoryTags(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestStripAllTags(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no tags",
input: "Hello world",
expected: "Hello world",
},
{
name: "both tag types",
input: "Hello <private>secret</private> and <claude-mnemonic-context>memory</claude-mnemonic-context> world",
expected: "Hello and world",
},
{
name: "interleaved tags",
input: "A <private>B</private> C <claude-mnemonic-context>D</claude-mnemonic-context> E",
expected: "A C E",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := StripAllTags(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsEntirelyPrivate(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "not private",
input: "Hello world",
expected: false,
},
{
name: "entirely private",
input: "<private>secret</private>",
expected: true,
},
{
name: "entirely private with whitespace",
input: " <private>secret</private> ",
expected: true,
},
{
name: "partially private",
input: "Hello <private>secret</private>",
expected: false,
},
{
name: "multiple private tags covering everything",
input: "<private>a</private><private>b</private>",
expected: true,
},
{
name: "empty string",
input: "",
expected: true, // Empty after stripping means nothing remains
},
{
name: "only whitespace",
input: " ",
expected: true, // Whitespace-only after stripping is empty
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsEntirelyPrivate(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestClean(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "no tags or whitespace",
input: "Hello world",
expected: "Hello world",
},
{
name: "strips private tags and trims",
input: " Hello <private>secret</private> world ",
expected: "Hello world",
},
{
name: "strips memory tags and trims",
input: " Hello <claude-mnemonic-context>memory</claude-mnemonic-context> world ",
expected: "Hello world",
},
{
name: "strips both tag types and trims",
input: "\n Hello <private>secret</private> and <claude-mnemonic-context>memory</claude-mnemonic-context> world \n",
expected: "Hello and world",
},
{
name: "entirely stripped content",
input: " <private>secret</private> ",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := Clean(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// Edge cases and security-related tests
func TestPrivacyEdgeCases(t *testing.T) {
t.Run("nested tags are handled correctly", func(t *testing.T) {
// Inner tag should be stripped as part of outer content
input := "<private>outer <private>inner</private> outer</private>"
result := StripPrivateTags(input)
// The regex is non-greedy, so it matches the first closing tag
assert.Equal(t, " outer</private>", result)
})
t.Run("html-like content is not confused with tags", func(t *testing.T) {
input := "Hello <div>world</div>"
result := StripPrivateTags(input)
assert.Equal(t, "Hello <div>world</div>", result)
})
t.Run("case sensitive tags", func(t *testing.T) {
input := "Hello <PRIVATE>secret</PRIVATE> world"
result := StripPrivateTags(input)
// Should not strip uppercase tags
assert.Equal(t, "Hello <PRIVATE>secret</PRIVATE> world", result)
})
t.Run("special characters in private content", func(t *testing.T) {
input := "Hello <private>secret$%^&*()</private> world"
result := StripPrivateTags(input)
assert.Equal(t, "Hello world", result)
})
t.Run("unicode content", func(t *testing.T) {
input := "Hello <private>秘密 🔒</private> world"
result := StripPrivateTags(input)
assert.Equal(t, "Hello world", result)
})
t.Run("very long private content", func(t *testing.T) {
longSecret := ""
for i := 0; i < 10000; i++ {
longSecret += "x"
}
input := "Hello <private>" + longSecret + "</private> world"
result := StripPrivateTags(input)
assert.Equal(t, "Hello world", result)
})
}
+327
View File
@@ -0,0 +1,327 @@
// Package search provides unified search capabilities for claude-mnemonic.
package search
import (
"context"
"strings"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/chroma"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// Manager provides unified search across SQLite and ChromaDB.
type Manager struct {
observationStore *sqlite.ObservationStore
summaryStore *sqlite.SummaryStore
promptStore *sqlite.PromptStore
chromaClient *chroma.Client
}
// NewManager creates a new search manager.
func NewManager(
observationStore *sqlite.ObservationStore,
summaryStore *sqlite.SummaryStore,
promptStore *sqlite.PromptStore,
chromaClient *chroma.Client,
) *Manager {
return &Manager{
observationStore: observationStore,
summaryStore: summaryStore,
promptStore: promptStore,
chromaClient: chromaClient,
}
}
// SearchParams contains parameters for unified search.
type SearchParams struct {
Query string
Type string // "observations", "sessions", "prompts", or empty for all
Project string
ObsType string // Observation type filter
Concepts string
Files string
DateStart int64
DateEnd int64
OrderBy string // "relevance", "date_desc", "date_asc"
Limit int
Offset int
Format string // "index" or "full"
Scope string // "project", "global", or empty for project+global
IncludeGlobal bool // If true, include global observations along with project-scoped
}
// SearchResult represents a unified search result.
type SearchResult struct {
Type string `json:"type"` // "observation", "session", "prompt"
ID int64 `json:"id"`
Title string `json:"title,omitempty"`
Content string `json:"content,omitempty"`
Project string `json:"project"`
Scope string `json:"scope,omitempty"` // "project" or "global"
CreatedAt int64 `json:"created_at_epoch"`
Score float64 `json:"score,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// UnifiedSearchResult contains the combined search results.
type UnifiedSearchResult struct {
Results []SearchResult `json:"results"`
TotalCount int `json:"total_count"`
Query string `json:"query,omitempty"`
}
// UnifiedSearch performs a unified search across all document types.
func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
if params.Limit <= 0 {
params.Limit = 20
}
if params.Limit > 100 {
params.Limit = 100
}
if params.OrderBy == "" {
params.OrderBy = "date_desc"
}
// If query is provided and Chroma is available, use vector search
if params.Query != "" && m.chromaClient != nil {
return m.vectorSearch(ctx, params)
}
// Otherwise fall back to structured filter search
return m.filterSearch(ctx, params)
}
// vectorSearch performs semantic search via ChromaDB.
func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
// Build where filter
where := make(map[string]interface{})
if params.Project != "" {
where["project"] = params.Project
}
if params.Type == "observations" {
where["doc_type"] = "observation"
} else if params.Type == "sessions" {
where["doc_type"] = "session_summary"
} else if params.Type == "prompts" {
where["doc_type"] = "user_prompt"
}
// Query ChromaDB
chromaResults, err := m.chromaClient.Query(ctx, params.Query, params.Limit*2, where)
if err != nil {
// Fall back to filter search on error
return m.filterSearch(ctx, params)
}
// Collect unique IDs by type
obsIDs := make([]int64, 0)
summaryIDs := make([]int64, 0)
promptIDs := make([]int64, 0)
seenObs := make(map[int64]bool)
seenSummary := make(map[int64]bool)
seenPrompt := make(map[int64]bool)
for _, result := range chromaResults {
sqliteID, ok := result.Metadata["sqlite_id"].(float64)
if !ok {
continue
}
id := int64(sqliteID)
docType, _ := result.Metadata["doc_type"].(string)
switch docType {
case "observation":
if !seenObs[id] {
seenObs[id] = true
obsIDs = append(obsIDs, id)
}
case "session_summary":
if !seenSummary[id] {
seenSummary[id] = true
summaryIDs = append(summaryIDs, id)
}
case "user_prompt":
if !seenPrompt[id] {
seenPrompt[id] = true
promptIDs = append(promptIDs, id)
}
}
}
// Fetch full records from SQLite
var results []SearchResult
if len(obsIDs) > 0 && (params.Type == "" || params.Type == "observations") {
obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0)
if err == nil {
for _, o := range obs {
results = append(results, m.observationToResult(o, params.Format))
}
}
}
if len(summaryIDs) > 0 && (params.Type == "" || params.Type == "sessions") {
summaries, err := m.summaryStore.GetSummariesByIDs(ctx, summaryIDs, params.OrderBy, 0)
if err == nil {
for _, s := range summaries {
results = append(results, m.summaryToResult(s, params.Format))
}
}
}
if len(promptIDs) > 0 && (params.Type == "" || params.Type == "prompts") {
prompts, err := m.promptStore.GetPromptsByIDs(ctx, promptIDs, params.OrderBy, 0)
if err == nil {
for _, p := range prompts {
results = append(results, m.promptToResult(p, params.Format))
}
}
}
// Apply limit
if len(results) > params.Limit {
results = results[:params.Limit]
}
return &UnifiedSearchResult{
Results: results,
TotalCount: len(results),
Query: params.Query,
}, nil
}
// filterSearch performs structured filter search via SQLite.
func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
var results []SearchResult
// Search observations
if params.Type == "" || params.Type == "observations" {
obs, err := m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
if err == nil {
for _, o := range obs {
results = append(results, m.observationToResult(o, params.Format))
}
}
}
// Search summaries
if params.Type == "" || params.Type == "sessions" {
summaries, err := m.summaryStore.GetRecentSummaries(ctx, params.Project, params.Limit)
if err == nil {
for _, s := range summaries {
results = append(results, m.summaryToResult(s, params.Format))
}
}
}
// Apply limit
if len(results) > params.Limit {
results = results[:params.Limit]
}
return &UnifiedSearchResult{
Results: results,
TotalCount: len(results),
}, nil
}
// Decisions performs a semantic search optimized for finding decisions.
func (m *Manager) Decisions(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
// Boost query with decision-related keywords
if params.Query != "" {
params.Query = params.Query + " decision chose architecture"
}
params.Type = "observations"
return m.UnifiedSearch(ctx, params)
}
// Changes performs a semantic search optimized for finding code changes.
func (m *Manager) Changes(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
// Boost query with change-related keywords
if params.Query != "" {
params.Query = params.Query + " changed modified refactored"
}
params.Type = "observations"
return m.UnifiedSearch(ctx, params)
}
// HowItWorks performs a semantic search optimized for understanding architecture.
func (m *Manager) HowItWorks(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
// Boost query with architecture-related keywords
if params.Query != "" {
params.Query = params.Query + " architecture design pattern implements"
}
params.Type = "observations"
return m.UnifiedSearch(ctx, params)
}
// Helper methods
func (m *Manager) observationToResult(obs *models.Observation, format string) SearchResult {
result := SearchResult{
Type: "observation",
ID: obs.ID,
Project: obs.Project,
Scope: string(obs.Scope),
CreatedAt: obs.CreatedAtEpoch,
Metadata: map[string]interface{}{
"obs_type": string(obs.Type),
"scope": string(obs.Scope),
},
}
if obs.Title.Valid {
result.Title = obs.Title.String
}
if format == "full" && obs.Narrative.Valid {
result.Content = obs.Narrative.String
}
return result
}
func (m *Manager) summaryToResult(summary *models.SessionSummary, format string) SearchResult {
result := SearchResult{
Type: "session",
ID: summary.ID,
Project: summary.Project,
CreatedAt: summary.CreatedAtEpoch,
}
if summary.Request.Valid {
result.Title = truncate(summary.Request.String, 100)
}
if format == "full" && summary.Learned.Valid {
result.Content = summary.Learned.String
}
return result
}
func (m *Manager) promptToResult(prompt *models.UserPromptWithSession, format string) SearchResult {
result := SearchResult{
Type: "prompt",
ID: prompt.ID,
Project: prompt.Project,
CreatedAt: prompt.CreatedAtEpoch,
}
result.Title = truncate(prompt.PromptText, 100)
if format == "full" {
result.Content = prompt.PromptText
}
return result
}
func truncate(s string, maxLen int) string {
s = strings.TrimSpace(s)
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
+514
View File
@@ -0,0 +1,514 @@
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
package chroma
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"sync"
"github.com/rs/zerolog/log"
)
// Document represents a document to store in ChromaDB.
type Document struct {
ID string `json:"id"`
Content string `json:"document"`
Metadata map[string]any `json:"metadata"`
}
// QueryResult represents a search result from ChromaDB.
type QueryResult struct {
ID string
Distance float64
Metadata map[string]any
}
// Client is a ChromaDB client that communicates via MCP protocol.
type Client struct {
collection string
dataDir string
pythonVer string
batchSize int
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
mu sync.Mutex
connected bool
requestID int
}
// Config holds configuration for the ChromaDB client.
type Config struct {
Project string
DataDir string
PythonVer string
BatchSize int
}
// NewClient creates a new ChromaDB client.
func NewClient(cfg Config) (*Client, error) {
if cfg.DataDir == "" {
home, _ := os.UserHomeDir()
cfg.DataDir = filepath.Join(home, ".claude-mnemonic", "vector-db")
}
if cfg.PythonVer == "" {
cfg.PythonVer = "3.13"
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100
}
return &Client{
collection: fmt.Sprintf("cm__%s", cfg.Project),
dataDir: cfg.DataDir,
pythonVer: cfg.PythonVer,
batchSize: cfg.BatchSize,
}, nil
}
// Connect starts the ChromaDB MCP server and establishes connection.
func (c *Client) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return nil
}
// Ensure data directory exists
if err := os.MkdirAll(c.dataDir, 0750); err != nil {
return fmt.Errorf("create data dir: %w", err)
}
// Start chroma-mcp server via uvx
c.cmd = exec.CommandContext(ctx, "uvx", // #nosec G204 -- config values from internal settings
"--python", c.pythonVer,
"chroma-mcp",
"--client-type", "persistent",
"--data-dir", c.dataDir,
)
var err error
c.stdin, err = c.cmd.StdinPipe()
if err != nil {
return fmt.Errorf("stdin pipe: %w", err)
}
stdout, err := c.cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("stdout pipe: %w", err)
}
c.stdout = bufio.NewReader(stdout)
c.cmd.Stderr = os.Stderr
if err := c.cmd.Start(); err != nil {
return fmt.Errorf("start chroma-mcp: %w", err)
}
// Send initialize request
if err := c.sendInitialize(); err != nil {
_ = c.Close()
return fmt.Errorf("initialize: %w", err)
}
c.connected = true
log.Info().
Str("collection", c.collection).
Str("dataDir", c.dataDir).
Msg("Connected to ChromaDB")
return nil
}
// sendInitialize sends the MCP initialize request.
func (c *Client) sendInitialize() error {
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{
"name": "claude-mnemonic",
"version": "1.0.0",
},
},
}
if err := c.send(req); err != nil {
return err
}
// Read response
_, err := c.readResponse()
return err
}
// EnsureCollection ensures the collection exists, creating it if needed.
func (c *Client) EnsureCollection(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
// Try to get collection info
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_get_collection_info",
"arguments": map[string]any{
"collection_name": c.collection,
},
},
}
if err := c.send(req); err != nil {
return err
}
resp, err := c.readResponse()
if err != nil {
// Collection doesn't exist, create it
return c.createCollection()
}
// Check if error in response (collection not found)
if _, ok := resp["error"]; ok {
return c.createCollection()
}
return nil
}
// createCollection creates a new collection.
func (c *Client) createCollection() error {
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_create_collection",
"arguments": map[string]any{
"collection_name": c.collection,
"embedding_function_name": "default",
},
},
}
if err := c.send(req); err != nil {
return err
}
_, err := c.readResponse()
if err != nil {
return fmt.Errorf("create collection: %w", err)
}
log.Info().
Str("collection", c.collection).
Msg("Created ChromaDB collection")
return nil
}
// AddDocuments adds documents to the collection in batches.
func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return fmt.Errorf("not connected")
}
for i := 0; i < len(docs); i += c.batchSize {
end := i + c.batchSize
if end > len(docs) {
end = len(docs)
}
batch := docs[i:end]
// Extract fields
documents := make([]string, len(batch))
ids := make([]string, len(batch))
metadatas := make([]map[string]any, len(batch))
for j, doc := range batch {
documents[j] = doc.Content
ids[j] = doc.ID
metadatas[j] = doc.Metadata
}
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_add_documents",
"arguments": map[string]any{
"collection_name": c.collection,
"documents": documents,
"ids": ids,
"metadatas": metadatas,
},
},
}
if err := c.send(req); err != nil {
return fmt.Errorf("send add_documents: %w", err)
}
if _, err := c.readResponse(); err != nil {
return fmt.Errorf("add_documents response: %w", err)
}
log.Debug().
Int("batchStart", i).
Int("batchEnd", end).
Int("total", len(docs)).
Msg("Added document batch")
}
return nil
}
// DeleteDocuments deletes documents from the collection by their IDs.
func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return fmt.Errorf("not connected")
}
if len(ids) == 0 {
return nil
}
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_delete_documents",
"arguments": map[string]any{
"collection_name": c.collection,
"ids": ids,
},
},
}
if err := c.send(req); err != nil {
return fmt.Errorf("send delete_documents: %w", err)
}
if _, err := c.readResponse(); err != nil {
return fmt.Errorf("delete_documents response: %w", err)
}
log.Debug().
Int("count", len(ids)).
Msg("Deleted documents from ChromaDB")
return nil
}
// Query performs a semantic search on the collection.
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil, fmt.Errorf("not connected")
}
args := map[string]any{
"collection_name": c.collection,
"query_texts": []string{query},
"n_results": limit,
"include": []string{"documents", "metadatas", "distances"},
}
if where != nil {
args["where"] = where
}
req := map[string]any{
"jsonrpc": "2.0",
"id": c.nextID(),
"method": "tools/call",
"params": map[string]any{
"name": "chroma_query_documents",
"arguments": args,
},
}
if err := c.send(req); err != nil {
return nil, fmt.Errorf("send query: %w", err)
}
resp, err := c.readResponse()
if err != nil {
return nil, fmt.Errorf("query response: %w", err)
}
return c.parseQueryResults(resp)
}
// parseQueryResults parses the query response into QueryResult structs.
func (c *Client) parseQueryResults(resp map[string]any) ([]QueryResult, error) {
result, ok := resp["result"].(map[string]any)
if !ok {
return nil, nil
}
content, ok := result["content"].([]any)
if !ok || len(content) == 0 {
return nil, nil
}
first, ok := content[0].(map[string]any)
if !ok {
return nil, nil
}
text, ok := first["text"].(string)
if !ok {
return nil, nil
}
var parsed struct {
IDs [][]string `json:"ids"`
Distances [][]float64 `json:"distances"`
Metadatas [][]map[string]any `json:"metadatas"`
}
if err := json.Unmarshal([]byte(text), &parsed); err != nil {
return nil, err
}
if len(parsed.IDs) == 0 || len(parsed.IDs[0]) == 0 {
return nil, nil
}
results := make([]QueryResult, len(parsed.IDs[0]))
for i := range parsed.IDs[0] {
results[i] = QueryResult{
ID: parsed.IDs[0][i],
}
if i < len(parsed.Distances[0]) {
results[i].Distance = parsed.Distances[0][i]
}
if i < len(parsed.Metadatas[0]) {
results[i].Metadata = parsed.Metadatas[0][i]
}
}
return results, nil
}
// send sends a JSON-RPC request to the MCP server.
func (c *Client) send(req map[string]any) error {
data, err := json.Marshal(req)
if err != nil {
return err
}
data = append(data, '\n')
_, err = c.stdin.Write(data)
return err
}
// readResponse reads a JSON-RPC response from the MCP server.
func (c *Client) readResponse() (map[string]any, error) {
line, err := c.stdout.ReadString('\n')
if err != nil {
return nil, err
}
var resp map[string]any
if err := json.Unmarshal([]byte(line), &resp); err != nil {
return nil, err
}
if errObj, ok := resp["error"]; ok {
return nil, fmt.Errorf("MCP error: %v", errObj)
}
return resp, nil
}
// nextID returns the next request ID.
func (c *Client) nextID() int {
c.requestID++
return c.requestID
}
// Close closes the connection to ChromaDB.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil
}
c.connected = false
if c.stdin != nil {
_ = c.stdin.Close()
}
if c.cmd != nil && c.cmd.Process != nil {
_ = c.cmd.Process.Kill()
_ = c.cmd.Wait()
}
log.Info().
Str("collection", c.collection).
Msg("ChromaDB connection closed")
return nil
}
// Reconnect closes the existing connection and establishes a new one.
// This is useful when the vector database directory has been deleted and recreated.
func (c *Client) Reconnect(ctx context.Context) error {
log.Info().
Str("collection", c.collection).
Msg("Reconnecting to ChromaDB...")
// Close existing connection
if err := c.Close(); err != nil {
log.Warn().Err(err).Msg("Error closing ChromaDB during reconnect")
}
// Small delay to allow cleanup
// (ChromaDB may need a moment to release resources)
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Reconnect
if err := c.Connect(ctx); err != nil {
return fmt.Errorf("reconnect failed: %w", err)
}
// Ensure collection exists
if err := c.EnsureCollection(ctx); err != nil {
return fmt.Errorf("ensure collection after reconnect: %w", err)
}
log.Info().
Str("collection", c.collection).
Msg("ChromaDB reconnected successfully")
return nil
}
+276
View File
@@ -0,0 +1,276 @@
// Package chroma provides ChromaDB vector database integration for claude-mnemonic.
package chroma
import (
"context"
"fmt"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// Sync provides synchronization between SQLite and ChromaDB.
type Sync struct {
client *Client
}
// NewSync creates a new ChromaDB sync service.
func NewSync(client *Client) *Sync {
return &Sync{client: client}
}
// SyncObservation syncs a single observation to ChromaDB.
func (s *Sync) SyncObservation(ctx context.Context, obs *models.Observation) error {
docs := s.formatObservationDocs(obs)
if len(docs) == 0 {
return nil
}
if err := s.client.AddDocuments(ctx, docs); err != nil {
return fmt.Errorf("add observation docs: %w", err)
}
log.Debug().
Int64("observationId", obs.ID).
Int("docCount", len(docs)).
Msg("Synced observation to ChromaDB")
return nil
}
// formatObservationDocs formats an observation into ChromaDB documents.
// Each semantic field becomes a separate vector document (granular approach).
func (s *Sync) formatObservationDocs(obs *models.Observation) []Document {
docs := make([]Document, 0, len(obs.Facts)+2)
// Determine scope for metadata
scope := string(obs.Scope)
if scope == "" {
scope = "project"
}
baseMetadata := map[string]any{
"sqlite_id": obs.ID,
"doc_type": "observation",
"sdk_session_id": obs.SDKSessionID,
"project": obs.Project,
"scope": scope,
"created_at_epoch": obs.CreatedAtEpoch,
"type": string(obs.Type),
}
if obs.Title.Valid {
baseMetadata["title"] = obs.Title.String
}
if obs.Subtitle.Valid {
baseMetadata["subtitle"] = obs.Subtitle.String
}
if len(obs.Concepts) > 0 {
baseMetadata["concepts"] = joinStrings(obs.Concepts, ",")
}
if len(obs.FilesRead) > 0 {
baseMetadata["files_read"] = joinStrings(obs.FilesRead, ",")
}
if len(obs.FilesModified) > 0 {
baseMetadata["files_modified"] = joinStrings(obs.FilesModified, ",")
}
// Narrative as separate document
if obs.Narrative.Valid && obs.Narrative.String != "" {
docs = append(docs, Document{
ID: fmt.Sprintf("obs_%d_narrative", obs.ID),
Content: obs.Narrative.String,
Metadata: copyMetadata(baseMetadata, "field_type", "narrative"),
})
}
// Each fact as separate document
for i, fact := range obs.Facts {
docs = append(docs, Document{
ID: fmt.Sprintf("obs_%d_fact_%d", obs.ID, i),
Content: fact,
Metadata: copyMetadataMulti(baseMetadata, map[string]any{
"field_type": "fact",
"fact_index": i,
}),
})
}
return docs
}
// SyncSummary syncs a single session summary to ChromaDB.
func (s *Sync) SyncSummary(ctx context.Context, summary *models.SessionSummary) error {
docs := s.formatSummaryDocs(summary)
if len(docs) == 0 {
return nil
}
if err := s.client.AddDocuments(ctx, docs); err != nil {
return fmt.Errorf("add summary docs: %w", err)
}
log.Debug().
Int64("summaryId", summary.ID).
Int("docCount", len(docs)).
Msg("Synced summary to ChromaDB")
return nil
}
// formatSummaryDocs formats a session summary into ChromaDB documents.
func (s *Sync) formatSummaryDocs(summary *models.SessionSummary) []Document {
docs := make([]Document, 0, 6)
baseMetadata := map[string]any{
"sqlite_id": summary.ID,
"doc_type": "session_summary",
"sdk_session_id": summary.SDKSessionID,
"project": summary.Project,
"created_at_epoch": summary.CreatedAtEpoch,
}
if summary.PromptNumber.Valid {
baseMetadata["prompt_number"] = summary.PromptNumber.Int64
}
// Each field as separate document
fields := []struct {
name string
value string
valid bool
}{
{"request", summary.Request.String, summary.Request.Valid},
{"investigated", summary.Investigated.String, summary.Investigated.Valid},
{"learned", summary.Learned.String, summary.Learned.Valid},
{"completed", summary.Completed.String, summary.Completed.Valid},
{"next_steps", summary.NextSteps.String, summary.NextSteps.Valid},
{"notes", summary.Notes.String, summary.Notes.Valid},
}
for _, field := range fields {
if field.valid && field.value != "" {
docs = append(docs, Document{
ID: fmt.Sprintf("summary_%d_%s", summary.ID, field.name),
Content: field.value,
Metadata: copyMetadata(baseMetadata, "field_type", field.name),
})
}
}
return docs
}
// SyncUserPrompt syncs a single user prompt to ChromaDB.
func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWithSession) error {
doc := Document{
ID: fmt.Sprintf("prompt_%d", prompt.ID),
Content: prompt.PromptText,
Metadata: map[string]any{
"sqlite_id": prompt.ID,
"doc_type": "user_prompt",
"sdk_session_id": prompt.SDKSessionID,
"project": prompt.Project,
"created_at_epoch": prompt.CreatedAtEpoch,
"prompt_number": prompt.PromptNumber,
},
}
if err := s.client.AddDocuments(ctx, []Document{doc}); err != nil {
return fmt.Errorf("add prompt doc: %w", err)
}
log.Debug().
Int64("promptId", prompt.ID).
Msg("Synced user prompt to ChromaDB")
return nil
}
// DeleteObservations removes observation documents from ChromaDB.
// Since each observation may have multiple documents (narrative + facts),
// we delete by the sqlite_id metadata prefix pattern.
func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) error {
if len(observationIDs) == 0 {
return nil
}
// Generate all possible document IDs for these observations
// Pattern: obs_{id}_narrative, obs_{id}_fact_{0..n}
// Since we don't know how many facts each had, we use a reasonable upper bound
const maxFactsPerObs = 20
ids := make([]string, 0, len(observationIDs)*(maxFactsPerObs+1))
for _, obsID := range observationIDs {
ids = append(ids, fmt.Sprintf("obs_%d_narrative", obsID))
for i := 0; i < maxFactsPerObs; i++ {
ids = append(ids, fmt.Sprintf("obs_%d_fact_%d", obsID, i))
}
}
if err := s.client.DeleteDocuments(ctx, ids); err != nil {
return fmt.Errorf("delete observation docs: %w", err)
}
log.Debug().
Int("observationCount", len(observationIDs)).
Msg("Deleted observations from ChromaDB")
return nil
}
// DeleteUserPrompts removes user prompt documents from ChromaDB.
func (s *Sync) DeleteUserPrompts(ctx context.Context, promptIDs []int64) error {
if len(promptIDs) == 0 {
return nil
}
// Each prompt is stored as a single document with ID pattern: prompt_{id}
ids := make([]string, len(promptIDs))
for i, promptID := range promptIDs {
ids[i] = fmt.Sprintf("prompt_%d", promptID)
}
if err := s.client.DeleteDocuments(ctx, ids); err != nil {
return fmt.Errorf("delete prompt docs: %w", err)
}
log.Debug().
Int("promptCount", len(promptIDs)).
Msg("Deleted user prompts from ChromaDB")
return nil
}
// Helper functions
func copyMetadata(base map[string]any, key string, value any) map[string]any {
result := make(map[string]any, len(base)+1)
for k, v := range base {
result[k] = v
}
result[key] = value
return result
}
func copyMetadataMulti(base map[string]any, extra map[string]any) map[string]any {
result := make(map[string]any, len(base)+len(extra))
for k, v := range base {
result[k] = v
}
for k, v := range extra {
result[k] = v
}
return result
}
func joinStrings(strs []string, sep string) string {
if len(strs) == 0 {
return ""
}
result := strs[0]
for i := 1; i < len(strs); i++ {
result += sep + strs[i]
}
return result
}
+350
View File
@@ -0,0 +1,350 @@
package chroma
import (
"database/sql"
"fmt"
"testing"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
)
// testSync creates a Sync with a nil client for testing format functions.
func testSync() *Sync {
return &Sync{client: nil}
}
func TestSync_FormatObservationDocs(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 1,
SDKSessionID: "test-session",
Project: "test-project",
Scope: models.ScopeProject,
Type: models.ObsTypeDiscovery,
Title: sql.NullString{String: "Test Title", Valid: true},
Subtitle: sql.NullString{String: "Test Subtitle", Valid: true},
Narrative: sql.NullString{String: "Test narrative content", Valid: true},
Facts: models.JSONStringArray{"Fact 1", "Fact 2", "Fact 3"},
Concepts: models.JSONStringArray{"concept1", "concept2"},
FilesRead: models.JSONStringArray{"file1.go", "file2.go"},
FilesModified: models.JSONStringArray{"file3.go"},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Should have 1 narrative + 3 facts = 4 documents
assert.Len(t, docs, 4)
// Check narrative document
narrativeDoc := docs[0]
assert.Equal(t, "obs_1_narrative", narrativeDoc.ID)
assert.Equal(t, "Test narrative content", narrativeDoc.Content)
assert.Equal(t, int64(1), narrativeDoc.Metadata["sqlite_id"])
assert.Equal(t, "observation", narrativeDoc.Metadata["doc_type"])
assert.Equal(t, "narrative", narrativeDoc.Metadata["field_type"])
assert.Equal(t, "test-project", narrativeDoc.Metadata["project"])
assert.Equal(t, "project", narrativeDoc.Metadata["scope"])
assert.Equal(t, "Test Title", narrativeDoc.Metadata["title"])
assert.Equal(t, "Test Subtitle", narrativeDoc.Metadata["subtitle"])
// Check fact documents
for i := 1; i <= 3; i++ {
factDoc := docs[i]
assert.Equal(t, fmt.Sprintf("obs_1_fact_%d", i-1), factDoc.ID)
assert.Equal(t, fmt.Sprintf("Fact %d", i), factDoc.Content)
assert.Equal(t, "fact", factDoc.Metadata["field_type"])
assert.Equal(t, i-1, factDoc.Metadata["fact_index"])
}
}
func TestSync_FormatObservationDocs_NoNarrative(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 2,
SDKSessionID: "test-session",
Project: "test-project",
Scope: models.ScopeGlobal,
Type: models.ObsTypeBugfix,
Facts: models.JSONStringArray{"Only fact"},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Should have 1 fact only (no narrative)
assert.Len(t, docs, 1)
assert.Equal(t, "obs_2_fact_0", docs[0].ID)
assert.Equal(t, "Only fact", docs[0].Content)
assert.Equal(t, "global", docs[0].Metadata["scope"])
}
func TestSync_FormatObservationDocs_Empty(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 3,
SDKSessionID: "test-session",
Project: "test-project",
Type: models.ObsTypeDiscovery,
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Should have no documents when no content
assert.Len(t, docs, 0)
}
func TestSync_FormatObservationDocs_EmptyScope(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 4,
SDKSessionID: "test-session",
Project: "test-project",
Scope: "", // Empty scope
Type: models.ObsTypeDiscovery,
Narrative: sql.NullString{String: "Content", Valid: true},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
// Empty scope should default to "project"
assert.Len(t, docs, 1)
assert.Equal(t, "project", docs[0].Metadata["scope"])
}
func TestSync_FormatSummaryDocs(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 1,
SDKSessionID: "test-session",
Project: "test-project",
Request: sql.NullString{String: "Add feature", Valid: true},
Investigated: sql.NullString{String: "Looked at code", Valid: true},
Learned: sql.NullString{String: "Found pattern", Valid: true},
Completed: sql.NullString{String: "Done", Valid: true},
NextSteps: sql.NullString{String: "Test it", Valid: true},
Notes: sql.NullString{String: "Notes here", Valid: true},
PromptNumber: sql.NullInt64{Int64: 5, Valid: true},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Should have 6 documents (all fields present)
assert.Len(t, docs, 6)
// Check first document
assert.Equal(t, "summary_1_request", docs[0].ID)
assert.Equal(t, "Add feature", docs[0].Content)
assert.Equal(t, "session_summary", docs[0].Metadata["doc_type"])
assert.Equal(t, "request", docs[0].Metadata["field_type"])
assert.Equal(t, int64(5), docs[0].Metadata["prompt_number"])
}
func TestSync_FormatSummaryDocs_PartialFields(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 2,
SDKSessionID: "test-session",
Project: "test-project",
Request: sql.NullString{String: "Only request", Valid: true},
Completed: sql.NullString{String: "Only completed", Valid: true},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Should have 2 documents (only valid fields)
assert.Len(t, docs, 2)
// Verify field types
fieldTypes := make([]string, len(docs))
for i, doc := range docs {
fieldTypes[i] = doc.Metadata["field_type"].(string)
}
assert.Contains(t, fieldTypes, "request")
assert.Contains(t, fieldTypes, "completed")
}
func TestSync_FormatSummaryDocs_Empty(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 3,
SDKSessionID: "test-session",
Project: "test-project",
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Should have no documents when no content
assert.Len(t, docs, 0)
}
func TestSync_FormatSummaryDocs_EmptyStrings(t *testing.T) {
sync := testSync()
summary := &models.SessionSummary{
ID: 4,
SDKSessionID: "test-session",
Project: "test-project",
Request: sql.NullString{String: "", Valid: true}, // Valid but empty
CreatedAtEpoch: 1234567890,
}
docs := sync.formatSummaryDocs(summary)
// Empty strings should not produce documents
assert.Len(t, docs, 0)
}
// Test helper functions
func TestJoinStrings(t *testing.T) {
tests := []struct {
name string
strs []string
sep string
expected string
}{
{"empty", []string{}, ",", ""},
{"single", []string{"a"}, ",", "a"},
{"multiple", []string{"a", "b", "c"}, ",", "a,b,c"},
{"different sep", []string{"a", "b"}, "-", "a-b"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := joinStrings(tt.strs, tt.sep)
assert.Equal(t, tt.expected, result)
})
}
}
func TestCopyMetadata(t *testing.T) {
base := map[string]any{
"key1": "value1",
"key2": 42,
}
result := copyMetadata(base, "key3", "value3")
// Original should be unchanged
assert.Len(t, base, 2)
// Result should have all keys
assert.Len(t, result, 3)
assert.Equal(t, "value1", result["key1"])
assert.Equal(t, 42, result["key2"])
assert.Equal(t, "value3", result["key3"])
}
func TestCopyMetadataMulti(t *testing.T) {
base := map[string]any{
"key1": "value1",
}
extra := map[string]any{
"key2": "value2",
"key3": "value3",
}
result := copyMetadataMulti(base, extra)
// Original should be unchanged
assert.Len(t, base, 1)
// Result should have all keys
assert.Len(t, result, 3)
assert.Equal(t, "value1", result["key1"])
assert.Equal(t, "value2", result["key2"])
assert.Equal(t, "value3", result["key3"])
}
// Test ID generation patterns for delete operations
func TestSync_DeleteObservationIDGeneration(t *testing.T) {
// Test that we generate correct document IDs for deletion
obsIDs := []int64{1, 2}
maxFactsPerObs := 20
ids := make([]string, 0, len(obsIDs)*(maxFactsPerObs+1))
for _, obsID := range obsIDs {
ids = append(ids, fmt.Sprintf("obs_%d_narrative", obsID))
for i := 0; i < maxFactsPerObs; i++ {
ids = append(ids, fmt.Sprintf("obs_%d_fact_%d", obsID, i))
}
}
// Each observation should generate 21 IDs (1 narrative + 20 facts)
assert.Len(t, ids, 42)
// Check some expected IDs
assert.Contains(t, ids, "obs_1_narrative")
assert.Contains(t, ids, "obs_1_fact_0")
assert.Contains(t, ids, "obs_1_fact_19")
assert.Contains(t, ids, "obs_2_narrative")
assert.Contains(t, ids, "obs_2_fact_0")
}
func TestSync_DeletePromptIDGeneration(t *testing.T) {
// Test that we generate correct document IDs for prompt deletion
promptIDs := []int64{10, 20, 30}
ids := make([]string, len(promptIDs))
for i, promptID := range promptIDs {
ids[i] = fmt.Sprintf("prompt_%d", promptID)
}
assert.Len(t, ids, 3)
assert.Contains(t, ids, "prompt_10")
assert.Contains(t, ids, "prompt_20")
assert.Contains(t, ids, "prompt_30")
}
// Test metadata includes all expected fields
func TestSync_ObservationMetadataFields(t *testing.T) {
sync := testSync()
obs := &models.Observation{
ID: 1,
SDKSessionID: "sdk-123",
Project: "my-project",
Scope: models.ScopeGlobal,
Type: models.ObsTypeBugfix,
Title: sql.NullString{String: "Bug Fix", Valid: true},
Subtitle: sql.NullString{String: "Memory leak", Valid: true},
Narrative: sql.NullString{String: "Fixed the leak", Valid: true},
Concepts: models.JSONStringArray{"memory", "performance"},
FilesRead: models.JSONStringArray{"main.go"},
FilesModified: models.JSONStringArray{"fix.go"},
CreatedAtEpoch: 1234567890,
}
docs := sync.formatObservationDocs(obs)
require := assert.New(t)
require.Len(docs, 1) // Only narrative, no facts
meta := docs[0].Metadata
require.Equal(int64(1), meta["sqlite_id"])
require.Equal("observation", meta["doc_type"])
require.Equal("sdk-123", meta["sdk_session_id"])
require.Equal("my-project", meta["project"])
require.Equal("global", meta["scope"])
require.Equal("bugfix", meta["type"])
require.Equal("Bug Fix", meta["title"])
require.Equal("Memory leak", meta["subtitle"])
require.Equal("memory,performance", meta["concepts"])
require.Equal("main.go", meta["files_read"])
require.Equal("fix.go", meta["files_modified"])
require.Equal(int64(1234567890), meta["created_at_epoch"])
require.Equal("narrative", meta["field_type"])
}
+187
View File
@@ -0,0 +1,187 @@
// Package watcher provides file system watching utilities for detecting
// database file/directory deletions and triggering recreation.
package watcher
import (
"context"
"os"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog/log"
)
// Watcher monitors a file or directory for deletion and calls onDelete when removed.
// It watches the parent directory since fsnotify cannot watch non-existent files.
type Watcher struct {
targetPath string // The file/directory to watch for deletion
parentPath string // Parent directory (what we actually watch)
onDelete func() // Callback when target is deleted
watcher *fsnotify.Watcher
ctx context.Context
cancel context.CancelFunc
mu sync.Mutex
running bool
debounce time.Duration
}
// New creates a new Watcher for the given target path.
// The onDelete callback is called when the target is deleted.
func New(targetPath string, onDelete func()) (*Watcher, error) {
fsw, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
return &Watcher{
targetPath: targetPath,
parentPath: filepath.Dir(targetPath),
onDelete: onDelete,
watcher: fsw,
ctx: ctx,
cancel: cancel,
debounce: 100 * time.Millisecond,
}, nil
}
// Start begins watching for file deletion events.
func (w *Watcher) Start() error {
w.mu.Lock()
if w.running {
w.mu.Unlock()
return nil
}
w.running = true
w.mu.Unlock()
// Add watch on parent directory
if err := w.addWatch(); err != nil {
log.Warn().Err(err).Str("path", w.parentPath).Msg("Failed to add initial watch")
// Continue anyway - we'll try to re-establish later
}
go w.watchLoop()
return nil
}
// Stop stops the watcher.
func (w *Watcher) Stop() error {
w.mu.Lock()
defer w.mu.Unlock()
if !w.running {
return nil
}
w.running = false
w.cancel()
return w.watcher.Close()
}
// addWatch adds the parent directory to the watch list.
func (w *Watcher) addWatch() error {
// Ensure parent exists
if _, err := os.Stat(w.parentPath); os.IsNotExist(err) {
return err
}
return w.watcher.Add(w.parentPath)
}
// watchLoop is the main event loop.
func (w *Watcher) watchLoop() {
var (
debounceTimer *time.Timer
pendingDelete bool
)
for {
select {
case <-w.ctx.Done():
if debounceTimer != nil {
debounceTimer.Stop()
}
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
// Check if this event is for our target
eventPath := filepath.Clean(event.Name)
targetPath := filepath.Clean(w.targetPath)
// Handle parent directory deletion (entire data dir removed)
if eventPath == w.parentPath && event.Op&fsnotify.Remove != 0 {
log.Info().Str("path", w.parentPath).Msg("Parent directory deleted")
pendingDelete = true
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(w.debounce, func() {
w.handleDeletion()
})
continue
}
// Handle target file/directory deletion
if eventPath == targetPath && event.Op&fsnotify.Remove != 0 {
log.Info().Str("path", w.targetPath).Msg("Target deleted")
pendingDelete = true
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(w.debounce, func() {
w.handleDeletion()
})
continue
}
// Handle parent directory recreation (re-establish watch)
if eventPath == w.parentPath && event.Op&fsnotify.Create != 0 {
log.Info().Str("path", w.parentPath).Msg("Parent directory recreated, re-establishing watch")
_ = w.addWatch()
continue
}
// If target was recreated after pending delete, cancel the callback
if pendingDelete && eventPath == targetPath && event.Op&fsnotify.Create != 0 {
log.Info().Str("path", w.targetPath).Msg("Target recreated, cancelling deletion callback")
pendingDelete = false
if debounceTimer != nil {
debounceTimer.Stop()
}
}
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Error().Err(err).Msg("Watcher error")
}
}
}
// handleDeletion calls the onDelete callback and attempts to re-establish the watch.
func (w *Watcher) handleDeletion() {
log.Info().Str("path", w.targetPath).Msg("Triggering deletion callback")
// Call the callback
if w.onDelete != nil {
w.onDelete()
}
// Try to re-establish watch after a short delay (parent may have been recreated)
go func() {
time.Sleep(500 * time.Millisecond)
if err := w.addWatch(); err != nil {
log.Warn().Err(err).Str("path", w.parentPath).Msg("Failed to re-establish watch after deletion")
} else {
log.Info().Str("path", w.parentPath).Msg("Re-established watch after recreation")
}
}()
}
+14
View File
@@ -0,0 +1,14 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
)
// clusterObservations groups similar observations and returns only one representative per cluster.
// Uses Jaccard similarity on extracted terms from title, narrative, and facts.
// Delegates to pkg/similarity for the actual clustering logic.
func clusterObservations(observations []*models.Observation, similarityThreshold float64) []*models.Observation {
return similarity.ClusterObservations(observations, similarityThreshold)
}
+705
View File
@@ -0,0 +1,705 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// Handler configuration constants
const (
// DefaultObservationsLimit is the default number of observations to return.
DefaultObservationsLimit = 100
// DefaultSummariesLimit is the default number of summaries to return.
DefaultSummariesLimit = 50
// DefaultPromptsLimit is the default number of prompts to return.
DefaultPromptsLimit = 100
// DefaultSearchLimit is the default number of search results to return.
DefaultSearchLimit = 50
// DefaultContextLimit is the default number of context observations to return.
DefaultContextLimit = 50
)
// writeJSON writes a JSON response with proper error handling.
func writeJSON(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(data); err != nil {
log.Error().Err(err).Msg("Failed to encode JSON response")
}
}
// handleHealth handles health check requests.
// Returns 200 OK immediately (even during init) so hooks can connect quickly.
// Use /api/ready for full readiness check.
func (s *Service) handleHealth(w http.ResponseWriter, r *http.Request) {
status := "starting"
if s.ready.Load() {
status = "ready"
} else if err := s.GetInitError(); err != nil {
status = "error"
}
writeJSON(w, map[string]interface{}{
"status": status,
"version": s.version,
})
}
// handleVersion returns the worker version for version checking.
func (s *Service) handleVersion(w http.ResponseWriter, r *http.Request) {
writeJSON(w, map[string]string{
"version": s.version,
})
}
// handleReady handles readiness check requests.
// Returns 200 only when fully initialized, 503 otherwise.
func (s *Service) handleReady(w http.ResponseWriter, r *http.Request) {
if !s.ready.Load() {
if err := s.GetInitError(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Error(w, "service initializing", http.StatusServiceUnavailable)
return
}
writeJSON(w, map[string]string{"status": "ready"})
}
// requireReady is middleware that returns 503 if service isn't ready.
func (s *Service) requireReady(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.ready.Load() {
if err := s.GetInitError(); err != nil {
http.Error(w, "service initialization failed: "+err.Error(), http.StatusInternalServerError)
return
}
http.Error(w, "service initializing", http.StatusServiceUnavailable)
return
}
next.ServeHTTP(w, r)
})
}
// SessionInitRequest is the request body for session initialization.
type SessionInitRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
Prompt string `json:"prompt"`
MatchedObservations int `json:"matchedObservations"`
}
// SessionInitResponse is the response for session initialization.
type SessionInitResponse struct {
SessionDBID int64 `json:"sessionDbId"`
PromptNumber int `json:"promptNumber"`
Skipped bool `json:"skipped,omitempty"`
Reason string `json:"reason,omitempty"`
}
// handleSessionInit handles session initialization from user-prompt hook.
func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
var req SessionInitRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Privacy check
if privacy.IsEntirelyPrivate(req.Prompt) {
// Create session but skip processing
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: promptNum,
Skipped: true,
Reason: "private",
})
return
}
// Clean prompt and create session
cleanedPrompt := privacy.Clean(req.Prompt)
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Increment prompt counter
promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Save user prompt with matched observation count
promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations)
if err != nil {
log.Warn().Err(err).Msg("Failed to save user prompt")
// Non-fatal: continue with session initialization
} else if s.chromaSync != nil {
// Sync to vector DB
now := time.Now()
promptWithSession := &models.UserPromptWithSession{
UserPrompt: models.UserPrompt{
ID: promptID,
ClaudeSessionID: req.ClaudeSessionID,
PromptNumber: promptNum,
PromptText: cleanedPrompt,
MatchedObservations: req.MatchedObservations,
CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: now.UnixMilli(),
},
Project: req.Project,
SDKSessionID: req.ClaudeSessionID,
}
if err := s.chromaSync.SyncUserPrompt(r.Context(), promptWithSession); err != nil {
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to ChromaDB")
}
}
log.Info().
Int64("sessionId", sessionID).
Int("promptNumber", promptNum).
Str("project", req.Project).
Msg("Session initialized")
// Broadcast prompt event for dashboard refresh
s.sseBroadcaster.Broadcast(map[string]interface{}{
"type": "prompt",
"action": "created",
"project": req.Project,
})
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: promptNum,
})
}
// SessionStartRequest is the request body for starting SDK agent.
type SessionStartRequest struct {
UserPrompt string `json:"userPrompt"`
PromptNumber int `json:"promptNumber"`
}
// handleSessionStart handles SDK agent session start.
func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid session id", http.StatusBadRequest)
return
}
var req SessionStartRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Initialize session in manager
sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if sess == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
// Session is now registered. Observations will be processed
// asynchronously by the background queue processor (processQueue in service.go).
log.Info().
Int64("sessionId", id).
Int("promptNumber", req.PromptNumber).
Msg("SDK agent session initialized")
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// ObservationRequest is the request body for posting observations.
type ObservationRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
ToolName string `json:"tool_name"`
ToolInput interface{} `json:"tool_input"`
ToolResponse interface{} `json:"tool_response"`
CWD string `json:"cwd"`
}
// handleObservation handles observation posting from post-tool-use hook.
func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) {
var req ObservationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Find session
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if sess == nil {
// Create session on-the-fly with project from request
id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
sess, _ = s.sessionStore.GetSessionByID(r.Context(), id)
}
// Queue observation
if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{
ToolName: req.ToolName,
ToolInput: req.ToolInput,
ToolResponse: req.ToolResponse,
CWD: req.CWD,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// SubagentCompleteRequest is the request body for subagent completion.
type SubagentCompleteRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
}
// handleSubagentComplete handles subagent/Task completion notifications.
// This triggers immediate processing of any queued observations from the subagent.
func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) {
var req SubagentCompleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Find session
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
if err != nil || sess == nil {
// Session not found - subagent may have been in a different context
log.Debug().
Str("claudeSessionId", req.ClaudeSessionID).
Msg("Subagent complete - no active session found")
w.WriteHeader(http.StatusOK)
return
}
// Trigger immediate processing of queued observations
messages := s.sessionManager.DrainMessages(sess.ID)
if len(messages) > 0 && s.processor != nil {
log.Info().
Int64("sessionId", sess.ID).
Int("messages", len(messages)).
Msg("Processing queued observations from subagent")
for _, msg := range messages {
if msg.Type == session.MessageTypeObservation && msg.Observation != nil {
err := s.processor.ProcessObservation(
r.Context(),
sess.SDKSessionID.String,
sess.Project,
msg.Observation.ToolName,
msg.Observation.ToolInput,
msg.Observation.ToolResponse,
msg.Observation.PromptNumber,
msg.Observation.CWD,
)
if err != nil {
log.Error().Err(err).
Str("tool", msg.Observation.ToolName).
Msg("Failed to process subagent observation")
}
}
}
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// handleGetSessionByClaudeID looks up a session by Claude session ID.
func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) {
claudeSessionID := r.URL.Query().Get("claudeSessionId")
if claudeSessionID == "" {
http.Error(w, "claudeSessionId required", http.StatusBadRequest)
return
}
session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if session == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
writeJSON(w, session)
}
// SummarizeRequest is the request body for summarize requests.
type SummarizeRequest struct {
LastUserMessage string `json:"lastUserMessage"`
LastAssistantMessage string `json:"lastAssistantMessage"`
}
// handleSummarize handles summarize requests from stop hook.
func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid session id", http.StatusBadRequest)
return
}
var req SummarizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Queue summarize request
if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// handleGetObservations returns recent observations.
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
limit := DefaultObservationsLimit
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
project := r.URL.Query().Get("project")
var observations []*models.Observation
var err error
if project != "" {
// Filter by project - includes project-scoped and global observations
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
} else {
// All projects
observations, err = s.observationStore.GetAllRecentObservations(r.Context(), limit)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if observations == nil {
observations = []*models.Observation{}
}
writeJSON(w, observations)
}
// handleGetSummaries returns recent summaries.
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
limit := DefaultSummariesLimit
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
project := r.URL.Query().Get("project")
var summaries []*models.SessionSummary
var err error
if project != "" {
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
} else {
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if summaries == nil {
summaries = []*models.SessionSummary{}
}
writeJSON(w, summaries)
}
// handleGetPrompts returns recent user prompts.
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
limit := DefaultPromptsLimit
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
project := r.URL.Query().Get("project")
var prompts []*models.UserPromptWithSession
var err error
if project != "" {
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
} else {
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if prompts == nil {
prompts = []*models.UserPromptWithSession{}
}
writeJSON(w, prompts)
}
// handleGetProjects returns all projects.
func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) {
projects, err := s.sessionStore.GetAllProjects(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, projects)
}
// handleGetStats returns worker statistics.
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
retrievalStats := s.GetRetrievalStats()
sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context())
writeJSON(w, map[string]interface{}{
"uptime": time.Since(s.startTime).String(),
"activeSessions": s.sessionManager.GetActiveSessionCount(),
"queueDepth": s.sessionManager.GetTotalQueueDepth(),
"isProcessing": s.sessionManager.IsAnySessionProcessing(),
"connectedClients": s.sseBroadcaster.ClientCount(),
"sessionsToday": sessionsToday,
"retrieval": retrievalStats,
})
}
// handleGetRetrievalStats returns detailed retrieval statistics.
func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) {
stats := s.GetRetrievalStats()
writeJSON(w, stats)
}
// handleContextCount returns the count of observations for a project.
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
count, err := s.observationStore.GetObservationCount(r.Context(), project)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{
"project": project,
"count": count,
})
}
// handleSearchByPrompt searches observations relevant to a user prompt.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
cwd := r.URL.Query().Get("cwd")
if project == "" || query == "" {
http.Error(w, "project and query required", http.StatusBadRequest)
return
}
limit := DefaultSearchLimit
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
// Search using FTS
observations, err := s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
if err != nil {
// FTS might fail if query has special chars, try without
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// Fast staleness filter - NO verification (that's too slow for interactive use)
// Just check mtimes and exclude obviously stale observations
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 && cwd != "" {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
// Record retrieval stats (no verification done, so verified=0, deleted=0)
s.recordRetrievalStats(int64(len(clusteredObservations)), 0, 0, true)
log.Info().
Str("project", project).
Str("query", query).
Int("found", len(clusteredObservations)).
Int("stale_excluded", staleCount).
Msg("Prompt-based observation search")
writeJSON(w, map[string]interface{}{
"project": project,
"query": query,
"observations": clusteredObservations,
})
}
// handleContextInject returns context for injection at session start.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
cwd := r.URL.Query().Get("cwd")
if cwd == "" {
cwd = "/"
}
// Limit observations for fast startup (configurable, default 100)
limit := s.config.ContextObservations
if limit <= 0 {
limit = DefaultContextLimit
}
// Full count determines how many observations get full detail (configurable, default 25)
fullCount := s.config.ContextFullCount
if fullCount <= 0 {
fullCount = 25
}
// Get recent observations
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Fast staleness filter - NO verification (that's too slow for startup)
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
// Record retrieval stats (no verification done)
s.recordRetrievalStats(int64(len(clusteredObservations)), 0, 0, false)
log.Info().
Str("project", project).
Int("total", len(observations)).
Int("fresh", len(freshObservations)).
Int("clustered", len(clusteredObservations)).
Int("duplicates", duplicatesRemoved).
Int("stale_excluded", staleCount).
Msg("Context injection with clustering")
writeJSON(w, map[string]interface{}{
"project": project,
"observations": clusteredObservations,
"full_count": fullCount,
"stale_excluded": staleCount,
"duplicates_removed": duplicatesRemoved,
})
}
+553
View File
@@ -0,0 +1,553 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testService creates a Service with a test SQLite database including FTS5 for testing.
func testService(t *testing.T) (*Service, func()) {
t.Helper()
// Create test store (runs migrations to create all tables including FTS5)
store, dbCleanup := testStore(t)
// Create store wrappers
sessionStore := sqlite.NewSessionStore(store)
observationStore := sqlite.NewObservationStore(store)
summaryStore := sqlite.NewSummaryStore(store)
promptStore := sqlite.NewPromptStore(store)
// Create domain services
sessionManager := session.NewManager(sessionStore)
sseBroadcaster := sse.NewBroadcaster()
// Create context
ctx, cancel := context.WithCancel(context.Background())
// Create router
router := chi.NewRouter()
svc := &Service{
version: "test-version",
config: config.Get(),
store: store,
sessionStore: sessionStore,
observationStore: observationStore,
summaryStore: summaryStore,
promptStore: promptStore,
sessionManager: sessionManager,
sseBroadcaster: sseBroadcaster,
router: router,
ctx: ctx,
cancel: cancel,
startTime: time.Now(),
}
svc.setupRoutes()
// Mark service as ready for tests
svc.ready.Store(true)
cleanup := func() {
cancel()
store.Close()
dbCleanup()
}
return svc, cleanup
}
// createTestObservation creates a test observation in the database.
func createTestObservation(t *testing.T, store *sqlite.ObservationStore, project, title, narrative string, concepts []string) int64 {
t.Helper()
obs := &models.ParsedObservation{
Type: models.ObsTypeDiscovery,
Title: title,
Narrative: narrative,
Concepts: concepts,
}
id, _, err := store.StoreObservation(context.Background(), "test-session", project, obs, 1, 100)
require.NoError(t, err)
return id
}
func TestHandleSearchByPrompt_DefaultLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
// Create 60 observations (more than the default limit of 50)
for i := 0; i < 60; i++ {
createTestObservation(t, svc.observationStore, project,
"Test observation about authentication",
"This observation is about authentication and security patterns",
[]string{"authentication", "security"})
// Small delay to ensure different timestamps
time.Sleep(time.Millisecond)
}
// Make request without limit parameter
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=authentication", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "observations should be an array")
// The default limit is now 50, not 5
// Note: clustering may reduce the count, but we should have more than 5
t.Logf("Got %d observations", len(observations))
// Just verify we got a reasonable number, accounting for clustering
assert.True(t, len(observations) >= 1, "should return at least one observation")
}
func TestHandleSearchByPrompt_CustomLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
// Create 20 unique observations
for i := 0; i < 20; i++ {
createTestObservation(t, svc.observationStore, project,
"Unique observation "+string(rune('A'+i))+" about testing",
"This is unique observation number "+string(rune('A'+i)),
[]string{"unique-" + string(rune('a'+i))})
time.Sleep(time.Millisecond)
}
// Request with custom limit of 15
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=observation&limit=15", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// Should respect the custom limit (accounting for clustering)
t.Logf("Got %d observations with limit=15", len(observations))
assert.LessOrEqual(t, len(observations), 15)
}
func TestHandleSearchByPrompt_NoHardcodedLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
// Create observations with VERY different content to avoid clustering
// Each has unique words that won't match other observations
uniqueObservations := []struct {
title string
narrative string
concepts []string
}{
{"JWT tokens expire daily", "OAuth2 bearer tokens authentication", []string{"jwt"}},
{"PostgreSQL indexes optimize queries", "B-tree index on user table", []string{"postgres"}},
{"Redis caching TTL configuration", "Memory eviction policy LRU", []string{"redis"}},
{"Zerolog structured logging", "JSON output formatting levels", []string{"logging"}},
{"Pytest fixtures setup teardown", "Mock objects dependency injection", []string{"pytest"}},
{"Docker containers orchestration", "Compose multi-stage builds", []string{"docker"}},
{"Prometheus metrics collection", "Grafana dashboards alerting", []string{"prometheus"}},
{"OWASP vulnerability scanning", "SQL injection XSS prevention", []string{"owasp"}},
}
for _, obs := range uniqueObservations {
createTestObservation(t, svc.observationStore, project, obs.title, obs.narrative, obs.concepts)
time.Sleep(time.Millisecond)
}
// Search using a common keyword that should match most observations
// Using broader query to match multiple items
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=tokens+indexes+caching+logging&limit=10", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// The key is that the limit is no longer hardcoded to 5
// With our new default of 50, we should be able to return more than 5
t.Logf("Got %d observations (limit=10)", len(observations))
// The test passes as long as the default limit (50) is being used instead of 5
// and we can request a custom limit
assert.LessOrEqual(t, len(observations), 10, "should respect the custom limit")
}
func TestHandleSearchByPrompt_RequiredParams(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
tests := []struct {
name string
query string
wantStatus int
}{
{
name: "missing project",
query: "/api/context/search?query=test",
wantStatus: http.StatusBadRequest,
},
{
name: "missing query",
query: "/api/context/search?project=test",
wantStatus: http.StatusBadRequest,
},
{
name: "both present",
query: "/api/context/search?project=test&query=test",
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tt.query, nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, tt.wantStatus, rec.Code)
})
}
}
func TestHandleContextInject_NoHardcodedLimit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Set a higher context observations limit in config
svc.config.ContextObservations = 50
project := "test-project"
// Create observations with VERY different content to avoid clustering
uniqueObservations := []struct {
title string
narrative string
concepts []string
}{
{"JWT tokens expire daily", "OAuth2 bearer tokens authentication", []string{"jwt"}},
{"PostgreSQL indexes optimize queries", "B-tree index on user table", []string{"postgres"}},
{"Redis caching TTL configuration", "Memory eviction policy LRU", []string{"redis"}},
{"Zerolog structured logging", "JSON output formatting levels", []string{"logging"}},
{"Pytest fixtures setup teardown", "Mock objects dependency injection", []string{"pytest"}},
{"Docker containers orchestration", "Compose multi-stage builds", []string{"docker"}},
{"Prometheus metrics collection", "Grafana dashboards alerting", []string{"prometheus"}},
{"OWASP vulnerability scanning", "SQL injection XSS prevention", []string{"owasp"}},
}
for _, obs := range uniqueObservations {
createTestObservation(t, svc.observationStore, project, obs.title, obs.narrative, obs.concepts)
time.Sleep(time.Millisecond)
}
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project, nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// With very different content, we should get multiple observations back
// The key verification is that the hardcoded limit of 5 has been removed
t.Logf("Got %d observations from context inject", len(observations))
// Should return more than old limit of 5 with unique observations
assert.GreaterOrEqual(t, len(observations), 1, "should return at least 1 observation")
}
func TestHandleContextInject_RequiresProject(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
req := httptest.NewRequest(http.MethodGet, "/api/context/inject", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestHandleGetObservations_Limit(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create 20 observations
for i := 0; i < 20; i++ {
createTestObservation(t, svc.observationStore, "project-"+string(rune('a'+i%5)),
"Observation "+string(rune('A'+i)),
"Content of observation "+string(rune('A'+i)),
[]string{"test"})
time.Sleep(time.Millisecond)
}
// Request with limit=10
req := httptest.NewRequest(http.MethodGet, "/api/observations?limit=10", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Parse as generic JSON array since the model uses custom marshaling
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
require.NoError(t, err)
assert.Len(t, observations, 10)
}
func TestSearchObservations_GlobalScope(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Create a project-scoped observation
createTestObservation(t, svc.observationStore, "project-a",
"Project specific code",
"This is specific to project-a",
[]string{"project-specific"})
// Create a global-scoped observation (has a globalizable concept)
createTestObservation(t, svc.observationStore, "project-a",
"Security best practice",
"Always validate user input",
[]string{"security", "best-practice"})
// Search from project-b - should find global observation
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project=project-b&query=security", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok)
// Should find the global observation even though it was created in project-a
assert.GreaterOrEqual(t, len(observations), 1)
}
func TestClusterObservations_RemovesDuplicates(t *testing.T) {
// Create similar observations
obs1 := &models.Observation{
ID: 1,
Title: sql.NullString{String: "Authentication flow implementation", Valid: true},
Narrative: sql.NullString{String: "We implemented JWT-based authentication", Valid: true},
}
obs2 := &models.Observation{
ID: 2,
Title: sql.NullString{String: "Authentication flow update", Valid: true},
Narrative: sql.NullString{String: "Updated JWT-based authentication logic", Valid: true},
}
obs3 := &models.Observation{
ID: 3,
Title: sql.NullString{String: "Database migration guide", Valid: true},
Narrative: sql.NullString{String: "How to run database migrations", Valid: true},
}
observations := []*models.Observation{obs1, obs2, obs3}
// Cluster with 0.4 threshold
clustered := clusterObservations(observations, 0.4)
// obs1 and obs2 should be clustered together, obs3 is different
assert.LessOrEqual(t, len(clustered), 3)
assert.GreaterOrEqual(t, len(clustered), 1)
// The first observation in each cluster should be kept (obs1, obs3)
t.Logf("Clustered %d observations down to %d", len(observations), len(clustered))
}
func TestRetrievalStats(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
project := "test-project"
createTestObservation(t, svc.observationStore, project,
"Test observation",
"Test narrative",
[]string{"test"})
// Make a search request
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test", nil)
rec := httptest.NewRecorder()
svc.router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Check stats
stats := svc.GetRetrievalStats()
assert.Equal(t, int64(1), stats.TotalRequests)
assert.Equal(t, int64(1), stats.SearchRequests)
assert.GreaterOrEqual(t, stats.ObservationsServed, int64(1))
}
func TestHandleHealth_ReturnsVersion(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.version = "test-version-1.2.3"
svc.ready.Store(true)
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
rec := httptest.NewRecorder()
svc.handleHealth(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "ready", response["status"])
assert.Equal(t, "test-version-1.2.3", response["version"])
}
func TestHandleVersion(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.version = "v2.0.0-beta"
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
rec := httptest.NewRecorder()
svc.handleVersion(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "v2.0.0-beta", response["version"])
}
func TestHandleReady_ServiceNotReady(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Reset ready state to simulate service not being ready
svc.ready.Store(false)
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
rec := httptest.NewRecorder()
svc.handleReady(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
}
func TestHandleReady_ServiceReady(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.ready.Store(true)
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
rec := httptest.NewRecorder()
svc.handleReady(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "ready", response["status"])
}
func TestRequireReadyMiddleware_Blocks(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
// Reset ready state to simulate service not being ready
svc.ready.Store(false)
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
}
func TestRequireReadyMiddleware_Allows(t *testing.T) {
svc, cleanup := testService(t)
defer cleanup()
svc.ready.Store(true)
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "success", rec.Body.String())
}
+179
View File
@@ -0,0 +1,179 @@
// Package sdk provides SDK agent integration for claude-mnemonic.
package sdk
import (
"regexp"
"strings"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
var (
// Observation parsing
observationRegex = regexp.MustCompile(`(?s)<observation>(.*?)</observation>`)
// Summary parsing
summaryRegex = regexp.MustCompile(`(?s)<summary>(.*?)</summary>`)
skipSummaryRegex = regexp.MustCompile(`<skip_summary\s+reason="([^"]+)"\s*/>`)
// Valid observation types
validObsTypes = map[string]bool{
"bugfix": true,
"feature": true,
"refactor": true,
"change": true,
"discovery": true,
"decision": true,
}
// Valid concepts (strict list - no custom tags allowed)
validConcepts = map[string]bool{
"how-it-works": true,
"why-it-exists": true,
"what-changed": true,
"problem-solution": true,
"gotcha": true,
"pattern": true,
"trade-off": true,
}
)
// ParseObservations parses observation XML blocks from SDK response text.
func ParseObservations(text string, correlationID string) []*models.ParsedObservation {
var observations []*models.ParsedObservation
matches := observationRegex.FindAllStringSubmatch(text, -1)
for _, match := range matches {
if len(match) < 2 {
continue
}
obsContent := match[1]
// Extract fields
obsType := extractField(obsContent, "type")
title := extractField(obsContent, "title")
subtitle := extractField(obsContent, "subtitle")
narrative := extractField(obsContent, "narrative")
facts := extractArrayElements(obsContent, "facts", "fact")
concepts := extractArrayElements(obsContent, "concepts", "concept")
filesRead := extractArrayElements(obsContent, "files_read", "file")
filesModified := extractArrayElements(obsContent, "files_modified", "file")
// Determine final type (default to "change" if invalid)
finalType := models.ObsTypeChange
if obsType != "" {
if validObsTypes[obsType] {
finalType = models.ObservationType(obsType)
} else {
log.Warn().
Str("correlationId", correlationID).
Str("invalidType", obsType).
Msg("Invalid observation type, using 'change'")
}
} else {
log.Warn().
Str("correlationId", correlationID).
Msg("Observation missing type field, using 'change'")
}
// Filter concepts: only keep valid ones from the strict list
cleanedConcepts := make([]string, 0, len(concepts))
var invalidConcepts []string
for _, c := range concepts {
c = strings.ToLower(strings.TrimSpace(c))
if c == string(finalType) {
continue // Skip type in concepts
}
if validConcepts[c] {
cleanedConcepts = append(cleanedConcepts, c)
} else {
invalidConcepts = append(invalidConcepts, c)
}
}
if len(invalidConcepts) > 0 {
log.Warn().
Str("correlationId", correlationID).
Strs("invalidConcepts", invalidConcepts).
Msg("Filtered out invalid concepts (not in allowed list)")
}
observations = append(observations, &models.ParsedObservation{
Type: finalType,
Title: title,
Subtitle: subtitle,
Facts: facts,
Narrative: narrative,
Concepts: cleanedConcepts,
FilesRead: filesRead,
FilesModified: filesModified,
})
}
return observations
}
// ParseSummary parses a summary XML block from SDK response text.
func ParseSummary(text string, sessionID int64) *models.ParsedSummary {
// Check for skip_summary first
if skipMatch := skipSummaryRegex.FindStringSubmatch(text); skipMatch != nil {
log.Info().
Int64("sessionId", sessionID).
Str("reason", skipMatch[1]).
Msg("Summary skipped")
return nil
}
// Find summary block
match := summaryRegex.FindStringSubmatch(text)
if len(match) < 2 {
return nil
}
summaryContent := match[1]
return &models.ParsedSummary{
Request: extractField(summaryContent, "request"),
Investigated: extractField(summaryContent, "investigated"),
Learned: extractField(summaryContent, "learned"),
Completed: extractField(summaryContent, "completed"),
NextSteps: extractField(summaryContent, "next_steps"),
Notes: extractField(summaryContent, "notes"),
}
}
// extractField extracts a simple field value from XML content.
func extractField(content, fieldName string) string {
pattern := regexp.MustCompile(`<` + fieldName + `>([^<]*)</` + fieldName + `>`)
match := pattern.FindStringSubmatch(content)
if len(match) < 2 {
return ""
}
return strings.TrimSpace(match[1])
}
// extractArrayElements extracts array elements from XML content.
func extractArrayElements(content, arrayName, elementName string) []string {
var elements []string
// Find the array block
arrayPattern := regexp.MustCompile(`(?s)<` + arrayName + `>(.*?)</` + arrayName + `>`)
arrayMatch := arrayPattern.FindStringSubmatch(content)
if len(arrayMatch) < 2 {
return elements
}
arrayContent := arrayMatch[1]
// Extract individual elements
elementPattern := regexp.MustCompile(`<` + elementName + `>([^<]+)</` + elementName + `>`)
elementMatches := elementPattern.FindAllStringSubmatch(arrayContent, -1)
for _, match := range elementMatches {
if len(match) >= 2 {
elements = append(elements, strings.TrimSpace(match[1]))
}
}
return elements
}
+678
View File
@@ -0,0 +1,678 @@
// Package sdk provides SDK agent integration for claude-mnemonic.
package sdk
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
"github.com/rs/zerolog/log"
)
// BroadcastFunc is a callback for broadcasting events to SSE clients.
type BroadcastFunc func(event map[string]interface{})
// SyncObservationFunc is a callback for syncing observations to vector DB.
type SyncObservationFunc func(obs *models.Observation)
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
type SyncSummaryFunc func(summary *models.SessionSummary)
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
type Processor struct {
claudePath string
model string
observationStore *sqlite.ObservationStore
summaryStore *sqlite.SummaryStore
broadcastFunc BroadcastFunc
syncObservationFunc SyncObservationFunc
syncSummaryFunc SyncSummaryFunc
mu sync.Mutex
}
// SetBroadcastFunc sets the broadcast callback for SSE events.
func (p *Processor) SetBroadcastFunc(fn BroadcastFunc) {
p.broadcastFunc = fn
}
// SetSyncObservationFunc sets the callback for syncing observations to vector DB.
func (p *Processor) SetSyncObservationFunc(fn SyncObservationFunc) {
p.syncObservationFunc = fn
}
// SetSyncSummaryFunc sets the callback for syncing summaries to vector DB.
func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
p.syncSummaryFunc = fn
}
// broadcast sends an event via the broadcast callback if set.
func (p *Processor) broadcast(event map[string]interface{}) {
if p.broadcastFunc != nil {
p.broadcastFunc(event)
}
}
// NewProcessor creates a new SDK processor.
func NewProcessor(observationStore *sqlite.ObservationStore, summaryStore *sqlite.SummaryStore) (*Processor, error) {
cfg := config.Get()
// Find Claude Code CLI
claudePath := cfg.ClaudeCodePath
if claudePath == "" {
// Try to find in PATH
path, err := exec.LookPath("claude")
if err != nil {
return nil, fmt.Errorf("claude CLI not found in PATH and CLAUDE_CODE_PATH not set")
}
claudePath = path
}
// Verify it exists
if _, err := os.Stat(claudePath); err != nil {
return nil, fmt.Errorf("claude CLI not found at %s: %w", claudePath, err)
}
return &Processor{
claudePath: claudePath,
model: cfg.Model,
observationStore: observationStore,
summaryStore: summaryStore,
}, nil
}
// ProcessObservation processes a single tool observation and extracts insights.
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
p.mu.Lock()
defer p.mu.Unlock()
// Skip certain tools that aren't worth processing
if shouldSkipTool(toolName) {
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
return nil
}
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
// Convert tool data to strings
inputStr := toJSONString(toolInput)
outputStr := toJSONString(toolResponse)
// Check if we already have observations for this file (skip if covered)
if filePath := extractFilePath(toolName, inputStr); filePath != "" {
exists, err := p.observationStore.ExistsSimilarObservation(ctx, project, []string{filePath}, nil)
if err == nil && exists {
log.Debug().
Str("tool", toolName).
Str("file", filePath).
Msg("Skipping - file already has observations")
return nil
}
}
// Build the prompt
exec := ToolExecution{
ToolName: toolName,
ToolInput: inputStr,
ToolOutput: outputStr,
CWD: cwd,
}
prompt := BuildObservationPrompt(exec)
// Call Claude Code CLI
response, err := p.callClaudeCLI(ctx, prompt)
if err != nil {
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
return err
}
// Parse observations from response
observations := ParseObservations(response, sdkSessionID)
if len(observations) == 0 {
log.Info().Str("tool", toolName).Msg("No observations extracted (Claude deemed not significant)")
return nil
}
// Get existing observations for deduplication
existingObs, err := p.observationStore.GetRecentObservations(ctx, project, 50)
if err != nil {
log.Warn().Err(err).Msg("Failed to get existing observations for dedup check")
existingObs = nil // Continue without dedup
}
// Store each observation (with deduplication check)
const similarityThreshold = 0.4 // Same threshold as retrieval clustering
var storedCount, skippedCount int
for _, obs := range observations {
// Capture file modification times for staleness detection
obs.FileMtimes = captureFileMtimes(obs.FilesRead, obs.FilesModified, cwd)
// Convert to stored observation for similarity check
storedObs := obs.ToStoredObservation()
// Check if this observation is too similar to existing ones
if existingObs != nil && similarity.IsSimilarToAny(storedObs, existingObs, similarityThreshold) {
log.Debug().
Str("type", string(obs.Type)).
Str("title", obs.Title).
Msg("Skipping observation - too similar to existing")
skippedCount++
continue
}
id, createdAtEpoch, err := p.observationStore.StoreObservation(ctx, sdkSessionID, project, obs, promptNumber, 0)
if err != nil {
log.Error().Err(err).Msg("Failed to store observation")
continue
}
storedCount++
log.Info().
Int64("id", id).
Str("type", string(obs.Type)).
Str("title", obs.Title).
Int("trackedFiles", len(obs.FileMtimes)).
Msg("Observation stored")
// Sync to vector DB if callback is set
if p.syncObservationFunc != nil {
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
fullObs.ID = id
fullObs.CreatedAtEpoch = createdAtEpoch
p.syncObservationFunc(fullObs)
}
// Broadcast new observation event for dashboard refresh
p.broadcast(map[string]interface{}{
"type": "observation",
"action": "created",
"id": id,
"project": project,
})
// Add to existing for subsequent dedup checks within same batch
if existingObs != nil {
existingObs = append(existingObs, storedObs)
}
}
if skippedCount > 0 {
log.Info().
Int("stored", storedCount).
Int("skipped", skippedCount).
Msg("Observation processing complete (duplicates skipped)")
}
return nil
}
// ProcessSummary processes a session summary request.
func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSessionID, project, userPrompt, lastUserMsg, lastAssistantMsg string) error {
p.mu.Lock()
defer p.mu.Unlock()
// Skip summary generation if there's no meaningful assistant response
// This prevents generic "initial session setup" summaries
if !hasMeaningfulContent(lastAssistantMsg) {
log.Info().
Int64("sessionId", sessionDBID).
Msg("Skipping summary - no meaningful assistant response")
return nil
}
// Build the summary prompt
req := SummaryRequest{
SessionDBID: sessionDBID,
SDKSessionID: sdkSessionID,
Project: project,
UserPrompt: userPrompt,
LastUserMessage: lastUserMsg,
LastAssistantMessage: lastAssistantMsg,
}
prompt := BuildSummaryPrompt(req)
// Call Claude Code CLI
response, err := p.callClaudeCLI(ctx, prompt)
if err != nil {
log.Error().Err(err).Int64("sessionId", sessionDBID).Msg("Failed to call Claude CLI for summary")
return err
}
// Parse summary from response
summary := ParseSummary(response, sessionDBID)
if summary == nil {
log.Info().Int64("sessionId", sessionDBID).Msg("No summary generated (skipped or empty)")
return nil
}
// Filter out summaries that describe the memory agent itself
if isSelfReferentialSummary(summary) {
log.Info().Int64("sessionId", sessionDBID).Msg("Skipping self-referential summary (describes agent, not user work)")
return nil
}
// Store the summary (promptNumber=0, discoveryTokens=0 for summaries)
id, createdAtEpoch, err := p.summaryStore.StoreSummary(ctx, sdkSessionID, project, summary, 0, 0)
if err != nil {
log.Error().Err(err).Msg("Failed to store summary")
return err
}
log.Info().
Int64("id", id).
Int64("sessionId", sessionDBID).
Msg("Summary stored")
// Sync to vector DB if callback is set
if p.syncSummaryFunc != nil {
fullSummary := models.NewSessionSummary(sdkSessionID, project, summary, 0, 0)
fullSummary.ID = id
fullSummary.CreatedAtEpoch = createdAtEpoch
p.syncSummaryFunc(fullSummary)
}
// Broadcast new summary event for dashboard refresh
p.broadcast(map[string]interface{}{
"type": "summary",
"action": "created",
"id": id,
"project": project,
})
return nil
}
// callClaudeCLI calls the Claude Code CLI with the given prompt.
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
// Build the full prompt with system instructions
fullPrompt := systemPrompt + "\n\n" + prompt
// Create command with timeout
ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
// Use claude CLI with --print flag for non-interactive output
// and -p for prompt input
cmd := exec.CommandContext(ctx, p.claudePath, "--print", "-p", fullPrompt) // #nosec G204 -- claudePath is from config, fullPrompt is internal
// Set model if specified (use haiku for cost efficiency)
if p.model != "" {
cmd.Args = append([]string{cmd.Args[0], "--model", p.model}, cmd.Args[1:]...)
} else {
// Default to haiku for processing (cheap and fast)
cmd.Args = append([]string{cmd.Args[0], "--model", "haiku"}, cmd.Args[1:]...)
}
// Run from /tmp to avoid triggering our own hooks
// (hooks are triggered based on working directory)
cmd.Dir = "/tmp"
// Disable any plugin hooks by setting an env var that our hooks can check
cmd.Env = append(os.Environ(), "CLAUDE_MNEMONIC_INTERNAL=1")
// Capture output
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// Run command
err := cmd.Run()
if err != nil {
log.Error().
Err(err).
Str("stderr", stderr.String()).
Msg("Claude CLI execution failed")
return "", fmt.Errorf("claude CLI failed: %w (stderr: %s)", err, stderr.String())
}
return stdout.String(), nil
}
// shouldSkipTool returns true for tools that aren't worth processing.
func shouldSkipTool(toolName string) bool {
// Only skip truly uninteresting tools
skipTools := map[string]bool{
"TodoWrite": true, // Skip TodoWrite - internal tracking
"Task": true, // Skip Task - sub-agent spawning
"TaskOutput": true, // Skip TaskOutput - sub-agent results
"Glob": true, // Skip Glob - just file listing
}
skip, found := skipTools[toolName]
if found {
return skip
}
return false // Process all other tools
}
// extractFilePath extracts the file path from tool input for deduplication.
func extractFilePath(toolName, inputStr string) string {
if inputStr == "" {
return ""
}
var input map[string]interface{}
if err := json.Unmarshal([]byte(inputStr), &input); err != nil {
return ""
}
// Handle different tool input formats
switch toolName {
case "Read":
if fp, ok := input["file_path"].(string); ok {
return fp
}
case "Grep", "Search":
if path, ok := input["path"].(string); ok {
return path
}
case "Edit", "Write":
if fp, ok := input["file_path"].(string); ok {
return fp
}
}
return ""
}
// toJSONString converts an interface to a JSON string.
func toJSONString(v interface{}) string {
if v == nil {
return ""
}
if s, ok := v.(string); ok {
return s
}
b, err := json.Marshal(v)
if err != nil {
return fmt.Sprintf("%v", v)
}
return string(b)
}
// captureFileMtimes captures current modification times for tracked files.
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
mtimes := make(map[string]int64)
// Helper to get mtime for a file path
getMtime := func(path string) (int64, bool) {
// Resolve relative paths against cwd
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
}
info, err := os.Stat(absPath)
if err != nil {
return 0, false
}
return info.ModTime().UnixMilli(), true
}
// Capture mtimes for all read files
for _, path := range filesRead {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
}
}
// Capture mtimes for all modified files
for _, path := range filesModified {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
}
}
return mtimes
}
// GetFileMtimes returns current modification times for a list of file paths.
// This is used for staleness checking when injecting context.
func GetFileMtimes(paths []string, cwd string) map[string]int64 {
return captureFileMtimes(paths, nil, cwd)
}
// GetFileContent reads file content for verification purposes.
// Returns content and ok status.
func GetFileContent(path, cwd string) (string, bool) {
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
}
content, err := os.ReadFile(absPath) // #nosec G304 -- intentional file read for verification
if err != nil {
return "", false
}
// Limit to first 2000 chars for verification (enough context, not too expensive)
if len(content) > 2000 {
return string(content[:2000]) + "\n...[truncated]", true
}
return string(content), true
}
// VerifyObservation checks if an observation is still valid given the current file contents.
// Returns true if the observation is still accurate, false if it should be deleted.
func (p *Processor) VerifyObservation(ctx context.Context, obs *models.Observation, cwd string) bool {
// Build file content context
var fileContents []string
var paths []string
// Combine files_read and files_modified
for _, path := range obs.FilesRead {
paths = append(paths, path)
}
for _, path := range obs.FilesModified {
paths = append(paths, path)
}
// Get current content of tracked files
for _, path := range paths {
if content, ok := GetFileContent(path, cwd); ok {
fileContents = append(fileContents, fmt.Sprintf("=== %s ===\n%s", path, content))
}
}
if len(fileContents) == 0 {
// No files available to verify against - keep the observation
return true
}
// Build verification prompt
prompt := fmt.Sprintf(`You are verifying if a previously recorded observation is still accurate.
OBSERVATION:
- Type: %s
- Title: %s
- Subtitle: %s
- Narrative: %s
- Facts: %v
CURRENT FILE CONTENTS:
%s
TASK: Check if the observation is still accurate given the current file contents.
Reply with ONLY one of:
- VALID - if the observation is still accurate
- INVALID - if the observation is no longer accurate (the code/behavior changed)
- UNCERTAIN - if you can't determine validity (files might be incomplete)
Your response:`,
obs.Type,
obs.Title.String,
obs.Subtitle.String,
obs.Narrative.String,
obs.Facts,
strings.Join(fileContents, "\n\n"),
)
// Call Claude CLI for quick verification
response, err := p.callClaudeCLI(ctx, prompt)
if err != nil {
log.Warn().Err(err).Msg("Failed to verify observation, keeping it")
return true // On error, keep the observation
}
response = strings.TrimSpace(strings.ToUpper(response))
// Parse response
if strings.Contains(response, "INVALID") {
log.Info().
Int64("id", obs.ID).
Str("title", obs.Title.String).
Msg("Observation verified as INVALID - will delete")
return false
}
// VALID or UNCERTAIN - keep the observation
log.Debug().
Int64("id", obs.ID).
Str("title", obs.Title.String).
Str("result", response).
Msg("Observation verified")
return true
}
// isSelfReferentialSummary checks if a summary describes the memory agent itself
// rather than actual user work. These summaries should be filtered out.
func isSelfReferentialSummary(summary *models.ParsedSummary) bool {
// Combine all summary fields for checking
content := strings.ToLower(summary.Request + " " + summary.Completed + " " + summary.Learned + " " + summary.NextSteps)
// Indicators that the summary is about the memory agent, not user work
selfReferentialPhrases := []string{
"memory extraction",
"memory agent",
"hook execution",
"hook mechanism",
"session initialization",
"session setup",
"agent initialization",
"no technical learnings",
"no code or project work",
"waiting for the user",
"waiting for user",
"awaiting actual",
"awaiting claude code",
"progress checkpoint",
"checkpoint request",
}
matchCount := 0
for _, phrase := range selfReferentialPhrases {
if strings.Contains(content, phrase) {
matchCount++
}
}
// If the summary mentions 2+ self-referential phrases, it's about the agent
return matchCount >= 2
}
// hasMeaningfulContent checks if the assistant response contains meaningful content
// worth generating a summary for. This filters out initial greetings, empty sessions,
// and sessions where only system messages were exchanged.
func hasMeaningfulContent(assistantMsg string) bool {
// Skip if empty or too short (need substantial content)
if len(strings.TrimSpace(assistantMsg)) < 200 {
return false
}
lowerMsg := strings.ToLower(assistantMsg)
// Skip messages that are primarily about system/hook status
skipIndicators := []string{
"hook success",
"callback hook",
"session start",
"sessionstart",
"system-reminder",
"memory extraction agent",
"memory agent",
"no technical learnings",
"waiting for",
"waiting to",
"no code or project work",
"no substantive",
}
skipCount := 0
for _, skip := range skipIndicators {
if strings.Contains(lowerMsg, skip) {
skipCount++
}
}
// If multiple skip indicators found, this is likely a system-only session
if skipCount >= 2 {
return false
}
// Check for indicators of actual work being done
workIndicators := []string{
// Concrete file operations (with paths)
".go", ".ts", ".js", ".py", ".md", ".json", ".yaml", ".yml",
// Code modifications
"edited", "modified", "created", "deleted", "updated", "changed",
"added", "removed", "fixed", "implemented", "refactored",
// Tool results
"```", "lines ", "function ", "const ", "var ", "let ",
"type ", "struct ", "class ", "def ", "func ",
}
matchCount := 0
for _, indicator := range workIndicators {
if strings.Contains(lowerMsg, strings.ToLower(indicator)) {
matchCount++
}
}
// Require at least 2 work indicators to generate a summary
return matchCount >= 2
}
const systemPrompt = `You are a memory extraction agent for Claude Code sessions. Your job is to analyze tool executions and extract meaningful observations that would be useful for future sessions.
GUIDELINES:
1. Only create observations for SIGNIFICANT learnings - not every tool call needs one
2. Focus on: decisions made, bugs fixed, patterns discovered, project structure learned
3. Skip trivial operations like simple file reads without insights
4. Be concise but informative in your observations
5. Use appropriate type tags: decision, bugfix, feature, refactor, discovery, change
OUTPUT FORMAT:
When you find something worth remembering, output:
<observation>
<type>decision|bugfix|feature|refactor|discovery|change</type>
<title>Short descriptive title</title>
<subtitle>One-line summary</subtitle>
<narrative>Detailed explanation</narrative>
<facts>
<fact>Specific fact 1</fact>
</facts>
<concepts>
<concept>tag1</concept>
</concepts>
<files_read>
<file>/path/to/file</file>
</files_read>
<files_modified>
<file>/path/to/file</file>
</files_modified>
</observation>
If the tool execution is not noteworthy, simply respond with:
<skip reason="not significant"/>`
+117
View File
@@ -0,0 +1,117 @@
// Package sdk provides SDK agent integration for claude-mnemonic.
package sdk
import (
"encoding/json"
"fmt"
"strings"
"time"
)
// ObservationTypes defines valid observation types.
var ObservationTypes = []string{"bugfix", "feature", "refactor", "change", "discovery", "decision"}
// ObservationConcepts defines valid observation concepts.
var ObservationConcepts = []string{
"how-it-works",
"why-it-exists",
"what-changed",
"problem-solution",
"gotcha",
"pattern",
"trade-off",
}
// ToolExecution represents a tool execution for observation.
type ToolExecution struct {
ID int64
ToolName string
ToolInput string
ToolOutput string
CreatedAtEpoch int64
CWD string
}
// BuildObservationPrompt builds a prompt for processing a tool observation.
func BuildObservationPrompt(exec ToolExecution) string {
// Safely parse tool_input and tool_output
var toolInput interface{}
var toolOutput interface{}
if err := json.Unmarshal([]byte(exec.ToolInput), &toolInput); err != nil {
toolInput = exec.ToolInput
}
if err := json.Unmarshal([]byte(exec.ToolOutput), &toolOutput); err != nil {
toolOutput = exec.ToolOutput
}
inputJSON, _ := json.MarshalIndent(toolInput, " ", " ")
outputJSON, _ := json.MarshalIndent(toolOutput, " ", " ")
timestamp := time.UnixMilli(exec.CreatedAtEpoch).Format(time.RFC3339)
var sb strings.Builder
sb.WriteString("<observed_from_primary_session>\n")
sb.WriteString(fmt.Sprintf(" <what_happened>%s</what_happened>\n", exec.ToolName))
sb.WriteString(fmt.Sprintf(" <occurred_at>%s</occurred_at>\n", timestamp))
if exec.CWD != "" {
sb.WriteString(fmt.Sprintf(" <working_directory>%s</working_directory>\n", exec.CWD))
}
sb.WriteString(fmt.Sprintf(" <parameters>%s</parameters>\n", truncate(string(inputJSON), 3000)))
sb.WriteString(fmt.Sprintf(" <outcome>%s</outcome>\n", truncate(string(outputJSON), 5000)))
sb.WriteString("</observed_from_primary_session>")
return sb.String()
}
// SummaryRequest contains data for building a summary prompt.
type SummaryRequest struct {
SessionDBID int64
SDKSessionID string
Project string
UserPrompt string
LastUserMessage string
LastAssistantMessage string
}
// BuildSummaryPrompt builds a prompt requesting a session summary.
func BuildSummaryPrompt(req SummaryRequest) string {
var sb strings.Builder
sb.WriteString("PROGRESS SUMMARY CHECKPOINT\n")
sb.WriteString("===========================\n")
sb.WriteString("Write progress notes of what was done, what was learned, and what's next. This is a checkpoint to capture progress so far. The session is ongoing - you may receive more requests and tool executions after this summary. Write \"next_steps\" as the current trajectory of work (what's actively being worked on or coming up next), not as post-session future work. Always write at least a minimal summary explaining current progress, even if work is still in early stages, so that users see a summary output tied to each request.\n\n")
if req.LastAssistantMessage != "" {
sb.WriteString("Claude's Full Response to User:\n")
sb.WriteString(truncate(req.LastAssistantMessage, 4000))
sb.WriteString("\n\n")
}
sb.WriteString(`Respond in this XML format:
<summary>
<request>[Short title capturing the user's request AND the substance of what was discussed/done]</request>
<investigated>[What has been explored so far? What was examined?]</investigated>
<learned>[What have you learned about how things work?]</learned>
<completed>[What work has been completed so far? What has shipped or changed?]</completed>
<next_steps>[What are you actively working on or planning to work on next in this session?]</next_steps>
<notes>[Additional insights or observations about the current progress]</notes>
</summary>
IMPORTANT! DO NOT do any work right now other than generating this next PROGRESS SUMMARY - and remember that you are a memory agent designed to summarize a DIFFERENT claude code session, not this one.
Never reference yourself or your own actions. Do not output anything other than the summary content formatted in the XML structure above. All other output is ignored by the system, and the system has been designed to be smart about token usage. Please spend your tokens wisely on useful summary content.
Thank you, this summary will be very useful for keeping track of our progress!`)
return sb.String()
}
// truncate truncates a string to the specified length.
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "... (truncated)"
}
+805
View File
@@ -0,0 +1,805 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"context"
"fmt"
"net/http"
"os"
"sync"
"sync/atomic"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/chroma"
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// Service configuration constants
const (
// DefaultHTTPTimeout is the default timeout for HTTP requests.
DefaultHTTPTimeout = 30 * time.Second
// ReadyPollInterval is how often WaitReady checks initialization status.
ReadyPollInterval = 50 * time.Millisecond
// StaleQueueSize is the buffer size for background stale verification.
StaleQueueSize = 100
// QueueProcessInterval is how often the background queue processor runs.
QueueProcessInterval = 2 * time.Second
)
// RetrievalStats tracks observation retrieval metrics.
type RetrievalStats struct {
TotalRequests int64 // Total retrieval requests (inject + search)
ObservationsServed int64 // Observations returned to clients
VerifiedStale int64 // Stale observations that passed verification
DeletedInvalid int64 // Invalid observations deleted
SearchRequests int64 // Semantic search requests
ContextInjections int64 // Session-start context injections
}
// Service is the main worker service orchestrator.
type Service struct {
// Version of the worker binary
version string
// Configuration
config *config.Config
// Database
store *sqlite.Store
sessionStore *sqlite.SessionStore
observationStore *sqlite.ObservationStore
summaryStore *sqlite.SummaryStore
promptStore *sqlite.PromptStore
// Domain services
sessionManager *session.Manager
sseBroadcaster *sse.Broadcaster
processor *sdk.Processor
// Vector database
chromaClient *chroma.Client
chromaSync *chroma.Sync
// HTTP server
router *chi.Mux
server *http.Server
startTime time.Time
// Retrieval statistics
retrievalStats RetrievalStats
// Lifecycle
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Initialization state (for deferred init)
ready atomic.Bool
initError error
initMu sync.RWMutex
// Background verification queue for stale observations
staleQueue chan staleVerifyRequest
staleQueueOnce sync.Once
// File watchers for auto-recreation on deletion
dbWatcher *watcher.Watcher
configWatcher *watcher.Watcher
}
// staleVerifyRequest represents a request to verify a stale observation in background
type staleVerifyRequest struct {
observationID int64
cwd string
}
// NewService creates a new worker service with deferred initialization.
// The service starts immediately with health endpoint available,
// while database and SDK initialization happens in the background.
func NewService(version string) (*Service, error) {
cfg := config.Get()
// Create context
ctx, cancel := context.WithCancel(context.Background())
// Create router and SSE broadcaster (lightweight, no dependencies)
router := chi.NewRouter()
sseBroadcaster := sse.NewBroadcaster()
svc := &Service{
version: version,
config: cfg,
sseBroadcaster: sseBroadcaster,
router: router,
ctx: ctx,
cancel: cancel,
startTime: time.Now(),
}
// Setup middleware and routes (health endpoint works immediately)
svc.setupMiddleware()
svc.setupRoutes()
// Start async initialization
go svc.initializeAsync()
return svc, nil
}
// initializeAsync performs heavy initialization in the background.
func (s *Service) initializeAsync() {
log.Info().Msg("Starting async initialization...")
// Ensure data directory, vector-db, and settings exist
if err := config.EnsureAll(); err != nil {
s.setInitError(fmt.Errorf("ensure data dir: %w", err))
return
}
// Initialize database (this includes migrations - can be slow)
store, err := sqlite.NewStore(sqlite.StoreConfig{
Path: s.config.DBPath,
MaxConns: s.config.MaxConns,
WALMode: true,
})
if err != nil {
s.setInitError(fmt.Errorf("init database: %w", err))
return
}
// Create store wrappers
sessionStore := sqlite.NewSessionStore(store)
observationStore := sqlite.NewObservationStore(store)
summaryStore := sqlite.NewSummaryStore(store)
promptStore := sqlite.NewPromptStore(store)
// Create session manager
sessionManager := session.NewManager(sessionStore)
// Create ChromaDB client for vector search (optional - will be nil if unavailable)
var chromaClient *chroma.Client
var chromaSync *chroma.Sync
chromaCfg := chroma.Config{
Project: "default", // Collection prefix
DataDir: s.config.VectorDBPath,
BatchSize: 100,
}
client, err := chroma.NewClient(chromaCfg)
if err != nil {
log.Warn().Err(err).Msg("ChromaDB client creation failed - vector sync disabled")
} else {
// Connect to ChromaDB (starts the MCP server)
if err := client.Connect(s.ctx); err != nil {
log.Warn().Err(err).Msg("ChromaDB connection failed - vector sync disabled")
} else {
chromaClient = client
chromaSync = chroma.NewSync(client)
log.Info().Msg("ChromaDB client connected - vector sync enabled")
}
}
// Create SDK processor (optional - will be nil if Claude CLI not available)
var processor *sdk.Processor
proc, err := sdk.NewProcessor(observationStore, summaryStore)
if err != nil {
log.Warn().Err(err).Msg("SDK processor not available - observations will be queued but not processed")
} else {
processor = proc
// Set broadcast callback for SSE events
processor.SetBroadcastFunc(func(event map[string]interface{}) {
s.sseBroadcaster.Broadcast(event)
})
log.Info().Msg("SDK processor initialized")
}
// Set all the initialized components
s.initMu.Lock()
s.store = store
s.sessionStore = sessionStore
s.observationStore = observationStore
s.summaryStore = summaryStore
s.promptStore = promptStore
s.sessionManager = sessionManager
s.processor = processor
s.chromaClient = chromaClient
s.chromaSync = chromaSync
s.initMu.Unlock()
// Set vector sync callbacks on processor if both are available
if processor != nil && chromaSync != nil {
processor.SetSyncObservationFunc(func(obs *models.Observation) {
if err := chromaSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to ChromaDB")
}
})
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
if err := chromaSync.SyncSummary(s.ctx, summary); err != nil {
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to ChromaDB")
}
})
}
// Set cleanup callback on observation store to sync deletes to ChromaDB
if observationStore != nil && chromaSync != nil {
observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
if err := chromaSync.DeleteObservations(ctx, deletedIDs); err != nil {
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from ChromaDB")
}
})
}
// Set cleanup callback on prompt store to sync deletes to ChromaDB
if promptStore != nil && chromaSync != nil {
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
if err := chromaSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from ChromaDB")
}
})
}
// Set callback for session deletion
sessionManager.SetOnSessionDeleted(func(id int64) {
s.broadcastProcessingStatus()
})
// Mark as ready
s.ready.Store(true)
log.Info().Msg("Async initialization complete - service ready")
// Start queue processor if SDK processor is available
if processor != nil {
s.wg.Add(1)
go s.processQueue()
}
// Start file watchers for auto-recreation on deletion
s.startWatchers()
}
// startWatchers initializes and starts file watchers for database and config.
func (s *Service) startWatchers() {
// Watch database file for deletion
dbWatcher, err := watcher.New(s.config.DBPath, func() {
log.Warn().Str("path", s.config.DBPath).Msg("Database file deleted, reinitializing...")
s.reinitializeDatabase()
})
if err != nil {
log.Warn().Err(err).Msg("Failed to create database watcher")
} else {
s.dbWatcher = dbWatcher
if err := dbWatcher.Start(); err != nil {
log.Warn().Err(err).Msg("Failed to start database watcher")
} else {
log.Info().Str("path", s.config.DBPath).Msg("Database file watcher started")
}
}
// Watch config file for changes (triggers process exit for restart)
configPath := config.SettingsPath()
configWatcher, err := watcher.New(configPath, func() {
log.Warn().Str("path", configPath).Msg("Config file changed, reloading...")
s.reloadConfig()
})
if err != nil {
log.Warn().Err(err).Msg("Failed to create config watcher")
} else {
s.configWatcher = configWatcher
if err := configWatcher.Start(); err != nil {
log.Warn().Err(err).Msg("Failed to start config watcher")
} else {
log.Info().Str("path", configPath).Msg("Config file watcher started")
}
}
}
// reinitializeDatabase recreates the database after deletion.
func (s *Service) reinitializeDatabase() {
// Block new requests
s.ready.Store(false)
log.Info().Msg("Database reinitialization starting...")
// Get old store references
s.initMu.Lock()
oldStore := s.store
oldSessionManager := s.sessionManager
oldChromaClient := s.chromaClient
s.initMu.Unlock()
// Close old stores
if oldChromaClient != nil {
if err := oldChromaClient.Close(); err != nil {
log.Warn().Err(err).Msg("Error closing old ChromaDB client")
}
}
if oldStore != nil {
if err := oldStore.Close(); err != nil {
log.Warn().Err(err).Msg("Error closing old database")
}
}
// Clear in-memory sessions (they reference old DB IDs)
if oldSessionManager != nil {
oldSessionManager.ShutdownAll(s.ctx)
}
// Ensure data directory, vector-db, and settings exist (may have been deleted)
if err := config.EnsureAll(); err != nil {
s.setInitError(fmt.Errorf("ensure data dir on reinit: %w", err))
return
}
// Create new database
store, err := sqlite.NewStore(sqlite.StoreConfig{
Path: s.config.DBPath,
MaxConns: s.config.MaxConns,
WALMode: true,
})
if err != nil {
s.setInitError(fmt.Errorf("reinit database: %w", err))
return
}
// Create new store wrappers
sessionStore := sqlite.NewSessionStore(store)
observationStore := sqlite.NewObservationStore(store)
summaryStore := sqlite.NewSummaryStore(store)
promptStore := sqlite.NewPromptStore(store)
// Create new session manager
sessionManager := session.NewManager(sessionStore)
// Recreate ChromaDB client
var chromaClient *chroma.Client
var chromaSync *chroma.Sync
chromaCfg := chroma.Config{
Project: "default",
DataDir: s.config.VectorDBPath,
BatchSize: 100,
}
client, err := chroma.NewClient(chromaCfg)
if err != nil {
log.Warn().Err(err).Msg("ChromaDB client creation failed after reinit")
} else {
if err := client.Connect(s.ctx); err != nil {
log.Warn().Err(err).Msg("ChromaDB connection failed after reinit")
} else {
chromaClient = client
chromaSync = chroma.NewSync(client)
log.Info().Msg("ChromaDB client reconnected after reinit")
}
}
// Recreate SDK processor with new stores
var processor *sdk.Processor
proc, err := sdk.NewProcessor(observationStore, summaryStore)
if err != nil {
log.Warn().Err(err).Msg("SDK processor not available after reinit")
} else {
processor = proc
processor.SetBroadcastFunc(func(event map[string]interface{}) {
s.sseBroadcaster.Broadcast(event)
})
}
// Atomically swap all components
s.initMu.Lock()
s.store = store
s.sessionStore = sessionStore
s.observationStore = observationStore
s.summaryStore = summaryStore
s.promptStore = promptStore
s.sessionManager = sessionManager
s.processor = processor
s.chromaClient = chromaClient
s.chromaSync = chromaSync
s.initError = nil
s.initMu.Unlock()
// Set vector sync callbacks on processor if both are available
if processor != nil && chromaSync != nil {
processor.SetSyncObservationFunc(func(obs *models.Observation) {
if err := chromaSync.SyncObservation(s.ctx, obs); err != nil {
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to ChromaDB")
}
})
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
if err := chromaSync.SyncSummary(s.ctx, summary); err != nil {
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to ChromaDB")
}
})
}
// Set cleanup callback on observation store to sync deletes to ChromaDB
if observationStore != nil && chromaSync != nil {
observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
if err := chromaSync.DeleteObservations(ctx, deletedIDs); err != nil {
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from ChromaDB")
}
})
}
// Set cleanup callback on prompt store to sync deletes to ChromaDB
if promptStore != nil && chromaSync != nil {
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
if err := chromaSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from ChromaDB")
}
})
}
// Set callback for session deletion
sessionManager.SetOnSessionDeleted(func(id int64) {
s.broadcastProcessingStatus()
})
// Mark as ready again
s.ready.Store(true)
log.Info().Msg("Database reinitialization complete")
// Broadcast status update
s.sseBroadcaster.Broadcast(map[string]interface{}{
"type": "database_reinitialized",
"message": "Database was recreated after deletion",
})
}
// reloadConfig reloads configuration from disk.
// For now, this triggers a graceful restart by exiting (hooks will restart us).
func (s *Service) reloadConfig() {
log.Info().Msg("Config changed, triggering graceful restart...")
// Broadcast notification
s.sseBroadcaster.Broadcast(map[string]interface{}{
"type": "config_changed",
"message": "Configuration changed, restarting worker...",
})
// Give SSE clients a moment to receive the message
time.Sleep(100 * time.Millisecond)
// Exit cleanly - hooks will restart us with new config
os.Exit(0)
}
// setInitError records an initialization error.
func (s *Service) setInitError(err error) {
s.initMu.Lock()
s.initError = err
s.initMu.Unlock()
log.Error().Err(err).Msg("Async initialization failed")
}
// GetInitError returns any initialization error.
func (s *Service) GetInitError() error {
s.initMu.RLock()
defer s.initMu.RUnlock()
return s.initError
}
// queueStaleVerification queues a stale observation for background verification.
// This is non-blocking - if the queue is full, the request is dropped.
func (s *Service) queueStaleVerification(observationID int64, cwd string) {
// Initialize queue on first use
s.staleQueueOnce.Do(func() {
s.staleQueue = make(chan staleVerifyRequest, StaleQueueSize)
s.wg.Add(1)
go s.processStaleQueue()
})
// Non-blocking send - drop if queue is full
select {
case s.staleQueue <- staleVerifyRequest{observationID: observationID, cwd: cwd}:
// Queued
default:
// Queue full, drop
log.Debug().Int64("id", observationID).Msg("Stale verification queue full, dropping")
}
}
// processStaleQueue processes stale observations in the background.
func (s *Service) processStaleQueue() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
case req := <-s.staleQueue:
s.verifyStaleObservation(req)
}
}
}
// verifyStaleObservation verifies a single stale observation in the background.
func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
// Wait for service to be ready
if !s.ready.Load() {
return
}
// Get observation from DB
s.initMu.RLock()
store := s.observationStore
processor := s.processor
s.initMu.RUnlock()
if store == nil || processor == nil {
return
}
obs, err := store.GetObservationByID(s.ctx, req.observationID)
if err != nil || obs == nil {
return
}
// Verify with Claude CLI (this is slow but we're in background)
if !processor.VerifyObservation(s.ctx, obs, req.cwd) {
// Invalid - delete it
deleted, err := store.DeleteObservations(s.ctx, []int64{obs.ID})
if err == nil && deleted > 0 {
log.Info().
Int64("id", obs.ID).
Str("title", obs.Title.String).
Msg("Background verification: deleted invalid observation")
}
} else {
log.Debug().
Int64("id", obs.ID).
Msg("Background verification: observation still valid")
}
}
// setupMiddleware configures HTTP middleware.
func (s *Service) setupMiddleware() {
s.router.Use(middleware.Logger)
s.router.Use(middleware.Recoverer)
s.router.Use(middleware.Timeout(DefaultHTTPTimeout))
s.router.Use(middleware.RealIP)
}
// setupRoutes configures HTTP routes.
func (s *Service) setupRoutes() {
// Serve Vue dashboard from embedded static files
s.router.Get("/", serveIndex)
s.router.Get("/assets/*", serveAssets)
// Health check (both root and API-prefixed for compatibility)
// Returns 200 immediately so hooks can connect quickly during init
// Also returns version for stale worker detection
s.router.Get("/health", s.handleHealth)
s.router.Get("/api/health", s.handleHealth)
// Version endpoint for hooks to check if worker needs restart
s.router.Get("/api/version", s.handleVersion)
// Readiness check - returns 200 only when fully initialized
s.router.Get("/api/ready", s.handleReady)
// SSE endpoint (works before DB is ready)
s.router.Get("/api/events", s.sseBroadcaster.HandleSSE)
// Routes that require DB to be ready
s.router.Group(func(r chi.Router) {
r.Use(s.requireReady)
// Session routes
r.Post("/api/sessions/init", s.handleSessionInit)
r.Get("/api/sessions", s.handleGetSessionByClaudeID)
r.Post("/sessions/{id}/init", s.handleSessionStart)
r.Post("/api/sessions/observations", s.handleObservation)
r.Post("/api/sessions/subagent-complete", s.handleSubagentComplete)
r.Post("/sessions/{id}/summarize", s.handleSummarize)
// Data routes
r.Get("/api/observations", s.handleGetObservations)
r.Get("/api/summaries", s.handleGetSummaries)
r.Get("/api/prompts", s.handleGetPrompts)
r.Get("/api/projects", s.handleGetProjects)
r.Get("/api/stats", s.handleGetStats)
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
// Context injection
r.Get("/api/context/count", s.handleContextCount)
r.Get("/api/context/inject", s.handleContextInject)
r.Get("/api/context/search", s.handleSearchByPrompt)
})
}
// recordRetrievalStats atomically updates retrieval statistics.
func (s *Service) recordRetrievalStats(served, verified, deleted int64, isSearch bool) {
atomic.AddInt64(&s.retrievalStats.TotalRequests, 1)
atomic.AddInt64(&s.retrievalStats.ObservationsServed, served)
atomic.AddInt64(&s.retrievalStats.VerifiedStale, verified)
atomic.AddInt64(&s.retrievalStats.DeletedInvalid, deleted)
if isSearch {
atomic.AddInt64(&s.retrievalStats.SearchRequests, 1)
} else {
atomic.AddInt64(&s.retrievalStats.ContextInjections, 1)
}
}
// GetRetrievalStats returns a copy of the retrieval stats.
func (s *Service) GetRetrievalStats() RetrievalStats {
return RetrievalStats{
TotalRequests: atomic.LoadInt64(&s.retrievalStats.TotalRequests),
ObservationsServed: atomic.LoadInt64(&s.retrievalStats.ObservationsServed),
VerifiedStale: atomic.LoadInt64(&s.retrievalStats.VerifiedStale),
DeletedInvalid: atomic.LoadInt64(&s.retrievalStats.DeletedInvalid),
SearchRequests: atomic.LoadInt64(&s.retrievalStats.SearchRequests),
ContextInjections: atomic.LoadInt64(&s.retrievalStats.ContextInjections),
}
}
// Start starts the worker service.
// The HTTP server starts immediately; database initialization happens async.
func (s *Service) Start() error {
port := config.GetWorkerPort()
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: s.router,
ReadHeaderTimeout: 10 * time.Second,
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := s.server.ListenAndServe(); err != http.ErrServerClosed {
log.Error().Err(err).Msg("HTTP server error")
}
}()
// Note: Queue processor is started in initializeAsync() after DB is ready
log.Info().
Int("port", port).
Int("pid", getPID()).
Msg("Worker HTTP server started (initialization in progress)")
return nil
}
// processQueue processes the observation queue in the background.
func (s *Service) processQueue() {
defer s.wg.Done()
ticker := time.NewTicker(QueueProcessInterval)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
s.processAllSessions()
}
}
}
// processAllSessions processes pending messages for all active sessions.
func (s *Service) processAllSessions() {
// Get all sessions with pending messages
sessions := s.sessionManager.GetAllSessions()
for _, sess := range sessions {
// Get pending messages
messages := s.sessionManager.DrainMessages(sess.SessionDBID)
if len(messages) == 0 {
continue
}
// Process each message
for _, msg := range messages {
switch msg.Type {
case session.MessageTypeObservation:
if msg.Observation != nil {
err := s.processor.ProcessObservation(
s.ctx,
sess.SDKSessionID,
sess.Project,
msg.Observation.ToolName,
msg.Observation.ToolInput,
msg.Observation.ToolResponse,
msg.Observation.PromptNumber,
msg.Observation.CWD,
)
if err != nil {
log.Error().Err(err).
Str("tool", msg.Observation.ToolName).
Msg("Failed to process observation")
}
}
case session.MessageTypeSummarize:
if msg.Summarize != nil {
err := s.processor.ProcessSummary(
s.ctx,
sess.SessionDBID,
sess.SDKSessionID,
sess.Project,
sess.UserPrompt,
msg.Summarize.LastUserMessage,
msg.Summarize.LastAssistantMessage,
)
if err != nil {
log.Error().Err(err).
Int64("sessionId", sess.SessionDBID).
Msg("Failed to process summary")
}
// Delete session after summary
s.sessionManager.DeleteSession(sess.SessionDBID)
}
}
}
s.broadcastProcessingStatus()
}
}
// Shutdown gracefully shuts down the service.
func (s *Service) Shutdown(ctx context.Context) error {
s.cancel()
// Stop file watchers
if s.dbWatcher != nil {
_ = s.dbWatcher.Stop()
}
if s.configWatcher != nil {
_ = s.configWatcher.Stop()
}
// Shutdown all sessions
s.sessionManager.ShutdownAll(ctx)
// Shutdown HTTP server
if s.server != nil {
if err := s.server.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("HTTP server shutdown error")
}
}
// Close ChromaDB client
if s.chromaClient != nil {
if err := s.chromaClient.Close(); err != nil {
log.Error().Err(err).Msg("ChromaDB close error")
}
}
// Close database
if err := s.store.Close(); err != nil {
log.Error().Err(err).Msg("Database close error")
}
s.wg.Wait()
log.Info().Msg("Worker service shutdown complete")
return nil
}
// broadcastProcessingStatus broadcasts the current processing status.
func (s *Service) broadcastProcessingStatus() {
isProcessing := s.sessionManager.IsAnySessionProcessing()
queueDepth := s.sessionManager.GetTotalQueueDepth()
s.sseBroadcaster.Broadcast(map[string]interface{}{
"type": "processing_status",
"isProcessing": isProcessing,
"queueDepth": queueDepth,
})
}
func getPID() int {
return os.Getpid()
}
+346
View File
@@ -0,0 +1,346 @@
// Package session provides session lifecycle management for claude-mnemonic.
package session
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
"github.com/rs/zerolog/log"
)
// MessageType represents the type of pending message.
type MessageType int
const (
MessageTypeObservation MessageType = iota
MessageTypeSummarize
)
// ObservationData contains data for a tool observation.
type ObservationData struct {
ToolName string
ToolInput interface{}
ToolResponse interface{}
PromptNumber int
CWD string
}
// SummarizeData contains data for a summarize request.
type SummarizeData struct {
LastUserMessage string
LastAssistantMessage string
}
// PendingMessage represents a message queued for SDK processing.
type PendingMessage struct {
Type MessageType
Observation *ObservationData
Summarize *SummarizeData
}
// ActiveSession represents an in-memory active session being processed.
type ActiveSession struct {
SessionDBID int64
ClaudeSessionID string
SDKSessionID string
Project string
UserPrompt string
LastPromptNumber int
StartTime time.Time
CumulativeInputTokens int64
CumulativeOutputTokens int64
// Concurrency control
pendingMessages []PendingMessage
messageMu sync.Mutex
notify chan struct{}
ctx context.Context
cancel context.CancelFunc
generatorActive atomic.Bool
}
// Manager manages active session lifecycles.
type Manager struct {
sessionStore *sqlite.SessionStore
sessions map[int64]*ActiveSession
mu sync.RWMutex
onDeleted func(int64)
}
// NewManager creates a new session manager.
func NewManager(sessionStore *sqlite.SessionStore) *Manager {
return &Manager{
sessionStore: sessionStore,
sessions: make(map[int64]*ActiveSession),
}
}
// SetOnSessionDeleted sets a callback for when a session is deleted.
func (m *Manager) SetOnSessionDeleted(callback func(int64)) {
m.onDeleted = callback
}
// InitializeSession initializes a session, creating it if needed.
func (m *Manager) InitializeSession(ctx context.Context, sessionDBID int64, userPrompt string, promptNumber int) (*ActiveSession, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Check if already active
if session, ok := m.sessions[sessionDBID]; ok {
// Update user prompt for continuation
if userPrompt != "" {
session.UserPrompt = userPrompt
session.LastPromptNumber = promptNumber
}
return session, nil
}
// Fetch from database
dbSession, err := m.sessionStore.GetSessionByID(ctx, sessionDBID)
if err != nil {
return nil, err
}
if dbSession == nil {
return nil, nil
}
// Use provided userPrompt or fall back to database
prompt := userPrompt
if prompt == "" && dbSession.UserPrompt.Valid {
prompt = dbSession.UserPrompt.String
}
// Get prompt counter if not provided
if promptNumber <= 0 {
promptNumber, _ = m.sessionStore.GetPromptCounter(ctx, sessionDBID)
}
// Create session context
sessionCtx, cancel := context.WithCancel(context.Background())
session := &ActiveSession{
SessionDBID: sessionDBID,
ClaudeSessionID: dbSession.ClaudeSessionID,
SDKSessionID: dbSession.SDKSessionID.String,
Project: dbSession.Project,
UserPrompt: prompt,
LastPromptNumber: promptNumber,
StartTime: time.Now(),
pendingMessages: make([]PendingMessage, 0, 32),
notify: make(chan struct{}, 1),
ctx: sessionCtx,
cancel: cancel,
}
m.sessions[sessionDBID] = session
log.Info().
Int64("sessionId", sessionDBID).
Str("project", session.Project).
Str("claudeSessionId", session.ClaudeSessionID).
Msg("Session initialized")
return session, nil
}
// QueueObservation queues an observation for SDK processing.
func (m *Manager) QueueObservation(ctx context.Context, sessionDBID int64, data ObservationData) error {
m.mu.Lock()
session, ok := m.sessions[sessionDBID]
if !ok {
// Auto-initialize from database
m.mu.Unlock()
var err error
session, err = m.InitializeSession(ctx, sessionDBID, "", 0)
if err != nil || session == nil {
return err
}
} else {
m.mu.Unlock()
}
session.messageMu.Lock()
session.pendingMessages = append(session.pendingMessages, PendingMessage{
Type: MessageTypeObservation,
Observation: &data,
})
queueDepth := len(session.pendingMessages)
session.messageMu.Unlock()
// Non-blocking notification
select {
case session.notify <- struct{}{}:
default:
}
log.Info().
Int64("sessionId", sessionDBID).
Str("tool", data.ToolName).
Int("queueDepth", queueDepth).
Msg("Observation queued")
return nil
}
// QueueSummarize queues a summarize request for SDK processing.
func (m *Manager) QueueSummarize(ctx context.Context, sessionDBID int64, lastUserMessage, lastAssistantMessage string) error {
m.mu.Lock()
session, ok := m.sessions[sessionDBID]
if !ok {
// Auto-initialize from database
m.mu.Unlock()
var err error
session, err = m.InitializeSession(ctx, sessionDBID, "", 0)
if err != nil || session == nil {
return err
}
} else {
m.mu.Unlock()
}
session.messageMu.Lock()
session.pendingMessages = append(session.pendingMessages, PendingMessage{
Type: MessageTypeSummarize,
Summarize: &SummarizeData{
LastUserMessage: lastUserMessage,
LastAssistantMessage: lastAssistantMessage,
},
})
queueDepth := len(session.pendingMessages)
session.messageMu.Unlock()
// Non-blocking notification
select {
case session.notify <- struct{}{}:
default:
}
log.Info().
Int64("sessionId", sessionDBID).
Int("queueDepth", queueDepth).
Msg("Summarize request queued")
return nil
}
// DeleteSession removes a session and cleans up resources.
func (m *Manager) DeleteSession(sessionDBID int64) {
m.mu.Lock()
session, ok := m.sessions[sessionDBID]
if !ok {
m.mu.Unlock()
return
}
delete(m.sessions, sessionDBID)
m.mu.Unlock()
// Cancel context to stop generator
session.cancel()
duration := time.Since(session.StartTime)
log.Info().
Int64("sessionId", sessionDBID).
Str("project", session.Project).
Dur("duration", duration).
Msg("Session deleted")
// Trigger callback
if m.onDeleted != nil {
m.onDeleted(sessionDBID)
}
}
// ShutdownAll shuts down all active sessions.
func (m *Manager) ShutdownAll(ctx context.Context) {
m.mu.Lock()
sessionIDs := make([]int64, 0, len(m.sessions))
for id := range m.sessions {
sessionIDs = append(sessionIDs, id)
}
m.mu.Unlock()
for _, id := range sessionIDs {
m.DeleteSession(id)
}
log.Info().
Int("count", len(sessionIDs)).
Msg("All sessions shut down")
}
// GetActiveSessionCount returns the number of active sessions.
func (m *Manager) GetActiveSessionCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.sessions)
}
// GetTotalQueueDepth returns the total queue depth across all sessions.
func (m *Manager) GetTotalQueueDepth() int {
m.mu.RLock()
defer m.mu.RUnlock()
total := 0
for _, session := range m.sessions {
session.messageMu.Lock()
total += len(session.pendingMessages)
session.messageMu.Unlock()
}
return total
}
// IsAnySessionProcessing returns true if any session is actively processing.
func (m *Manager) IsAnySessionProcessing() bool {
m.mu.RLock()
defer m.mu.RUnlock()
for _, session := range m.sessions {
// Check for pending messages
session.messageMu.Lock()
hasPending := len(session.pendingMessages) > 0
session.messageMu.Unlock()
if hasPending {
return true
}
// Check for active generator
if session.generatorActive.Load() {
return true
}
}
return false
}
// GetAllSessions returns a copy of all active sessions.
func (m *Manager) GetAllSessions() []*ActiveSession {
m.mu.RLock()
defer m.mu.RUnlock()
sessions := make([]*ActiveSession, 0, len(m.sessions))
for _, session := range m.sessions {
sessions = append(sessions, session)
}
return sessions
}
// DrainMessages drains and returns all pending messages for a session.
func (m *Manager) DrainMessages(sessionDBID int64) []PendingMessage {
m.mu.RLock()
session, ok := m.sessions[sessionDBID]
m.mu.RUnlock()
if !ok {
return nil
}
session.messageMu.Lock()
messages := make([]PendingMessage, len(session.pendingMessages))
copy(messages, session.pendingMessages)
session.pendingMessages = session.pendingMessages[:0]
session.messageMu.Unlock()
return messages
}
+141
View File
@@ -0,0 +1,141 @@
// Package sse provides Server-Sent Events broadcasting for claude-mnemonic.
package sse
import (
"encoding/json"
"fmt"
"net/http"
"sync"
"github.com/rs/zerolog/log"
)
// Client represents a connected SSE client.
type Client struct {
ID string
Writer http.ResponseWriter
Flusher http.Flusher
Done chan struct{}
}
// Broadcaster manages SSE client connections and message broadcasting.
type Broadcaster struct {
clients map[string]*Client
mu sync.RWMutex
nextID int
}
// NewBroadcaster creates a new SSE broadcaster.
func NewBroadcaster() *Broadcaster {
return &Broadcaster{
clients: make(map[string]*Client),
}
}
// AddClient adds a new SSE client connection.
func (b *Broadcaster) AddClient(w http.ResponseWriter) (*Client, error) {
flusher, ok := w.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
b.mu.Lock()
b.nextID++
id := fmt.Sprintf("client-%d", b.nextID)
client := &Client{
ID: id,
Writer: w,
Flusher: flusher,
Done: make(chan struct{}),
}
b.clients[id] = client
clientCount := len(b.clients)
b.mu.Unlock()
log.Debug().
Str("clientId", id).
Int("totalClients", clientCount).
Msg("SSE client connected")
return client, nil
}
// RemoveClient removes a client connection.
func (b *Broadcaster) RemoveClient(client *Client) {
b.mu.Lock()
delete(b.clients, client.ID)
clientCount := len(b.clients)
b.mu.Unlock()
close(client.Done)
log.Debug().
Str("clientId", client.ID).
Int("totalClients", clientCount).
Msg("SSE client disconnected")
}
// Broadcast sends a message to all connected clients.
func (b *Broadcaster) Broadcast(data interface{}) {
jsonData, err := json.Marshal(data)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal SSE data")
return
}
message := fmt.Sprintf("data: %s\n\n", jsonData)
b.mu.RLock()
clients := make([]*Client, 0, len(b.clients))
for _, client := range b.clients {
clients = append(clients, client)
}
b.mu.RUnlock()
for _, client := range clients {
select {
case <-client.Done:
continue
default:
_, err := client.Writer.Write([]byte(message))
if err != nil {
log.Debug().
Str("clientId", client.ID).
Err(err).
Msg("Failed to write to SSE client")
continue
}
client.Flusher.Flush()
}
}
}
// ClientCount returns the number of connected clients.
func (b *Broadcaster) ClientCount() int {
b.mu.RLock()
defer b.mu.RUnlock()
return len(b.clients)
}
// HandleSSE handles an SSE connection request.
func (b *Broadcaster) HandleSSE(w http.ResponseWriter, r *http.Request) {
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
client, err := b.AddClient(w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer b.RemoveClient(client)
// Send initial connection message
fmt.Fprintf(w, "data: {\"type\":\"connected\",\"clientId\":\"%s\"}\n\n", client.ID)
client.Flusher.Flush()
// Wait for client disconnect
<-r.Context().Done()
}
+62
View File
@@ -0,0 +1,62 @@
package worker
import (
"embed"
"io/fs"
"net/http"
"strings"
)
//go:embed static/*
var staticFS embed.FS
// staticSubFS is the static subdirectory filesystem
var staticSubFS fs.FS
func init() {
var err error
staticSubFS, err = fs.Sub(staticFS, "static")
if err != nil {
panic("failed to create sub filesystem: " + err.Error())
}
}
// serveIndex serves the index.html file for the root path
func serveIndex(w http.ResponseWriter, r *http.Request) {
content, err := fs.ReadFile(staticSubFS, "index.html")
if err != nil {
http.Error(w, "Dashboard not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
_, _ = w.Write(content)
}
// serveAssets serves static assets from the embedded filesystem
func serveAssets(w http.ResponseWriter, r *http.Request) {
// Strip the /assets/ prefix and serve the file
path := strings.TrimPrefix(r.URL.Path, "/")
content, err := fs.ReadFile(staticSubFS, path)
if err != nil {
http.Error(w, "Asset not found", http.StatusNotFound)
return
}
// Set content type based on extension
if strings.HasSuffix(path, ".js") {
w.Header().Set("Content-Type", "application/javascript")
} else if strings.HasSuffix(path, ".css") {
w.Header().Set("Content-Type", "text/css")
}
// No caching - always serve fresh content
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
_, _ = w.Write(content)
}
View File
+71
View File
@@ -0,0 +1,71 @@
package worker
import (
"database/sql"
"os"
"testing"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
_ "github.com/mattn/go-sqlite3"
)
// testStore creates a sqlite.Store with a temporary database for testing.
// Uses sqlite.NewStore which runs migrations (requires FTS5).
// Skips the test if FTS5 is not available.
func testStore(t *testing.T) (*sqlite.Store, func()) {
t.Helper()
// First check if FTS5 is available
if !hasFTS5ForTest(t) {
t.Skip("FTS5 not available in this SQLite build")
}
tmpDir, err := os.MkdirTemp("", "claude-mnemonic-test-*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
dbPath := tmpDir + "/test.db"
store, err := sqlite.NewStore(sqlite.StoreConfig{
Path: dbPath,
MaxConns: 1,
WALMode: true,
})
if err != nil {
_ = os.RemoveAll(tmpDir)
t.Fatalf("create store: %v", err)
}
cleanup := func() {
_ = store.Close()
_ = os.RemoveAll(tmpDir)
}
return store, cleanup
}
// hasFTS5ForTest checks if FTS5 is available in the SQLite build.
func hasFTS5ForTest(t *testing.T) bool {
t.Helper()
tmpDir, err := os.MkdirTemp("", "fts5-check-*")
if err != nil {
return false
}
defer func() { _ = os.RemoveAll(tmpDir) }()
dbPath := tmpDir + "/check.db"
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return false
}
defer func() { _ = db.Close() }()
_, err = db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)")
if err != nil {
return false
}
_, _ = db.Exec("DROP TABLE IF EXISTS fts5_test")
return true
}