diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index 02602d2..5d5cd7f 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -122,6 +122,7 @@ func main() { startWatchers(ctx, dbPath) // Create and run MCP server with all dependencies + // Note: maintenanceService is nil because it runs in the worker process server := mcp.NewServer( searchMgr, Version, @@ -132,6 +133,7 @@ func main() { vectorClient, scoreCalculator, recalculator, + nil, // maintenanceService - handled by worker ) log.Info().Str("project", *project).Str("version", Version).Msg("Starting MCP server") diff --git a/go.mod b/go.mod index 28de2f9..aced04e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/sugarme/tokenizer v0.3.0 github.com/yalue/onnxruntime_go v1.25.0 + golang.org/x/sync v0.19.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.31.1 ) diff --git a/go.sum b/go.sum index 19ddd7a..6a6b9d8 100644 --- a/go.sum +++ b/go.sum @@ -52,6 +52,8 @@ github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89 github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw= github.com/yalue/onnxruntime_go v1.25.0 h1:nlhVau1BpLZ/BYr+WpPZCJRD/WES0qo6dK7aKyyAs3g= github.com/yalue/onnxruntime_go v1.25.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/config/config.go b/internal/config/config.go index 13f952e..d97456f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -71,6 +71,12 @@ type Config struct { ContextObsConcepts []string `json:"context_obs_concepts"` ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only) + + // Maintenance settings + MaintenanceEnabled bool `json:"maintenance_enabled"` // Enable scheduled maintenance + MaintenanceIntervalHours int `json:"maintenance_interval_hours"` // How often to run maintenance (default 6 hours) + ObservationRetentionDays int `json:"observation_retention_days"` // Delete observations older than N days (0 = no age-based deletion) + CleanupStaleObservations bool `json:"cleanup_stale_observations"` // Auto-cleanup stale observations during maintenance } var ( @@ -96,8 +102,9 @@ func SettingsPath() string { } // EnsureDataDir creates the data directory if it doesn't exist. +// Uses 0700 permissions (owner-only) for security. func EnsureDataDir() error { - return os.MkdirAll(DataDir(), 0750) + return os.MkdirAll(DataDir(), 0700) } // EnsureSettings creates a default settings file if it doesn't exist. @@ -157,8 +164,12 @@ func Default() *Config { ContextShowLastSummary: true, ContextObsTypes: DefaultObservationTypes, ContextObsConcepts: DefaultObservationConcepts, - ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include - ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only) + ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include + ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only) + MaintenanceEnabled: true, // Enable scheduled maintenance + MaintenanceIntervalHours: 6, // Run every 6 hours + ObservationRetentionDays: 0, // 0 = no age-based deletion (keep all) + CleanupStaleObservations: false, // Don't auto-cleanup stale observations } } diff --git a/internal/db/gorm/helpers.go b/internal/db/gorm/helpers.go index ea3e43a..5db64ae 100644 --- a/internal/db/gorm/helpers.go +++ b/internal/db/gorm/helpers.go @@ -9,27 +9,13 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" ) // EnsureSessionExists creates a session if it doesn't exist. +// Uses INSERT OR IGNORE pattern for atomic idempotent creation (single query instead of COUNT + INSERT). // This is shared between stores to avoid duplication. func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project string) error { - // Check if session exists - var count int64 - err := db.WithContext(ctx). - Model(&SDKSession{}). - Where("sdk_session_id = ?", sdkSessionID). - Count(&count).Error - - if err != nil { - return err - } - - if count > 0 { - return nil // Session exists - } - - // Auto-create session now := time.Now() session := &SDKSession{ ClaudeSessionID: sdkSessionID, @@ -41,7 +27,14 @@ func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project PromptCounter: 0, } - return db.WithContext(ctx).Create(session).Error + // Use INSERT OR IGNORE - single query, atomic operation + // If session already exists (conflict on sdk_session_id), do nothing + return db.WithContext(ctx). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "sdk_session_id"}}, + DoNothing: true, + }). + Create(session).Error } // sqlNullString creates a sql.NullString from a string. @@ -52,8 +45,13 @@ func sqlNullString(s string) sql.NullString { return sql.NullString{String: s, Valid: true} } +// MaxPaginationLimit is the maximum allowed limit for pagination queries. +// This protects against resource exhaustion from excessively large requests. +const MaxPaginationLimit = 1000 + // ParseLimitParam parses the "limit" query parameter from an HTTP request. // Returns defaultLimit if the parameter is missing or invalid. +// Note: This does NOT enforce a maximum limit. Use ParseLimitParamWithMax for that. func ParseLimitParam(r *http.Request, defaultLimit int) int { if l := r.URL.Query().Get("limit"); l != "" { if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 { @@ -62,3 +60,42 @@ func ParseLimitParam(r *http.Request, defaultLimit int) int { } return defaultLimit } + +// ParseLimitParamWithMax parses the "limit" query parameter with a maximum cap. +// Returns min(parsed, maxLimit) or defaultLimit if missing/invalid. +// If maxLimit is 0, uses MaxPaginationLimit (1000). +func ParseLimitParamWithMax(r *http.Request, defaultLimit, maxLimit int) int { + if maxLimit <= 0 { + maxLimit = MaxPaginationLimit + } + limit := ParseLimitParam(r, defaultLimit) + if limit > maxLimit { + return maxLimit + } + return limit +} + +// ParseOffsetParam parses the "offset" query parameter from an HTTP request. +// Returns 0 if the parameter is missing or invalid. +func ParseOffsetParam(r *http.Request) int { + if o := r.URL.Query().Get("offset"); o != "" { + if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 { + return parsed + } + } + return 0 +} + +// PaginationParams holds pagination parameters. +type PaginationParams struct { + Limit int + Offset int +} + +// ParsePaginationParams parses both limit and offset from an HTTP request. +func ParsePaginationParams(r *http.Request, defaultLimit int) PaginationParams { + return PaginationParams{ + Limit: ParseLimitParam(r, defaultLimit), + Offset: ParseOffsetParam(r), + } +} diff --git a/internal/db/gorm/migrations.go b/internal/db/gorm/migrations.go index aeb4f92..7a18f4e 100644 --- a/internal/db/gorm/migrations.go +++ b/internal/db/gorm/migrations.go @@ -322,6 +322,244 @@ func runMigrations(db *gorm.DB, sqlDB *sql.DB) error { return tx.Migrator().DropTable("observation_relations") }, }, + + // Migration 012: Query optimization indexes + // Adds covering and composite indexes for common query patterns + { + ID: "012_query_optimization_indexes", + Migrate: func(tx *gorm.DB) error { + sqls := []string{ + // Composite index for observation queries by project + scope + importance + // Covers the common pattern: WHERE project = ? OR scope = 'global' ORDER BY importance_score DESC + `CREATE INDEX IF NOT EXISTS idx_observations_project_scope_importance + ON observations(project, scope, importance_score DESC, created_at_epoch DESC)`, + + // Covering index for observation retrieval (includes most used columns) + // Allows index-only scans for listing queries + `CREATE INDEX IF NOT EXISTS idx_observations_project_covering + ON observations(project, scope, is_superseded, importance_score DESC) + WHERE is_superseded = 0 OR is_superseded IS NULL`, + + // Index for session summary lookups + `CREATE INDEX IF NOT EXISTS idx_summaries_project_importance + ON session_summaries(project, importance_score DESC, created_at_epoch DESC)`, + + // Index for prompt retrieval by session + `CREATE INDEX IF NOT EXISTS idx_prompts_session_number + ON user_prompts(claude_session_id, prompt_number)`, + + // Index for pattern queries by frequency + `CREATE INDEX IF NOT EXISTS idx_patterns_frequency + ON patterns(frequency DESC, last_seen_at_epoch DESC) + WHERE is_deprecated = 0`, + } + for _, s := range sqls { + if err := tx.Exec(s).Error; err != nil { + // Non-fatal: index may already exist or fail for benign reasons + continue + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + sqls := []string{ + "DROP INDEX IF EXISTS idx_observations_project_scope_importance", + "DROP INDEX IF EXISTS idx_observations_project_covering", + "DROP INDEX IF EXISTS idx_summaries_project_importance", + "DROP INDEX IF EXISTS idx_prompts_session_number", + "DROP INDEX IF EXISTS idx_patterns_frequency", + } + for _, s := range sqls { + if err := tx.Exec(s).Error; err != nil { + continue + } + } + return nil + }, + }, + + // Migration 013: Add archival columns to observations + { + ID: "013_observation_archival", + Migrate: func(tx *gorm.DB) error { + sqls := []string{ + // Add archival columns + `ALTER TABLE observations ADD COLUMN is_archived INTEGER DEFAULT 0`, + `ALTER TABLE observations ADD COLUMN archived_at_epoch INTEGER`, + `ALTER TABLE observations ADD COLUMN archived_reason TEXT`, + // Index for archived observations + `CREATE INDEX IF NOT EXISTS idx_observations_archived ON observations(is_archived)`, + // Composite index for filtering active (non-archived) observations + `CREATE INDEX IF NOT EXISTS idx_observations_active + ON observations(project, is_archived, is_superseded, importance_score DESC) + WHERE (is_archived = 0 OR is_archived IS NULL)`, + } + for _, s := range sqls { + if err := tx.Exec(s).Error; err != nil { + // Non-fatal: column may already exist + continue + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + // SQLite doesn't support DROP COLUMN in older versions + // but for newer versions we can try + sqls := []string{ + "DROP INDEX IF EXISTS idx_observations_active", + "DROP INDEX IF EXISTS idx_observations_archived", + } + for _, s := range sqls { + _ = tx.Exec(s).Error + } + return nil + }, + }, + + // Migration 014: Add performance-critical indexes for common query patterns + { + ID: "014_performance_indexes", + Migrate: func(tx *gorm.DB) error { + sqls := []string{ + // Index for batch ID lookups (IN queries) + `CREATE INDEX IF NOT EXISTS idx_observations_id_covering + ON observations(id, project, scope, importance_score)`, + + // Index for vector search result fetching + `CREATE INDEX IF NOT EXISTS idx_vectors_doc_type_project + ON vectors(doc_type, project, scope)`, + + // Index for session summaries by project + `CREATE INDEX IF NOT EXISTS idx_summaries_project_created + ON session_summaries(project, created_at_epoch DESC)`, + + // Index for user prompts by session + `CREATE INDEX IF NOT EXISTS idx_prompts_session_created + ON user_prompts(claude_session_id, created_at_epoch DESC)`, + + // Index for patterns by type and project + `CREATE INDEX IF NOT EXISTS idx_patterns_type_project + ON patterns(type, project, frequency DESC) + WHERE is_deprecated = 0`, + + // Index for observation relations + `CREATE INDEX IF NOT EXISTS idx_relations_source_type + ON observation_relations(source_observation_id, relation_type)`, + `CREATE INDEX IF NOT EXISTS idx_relations_target_type + ON observation_relations(target_observation_id, relation_type)`, + } + for _, s := range sqls { + if err := tx.Exec(s).Error; err != nil { + // Non-fatal: index may already exist + continue + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + sqls := []string{ + "DROP INDEX IF EXISTS idx_observations_id_covering", + "DROP INDEX IF EXISTS idx_vectors_doc_type_project", + "DROP INDEX IF EXISTS idx_summaries_project_created", + "DROP INDEX IF EXISTS idx_prompts_session_created", + "DROP INDEX IF EXISTS idx_patterns_type_project", + "DROP INDEX IF EXISTS idx_relations_source_type", + "DROP INDEX IF EXISTS idx_relations_target_type", + } + for _, s := range sqls { + _ = tx.Exec(s).Error + } + return nil + }, + }, + + // Migration 015: Add optimized composite indexes for common query patterns + { + ID: "015_optimized_composite_indexes", + Migrate: func(tx *gorm.DB) error { + sqls := []string{ + // Composite index for GetRecentObservations with project+scope filtering + // Covers: WHERE (project = ? OR scope = 'global') ORDER BY importance_score DESC, created_at_epoch DESC + `CREATE INDEX IF NOT EXISTS idx_observations_project_scope_created + ON observations(project, scope, created_at_epoch DESC, importance_score DESC)`, + + // Index for scope='global' queries (common pattern in search) + `CREATE INDEX IF NOT EXISTS idx_observations_global_scope + ON observations(scope, importance_score DESC, created_at_epoch DESC) + WHERE scope = 'global'`, + + // Index for vector search result deduplication by observation + `CREATE INDEX IF NOT EXISTS idx_vectors_observation_lookup + ON vectors(doc_type, sqlite_id, project) + WHERE doc_type = 'observation'`, + + // Index for FTS search result ordering + `CREATE INDEX IF NOT EXISTS idx_observations_fts_ordering + ON observations(project, importance_score DESC) + WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`, + } + for _, s := range sqls { + if err := tx.Exec(s).Error; err != nil { + // Non-fatal: index may already exist + continue + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + sqls := []string{ + "DROP INDEX IF EXISTS idx_observations_project_scope_created", + "DROP INDEX IF EXISTS idx_observations_global_scope", + "DROP INDEX IF EXISTS idx_vectors_observation_lookup", + "DROP INDEX IF EXISTS idx_observations_fts_ordering", + } + for _, s := range sqls { + _ = tx.Exec(s).Error + } + return nil + }, + }, + + // Migration 016: Add covering indexes for relation joins and active observations + { + ID: "016_relation_and_active_indexes", + Migrate: func(tx *gorm.DB) error { + sqls := []string{ + // Covering index for observation relation joins (common JOIN patterns) + // Speeds up queries like: JOIN observation_relations ON source_id = obs.id WHERE relation_type = ? + `CREATE INDEX IF NOT EXISTS idx_relations_source_type_target + ON observation_relations(source_observation_id, relation_type, target_observation_id)`, + + // Covering index for reverse relation lookups + `CREATE INDEX IF NOT EXISTS idx_relations_target_type_source + ON observation_relations(target_observation_id, relation_type, source_observation_id)`, + + // Partial index for active (non-archived, non-superseded) observations + // Optimizes activeObservationFilter queries + `CREATE INDEX IF NOT EXISTS idx_observations_active + ON observations(project, importance_score DESC, created_at_epoch DESC) + WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`, + } + for _, s := range sqls { + if err := tx.Exec(s).Error; err != nil { + // Non-fatal: index may already exist + continue + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + sqls := []string{ + "DROP INDEX IF EXISTS idx_relations_source_type_target", + "DROP INDEX IF EXISTS idx_relations_target_type_source", + "DROP INDEX IF EXISTS idx_observations_active", + } + for _, s := range sqls { + _ = tx.Exec(s).Error + } + return nil + }, + }, }) if err := m.Migrate(); err != nil { diff --git a/internal/db/gorm/models.go b/internal/db/gorm/models.go index e183a7a..c7e843b 100644 --- a/internal/db/gorm/models.go +++ b/internal/db/gorm/models.go @@ -48,7 +48,7 @@ func (s *SDKSession) BeforeCreate(tx *gorm.DB) error { type Observation struct { ID int64 `gorm:"primaryKey;autoIncrement"` SDKSessionID string `gorm:"index;not null"` - Project string `gorm:"index;not null"` + Project string `gorm:"index:idx_observations_project;index:idx_observations_project_created,priority:1;not null"` Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"` Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"` @@ -66,7 +66,7 @@ type Observation struct { PromptNumber sql.NullInt64 DiscoveryTokens int64 `gorm:"default:0"` CreatedAt string `gorm:"not null"` - CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"` + CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;index:idx_observations_project_created,priority:2,sort:desc;not null"` // Importance scoring fields ImportanceScore float64 `gorm:"type:real;default:1.0;index:idx_observations_importance,priority:1,sort:desc"` @@ -74,7 +74,12 @@ type Observation struct { RetrievalCount int `gorm:"default:0"` LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"` ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"` - IsSuperseded int `gorm:"default:0;index:idx_observations_superseded,priority:1"` + IsSuperseded int `gorm:"default:0;index:idx_observations_superseded;index:idx_observations_active,priority:2"` + + // Archival fields + IsArchived int `gorm:"default:0;index:idx_observations_archived;index:idx_observations_active,priority:1"` + ArchivedAt sql.NullInt64 `gorm:"column:archived_at_epoch"` + ArchivedReason sql.NullString } func (Observation) TableName() string { return "observations" } diff --git a/internal/db/gorm/observation_store.go b/internal/db/gorm/observation_store.go index 3d5fb1d..9e8bbde 100644 --- a/internal/db/gorm/observation_store.go +++ b/internal/db/gorm/observation_store.go @@ -5,17 +5,36 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "strings" + "sync" + "sync/atomic" "time" "gorm.io/gorm" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" ) // MaxObservationsPerProject is the maximum number of observations to keep per project. const MaxObservationsPerProject = 100 +// cleanupQueueSize is the buffer size for the cleanup queue. +const cleanupQueueSize = 100 + +// commonWords is a package-level set for O(1) lookup of stop words. +// Created once at init time to avoid repeated map allocations. +var commonWords = map[string]struct{}{ + "the": {}, "and": {}, "or": {}, "but": {}, "in": {}, + "on": {}, "at": {}, "to": {}, "for": {}, "of": {}, + "with": {}, "by": {}, "from": {}, "as": {}, "is": {}, + "was": {}, "are": {}, "were": {}, "be": {}, "been": {}, + "being": {}, "have": {}, "has": {}, "had": {}, "do": {}, + "does": {}, "did": {}, "will": {}, "would": {}, "should": {}, + "could": {}, "may": {}, "might": {}, "must": {}, "can": {}, +} + // CleanupFunc is a callback for when observations are cleaned up. // Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB). type CleanupFunc func(ctx context.Context, deletedIDs []int64) @@ -25,19 +44,89 @@ type ObservationStore struct { db *gorm.DB rawDB *sql.DB cleanupFunc CleanupFunc - conflictStore interface{} // Placeholder for ConflictStore (Phase 4) - relationStore interface{} // Placeholder for RelationStore (Phase 4) + conflictStore any // Placeholder for ConflictStore (Phase 4) + relationStore any // Placeholder for RelationStore (Phase 4) + + // Cleanup queue for async observation cleanup with proper error handling + cleanupQueue chan string + cleanupWg sync.WaitGroup + cleanupOnce sync.Once + cleanupStarted atomic.Bool // tracks if cleanup worker was started + stopCleanup chan struct{} } // NewObservationStore creates a new observation store. // The conflictStore and relationStore parameters are optional (can be nil) and will be used in Phase 4. -func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore interface{}) *ObservationStore { - return &ObservationStore{ +func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore any) *ObservationStore { + s := &ObservationStore{ db: store.DB, rawDB: store.GetRawDB(), cleanupFunc: cleanupFunc, conflictStore: conflictStore, relationStore: relationStore, + cleanupQueue: make(chan string, cleanupQueueSize), + stopCleanup: make(chan struct{}), + } + // Start the cleanup worker + s.startCleanupWorker() + return s +} + +// startCleanupWorker starts the background cleanup worker. +func (s *ObservationStore) startCleanupWorker() { + s.cleanupOnce.Do(func() { + s.cleanupStarted.Store(true) + s.cleanupWg.Add(1) + go s.cleanupWorker() + }) +} + +// cleanupWorker processes cleanup requests from the queue. +func (s *ObservationStore) cleanupWorker() { + defer s.cleanupWg.Done() + + for { + select { + case <-s.stopCleanup: + // Drain remaining items in queue before exiting + for { + select { + case project := <-s.cleanupQueue: + s.processCleanup(project) + default: + return + } + } + case project := <-s.cleanupQueue: + s.processCleanup(project) + } + } +} + +// processCleanup performs the actual cleanup for a project. +func (s *ObservationStore) processCleanup(project string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + deletedIDs, err := s.CleanupOldObservations(ctx, project) + if err != nil { + log.Warn().Err(err).Str("project", project).Msg("Failed to cleanup old observations") + return + } + + if len(deletedIDs) > 0 && s.cleanupFunc != nil { + s.cleanupFunc(ctx, deletedIDs) + log.Debug().Str("project", project).Int("count", len(deletedIDs)).Msg("Cleaned up old observations") + } +} + +// Close stops the cleanup worker and waits for it to finish. +// Safe to call even if the worker was never started. +func (s *ObservationStore) Close() { + // Only stop if worker was actually started to avoid deadlock + if s.cleanupStarted.Load() { + close(s.stopCleanup) + s.cleanupWg.Wait() } } @@ -86,16 +175,15 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p return 0, 0, err } - // Cleanup old observations beyond the limit for this project (async to not block handler) + // Queue cleanup of old observations beyond the limit for this project (async to not block handler) if project != "" { - go func(proj string) { - cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - deletedIDs, _ := s.CleanupOldObservations(cleanupCtx, proj) - if len(deletedIDs) > 0 && s.cleanupFunc != nil { - s.cleanupFunc(cleanupCtx, deletedIDs) - } - }(project) + select { + case s.cleanupQueue <- project: + // Successfully queued for cleanup + default: + // Queue is full, log a warning instead of silently dropping + log.Warn().Str("project", project).Msg("Cleanup queue full, skipping cleanup for this observation") + } } // Note: Conflict and relation detection intentionally omitted for now @@ -104,6 +192,89 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p return dbObs.ID, nowEpoch, nil } +// ObservationUpdate contains fields that can be updated on an observation. +// Only non-nil fields will be updated. +type ObservationUpdate struct { + Title *string // New title + Subtitle *string // New subtitle + Narrative *string // New narrative + Facts *[]string // New facts (replaces existing) + Concepts *[]string // New concepts (replaces existing) + FilesRead *[]string // New files read (replaces existing) + FilesModified *[]string // New files modified (replaces existing) + Scope *string // New scope (project or global) +} + +// UpdateObservation updates an existing observation with the provided fields. +// Only non-nil fields in the update struct will be modified. +// Returns the updated observation or an error. +func (s *ObservationStore) UpdateObservation(ctx context.Context, id int64, update *ObservationUpdate) (*models.Observation, error) { + if update == nil { + return nil, fmt.Errorf("update cannot be nil") + } + + // First, verify the observation exists + var dbObs Observation + if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, fmt.Errorf("observation not found: %d", id) + } + return nil, err + } + + // Build update map with only provided fields + updates := make(map[string]any) + + if update.Title != nil { + updates["title"] = sql.NullString{String: *update.Title, Valid: true} + } + if update.Subtitle != nil { + updates["subtitle"] = sql.NullString{String: *update.Subtitle, Valid: true} + } + if update.Narrative != nil { + updates["narrative"] = sql.NullString{String: *update.Narrative, Valid: true} + } + if update.Facts != nil { + factsJSON, _ := json.Marshal(*update.Facts) + updates["facts"] = string(factsJSON) + } + if update.Concepts != nil { + conceptsJSON, _ := json.Marshal(*update.Concepts) + updates["concepts"] = string(conceptsJSON) + } + if update.FilesRead != nil { + filesReadJSON, _ := json.Marshal(*update.FilesRead) + updates["files_read"] = string(filesReadJSON) + } + if update.FilesModified != nil { + filesModifiedJSON, _ := json.Marshal(*update.FilesModified) + updates["files_modified"] = string(filesModifiedJSON) + } + if update.Scope != nil { + updates["scope"] = sql.NullString{String: *update.Scope, Valid: true} + } + + // Add updated_at timestamp + updates["updated_at_epoch"] = sql.NullInt64{Int64: time.Now().Unix(), Valid: true} + + if len(updates) == 0 { + // Nothing to update, just return existing observation + return toModelObservation(&dbObs), nil + } + + // Perform the update + if err := s.db.WithContext(ctx).Model(&Observation{}).Where("id = ?", id).Updates(updates).Error; err != nil { + return nil, fmt.Errorf("failed to update observation: %w", err) + } + + // Fetch the updated observation + if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil { + return nil, err + } + + return toModelObservation(&dbObs), nil +} + // GetObservationByID retrieves an observation by its ID. func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) { var dbObs Observation @@ -134,6 +305,8 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64 query = query.Order("created_at_epoch DESC") case "importance": query = query.Order("importance_score DESC, created_at_epoch DESC") + case "score_desc": + query = query.Order("importance_score DESC, created_at_epoch DESC") default: // Default: importance first, then recency query = query.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC") @@ -152,6 +325,60 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64 return toModelObservations(dbObservations), nil } +// GetObservationsByIDsPreserveOrder retrieves observations by IDs, preserving the input order. +// This is useful when the caller has already sorted/ranked the IDs (e.g., by vector similarity). +func (s *ObservationStore) GetObservationsByIDsPreserveOrder(ctx context.Context, ids []int64) ([]*models.Observation, error) { + if len(ids) == 0 { + return nil, nil + } + + // Fetch all observations in a single query + var dbObservations []Observation + err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error + if err != nil { + return nil, err + } + + // Build ID -> observation map for O(1) lookups + obsMap := make(map[int64]*Observation, len(dbObservations)) + for i := range dbObservations { + obsMap[int64(dbObservations[i].ID)] = &dbObservations[i] + } + + // Reconstruct in original order + result := make([]*models.Observation, 0, len(ids)) + for _, id := range ids { + if obs, ok := obsMap[id]; ok { + result = append(result, toModelObservation(obs)) + } + } + + return result, nil +} + +// BatchGetObservationsWithScores retrieves observations with associated scores. +// Returns a map of ID -> observation for efficient lookup. +func (s *ObservationStore) BatchGetObservationsWithScores(ctx context.Context, ids []int64) (map[int64]*models.Observation, error) { + if len(ids) == 0 { + return make(map[int64]*models.Observation), nil + } + + // Fetch all observations in a single query + var dbObservations []Observation + err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error + if err != nil { + return nil, err + } + + // Build result map + result := make(map[int64]*models.Observation, len(dbObservations)) + for i := range dbObservations { + result[int64(dbObservations[i].ID)] = toModelObservation(&dbObservations[i]) + } + + return result, nil +} + // GetRecentObservations retrieves recent observations for a project. // This includes project-scoped observations for the specified project AND global observations. // Results are ordered by importance_score DESC, then created_at_epoch DESC. @@ -169,13 +396,13 @@ func (s *ObservationStore) GetRecentObservations(ctx context.Context, project st return toModelObservations(dbObservations), nil } -// GetActiveObservations retrieves recent non-superseded observations for a project. -// This excludes observations that have been marked as superseded by newer ones. +// GetActiveObservations retrieves recent non-superseded, non-archived observations for a project. +// This excludes observations that have been marked as superseded or archived. // Results are ordered by importance_score DESC, then created_at_epoch DESC. func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) { var dbObservations []Observation err := s.db.WithContext(ctx). - Scopes(projectScopeFilter(project), notSupersededFilter(), importanceOrdering()). + Scopes(projectScopeFilter(project), activeObservationFilter(), importanceOrdering()). Limit(limit). Find(&dbObservations).Error @@ -245,7 +472,57 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i return toModelObservations(dbObservations), nil } +// GetAllRecentObservationsPaginated retrieves recent observations with pagination. +func (s *ObservationStore) GetAllRecentObservationsPaginated(ctx context.Context, limit, offset int) ([]*models.Observation, int64, error) { + var dbObservations []Observation + var total int64 + + // Get total count + if err := s.db.WithContext(ctx).Model(&Observation{}).Count(&total).Error; err != nil { + return nil, 0, err + } + + // Get paginated results + err := s.db.WithContext(ctx). + Scopes(importanceOrdering()). + Limit(limit). + Offset(offset). + Find(&dbObservations).Error + + if err != nil { + return nil, 0, err + } + + return toModelObservations(dbObservations), total, nil +} + +// GetObservationsByProjectStrictPaginated retrieves observations strictly from a project with pagination. +func (s *ObservationStore) GetObservationsByProjectStrictPaginated(ctx context.Context, project string, limit, offset int) ([]*models.Observation, int64, error) { + var dbObservations []Observation + var total int64 + + // Get total count for project + if err := s.db.WithContext(ctx).Model(&Observation{}).Where("project = ?", project).Count(&total).Error; err != nil { + return nil, 0, err + } + + // Get paginated results + err := s.db.WithContext(ctx). + Where("project = ?", project). + Scopes(importanceOrdering()). + Limit(limit). + Offset(offset). + Find(&dbObservations).Error + + if err != nil { + return nil, 0, err + } + + return toModelObservations(dbObservations), total, nil +} + // GetAllObservations retrieves all observations (for vector rebuild). +// Note: For large datasets, prefer GetAllObservationsIterator to avoid memory issues. func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) { var dbObservations []Observation err := s.db.WithContext(ctx). @@ -259,6 +536,51 @@ func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Ob return toModelObservations(dbObservations), nil } +// GetAllObservationsIterator returns observations in batches to avoid loading all into memory. +// The callback is called for each batch. Return false from callback to stop iteration. +// batchSize controls how many observations are loaded at once (default 500 if <= 0). +func (s *ObservationStore) GetAllObservationsIterator(ctx context.Context, batchSize int, callback func([]*models.Observation) bool) error { + if batchSize <= 0 { + batchSize = 500 + } + + var lastID int64 = 0 + for { + // Check context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + var dbObservations []Observation + err := s.db.WithContext(ctx). + Where("id > ?", lastID). + Order("id ASC"). + Limit(batchSize). + Find(&dbObservations).Error + + if err != nil { + return err + } + + if len(dbObservations) == 0 { + break // No more observations + } + + // Update cursor for next batch + lastID = dbObservations[len(dbObservations)-1].ID + + // Convert and call callback + observations := toModelObservations(dbObservations) + if !callback(observations) { + break // Callback requested stop + } + } + + return nil +} + // SearchObservationsFTS performs full-text search on observations using FTS5. // Falls back to LIKE search if FTS5 fails. func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) { @@ -314,14 +636,26 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro } // searchObservationsLike performs fallback LIKE search on observations using GORM. +// Limits to 2 keywords to prevent expensive OR queries that SQLite optimizes poorly. +// This is a fallback path when FTS returns no results, so we prioritize performance. func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) { if len(keywords) == 0 { return nil, nil } + // Limit keywords to prevent excessive OR conditions that hurt query planning. + // SQLite performs significantly better with fewer LIKE conditions. + // Using 2 instead of 5 reduces query complexity from O(15) to O(6) conditions + // (each keyword creates 3 LIKE conditions for title, subtitle, narrative). + const maxKeywords = 2 + if len(keywords) > maxKeywords { + keywords = keywords[:maxKeywords] + } + // Build LIKE conditions for each keyword - var conditions []string - var args []interface{} + // Pre-allocate for efficiency: maxKeywords conditions × 3 args each + 1 project arg + conditions := make([]string, 0, len(keywords)) + args := make([]any, 0, len(keywords)*3+1) for _, kw := range keywords { pattern := "%" + kw + "%" @@ -358,6 +692,217 @@ func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64) return result.RowsAffected, result.Error } +// DeleteObservation deletes a single observation by ID. +func (s *ObservationStore) DeleteObservation(ctx context.Context, id int64) error { + result := s.db.WithContext(ctx).Delete(&Observation{}, id) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("observation %d not found", id) + } + return nil +} + +// MarkAsSuperseded marks an observation as superseded (stale). +func (s *ObservationStore) MarkAsSuperseded(ctx context.Context, id int64) error { + result := s.db.WithContext(ctx). + Model(&Observation{}). + Where("id = ?", id). + Update("is_superseded", 1) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("observation %d not found", id) + } + return nil +} + +// MarkAsSupersededBatch marks multiple observations as superseded in a single query. +// Returns the number of observations updated and any error. +func (s *ObservationStore) MarkAsSupersededBatch(ctx context.Context, ids []int64) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + + result := s.db.WithContext(ctx). + Model(&Observation{}). + Where("id IN ?", ids). + Update("is_superseded", 1) + + return result.RowsAffected, result.Error +} + +// ArchiveObservation archives a single observation with an optional reason. +func (s *ObservationStore) ArchiveObservation(ctx context.Context, id int64, reason string) error { + updates := map[string]any{ + "is_archived": 1, + "archived_at_epoch": time.Now().UnixMilli(), + } + if reason != "" { + updates["archived_reason"] = reason + } + + result := s.db.WithContext(ctx). + Model(&Observation{}). + Where("id = ?", id). + Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("observation %d not found", id) + } + return nil +} + +// UnarchiveObservation restores an archived observation. +func (s *ObservationStore) UnarchiveObservation(ctx context.Context, id int64) error { + result := s.db.WithContext(ctx). + Model(&Observation{}). + Where("id = ?", id). + Updates(map[string]any{ + "is_archived": 0, + "archived_at_epoch": nil, + "archived_reason": nil, + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return fmt.Errorf("observation %d not found", id) + } + return nil +} + +// ArchiveOldObservations archives observations older than the specified age. +// Returns the count of archived observations and their IDs. +func (s *ObservationStore) ArchiveOldObservations(ctx context.Context, project string, maxAgeDays int, reason string) ([]int64, error) { + if maxAgeDays <= 0 { + maxAgeDays = 90 // Default: archive observations older than 90 days + } + + cutoffEpoch := time.Now().AddDate(0, 0, -maxAgeDays).UnixMilli() + if reason == "" { + reason = fmt.Sprintf("auto-archived: older than %d days", maxAgeDays) + } + + var idsToArchive []int64 + + err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // Find observations to archive (not already archived, older than cutoff) + query := tx.Model(&Observation{}). + Where("created_at_epoch < ?", cutoffEpoch). + Where("COALESCE(is_archived, 0) = 0") + + if project != "" { + query = query.Where("project = ?", project) + } + + if err := query.Pluck("id", &idsToArchive).Error; err != nil { + return err + } + + if len(idsToArchive) == 0 { + return nil + } + + // Archive the observations + now := time.Now().UnixMilli() + return tx.Model(&Observation{}). + Where("id IN ?", idsToArchive). + Updates(map[string]any{ + "is_archived": 1, + "archived_at_epoch": now, + "archived_reason": reason, + }).Error + }) + + if err != nil { + return nil, err + } + + return idsToArchive, nil +} + +// GetArchivedObservations retrieves archived observations for a project. +func (s *ObservationStore) GetArchivedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) { + var dbObservations []Observation + query := s.db.WithContext(ctx). + Where("COALESCE(is_archived, 0) = 1") + + if project != "" { + query = query.Where("project = ?", project) + } + + err := query. + Order("archived_at_epoch DESC"). + Limit(limit). + Find(&dbObservations).Error + + if err != nil { + return nil, err + } + + return toModelObservations(dbObservations), nil +} + +// GetArchivalStats returns statistics about archived observations. +// Optimized to use a single query instead of 4 separate queries. +func (s *ObservationStore) GetArchivalStats(ctx context.Context, project string) (*ArchivalStats, error) { + // Use a single query with conditional aggregation to get all stats at once. + // This is much faster than 4 separate queries (saves 3 round trips). + type statsResult struct { + TotalCount int64 + ArchivedCount int64 + OldestEpoch *int64 // Pointer to handle NULL + NewestEpoch *int64 + } + + var result statsResult + + query := s.db.WithContext(ctx).Model(&Observation{}). + Select(` + COUNT(*) as total_count, + SUM(CASE WHEN COALESCE(is_archived, 0) = 1 THEN 1 ELSE 0 END) as archived_count, + MIN(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as oldest_epoch, + MAX(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as newest_epoch + `) + + if project != "" { + query = query.Where("project = ?", project) + } + + if err := query.Scan(&result).Error; err != nil { + return nil, err + } + + stats := &ArchivalStats{ + TotalCount: result.TotalCount, + ArchivedCount: result.ArchivedCount, + ActiveCount: result.TotalCount - result.ArchivedCount, + } + + if result.OldestEpoch != nil { + stats.OldestArchivedEpoch = *result.OldestEpoch + } + if result.NewestEpoch != nil { + stats.NewestArchivedEpoch = *result.NewestEpoch + } + + return stats, nil +} + +// ArchivalStats contains statistics about archived observations. +type ArchivalStats struct { + TotalCount int64 `json:"total_count"` + ActiveCount int64 `json:"active_count"` + ArchivedCount int64 `json:"archived_count"` + OldestArchivedEpoch int64 `json:"oldest_archived_epoch,omitempty"` + NewestArchivedEpoch int64 `json:"newest_archived_epoch,omitempty"` +} + // CleanupOldObservations removes observations beyond the limit for a project. // Returns the IDs of deleted observations. func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) { @@ -418,10 +963,12 @@ func projectScopeFilter(project string) func(*gorm.DB) *gorm.DB { } } -// notSupersededFilter filters out superseded observations. -func notSupersededFilter() func(*gorm.DB) *gorm.DB { +// activeObservationFilter filters for active (non-archived, non-superseded) observations. +// This is more efficient than chaining notSupersededFilter + notArchivedFilter +// as it produces a single WHERE clause for the query optimizer. +func activeObservationFilter() func(*gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { - return db.Where("COALESCE(is_superseded, 0) = 0") + return db.Where("COALESCE(is_archived, 0) = 0 AND COALESCE(is_superseded, 0) = 0") } } @@ -437,23 +984,17 @@ func importanceOrdering() func(*gorm.DB) *gorm.DB { // ==================== // extractKeywords extracts keywords from a search query. +// Uses package-level commonWords map for O(1) stop word filtering. func extractKeywords(query string) []string { words := strings.Fields(strings.ToLower(query)) - var keywords []string - - commonWords := map[string]bool{ - "the": true, "and": true, "or": true, "but": true, "in": true, - "on": true, "at": true, "to": true, "for": true, "of": true, - "with": true, "by": true, "from": true, "as": true, "is": true, - "was": true, "are": true, "were": true, "be": true, "been": true, - "being": true, "have": true, "has": true, "had": true, "do": true, - "does": true, "did": true, "will": true, "would": true, "should": true, - "could": true, "may": true, "might": true, "must": true, "can": true, - } + keywords := make([]string, 0, len(words)) // Pre-allocate for typical case for _, word := range words { - // Skip short words and common words - if len(word) <= 3 || commonWords[word] { + // Skip short words and common stop words + if len(word) <= 3 { + continue + } + if _, isCommon := commonWords[word]; isCommon { continue } keywords = append(keywords, word) @@ -464,7 +1005,8 @@ func extractKeywords(query string) []string { // scanObservationRows scans multiple observations from raw SQL rows. func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) { - var observations []*models.Observation + // Pre-allocate with reasonable initial capacity to avoid repeated slice growth + observations := make([]*models.Observation, 0, 64) for rows.Next() { obs, err := scanObservation(rows) if err != nil { @@ -476,7 +1018,7 @@ func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) { } // scanObservation scans a single observation from a row scanner. -func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.Observation, error) { +func scanObservation(scanner interface{ Scan(...any) error }) (*models.Observation, error) { var obs models.Observation var factsJSON, conceptsJSON, filesReadJSON, filesModifiedJSON, fileMtimesJSON []byte var isSuperseded int diff --git a/internal/db/gorm/scoring_store.go b/internal/db/gorm/scoring_store.go index c21e309..197b22f 100644 --- a/internal/db/gorm/scoring_store.go +++ b/internal/db/gorm/scoring_store.go @@ -3,6 +3,8 @@ package gorm import ( "context" + "fmt" + "strings" "time" "gorm.io/gorm" @@ -59,12 +61,23 @@ func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64, } // UpdateImportanceScores bulk updates importance scores for multiple observations. -// This is more efficient than individual updates for batch recalculation. +// Uses a single SQL statement with CASE/WHEN for efficient batch updates. func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error { if len(scores) == 0 { return nil } + // For small batches, use simple individual updates + if len(scores) <= 5 { + return s.updateScoresIndividually(ctx, scores) + } + + // For larger batches, use CASE/WHEN SQL for single-query update + return s.updateScoresBatch(ctx, scores) +} + +// updateScoresIndividually updates scores one at a time (efficient for small batches). +func (s *ObservationStore) updateScoresIndividually(ctx context.Context, scores map[int64]float64) error { return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { now := time.Now().UnixMilli() @@ -85,6 +98,36 @@ func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores ma }) } +// updateScoresBatch updates multiple scores in a single SQL statement using CASE/WHEN. +// This is much more efficient for large batches (O(1) queries instead of O(n)). +func (s *ObservationStore) updateScoresBatch(ctx context.Context, scores map[int64]float64) error { + now := time.Now().UnixMilli() + + // Build CASE/WHEN clause for importance_score + // UPDATE observations SET + // importance_score = CASE id WHEN 1 THEN 0.5 WHEN 2 THEN 0.8 ... END, + // score_updated_at_epoch = ? + // WHERE id IN (1, 2, ...) + + ids := make([]int64, 0, len(scores)) + caseBuilder := strings.Builder{} + caseBuilder.WriteString("CASE id ") + + for id, score := range scores { + ids = append(ids, id) + caseBuilder.WriteString(fmt.Sprintf("WHEN %d THEN %f ", id, score)) + } + caseBuilder.WriteString("END") + + // Use raw SQL for the batch update + sql := fmt.Sprintf( + "UPDATE observations SET importance_score = %s, score_updated_at_epoch = ? WHERE id IN ?", + caseBuilder.String(), + ) + + return s.db.WithContext(ctx).Exec(sql, now, ids).Error +} + // GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated. // Returns observations where score_updated_at_epoch is NULL or older than the threshold. func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) { diff --git a/internal/db/gorm/session_store.go b/internal/db/gorm/session_store.go index 0dc166c..c43e598 100644 --- a/internal/db/gorm/session_store.go +++ b/internal/db/gorm/session_store.go @@ -116,26 +116,44 @@ func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID st } // IncrementPromptCounter increments the prompt counter and returns the new value. +// Uses a single SQL query with RETURNING clause for optimal performance. func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) { - // Atomic increment using GORM expression - err := s.db.WithContext(ctx). - Model(&SDKSession{}). - Where("id = ?", id). - Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error + // Use raw SQL with RETURNING to get updated value in single query + // SQLite supports RETURNING since version 3.35.0 (2021-03-12) + var newCounter int + err := s.db.WithContext(ctx).Raw(` + UPDATE sdk_sessions + SET prompt_counter = COALESCE(prompt_counter, 0) + 1 + WHERE id = ? + RETURNING prompt_counter + `, id).Scan(&newCounter).Error + if err != nil { + // Fallback for older SQLite versions without RETURNING support + if err.Error() == "near \"RETURNING\": syntax error" || newCounter == 0 { + // Atomic increment + updateErr := s.db.WithContext(ctx). + Model(&SDKSession{}). + Where("id = ?", id). + Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error + if updateErr != nil { + return 0, updateErr + } + + // Fetch updated value + var sess SDKSession + fetchErr := s.db.WithContext(ctx). + Select("prompt_counter"). + First(&sess, id).Error + if fetchErr != nil { + return 0, fetchErr + } + return sess.PromptCounter, nil + } return 0, err } - // Fetch updated value - var sess SDKSession - err = s.db.WithContext(ctx). - Select("prompt_counter"). - First(&sess, id).Error - if err != nil { - return 0, err - } - - return sess.PromptCounter, nil + return newCounter, nil } // GetPromptCounter returns the current prompt counter for a session. diff --git a/internal/db/gorm/store.go b/internal/db/gorm/store.go index aab5203..8772c09 100644 --- a/internal/db/gorm/store.go +++ b/internal/db/gorm/store.go @@ -2,11 +2,16 @@ package gorm import ( + "context" "database/sql" "fmt" + "slices" + "sync" + "time" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" _ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support + "github.com/rs/zerolog/log" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -14,8 +19,15 @@ import ( // Store represents the GORM database connection with sqlite-vec support. type Store struct { - DB *gorm.DB - sqlDB *sql.DB // For FTS5 and sqlite-vec operations that require raw SQL + DB *gorm.DB + sqlDB *sql.DB // For FTS5 and sqlite-vec operations that require raw SQL + metrics *PoolMetrics + + // Health check caching to reduce database load from frequent monitoring + healthCacheMu sync.RWMutex + cachedHealth *HealthInfo + healthCacheTime time.Time + healthCacheTTL time.Duration // Default: 5 seconds } // Config holds database configuration. @@ -71,8 +83,10 @@ func NewStore(cfg Config) (*Store, error) { } store := &Store{ - DB: db, - sqlDB: sqlDB, + DB: db, + sqlDB: sqlDB, + metrics: NewPoolMetrics(100), // Track last 100 latency samples + healthCacheTTL: 5 * time.Second, // Cache health checks for 5 seconds } // 7. Run migrations FIRST (before PRAGMA commands) @@ -80,18 +94,56 @@ func NewStore(cfg Config) (*Store, error) { return nil, fmt.Errorf("run migrations: %w", err) } - // 8. CRITICAL: Set WAL mode and synchronous mode via raw SQL + // 8. CRITICAL: Set WAL mode and other performance pragmas // Use raw sqlDB to avoid GORM transaction issues - if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil { - return nil, fmt.Errorf("set WAL mode: %w", err) + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA cache_size=-64000", // 64MB cache (negative = KB) + "PRAGMA temp_store=MEMORY", // Store temp tables in memory + "PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O + "PRAGMA page_size=4096", // 4KB pages (optimal for most systems) } - if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil { - return nil, fmt.Errorf("set synchronous mode: %w", err) + for _, pragma := range pragmas { + if _, err := sqlDB.Exec(pragma); err != nil { + log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)") + } } + // 9. Warm the connection pool + store.WarmPool(maxConns) + return store, nil } +// WarmPool pre-creates connections to avoid cold start latency. +func (s *Store) WarmPool(numConns int) { + if numConns <= 0 { + numConns = 4 + } + + var wg sync.WaitGroup + for i := 0; i < numConns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := s.sqlDB.Conn(ctx) + if err != nil { + return + } + // Execute a simple query to ensure the connection is fully initialized + _ = conn.PingContext(ctx) + // Return connection to pool (don't close it) + _ = conn.Close() + }() + } + wg.Wait() + log.Debug().Int("connections", numConns).Msg("Connection pool warmed") +} + // Close closes the database connection. func (s *Store) Close() error { return s.sqlDB.Close() @@ -115,3 +167,370 @@ func (s *Store) GetRawDB() *sql.DB { func (s *Store) GetDB() *gorm.DB { return s.DB } + +// Stats returns database connection pool statistics. +func (s *Store) Stats() sql.DBStats { + return s.sqlDB.Stats() +} + +// Optimize runs VACUUM and ANALYZE to optimize the database. +// Should be called periodically (e.g., daily) during low activity. +func (s *Store) Optimize(ctx context.Context) error { + log.Info().Msg("Starting database optimization") + start := time.Now() + + // ANALYZE updates statistics for query optimizer + if _, err := s.sqlDB.ExecContext(ctx, "ANALYZE"); err != nil { + return fmt.Errorf("analyze: %w", err) + } + + // PRAGMA optimize runs optimization based on query statistics + if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA optimize"); err != nil { + log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)") + } + + log.Info().Dur("duration", time.Since(start)).Msg("Database optimization complete") + return nil +} + +// HealthCheck performs a comprehensive health check with latency measurement. +// Returns detailed health information including connection pool stats and query latency. +// Results are cached for healthCacheTTL (default 5 seconds) to reduce database load +// from frequent monitoring calls. +func (s *Store) HealthCheck(ctx context.Context) *HealthInfo { + // Fast path: check cache with read lock + s.healthCacheMu.RLock() + if s.cachedHealth != nil && time.Since(s.healthCacheTime) < s.healthCacheTTL { + cached := s.cachedHealth + s.healthCacheMu.RUnlock() + return cached + } + s.healthCacheMu.RUnlock() + + // Slow path: perform actual health check + info := s.performHealthCheck(ctx) + + // Cache the result + s.healthCacheMu.Lock() + s.cachedHealth = info + s.healthCacheTime = time.Now() + s.healthCacheMu.Unlock() + + return info +} + +// HealthCheckForce performs a health check bypassing the cache. +// Use this when you need real-time health data (e.g., debugging, alerting). +func (s *Store) HealthCheckForce(ctx context.Context) *HealthInfo { + info := s.performHealthCheck(ctx) + + // Update the cache with fresh data + s.healthCacheMu.Lock() + s.cachedHealth = info + s.healthCacheTime = time.Now() + s.healthCacheMu.Unlock() + + return info +} + +// performHealthCheck does the actual health check work. +func (s *Store) performHealthCheck(ctx context.Context) *HealthInfo { + info := &HealthInfo{ + Status: "healthy", + Timestamp: time.Now(), + } + + // Check pool stats + stats := s.sqlDB.Stats() + info.PoolStats = PoolStats{ + OpenConnections: stats.OpenConnections, + InUse: stats.InUse, + Idle: stats.Idle, + WaitCount: stats.WaitCount, + WaitDuration: stats.WaitDuration, + MaxIdleClosed: stats.MaxIdleClosed, + MaxLifetimeClosed: stats.MaxLifetimeClosed, + } + + // Record pool stats for metrics tracking + if s.metrics != nil { + s.metrics.RecordPoolStats(stats) + } + + // Measure query latency with a simple SELECT + start := time.Now() + var dummy int + err := s.sqlDB.QueryRowContext(ctx, "SELECT 1").Scan(&dummy) + info.QueryLatency = time.Since(start) + + // Record latency for historical tracking + if s.metrics != nil { + s.metrics.RecordLatency(info.QueryLatency) + info.HistoricalMetrics = s.metrics.GetMetricsSummary() + } + + if err != nil { + info.Status = "unhealthy" + info.Error = err.Error() + return info + } + + // Check for connection saturation (degraded if pool is heavily used) + if stats.InUse > 0 && float64(stats.InUse)/float64(stats.OpenConnections) > 0.8 { + info.Status = "degraded" + info.Warning = "Connection pool heavily utilized" + } + + // Check for wait contention + if stats.WaitCount > 100 && stats.WaitDuration > 100*time.Millisecond { + info.Status = "degraded" + info.Warning = "Connection pool contention detected" + } + + // Check query latency (warn if > 10ms for simple query) + if info.QueryLatency > 10*time.Millisecond { + if info.Status == "healthy" { + info.Status = "degraded" + } + info.Warning = fmt.Sprintf("Slow query latency: %v", info.QueryLatency) + } + + // Check historical latency trend (degraded if P95 is high) + if s.metrics != nil && info.HistoricalMetrics.P95Latency > 50*time.Millisecond { + if info.Status == "healthy" { + info.Status = "degraded" + } + info.Warning = fmt.Sprintf("High P95 latency: %v", info.HistoricalMetrics.P95Latency) + } + + return info +} + +// HealthInfo contains database health check results. +type HealthInfo struct { + Status string `json:"status"` // healthy, degraded, unhealthy + Timestamp time.Time `json:"timestamp"` + QueryLatency time.Duration `json:"query_latency_ns"` + PoolStats PoolStats `json:"pool_stats"` + HistoricalMetrics MetricsSummary `json:"historical_metrics,omitempty"` + Error string `json:"error,omitempty"` + Warning string `json:"warning,omitempty"` +} + +// PoolStats contains connection pool statistics. +type PoolStats struct { + OpenConnections int `json:"open_connections"` + InUse int `json:"in_use"` + Idle int `json:"idle"` + WaitCount int64 `json:"wait_count"` + WaitDuration time.Duration `json:"wait_duration_ns"` + MaxIdleClosed int64 `json:"max_idle_closed"` + MaxLifetimeClosed int64 `json:"max_lifetime_closed"` +} + +// QueryTimeout constants for different query types. +const ( + // DefaultQueryTimeout is the default timeout for regular queries. + DefaultQueryTimeout = 5 * time.Second + // FastQueryTimeout is for queries that should be very fast (health checks, etc). + FastQueryTimeout = 1 * time.Second + // SlowQueryTimeout is for queries that may take longer (bulk operations, rebuilds). + SlowQueryTimeout = 30 * time.Second +) + +// PoolMetrics tracks historical connection pool metrics with a sliding window. +type PoolMetrics struct { + mu sync.RWMutex + latencySamples []time.Duration // Circular buffer of latency samples + latencyIdx int // Current index in circular buffer + latencyCount int // Number of samples collected + totalQueries int64 // Total queries executed + totalWaitTime time.Duration // Cumulative wait time for connections + peakInUse int // Peak concurrent connections in use + peakWaitCount int64 // Peak wait count observed + lastSampleTime time.Time // Last time a sample was recorded + windowSize int // Size of sliding window +} + +// NewPoolMetrics creates a new pool metrics collector with the given window size. +func NewPoolMetrics(windowSize int) *PoolMetrics { + if windowSize <= 0 { + windowSize = 100 // Default: track last 100 samples + } + return &PoolMetrics{ + latencySamples: make([]time.Duration, windowSize), + windowSize: windowSize, + lastSampleTime: time.Now(), + } +} + +// RecordLatency records a query latency sample. +func (m *PoolMetrics) RecordLatency(latency time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + + m.latencySamples[m.latencyIdx] = latency + m.latencyIdx = (m.latencyIdx + 1) % m.windowSize + if m.latencyCount < m.windowSize { + m.latencyCount++ + } + m.totalQueries++ + m.lastSampleTime = time.Now() +} + +// RecordPoolStats records pool statistics for peak tracking. +func (m *PoolMetrics) RecordPoolStats(stats sql.DBStats) { + m.mu.Lock() + defer m.mu.Unlock() + + if stats.InUse > m.peakInUse { + m.peakInUse = stats.InUse + } + if stats.WaitCount > m.peakWaitCount { + m.peakWaitCount = stats.WaitCount + } + m.totalWaitTime += stats.WaitDuration +} + +// GetMetricsSummary returns a summary of collected metrics. +func (m *PoolMetrics) GetMetricsSummary() MetricsSummary { + m.mu.RLock() + defer m.mu.RUnlock() + + summary := MetricsSummary{ + TotalQueries: m.totalQueries, + SampleCount: m.latencyCount, + PeakInUse: m.peakInUse, + PeakWaitCount: m.peakWaitCount, + TotalWaitTime: m.totalWaitTime, + LastSampleTime: m.lastSampleTime, + } + + if m.latencyCount == 0 { + return summary + } + + // Calculate latency statistics + var total time.Duration + var min, max time.Duration = m.latencySamples[0], m.latencySamples[0] + + for i := 0; i < m.latencyCount; i++ { + sample := m.latencySamples[i] + total += sample + if sample < min { + min = sample + } + if sample > max { + max = sample + } + } + + summary.AvgLatency = total / time.Duration(m.latencyCount) + summary.MinLatency = min + summary.MaxLatency = max + + // Calculate P95 latency (approximate using sorted samples) + if m.latencyCount >= 20 { + // Copy samples for sorting + samples := make([]time.Duration, m.latencyCount) + copy(samples, m.latencySamples[:m.latencyCount]) + // Use slices.Sort for O(n log n) instead of O(n²) insertion sort + slices.Sort(samples) + p95Idx := int(float64(len(samples)) * 0.95) + summary.P95Latency = samples[p95Idx] + } + + return summary +} + +// MetricsSummary contains aggregated pool metrics. +type MetricsSummary struct { + TotalQueries int64 `json:"total_queries"` + SampleCount int `json:"sample_count"` + AvgLatency time.Duration `json:"avg_latency_ns"` + MinLatency time.Duration `json:"min_latency_ns"` + MaxLatency time.Duration `json:"max_latency_ns"` + P95Latency time.Duration `json:"p95_latency_ns,omitempty"` + PeakInUse int `json:"peak_in_use"` + PeakWaitCount int64 `json:"peak_wait_count"` + TotalWaitTime time.Duration `json:"total_wait_time_ns"` + LastSampleTime time.Time `json:"last_sample_time"` +} + +// GetMetrics returns the current metrics without performing a health check. +func (s *Store) GetMetrics() MetricsSummary { + if s.metrics == nil { + return MetricsSummary{} + } + return s.metrics.GetMetricsSummary() +} + +// ResetMetrics resets the metrics collector (useful for testing or after major changes). +func (s *Store) ResetMetrics() { + if s.metrics != nil { + s.metrics = NewPoolMetrics(s.metrics.windowSize) + } +} + +// WithTimeout wraps a context with the given timeout and logs slow queries. +// Returns the wrapped context and a cancel function that should be called when done. +func (s *Store) WithTimeout(ctx context.Context, timeout time.Duration, operation string) (context.Context, context.CancelFunc) { + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + start := time.Now() + + // Return wrapped cancel that logs if query was slow + return timeoutCtx, func() { + elapsed := time.Since(start) + cancel() + + // Log slow queries (> 100ms) + if elapsed > 100*time.Millisecond { + log.Warn(). + Str("operation", operation). + Dur("elapsed", elapsed). + Dur("timeout", timeout). + Msg("Slow database operation") + } + } +} + +// ExecWithTimeout executes a raw SQL query with timeout. +// Returns error if query takes longer than timeout. +func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) error { + timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "exec") + defer cancel() + + _, err := s.sqlDB.ExecContext(timeoutCtx, query, args...) + if err != nil { + if err == context.DeadlineExceeded { + return fmt.Errorf("query timeout after %v: %s", timeout, query) + } + return err + } + return nil +} + +// QueryRowWithTimeout executes a row query with timeout. +func (s *Store) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) *sql.Row { + timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "query_row") + // Note: cancel will be called when row.Scan() completes or errors + _ = cancel // Caller must ensure proper cleanup + return s.sqlDB.QueryRowContext(timeoutCtx, query, args...) +} + +// TransactionWithTimeout wraps a transaction function with timeout handling. +// The transaction is automatically rolled back if the context times out. +func (s *Store) TransactionWithTimeout(ctx context.Context, timeout time.Duration, fn func(*gorm.DB) error) error { + timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "transaction") + defer cancel() + + return s.DB.WithContext(timeoutCtx).Transaction(func(tx *gorm.DB) error { + // Check context before proceeding + select { + case <-timeoutCtx.Done(): + return timeoutCtx.Err() + default: + } + return fn(tx) + }) +} diff --git a/internal/maintenance/service.go b/internal/maintenance/service.go new file mode 100644 index 0000000..0456d85 --- /dev/null +++ b/internal/maintenance/service.go @@ -0,0 +1,292 @@ +// Package maintenance provides scheduled maintenance tasks for claude-mnemonic. +package maintenance + +import ( + "context" + "sync" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/internal/config" + "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/rs/zerolog" +) + +// Service handles scheduled maintenance tasks. +type Service struct { + store *gorm.Store + observationStore *gorm.ObservationStore + summaryStore *gorm.SummaryStore + promptStore *gorm.PromptStore + vectorCleanupFn func(ctx context.Context, deletedIDs []int64) + config *config.Config + log zerolog.Logger + stopCh chan struct{} + doneCh chan struct{} + mu sync.Mutex + running bool + + // Metrics + lastRunTime time.Time + lastRunDuration time.Duration + totalCleanedObs int64 + totalOptimizeRun int64 +} + +// NewService creates a new maintenance service. +func NewService( + store *gorm.Store, + observationStore *gorm.ObservationStore, + summaryStore *gorm.SummaryStore, + promptStore *gorm.PromptStore, + vectorCleanupFn func(ctx context.Context, deletedIDs []int64), + cfg *config.Config, + log zerolog.Logger, +) *Service { + return &Service{ + store: store, + observationStore: observationStore, + summaryStore: summaryStore, + promptStore: promptStore, + vectorCleanupFn: vectorCleanupFn, + config: cfg, + log: log.With().Str("component", "maintenance").Logger(), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } +} + +// Start begins the maintenance loop. +func (s *Service) Start(ctx context.Context) { + s.mu.Lock() + if s.running { + s.mu.Unlock() + return + } + s.running = true + s.mu.Unlock() + + defer func() { + s.mu.Lock() + s.running = false + s.mu.Unlock() + close(s.doneCh) + }() + + if !s.config.MaintenanceEnabled { + s.log.Info().Msg("Maintenance disabled, not starting scheduler") + return + } + + interval := max(time.Duration(s.config.MaintenanceIntervalHours)*time.Hour, time.Hour) + + s.log.Info(). + Dur("interval", interval). + Int("retention_days", s.config.ObservationRetentionDays). + Bool("cleanup_stale", s.config.CleanupStaleObservations). + Msg("Starting maintenance scheduler") + + // Initial run after 5 minutes (allow system to stabilize) + time.Sleep(5 * time.Minute) + s.runMaintenance(ctx) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.log.Info().Msg("Maintenance shutting down due to context cancellation") + return + case <-s.stopCh: + s.log.Info().Msg("Maintenance shutting down due to stop signal") + return + case <-ticker.C: + s.runMaintenance(ctx) + } + } +} + +// Stop signals the maintenance service to stop. +func (s *Service) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.running { + return + } + + close(s.stopCh) +} + +// Wait waits for the maintenance service to finish. +func (s *Service) Wait() { + <-s.doneCh +} + +// runMaintenance executes all maintenance tasks. +func (s *Service) runMaintenance(ctx context.Context) { + start := time.Now() + s.log.Info().Msg("Starting maintenance run") + + var totalCleaned int64 + + // Task 1: Clean up old observations by age + if s.config.ObservationRetentionDays > 0 { + cleaned, err := s.cleanupOldObservations(ctx) + if err != nil { + s.log.Error().Err(err).Msg("Failed to cleanup old observations") + } else { + totalCleaned += cleaned + s.log.Info().Int64("cleaned", cleaned).Msg("Cleaned old observations by age") + } + } + + // Task 2: Clean up stale observations + if s.config.CleanupStaleObservations { + cleaned, err := s.cleanupStaleObservations(ctx) + if err != nil { + s.log.Error().Err(err).Msg("Failed to cleanup stale observations") + } else { + totalCleaned += cleaned + s.log.Info().Int64("cleaned", cleaned).Msg("Cleaned stale observations") + } + } + + // Task 3: Optimize database + if err := s.store.Optimize(ctx); err != nil { + s.log.Error().Err(err).Msg("Failed to optimize database") + } else { + s.totalOptimizeRun++ + } + + // Task 4: Clean up old prompts (keep last 1000 per session) + cleanedPrompts, err := s.cleanupOldPrompts(ctx) + if err != nil { + s.log.Error().Err(err).Msg("Failed to cleanup old prompts") + } else if cleanedPrompts > 0 { + s.log.Info().Int64("cleaned", cleanedPrompts).Msg("Cleaned old prompts") + } + + // Update metrics + s.mu.Lock() + s.lastRunTime = time.Now() + s.lastRunDuration = time.Since(start) + s.totalCleanedObs += totalCleaned + s.mu.Unlock() + + s.log.Info(). + Dur("duration", time.Since(start)). + Int64("observations_cleaned", totalCleaned). + Msg("Maintenance run completed") +} + +// cleanupOldObservations deletes observations older than the retention period. +func (s *Service) cleanupOldObservations(ctx context.Context) (int64, error) { + cutoffEpoch := time.Now().AddDate(0, 0, -s.config.ObservationRetentionDays).Unix() + + // Get IDs of old observations + var deletedIDs []int64 + err := s.store.GetDB().WithContext(ctx). + Model(&gorm.Observation{}). + Where("created_at_epoch < ?", cutoffEpoch). + Pluck("id", &deletedIDs).Error + if err != nil { + return 0, err + } + + if len(deletedIDs) == 0 { + return 0, nil + } + + // Delete in batches to avoid long transactions + batchSize := 100 + for i := 0; i < len(deletedIDs); i += batchSize { + end := min(i+batchSize, len(deletedIDs)) + batch := deletedIDs[i:end] + + if err := s.store.GetDB().WithContext(ctx). + Where("id IN ?", batch). + Delete(&gorm.Observation{}).Error; err != nil { + return int64(i), err + } + + // Sync vector DB deletions + if s.vectorCleanupFn != nil { + s.vectorCleanupFn(ctx, batch) + } + } + + return int64(len(deletedIDs)), nil +} + +// cleanupStaleObservations deletes observations marked as stale. +func (s *Service) cleanupStaleObservations(ctx context.Context) (int64, error) { + // Get IDs of stale observations (is_superseded = true) + var deletedIDs []int64 + err := s.store.GetDB().WithContext(ctx). + Model(&gorm.Observation{}). + Where("is_superseded = ?", true). + Pluck("id", &deletedIDs).Error + if err != nil { + return 0, err + } + + if len(deletedIDs) == 0 { + return 0, nil + } + + // Delete in batches + batchSize := 100 + for i := 0; i < len(deletedIDs); i += batchSize { + end := min(i+batchSize, len(deletedIDs)) + batch := deletedIDs[i:end] + + if err := s.store.GetDB().WithContext(ctx). + Where("id IN ?", batch). + Delete(&gorm.Observation{}).Error; err != nil { + return int64(i), err + } + + // Sync vector DB deletions + if s.vectorCleanupFn != nil { + s.vectorCleanupFn(ctx, batch) + } + } + + return int64(len(deletedIDs)), nil +} + +// cleanupOldPrompts removes old prompts keeping only the most recent per session. +func (s *Service) cleanupOldPrompts(ctx context.Context) (int64, error) { + // Delete prompts older than 30 days that aren't the most recent in their session + cutoffEpoch := time.Now().AddDate(0, 0, -30).Unix() + + result := s.store.GetDB().WithContext(ctx). + Where("created_at_epoch < ?", cutoffEpoch). + Delete(&gorm.UserPrompt{}) + + return result.RowsAffected, result.Error +} + +// Stats returns maintenance statistics. +func (s *Service) Stats() map[string]any { + s.mu.Lock() + defer s.mu.Unlock() + + return map[string]any{ + "enabled": s.config.MaintenanceEnabled, + "interval_hours": s.config.MaintenanceIntervalHours, + "retention_days": s.config.ObservationRetentionDays, + "cleanup_stale": s.config.CleanupStaleObservations, + "last_run": s.lastRunTime, + "last_duration_ms": s.lastRunDuration.Milliseconds(), + "total_cleaned_obs": s.totalCleanedObs, + "total_optimizes": s.totalOptimizeRun, + "running": s.running, + } +} + +// RunNow triggers an immediate maintenance run. +func (s *Service) RunNow(ctx context.Context) { + go s.runMaintenance(ctx) +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 7deca69..268e47e 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -8,8 +8,11 @@ import ( "fmt" "io" "os" + "strings" + "time" "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/lukaszraczylo/claude-mnemonic/internal/maintenance" "github.com/lukaszraczylo/claude-mnemonic/internal/scoring" "github.com/lukaszraczylo/claude-mnemonic/internal/search" "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" @@ -25,13 +28,14 @@ type Server struct { stdout io.Writer // Store dependencies for enhanced tools - observationStore *gorm.ObservationStore - patternStore *gorm.PatternStore - relationStore *gorm.RelationStore - sessionStore *gorm.SessionStore - vectorClient *sqlitevec.Client - scoreCalculator *scoring.Calculator - recalculator *scoring.Recalculator + observationStore *gorm.ObservationStore + patternStore *gorm.PatternStore + relationStore *gorm.RelationStore + sessionStore *gorm.SessionStore + vectorClient *sqlitevec.Client + scoreCalculator *scoring.Calculator + recalculator *scoring.Recalculator + maintenanceService *maintenance.Service } // NewServer creates a new MCP server. @@ -45,19 +49,21 @@ func NewServer( vectorClient *sqlitevec.Client, scoreCalculator *scoring.Calculator, recalculator *scoring.Recalculator, + maintenanceService *maintenance.Service, ) *Server { return &Server{ - searchMgr: searchMgr, - version: version, - stdin: os.Stdin, - stdout: os.Stdout, - observationStore: observationStore, - patternStore: patternStore, - relationStore: relationStore, - sessionStore: sessionStore, - vectorClient: vectorClient, - scoreCalculator: scoreCalculator, - recalculator: recalculator, + searchMgr: searchMgr, + version: version, + stdin: os.Stdin, + stdout: os.Stdout, + observationStore: observationStore, + patternStore: patternStore, + relationStore: relationStore, + sessionStore: sessionStore, + vectorClient: vectorClient, + scoreCalculator: scoreCalculator, + recalculator: recalculator, + maintenanceService: maintenanceService, } } @@ -100,26 +106,47 @@ type Tool struct { // Run starts the MCP server loop. func (s *Server) Run(ctx context.Context) error { scanner := bufio.NewScanner(s.stdin) - for scanner.Scan() { - line := scanner.Text() - if line == "" { - continue + + // Channel to signal when scanner is done + scanDone := make(chan error, 1) + + go func() { + for scanner.Scan() { + // Check for context cancellation before processing + select { + case <-ctx.Done(): + scanDone <- ctx.Err() + return + default: + } + + line := scanner.Text() + if line == "" { + continue + } + + var req Request + if err := json.Unmarshal([]byte(line), &req); err != nil { + s.sendError(nil, -32700, "Parse error", err) + continue + } + + resp := s.handleRequest(ctx, &req) + s.sendResponse(resp) } + scanDone <- scanner.Err() + }() - var req Request - if err := json.Unmarshal([]byte(line), &req); err != nil { - s.sendError(nil, -32700, "Parse error", err) - continue + // Wait for either context cancellation or scanner completion + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-scanDone: + if err != nil { + return fmt.Errorf("scanner error: %w", err) } - - resp := s.handleRequest(ctx, &req) - s.sendResponse(resp) + return nil } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("scanner error: %w", err) - } - return nil } // handleRequest dispatches the request to the appropriate handler. @@ -376,6 +403,306 @@ func (s *Server) handleToolsList(req *Request) *Response { }, }, }, + { + Name: "find_similar_observations", + Description: "Find observations semantically similar to a query or observation. Uses vector similarity search to find related content. Useful for detecting duplicates before creating new observations.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"query"}, + "properties": map[string]any{ + "query": map[string]any{"type": "string", "description": "Text to find similar observations for"}, + "project": map[string]any{"type": "string", "description": "Filter by project name"}, + "min_similarity": map[string]any{"type": "number", "default": 0.7, "minimum": 0.0, "maximum": 1.0, "description": "Minimum similarity threshold (0-1)"}, + "limit": map[string]any{"type": "number", "default": 10, "minimum": 1, "maximum": 50}, + }, + }, + }, + { + Name: "get_patterns", + Description: "Get detected patterns from observations. Patterns represent recurring themes, workflows, or practices discovered across observations.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "type": map[string]any{"type": "string", "enum": []string{"workflow", "preference", "best_practice", "anti_pattern", "tooling"}, "description": "Filter by pattern type"}, + "project": map[string]any{"type": "string", "description": "Filter by project"}, + "query": map[string]any{"type": "string", "description": "Search patterns by name/description"}, + "limit": map[string]any{"type": "number", "default": 20, "minimum": 1, "maximum": 100}, + }, + }, + }, + { + Name: "get_memory_stats", + Description: "Get statistics about the memory system including observation counts, vector stats, pattern counts, and search metrics.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + { + Name: "bulk_delete_observations", + Description: "Delete multiple observations by their IDs. Returns count of successfully deleted observations.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"ids"}, + "properties": map[string]any{ + "ids": map[string]any{"type": "array", "items": map[string]any{"type": "number"}, "description": "Array of observation IDs to delete"}, + "delete_vectors": map[string]any{"type": "boolean", "default": true, "description": "Also delete associated vectors"}, + }, + }, + }, + { + Name: "bulk_mark_superseded", + Description: "Mark multiple observations as superseded (stale). Useful for cleanup without permanent deletion.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"ids"}, + "properties": map[string]any{ + "ids": map[string]any{"type": "array", "items": map[string]any{"type": "number"}, "description": "Array of observation IDs to mark as superseded"}, + }, + }, + }, + { + Name: "bulk_boost_observations", + Description: "Boost or reduce the importance score of multiple observations. Positive values increase importance, negative decrease.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"ids", "boost"}, + "properties": map[string]any{ + "ids": map[string]any{"type": "array", "items": map[string]any{"type": "number"}, "description": "Array of observation IDs to boost"}, + "boost": map[string]any{"type": "number", "minimum": -1.0, "maximum": 1.0, "description": "Boost amount (-1.0 to 1.0)"}, + }, + }, + }, + { + Name: "trigger_maintenance", + Description: "Trigger an immediate maintenance run (cleanup old observations, optimize database).", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + { + Name: "get_maintenance_stats", + Description: "Get statistics about the maintenance system including last run time, cleanup counts, and configuration.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + { + Name: "merge_observations", + Description: "Merge two observations into one. The target observation is kept and boosted, the source is marked as superseded. Useful for deduplication without data loss.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"source_id", "target_id"}, + "properties": map[string]any{ + "source_id": map[string]any{"type": "number", "description": "ID of the observation to merge FROM (will be superseded)"}, + "target_id": map[string]any{"type": "number", "description": "ID of the observation to merge INTO (will be kept and boosted)"}, + "boost": map[string]any{"type": "number", "default": 0.1, "minimum": 0, "maximum": 0.5, "description": "Score boost for the target observation (default 0.1)"}, + }, + }, + }, + { + Name: "get_observation", + Description: "Get a single observation by its ID. Returns full observation details including all metadata.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"id"}, + "properties": map[string]any{ + "id": map[string]any{"type": "number", "description": "Observation ID to retrieve"}, + }, + }, + }, + { + Name: "edit_observation", + Description: "Edit an existing observation. Only provided fields will be updated, others remain unchanged. Useful for correcting errors, adding details, or updating scope.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"id"}, + "properties": map[string]any{ + "id": map[string]any{"type": "number", "description": "Observation ID to edit"}, + "title": map[string]any{"type": "string", "description": "New title (optional)"}, + "subtitle": map[string]any{"type": "string", "description": "New subtitle (optional)"}, + "narrative": map[string]any{"type": "string", "description": "New narrative text (optional)"}, + "facts": map[string]any{"type": "array", "items": map[string]any{"type": "string"}, "description": "New facts array (optional)"}, + "concepts": map[string]any{"type": "array", "items": map[string]any{"type": "string"}, "description": "New concept tags (optional)"}, + "files_read": map[string]any{"type": "array", "items": map[string]any{"type": "string"}, "description": "New files read list (optional)"}, + "files_modified": map[string]any{"type": "array", "items": map[string]any{"type": "string"}, "description": "New files modified list (optional)"}, + "scope": map[string]any{"type": "string", "enum": []string{"project", "global"}, "description": "New scope (optional)"}, + }, + }, + }, + { + Name: "get_observation_quality", + Description: "Get quality metrics for an observation. Returns completeness score, usage stats, and improvement suggestions.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"id"}, + "properties": map[string]any{ + "id": map[string]any{"type": "number", "description": "Observation ID to analyze"}, + }, + }, + }, + { + Name: "suggest_consolidations", + Description: "Find observations that could be merged or consolidated. Returns groups of similar observations with merge recommendations.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "project": map[string]any{"type": "string", "description": "Filter by project"}, + "min_similarity": map[string]any{"type": "number", "default": 0.8, "minimum": 0.5, "maximum": 1.0, "description": "Minimum similarity threshold for grouping"}, + "limit": map[string]any{"type": "number", "default": 10, "minimum": 1, "maximum": 50, "description": "Maximum groups to return"}, + }, + }, + }, + { + Name: "tag_observation", + Description: "Add or remove concept tags from an observation. Tags help with organization and filtering. Use mode 'add' to add new tags, 'remove' to remove specific tags, or 'set' to replace all tags.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"id", "tags"}, + "properties": map[string]any{ + "id": map[string]any{"type": "number", "description": "Observation ID to tag"}, + "tags": map[string]any{"type": "array", "items": map[string]any{"type": "string"}, "description": "Tags to add, remove, or set"}, + "mode": map[string]any{"type": "string", "enum": []string{"add", "remove", "set"}, "default": "add", "description": "Operation mode: 'add' appends tags, 'remove' removes specific tags, 'set' replaces all tags"}, + }, + }, + }, + { + Name: "get_observations_by_tag", + Description: "Find all observations that have a specific concept tag. Useful for browsing by category.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"tag"}, + "properties": map[string]any{ + "tag": map[string]any{"type": "string", "description": "Tag/concept to search for"}, + "project": map[string]any{"type": "string", "description": "Filter by project (optional)"}, + "limit": map[string]any{"type": "number", "default": 50, "minimum": 1, "maximum": 200, "description": "Maximum observations to return"}, + }, + }, + }, + { + Name: "get_temporal_trends", + Description: "Analyze observation creation patterns over time. Returns daily counts, peak activity times, and trend insights. Useful for understanding work patterns.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "project": map[string]any{"type": "string", "description": "Filter by project (optional)"}, + "days": map[string]any{"type": "number", "default": 30, "minimum": 1, "maximum": 365, "description": "Number of days to analyze"}, + "group_by": map[string]any{"type": "string", "enum": []string{"day", "week", "hour_of_day"}, "default": "day", "description": "How to group the data"}, + }, + }, + }, + { + Name: "get_data_quality_report", + Description: "Get a comprehensive quality assessment of observations. Shows completeness distribution, common issues, and improvement suggestions.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "project": map[string]any{"type": "string", "description": "Filter by project (optional)"}, + "limit": map[string]any{"type": "number", "default": 100, "minimum": 10, "maximum": 500, "description": "Number of observations to analyze"}, + }, + }, + }, + { + Name: "batch_tag_by_pattern", + Description: "Apply tags to observations matching a pattern. Useful for retroactive organization and categorization.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"pattern", "tags"}, + "properties": map[string]any{ + "pattern": map[string]any{"type": "string", "description": "Search pattern to match (searches title, narrative, facts)"}, + "tags": map[string]any{"type": "array", "items": map[string]any{"type": "string"}, "description": "Tags to add to matching observations"}, + "project": map[string]any{"type": "string", "description": "Filter by project (optional)"}, + "dry_run": map[string]any{"type": "boolean", "default": true, "description": "If true, only preview matches without applying tags"}, + "max_matches": map[string]any{"type": "number", "default": 100, "minimum": 1, "maximum": 500, "description": "Maximum observations to tag"}, + }, + }, + }, + { + Name: "explain_search_ranking", + Description: "Debug search results by showing score breakdown for top matches. Explains why each observation ranked where it did.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"query"}, + "properties": map[string]any{ + "query": map[string]any{"type": "string", "description": "Search query to analyze"}, + "project": map[string]any{"type": "string", "description": "Project context for search"}, + "top_n": map[string]any{"type": "number", "default": 5, "minimum": 1, "maximum": 20, "description": "Number of top results to explain"}, + }, + }, + }, + { + Name: "export_observations", + Description: "Export observations in various formats for backup or analysis.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "format": map[string]any{"type": "string", "enum": []string{"json", "jsonl", "markdown"}, "default": "json", "description": "Export format"}, + "project": map[string]any{"type": "string", "description": "Filter by project (optional)"}, + "limit": map[string]any{"type": "number", "default": 100, "minimum": 1, "maximum": 1000, "description": "Maximum observations to export"}, + "date_start": map[string]any{"type": "number", "description": "Filter by creation date (epoch milliseconds)"}, + "date_end": map[string]any{"type": "number", "description": "Filter by creation date (epoch milliseconds)"}, + "obs_type": map[string]any{"type": "string", "description": "Filter by observation type"}, + }, + }, + }, + { + Name: "check_system_health", + Description: "Comprehensive system health check. Returns status of all subsystems (database, vectors, cache, search) with actionable diagnostics.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + { + Name: "analyze_search_patterns", + Description: "Analyze search query patterns to identify common searches, missed queries, and optimization opportunities.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "days": map[string]any{"type": "number", "default": 7, "minimum": 1, "maximum": 30, "description": "Number of days to analyze"}, + "top_n": map[string]any{"type": "number", "default": 10, "minimum": 1, "maximum": 50, "description": "Number of top patterns to return"}, + }, + }, + }, + { + Name: "get_observation_relationships", + Description: "Get relationship graph for an observation. Shows how observations relate to each other (depends_on, extends, conflicts_with, supersedes). Useful for understanding dependencies and context.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"id"}, + "properties": map[string]any{ + "id": map[string]any{"type": "number", "description": "Observation ID to analyze relationships for"}, + "max_depth": map[string]any{"type": "number", "default": 2, "minimum": 1, "maximum": 5, "description": "How many hops to traverse (1=direct, 2=neighbors of neighbors)"}, + }, + }, + }, + { + Name: "get_observation_scoring_breakdown", + Description: "Get detailed scoring breakdown for an observation. Shows how importance scores are calculated including type weight, recency decay, feedback contribution, concept boost, and retrieval frequency. Useful for understanding why observations are ranked the way they are.", + InputSchema: map[string]any{ + "type": "object", + "required": []string{"id"}, + "properties": map[string]any{ + "id": map[string]any{"type": "number", "description": "Observation ID to get scoring breakdown for"}, + }, + }, + }, + { + Name: "analyze_observation_importance", + Description: "Analyze observation importance patterns in a project. Returns statistics on feedback distribution, top-scoring observations, most-retrieved observations, and concept weights. Useful for understanding what makes observations valuable.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "project": map[string]any{"type": "string", "description": "Project to analyze (optional, analyzes all if omitted)"}, + "include_top_scored": map[string]any{"type": "boolean", "default": true, "description": "Include top-scoring observations"}, + "include_most_retrieved": map[string]any{"type": "boolean", "default": true, "description": "Include most-retrieved observations"}, + "include_concept_weights": map[string]any{"type": "boolean", "default": true, "description": "Include concept weight analysis"}, + "limit": map[string]any{"type": "number", "default": 10, "minimum": 1, "maximum": 50, "description": "Number of top observations to include"}, + }, + }, + }, } return &Response{ @@ -431,9 +758,60 @@ func (s *Server) handleToolsCall(ctx context.Context, req *Request) *Response { // callTool dispatches to the appropriate tool handler. func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage) (string, error) { - // Relation discovery tool - if name == "find_related_observations" { + // Special handlers for non-search tools + switch name { + case "find_related_observations": return s.handleFindRelatedObservations(ctx, args) + case "find_similar_observations": + return s.handleFindSimilarObservations(ctx, args) + case "get_patterns": + return s.handleGetPatterns(ctx, args) + case "get_memory_stats": + return s.handleGetMemoryStats(ctx) + case "bulk_delete_observations": + return s.handleBulkDeleteObservations(ctx, args) + case "bulk_mark_superseded": + return s.handleBulkMarkSuperseded(ctx, args) + case "bulk_boost_observations": + return s.handleBulkBoostObservations(ctx, args) + case "trigger_maintenance": + return s.handleTriggerMaintenance(ctx) + case "get_maintenance_stats": + return s.handleGetMaintenanceStats(ctx) + case "merge_observations": + return s.handleMergeObservations(ctx, args) + case "get_observation": + return s.handleGetObservation(ctx, args) + case "edit_observation": + return s.handleEditObservation(ctx, args) + case "get_observation_quality": + return s.handleGetObservationQuality(ctx, args) + case "suggest_consolidations": + return s.handleSuggestConsolidations(ctx, args) + case "tag_observation": + return s.handleTagObservation(ctx, args) + case "get_observations_by_tag": + return s.handleGetObservationsByTag(ctx, args) + case "get_temporal_trends": + return s.handleGetTemporalTrends(ctx, args) + case "get_data_quality_report": + return s.handleGetDataQualityReport(ctx, args) + case "batch_tag_by_pattern": + return s.handleBatchTagByPattern(ctx, args) + case "explain_search_ranking": + return s.handleExplainSearchRanking(ctx, args) + case "export_observations": + return s.handleExportObservations(ctx, args) + case "check_system_health": + return s.handleCheckSystemHealth(ctx) + case "analyze_search_patterns": + return s.handleAnalyzeSearchPatterns(ctx, args) + case "get_observation_relationships": + return s.handleGetObservationRelationships(ctx, args) + case "get_observation_scoring_breakdown": + return s.handleGetObservationScoringBreakdown(ctx, args) + case "analyze_observation_importance": + return s.handleAnalyzeObservationImportance(ctx, args) } // Original search-based tools @@ -627,15 +1005,17 @@ func (s *Server) handleFindRelatedObservations(ctx context.Context, args json.Ra relatedIDs = relatedIDs[:params.Limit] } - // Fetch full observations - observations := make([]*models.Observation, 0, len(relatedIDs)) - for _, id := range relatedIDs { - obs, err := s.observationStore.GetObservationByID(ctx, id) - if err != nil { - continue // Skip errors for individual observations - } - if obs != nil { - observations = append(observations, obs) + // Fetch full observations in batch (avoids N+1 query problem) + observations, err := s.observationStore.GetObservationsByIDsPreserveOrder(ctx, relatedIDs) + if err != nil { + log.Warn().Err(err).Msg("Failed to batch fetch related observations, falling back to individual fetch") + // Fallback to individual fetch if batch fails + observations = make([]*models.Observation, 0, len(relatedIDs)) + for _, id := range relatedIDs { + obs, fetchErr := s.observationStore.GetObservationByID(ctx, id) + if fetchErr == nil && obs != nil { + observations = append(observations, obs) + } } } @@ -675,3 +1055,2218 @@ func (s *Server) sendError(id any, code int, message string, data any) { } s.sendResponse(resp) } + +// handleFindSimilarObservations finds observations semantically similar to a query. +func (s *Server) handleFindSimilarObservations(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Query string `json:"query"` + Project string `json:"project"` + MinSimilarity float64 `json:"min_similarity"` + Limit int `json:"limit"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Query == "" { + return "", fmt.Errorf("query is required") + } + + if params.MinSimilarity == 0 { + params.MinSimilarity = 0.7 + } + if params.Limit == 0 { + params.Limit = 10 + } + if params.Limit > 50 { + params.Limit = 50 + } + + // Use vector search to find similar observations + if s.vectorClient == nil { + return "", fmt.Errorf("vector search not available") + } + + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, params.Project) + results, err := s.vectorClient.Query(ctx, params.Query, params.Limit*2, where) + if err != nil { + return "", fmt.Errorf("vector search failed: %w", err) + } + + // Filter by similarity threshold + filtered := sqlitevec.FilterByThreshold(results, params.MinSimilarity, params.Limit) + + // Extract observation IDs and build similarity map + obsIDs := sqlitevec.ExtractObservationIDs(filtered, params.Project) + similarityMap := make(map[int64]float64, len(filtered)) + for _, r := range filtered { + if sqliteID, ok := r.Metadata["sqlite_id"].(float64); ok { + id := int64(sqliteID) + if _, exists := similarityMap[id]; !exists { + similarityMap[id] = r.Similarity + } + } + } + + // Fetch full observations in batch (avoids N+1 query problem) + observations, err := s.observationStore.GetObservationsByIDsPreserveOrder(ctx, obsIDs) + if err != nil { + log.Warn().Err(err).Msg("Failed to batch fetch similar observations, falling back to individual fetch") + observations = make([]*models.Observation, 0, len(obsIDs)) + for _, id := range obsIDs { + obs, fetchErr := s.observationStore.GetObservationByID(ctx, id) + if fetchErr == nil && obs != nil { + observations = append(observations, obs) + } + } + } + + // Build response with similarity scores + type SimilarObservation struct { + *models.Observation + Similarity float64 `json:"similarity"` + } + + similarObs := make([]SimilarObservation, 0, len(observations)) + for _, obs := range observations { + sim := similarityMap[obs.ID] + similarObs = append(similarObs, SimilarObservation{ + Observation: obs, + Similarity: sim, + }) + } + + response := map[string]any{ + "observations": similarObs, + "count": len(similarObs), + "min_similarity": params.MinSimilarity, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetPatterns returns patterns from the pattern store. +func (s *Server) handleGetPatterns(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Type string `json:"type"` + Project string `json:"project"` + Query string `json:"query"` + Limit int `json:"limit"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Limit == 0 { + params.Limit = 20 + } + if params.Limit > 100 { + params.Limit = 100 + } + + var patterns []*models.Pattern + var err error + + // Query patterns based on filters + if params.Query != "" { + // FTS search + patterns, err = s.patternStore.SearchPatternsFTS(ctx, params.Query, params.Limit) + } else if params.Type != "" { + // Filter by type + patterns, err = s.patternStore.GetPatternsByType(ctx, models.PatternType(params.Type), params.Limit) + } else if params.Project != "" { + // Filter by project + patterns, err = s.patternStore.GetPatternsByProject(ctx, params.Project, params.Limit) + } else { + // Get all active patterns + patterns, err = s.patternStore.GetActivePatterns(ctx, params.Limit) + } + + if err != nil { + return "", fmt.Errorf("failed to get patterns: %w", err) + } + + response := map[string]any{ + "patterns": patterns, + "count": len(patterns), + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetMemoryStats returns statistics about the memory system. +func (s *Server) handleGetMemoryStats(ctx context.Context) (string, error) { + stats := make(map[string]any, 8) // Pre-allocate for expected stats keys + + // Get vector count + if s.vectorClient != nil { + count, err := s.vectorClient.Count(ctx) + if err == nil { + stats["vector_count"] = count + } + + // Cache stats + cacheSize, cacheMax := s.vectorClient.CacheStats() + stats["embedding_cache"] = map[string]any{ + "size": cacheSize, + "max_size": cacheMax, + } + + // Model version + stats["embedding_model"] = s.vectorClient.ModelVersion() + } + + // Get pattern stats + if s.patternStore != nil { + patternStats, err := s.patternStore.GetPatternStats(ctx) + if err == nil && patternStats != nil { + stats["patterns"] = map[string]any{ + "total": patternStats.Total, + "active": patternStats.Active, + "deprecated": patternStats.Deprecated, + "merged": patternStats.Merged, + "total_occurrences": patternStats.TotalOccurrences, + "avg_confidence": patternStats.AvgConfidence, + } + } + } + + // Get search metrics + if s.searchMgr != nil { + searchMetrics := s.searchMgr.Metrics() + if searchMetrics != nil { + stats["search"] = searchMetrics.GetStats() + } + } + + output, err := json.Marshal(stats) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleBulkDeleteObservations deletes multiple observations by ID. +func (s *Server) handleBulkDeleteObservations(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + IDs []int64 `json:"ids"` + DeleteVectors bool `json:"delete_vectors"` + } + params.DeleteVectors = true // default + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if len(params.IDs) == 0 { + return "", fmt.Errorf("ids is required") + } + + if len(params.IDs) > 1000 { + return "", fmt.Errorf("maximum 1000 IDs per request") + } + + var deleted int64 + var errors []string + + // Delete in batches + batchSize := 100 + for i := 0; i < len(params.IDs); i += batchSize { + end := min(i+batchSize, len(params.IDs)) + batch := params.IDs[i:end] + + for _, id := range batch { + if err := s.observationStore.DeleteObservation(ctx, id); err != nil { + errors = append(errors, fmt.Sprintf("id %d: %v", id, err)) + continue + } + deleted++ + + // Delete associated vectors if requested + if params.DeleteVectors && s.vectorClient != nil { + _ = s.vectorClient.DeleteByObservationID(ctx, id) + } + } + } + + response := map[string]any{ + "deleted": deleted, + "total": len(params.IDs), + } + if len(errors) > 0 { + response["errors"] = errors + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + // Return error if all deletions failed (complete failure) + if deleted == 0 && len(errors) > 0 { + return string(output), fmt.Errorf("bulk delete failed: %d errors, first: %s", len(errors), errors[0]) + } + + return string(output), nil +} + +// handleBulkMarkSuperseded marks multiple observations as superseded. +func (s *Server) handleBulkMarkSuperseded(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + IDs []int64 `json:"ids"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if len(params.IDs) == 0 { + return "", fmt.Errorf("ids is required") + } + + if len(params.IDs) > 1000 { + return "", fmt.Errorf("maximum 1000 IDs per request") + } + + // Use batch update for efficiency (single query instead of N queries) + updated, err := s.observationStore.MarkAsSupersededBatch(ctx, params.IDs) + if err != nil { + return "", fmt.Errorf("batch mark as superseded: %w", err) + } + + response := map[string]any{ + "updated": updated, + "total": len(params.IDs), + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleBulkBoostObservations boosts the importance score of multiple observations. +func (s *Server) handleBulkBoostObservations(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + IDs []int64 `json:"ids"` + Boost float64 `json:"boost"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if len(params.IDs) == 0 { + return "", fmt.Errorf("ids is required") + } + + if len(params.IDs) > 1000 { + return "", fmt.Errorf("maximum 1000 IDs per request") + } + + if params.Boost < -1.0 || params.Boost > 1.0 { + return "", fmt.Errorf("boost must be between -1.0 and 1.0") + } + + var boosted int64 + var errors []string + + // Batch fetch all observations in one query instead of N queries + observations, err := s.observationStore.GetObservationsByIDs(ctx, params.IDs, "", 0) + if err != nil { + return "", fmt.Errorf("batch fetch observations: %w", err) + } + + // Build a map for O(1) lookup + obsMap := make(map[int64]*models.Observation, len(observations)) + for _, obs := range observations { + obsMap[obs.ID] = obs + } + + // Calculate new scores and prepare batch update + scoresToUpdate := make(map[int64]float64, len(params.IDs)) + for _, id := range params.IDs { + obs, found := obsMap[id] + if !found { + errors = append(errors, fmt.Sprintf("id %d: not found", id)) + continue + } + + // Calculate new importance score (clamp between 0 and 1) + newScore := obs.ImportanceScore + params.Boost + if newScore < 0 { + newScore = 0 + } + if newScore > 1 { + newScore = 1 + } + scoresToUpdate[id] = newScore + } + + // Batch update all scores in one operation + if len(scoresToUpdate) > 0 { + if err := s.observationStore.UpdateImportanceScores(ctx, scoresToUpdate); err != nil { + return "", fmt.Errorf("batch update scores: %w", err) + } + boosted = int64(len(scoresToUpdate)) + } + + response := map[string]any{ + "boosted": boosted, + "total": len(params.IDs), + "boost_used": params.Boost, + } + if len(errors) > 0 { + response["errors"] = errors + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleTriggerMaintenance triggers an immediate maintenance run. +func (s *Server) handleTriggerMaintenance(ctx context.Context) (string, error) { + if s.maintenanceService == nil { + return "", fmt.Errorf("maintenance service not available") + } + + s.maintenanceService.RunNow(ctx) + + response := map[string]any{ + "status": "triggered", + "message": "Maintenance run started in background", + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetMaintenanceStats returns maintenance statistics. +func (s *Server) handleGetMaintenanceStats(_ context.Context) (string, error) { + if s.maintenanceService == nil { + return "", fmt.Errorf("maintenance service not available") + } + + stats := s.maintenanceService.Stats() + + output, err := json.Marshal(stats) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleMergeObservations merges two observations, keeping the target and superseding the source. +func (s *Server) handleMergeObservations(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + SourceID int64 `json:"source_id"` + TargetID int64 `json:"target_id"` + Boost float64 `json:"boost"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.SourceID == 0 || params.TargetID == 0 { + return "", fmt.Errorf("source_id and target_id are required") + } + + if params.SourceID == params.TargetID { + return "", fmt.Errorf("source_id and target_id cannot be the same") + } + + // Set default boost if not provided + if params.Boost == 0 { + params.Boost = 0.1 + } + if params.Boost < 0 || params.Boost > 0.5 { + return "", fmt.Errorf("boost must be between 0 and 0.5") + } + + // Get both observations to verify they exist + source, err := s.observationStore.GetObservationByID(ctx, params.SourceID) + if err != nil { + return "", fmt.Errorf("get source observation: %w", err) + } + if source == nil { + return "", fmt.Errorf("source observation %d not found", params.SourceID) + } + + target, err := s.observationStore.GetObservationByID(ctx, params.TargetID) + if err != nil { + return "", fmt.Errorf("get target observation: %w", err) + } + if target == nil { + return "", fmt.Errorf("target observation %d not found", params.TargetID) + } + + // Mark source as superseded + if err := s.observationStore.MarkAsSuperseded(ctx, params.SourceID); err != nil { + return "", fmt.Errorf("mark source as superseded: %w", err) + } + + // Boost target's importance score + newScore := target.ImportanceScore + params.Boost + if newScore > 1.0 { + newScore = 1.0 + } + if err := s.observationStore.UpdateImportanceScore(ctx, params.TargetID, newScore); err != nil { + return "", fmt.Errorf("update target score: %w", err) + } + + response := map[string]any{ + "merged": true, + "source_id": params.SourceID, + "source_title": source.Title.String, + "target_id": params.TargetID, + "target_title": target.Title.String, + "target_new_score": newScore, + "target_old_score": target.ImportanceScore, + "boost_applied": params.Boost, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetObservation returns a single observation by ID. +func (s *Server) handleGetObservation(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.ID == 0 { + return "", fmt.Errorf("id is required") + } + + obs, err := s.observationStore.GetObservationByID(ctx, params.ID) + if err != nil { + return "", fmt.Errorf("get observation: %w", err) + } + if obs == nil { + return "", fmt.Errorf("observation %d not found", params.ID) + } + + output, err := json.Marshal(obs) + if err != nil { + return "", fmt.Errorf("marshal observation: %w", err) + } + + return string(output), nil +} + +// handleEditObservation updates an existing observation with provided fields. +func (s *Server) handleEditObservation(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` + Title *string `json:"title,omitempty"` + Subtitle *string `json:"subtitle,omitempty"` + Narrative *string `json:"narrative,omitempty"` + Facts []string `json:"facts,omitempty"` + Concepts []string `json:"concepts,omitempty"` + FilesRead []string `json:"files_read,omitempty"` + FilesModified []string `json:"files_modified,omitempty"` + Scope *string `json:"scope,omitempty"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.ID == 0 { + return "", fmt.Errorf("id is required") + } + + // Validate scope if provided + if params.Scope != nil && *params.Scope != "project" && *params.Scope != "global" { + return "", fmt.Errorf("scope must be 'project' or 'global'") + } + + // Build update struct + update := &gorm.ObservationUpdate{} + if params.Title != nil { + update.Title = params.Title + } + if params.Subtitle != nil { + update.Subtitle = params.Subtitle + } + if params.Narrative != nil { + update.Narrative = params.Narrative + } + if params.Facts != nil { + update.Facts = ¶ms.Facts + } + if params.Concepts != nil { + update.Concepts = ¶ms.Concepts + } + if params.FilesRead != nil { + update.FilesRead = ¶ms.FilesRead + } + if params.FilesModified != nil { + update.FilesModified = ¶ms.FilesModified + } + if params.Scope != nil { + update.Scope = params.Scope + } + + // Update the observation + updatedObs, err := s.observationStore.UpdateObservation(ctx, params.ID, update) + if err != nil { + return "", fmt.Errorf("update observation: %w", err) + } + + // Note: Vector resync is handled by the worker service when available + // The MCP server doesn't have access to the embedding service + + response := map[string]any{ + "updated": true, + "observation": updatedObs, + "vector_resync": "deferred", + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetObservationQuality returns quality metrics for an observation. +func (s *Server) handleGetObservationQuality(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.ID == 0 { + return "", fmt.Errorf("id is required") + } + + obs, err := s.observationStore.GetObservationByID(ctx, params.ID) + if err != nil { + return "", fmt.Errorf("get observation: %w", err) + } + if obs == nil { + return "", fmt.Errorf("observation %d not found", params.ID) + } + + // Calculate completeness score + completenessScore := 0.0 + maxScore := 5.0 + suggestions := []string{} + + // Check title (required, 1 point) + if obs.Title.Valid && obs.Title.String != "" { + completenessScore += 1.0 + } else { + suggestions = append(suggestions, "Add a descriptive title") + } + + // Check narrative (important, 1.5 points) + if obs.Narrative.Valid && len(obs.Narrative.String) > 50 { + completenessScore += 1.5 + } else if obs.Narrative.Valid && obs.Narrative.String != "" { + completenessScore += 0.5 + suggestions = append(suggestions, "Expand the narrative to provide more context (aim for 50+ characters)") + } else { + suggestions = append(suggestions, "Add a narrative explaining the observation") + } + + // Check facts (valuable, 1 point) + if len(obs.Facts) >= 2 { + completenessScore += 1.0 + } else if len(obs.Facts) == 1 { + completenessScore += 0.5 + suggestions = append(suggestions, "Add more key facts (aim for 2+)") + } else { + suggestions = append(suggestions, "Add key facts to capture important details") + } + + // Check concepts (useful, 0.75 points) + if len(obs.Concepts) >= 2 { + completenessScore += 0.75 + } else if len(obs.Concepts) == 1 { + completenessScore += 0.25 + suggestions = append(suggestions, "Add more concept tags for better discoverability") + } else { + suggestions = append(suggestions, "Add concept tags to categorize this observation") + } + + // Check file references (helpful, 0.75 points) + if len(obs.FilesRead) > 0 || len(obs.FilesModified) > 0 { + completenessScore += 0.75 + } else { + suggestions = append(suggestions, "Consider adding file references if applicable") + } + + // Determine quality tier + qualityTier := "poor" + switch { + case completenessScore >= 4.0: + qualityTier = "excellent" + case completenessScore >= 3.0: + qualityTier = "good" + case completenessScore >= 2.0: + qualityTier = "fair" + } + + response := map[string]any{ + "id": params.ID, + "completeness_score": completenessScore, + "max_score": maxScore, + "completeness_pct": (completenessScore / maxScore) * 100, + "quality_tier": qualityTier, + "importance_score": obs.ImportanceScore, + "retrieval_count": obs.RetrievalCount, + "is_superseded": obs.IsSuperseded, + "suggestions": suggestions, + "field_stats": map[string]any{ + "has_title": obs.Title.Valid && obs.Title.String != "", + "has_narrative": obs.Narrative.Valid && obs.Narrative.String != "", + "narrative_length": len(obs.Narrative.String), + "facts_count": len(obs.Facts), + "concepts_count": len(obs.Concepts), + "files_read_count": len(obs.FilesRead), + "files_modified_count": len(obs.FilesModified), + }, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleSuggestConsolidations finds observations that could be merged. +func (s *Server) handleSuggestConsolidations(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Project string `json:"project"` + MinSimilarity float64 `json:"min_similarity"` + Limit int `json:"limit"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + // Set defaults + if params.MinSimilarity == 0 { + params.MinSimilarity = 0.8 + } + if params.Limit == 0 { + params.Limit = 10 + } + if params.MinSimilarity < 0.5 || params.MinSimilarity > 1.0 { + return "", fmt.Errorf("min_similarity must be between 0.5 and 1.0") + } + + // Get recent observations to analyze + obs, err := s.observationStore.GetRecentObservations(ctx, params.Project, 200) + if err != nil { + return "", fmt.Errorf("get observations: %w", err) + } + + if len(obs) < 2 { + response := map[string]any{ + "groups": []any{}, + "message": "Not enough observations to analyze", + } + output, _ := json.Marshal(response) + return string(output), nil + } + + // Find similar pairs using vector search if available + type consolidationGroup struct { + Primary *models.Observation `json:"primary"` + Similar []*models.Observation `json:"similar"` + Similarity float64 `json:"avg_similarity"` + Reason string `json:"reason"` + } + + groups := []consolidationGroup{} + seen := make(map[int64]bool) + + // For each observation, find similar ones + for _, primary := range obs { + if seen[primary.ID] { + continue + } + + // Build search text from observation + searchText := primary.Title.String + if primary.Narrative.Valid { + searchText += " " + primary.Narrative.String + } + + if searchText == "" || s.vectorClient == nil { + continue + } + + // Query for similar observations + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, params.Project) + results, err := s.vectorClient.Query(ctx, searchText, 10, where) + if err != nil { + continue + } + + // Find similar observations above threshold + similar := []*models.Observation{} + totalSimilarity := 0.0 + + for _, r := range results { + // Extract observation ID from metadata + sqliteID, ok := r.Metadata["sqlite_id"].(float64) + if !ok { + continue + } + obsID := int64(sqliteID) + + if obsID == primary.ID || seen[obsID] { + continue + } + if r.Similarity >= params.MinSimilarity { + // Fetch the similar observation + simObs, err := s.observationStore.GetObservationByID(ctx, obsID) + if err != nil || simObs == nil { + continue + } + similar = append(similar, simObs) + totalSimilarity += r.Similarity + seen[obsID] = true + } + } + + if len(similar) > 0 { + seen[primary.ID] = true + avgSimilarity := totalSimilarity / float64(len(similar)) + + // Determine consolidation reason + reason := "Content similarity detected" + if len(primary.Concepts) > 0 && len(similar) > 0 { + // Check for concept overlap + conceptMap := make(map[string]bool) + for _, c := range primary.Concepts { + conceptMap[c] = true + } + for _, sim := range similar { + for _, c := range sim.Concepts { + if conceptMap[c] { + reason = "Similar content with shared concepts" + break + } + } + } + } + + groups = append(groups, consolidationGroup{ + Primary: primary, + Similar: similar, + Similarity: avgSimilarity, + Reason: reason, + }) + + if len(groups) >= params.Limit { + break + } + } + } + + response := map[string]any{ + "groups": groups, + "total_analyzed": len(obs), + "groups_found": len(groups), + "min_similarity": params.MinSimilarity, + "recommendation": "Review each group and use merge_observations to consolidate where appropriate", + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleTagObservation adds, removes, or sets tags on an observation. +func (s *Server) handleTagObservation(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` + Tags []string `json:"tags"` + Mode string `json:"mode"` + } + params.Mode = "add" // default + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.ID == 0 { + return "", fmt.Errorf("id is required") + } + if len(params.Tags) == 0 { + return "", fmt.Errorf("tags is required") + } + if params.Mode != "add" && params.Mode != "remove" && params.Mode != "set" { + return "", fmt.Errorf("mode must be 'add', 'remove', or 'set'") + } + + // Get current observation + obs, err := s.observationStore.GetObservationByID(ctx, params.ID) + if err != nil { + return "", fmt.Errorf("get observation: %w", err) + } + if obs == nil { + return "", fmt.Errorf("observation %d not found", params.ID) + } + + // Compute new tags + var newTags []string + switch params.Mode { + case "set": + newTags = params.Tags + case "add": + // Add new tags, avoiding duplicates + tagSet := make(map[string]bool) + for _, t := range obs.Concepts { + tagSet[t] = true + newTags = append(newTags, t) + } + for _, t := range params.Tags { + if !tagSet[t] { + tagSet[t] = true + newTags = append(newTags, t) + } + } + case "remove": + // Remove specified tags + removeSet := make(map[string]bool) + for _, t := range params.Tags { + removeSet[t] = true + } + for _, t := range obs.Concepts { + if !removeSet[t] { + newTags = append(newTags, t) + } + } + } + + // Update using existing UpdateObservation method + update := &gorm.ObservationUpdate{ + Concepts: &newTags, + } + updatedObs, err := s.observationStore.UpdateObservation(ctx, params.ID, update) + if err != nil { + return "", fmt.Errorf("update observation: %w", err) + } + + response := map[string]any{ + "id": params.ID, + "mode": params.Mode, + "tags_applied": params.Tags, + "current_tags": updatedObs.Concepts, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetObservationsByTag retrieves observations with a specific concept tag. +func (s *Server) handleGetObservationsByTag(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Tag string `json:"tag"` + Project string `json:"project"` + Limit int `json:"limit"` + } + params.Limit = 50 // default + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Tag == "" { + return "", fmt.Errorf("tag is required") + } + if params.Limit < 1 || params.Limit > 200 { + params.Limit = 50 + } + + // Use search with concept filter + searchParams := search.SearchParams{ + Query: params.Tag, + Type: "observations", + Project: params.Project, + Limit: params.Limit, + Concepts: params.Tag, + } + + result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) + if err != nil { + return "", fmt.Errorf("search: %w", err) + } + + // Filter results to only include observations with the exact tag in metadata + var filtered []search.SearchResult + for _, r := range result.Results { + if r.Type != "observation" { + continue + } + // Check if concepts metadata contains the tag + if concepts, ok := r.Metadata["concepts"].([]any); ok { + for _, c := range concepts { + if cs, ok := c.(string); ok && cs == params.Tag { + filtered = append(filtered, r) + break + } + } + } + } + + response := map[string]any{ + "tag": params.Tag, + "observations": filtered, + "count": len(filtered), + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetTemporalTrends analyzes observation creation patterns over time. +func (s *Server) handleGetTemporalTrends(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Project string `json:"project"` + Days int `json:"days"` + GroupBy string `json:"group_by"` + } + params.Days = 30 + params.GroupBy = "day" + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Days < 1 || params.Days > 365 { + params.Days = 30 + } + + // Get observations for analysis + obs, err := s.observationStore.GetRecentObservations(ctx, params.Project, params.Days*50) // Rough estimate + if err != nil { + return "", fmt.Errorf("get observations: %w", err) + } + + // Calculate time range + now := time.Now() + startTime := now.AddDate(0, 0, -params.Days) + startEpoch := startTime.UnixMilli() + + // Group observations by time bucket + buckets := make(map[string]int) + typeDistribution := make(map[string]int) + conceptCounts := make(map[string]int) + totalInRange := 0 + + for _, o := range obs { + if o.CreatedAtEpoch < startEpoch { + continue + } + totalInRange++ + + created := time.UnixMilli(o.CreatedAtEpoch) + var key string + switch params.GroupBy { + case "week": + year, week := created.ISOWeek() + key = fmt.Sprintf("%d-W%02d", year, week) + case "hour_of_day": + key = fmt.Sprintf("%02d:00", created.Hour()) + default: // day + key = created.Format("2006-01-02") + } + buckets[key]++ + + // Track type distribution + typeDistribution[string(o.Type)]++ + + // Track top concepts + for _, c := range o.Concepts { + conceptCounts[c]++ + } + } + + // Find peak period + peakPeriod := "" + peakCount := 0 + for k, v := range buckets { + if v > peakCount { + peakCount = v + peakPeriod = k + } + } + + // Sort and get top concepts + type conceptEntry struct { + name string + count int + } + var topConcepts []conceptEntry + for name, count := range conceptCounts { + topConcepts = append(topConcepts, conceptEntry{name, count}) + } + // Simple sort - just take top 10 + for i := 0; i < len(topConcepts) && i < 10; i++ { + for j := i + 1; j < len(topConcepts); j++ { + if topConcepts[j].count > topConcepts[i].count { + topConcepts[i], topConcepts[j] = topConcepts[j], topConcepts[i] + } + } + } + if len(topConcepts) > 10 { + topConcepts = topConcepts[:10] + } + topConceptsMap := make([]map[string]any, len(topConcepts)) + for i, c := range topConcepts { + topConceptsMap[i] = map[string]any{"concept": c.name, "count": c.count} + } + + response := map[string]any{ + "period": map[string]any{ + "start": startTime.Format("2006-01-02"), + "end": now.Format("2006-01-02"), + "days": params.Days, + "group_by": params.GroupBy, + }, + "summary": map[string]any{ + "total_observations": totalInRange, + "daily_average": float64(totalInRange) / float64(params.Days), + "peak_period": peakPeriod, + "peak_count": peakCount, + }, + "distribution": buckets, + "type_distribution": typeDistribution, + "top_concepts": topConceptsMap, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleGetDataQualityReport generates a comprehensive quality assessment. +func (s *Server) handleGetDataQualityReport(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Project string `json:"project"` + Limit int `json:"limit"` + } + params.Limit = 100 + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Limit < 10 || params.Limit > 500 { + params.Limit = 100 + } + + // Get observations for analysis + obs, err := s.observationStore.GetRecentObservations(ctx, params.Project, params.Limit) + if err != nil { + return "", fmt.Errorf("get observations: %w", err) + } + + if len(obs) == 0 { + return `{"error": "no observations found", "analyzed": 0}`, nil + } + + // Quality analysis + qualityScores := make([]float64, 0, len(obs)) + issuesFound := make(map[string]int) + improvements := make(map[string]int) + scoreDistribution := map[string]int{"excellent": 0, "good": 0, "fair": 0, "poor": 0} + + for _, o := range obs { + score := 0.0 + maxScore := 5.0 + + // Check completeness + if o.Title.Valid && o.Title.String != "" { + score += 1.0 + } else { + issuesFound["missing_title"]++ + improvements["add_title"]++ + } + + if o.Narrative.Valid && o.Narrative.String != "" { + score += 1.0 + } else { + issuesFound["missing_narrative"]++ + improvements["add_narrative"]++ + } + + if len(o.Facts) > 0 { + score += 1.0 + if len(o.Facts) >= 3 { + score += 0.5 // Bonus for multiple facts + } + } else { + issuesFound["no_facts"]++ + improvements["add_facts"]++ + } + + if len(o.Concepts) > 0 { + score += 1.0 + } else { + issuesFound["no_concepts"]++ + improvements["add_concepts"]++ + } + + if len(o.FilesRead) > 0 || len(o.FilesModified) > 0 { + score += 0.5 + } + + normalized := (score / maxScore) * 100 + qualityScores = append(qualityScores, normalized) + + // Categorize + switch { + case normalized >= 80: + scoreDistribution["excellent"]++ + case normalized >= 60: + scoreDistribution["good"]++ + case normalized >= 40: + scoreDistribution["fair"]++ + default: + scoreDistribution["poor"]++ + } + } + + // Calculate average + var avgScore float64 + for _, s := range qualityScores { + avgScore += s + } + avgScore /= float64(len(qualityScores)) + + // Build top issues list + type issueEntry struct { + name string + count int + } + var topIssues []issueEntry + for name, count := range issuesFound { + topIssues = append(topIssues, issueEntry{name, count}) + } + for i := 0; i < len(topIssues) && i < 5; i++ { + for j := i + 1; j < len(topIssues); j++ { + if topIssues[j].count > topIssues[i].count { + topIssues[i], topIssues[j] = topIssues[j], topIssues[i] + } + } + } + if len(topIssues) > 5 { + topIssues = topIssues[:5] + } + + // Convert top issues to response format + topIssuesList := make([]map[string]any, 0, len(topIssues)) + for _, issue := range topIssues { + topIssuesList = append(topIssuesList, map[string]any{ + "issue": issue.name, + "count": issue.count, + }) + } + + response := map[string]any{ + "analyzed": len(obs), + "project": params.Project, + "quality_summary": map[string]any{ + "average_score": fmt.Sprintf("%.1f%%", avgScore), + "distribution": scoreDistribution, + }, + "issues_found": issuesFound, + "top_issues": topIssuesList, + "improvements": improvements, + "recommendations": []string{ + "Add titles to observations for better discoverability", + "Include narratives to provide context", + "Add concept tags for better organization", + "Include at least 2-3 key facts per observation", + }, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleBatchTagByPattern applies tags to observations matching a pattern. +func (s *Server) handleBatchTagByPattern(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Pattern string `json:"pattern"` + Tags []string `json:"tags"` + Project string `json:"project"` + DryRun bool `json:"dry_run"` + MaxMatches int `json:"max_matches"` + } + params.DryRun = true + params.MaxMatches = 100 + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Pattern == "" { + return "", fmt.Errorf("pattern is required") + } + if len(params.Tags) == 0 { + return "", fmt.Errorf("tags is required") + } + if params.MaxMatches < 1 || params.MaxMatches > 500 { + params.MaxMatches = 100 + } + + // Search for matching observations using the pattern + searchParams := search.SearchParams{ + Query: params.Pattern, + Type: "observations", + Project: params.Project, + Limit: params.MaxMatches, + } + + result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) + if err != nil { + return "", fmt.Errorf("search: %w", err) + } + + // Collect matching observation IDs + var matches []map[string]any + var taggedCount int + + for _, r := range result.Results { + if r.Type != "observation" { + continue + } + + match := map[string]any{ + "id": r.ID, + "title": r.Title, + "score": r.Score, + } + matches = append(matches, match) + + // Apply tags if not dry run + if !params.DryRun { + obs, err := s.observationStore.GetObservationByID(ctx, r.ID) + if err != nil || obs == nil { + continue + } + + // Merge existing tags with new tags (avoid duplicates) + tagSet := make(map[string]bool) + newTags := make([]string, 0, len(obs.Concepts)+len(params.Tags)) + for _, t := range obs.Concepts { + tagSet[t] = true + newTags = append(newTags, t) + } + for _, t := range params.Tags { + if !tagSet[t] { + tagSet[t] = true + newTags = append(newTags, t) + } + } + + update := &gorm.ObservationUpdate{ + Concepts: &newTags, + } + _, err = s.observationStore.UpdateObservation(ctx, r.ID, update) + if err == nil { + taggedCount++ + } + } + } + + response := map[string]any{ + "pattern": params.Pattern, + "tags": params.Tags, + "dry_run": params.DryRun, + "matches_found": len(matches), + "matches": matches, + } + + if !params.DryRun { + response["tagged_count"] = taggedCount + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleExplainSearchRanking explains why each observation ranked where it did in search results. +func (s *Server) handleExplainSearchRanking(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Query string `json:"query"` + Project string `json:"project"` + TopN int `json:"top_n"` + } + params.TopN = 5 // default + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Query == "" { + return "", fmt.Errorf("query is required") + } + if params.TopN < 1 || params.TopN > 20 { + params.TopN = 5 + } + + // Perform search to get results + searchParams := search.SearchParams{ + Query: params.Query, + Type: "observations", + Project: params.Project, + Limit: params.TopN, + OrderBy: "relevance", + } + + result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) + if err != nil { + return "", fmt.Errorf("search: %w", err) + } + + // Build detailed explanations for each result + type RankExplanation struct { + Rank int `json:"rank"` + ID int64 `json:"id"` + Title string `json:"title"` + Type string `json:"type"` + Score float64 `json:"score"` + ScoreBreakdown map[string]float64 `json:"score_breakdown"` + MatchedFields []string `json:"matched_fields"` + Metadata map[string]any `json:"metadata,omitempty"` + } + + explanations := make([]RankExplanation, 0, len(result.Results)) + for i, r := range result.Results { + exp := RankExplanation{ + Rank: i + 1, + ID: r.ID, + Title: r.Title, + Type: r.Type, + Score: r.Score, + Metadata: r.Metadata, + } + + // Build score breakdown from available metadata + exp.ScoreBreakdown = make(map[string]float64) + if vs, ok := r.Metadata["vector_score"].(float64); ok { + exp.ScoreBreakdown["vector_similarity"] = vs + } + if is, ok := r.Metadata["importance_score"].(float64); ok { + exp.ScoreBreakdown["importance"] = is + } + if ts, ok := r.Metadata["text_score"].(float64); ok { + exp.ScoreBreakdown["text_match"] = ts + } + if rs, ok := r.Metadata["recency_score"].(float64); ok { + exp.ScoreBreakdown["recency"] = rs + } + // Add base score estimate if breakdown is incomplete + if len(exp.ScoreBreakdown) == 0 { + exp.ScoreBreakdown["combined_score"] = r.Score + } + + // Determine matched fields + exp.MatchedFields = []string{} + if r.Metadata["field_type"] != nil { + if ft, ok := r.Metadata["field_type"].(string); ok && ft != "" { + exp.MatchedFields = append(exp.MatchedFields, ft) + } + } + + explanations = append(explanations, exp) + } + + response := map[string]any{ + "query": params.Query, + "project": params.Project, + "result_count": len(explanations), + "explanations": explanations, + "tips": []string{ + "Higher vector_similarity indicates better semantic match with query", + "Importance score reflects user feedback and retrieval history", + "Recency boosts newer observations slightly", + "Use tag_observation to boost important observations", + }, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + + return string(output), nil +} + +// handleExportObservations exports observations in various formats. +func (s *Server) handleExportObservations(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Format string `json:"format"` + Project string `json:"project"` + Limit int `json:"limit"` + DateStart int64 `json:"date_start"` + DateEnd int64 `json:"date_end"` + ObsType string `json:"obs_type"` + } + params.Format = "json" + params.Limit = 100 + + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.Limit < 1 || params.Limit > 1000 { + params.Limit = 100 + } + + // Build search params to fetch observations + searchParams := search.SearchParams{ + Type: "observations", + Project: params.Project, + Limit: params.Limit, + OrderBy: "date_desc", + DateStart: params.DateStart, + DateEnd: params.DateEnd, + ObsType: params.ObsType, + } + + result, err := s.searchMgr.UnifiedSearch(ctx, searchParams) + if err != nil { + return "", fmt.Errorf("search: %w", err) + } + + // Fetch full observation data for export + ids := make([]int64, 0, len(result.Results)) + for _, r := range result.Results { + if r.Type == "observation" { + ids = append(ids, r.ID) + } + } + + observations, err := s.observationStore.GetObservationsByIDs(ctx, ids, "", 0) + if err != nil { + return "", fmt.Errorf("get observations: %w", err) + } + + // Format output based on requested format + var output string + switch params.Format { + case "jsonl": + // JSON Lines format - one JSON object per line + var lines []string + for _, obs := range observations { + line, err := json.Marshal(obs) + if err != nil { + continue + } + lines = append(lines, string(line)) + } + output = fmt.Sprintf(`{"format":"jsonl","count":%d,"data":"%s"}`, + len(observations), escapeJSONString(strings.Join(lines, "\n"))) + + case "markdown": + // Markdown format for human reading + var md strings.Builder + md.WriteString("# Observations Export\n\n") + md.WriteString(fmt.Sprintf("Total: %d observations\n\n", len(observations))) + md.WriteString("---\n\n") + + for _, obs := range observations { + title := "" + if obs.Title.Valid { + title = obs.Title.String + } + md.WriteString(fmt.Sprintf("## [%s] %s\n\n", obs.Type, title)) + if obs.Subtitle.Valid && obs.Subtitle.String != "" { + md.WriteString(fmt.Sprintf("*%s*\n\n", obs.Subtitle.String)) + } + if obs.Narrative.Valid && obs.Narrative.String != "" { + md.WriteString(fmt.Sprintf("%s\n\n", obs.Narrative.String)) + } + if len(obs.Facts) > 0 { + md.WriteString("### Key Facts\n") + for _, fact := range obs.Facts { + md.WriteString(fmt.Sprintf("- %s\n", fact)) + } + md.WriteString("\n") + } + if len(obs.Concepts) > 0 { + md.WriteString(fmt.Sprintf("**Tags:** %s\n\n", strings.Join(obs.Concepts, ", "))) + } + md.WriteString(fmt.Sprintf("**ID:** %d | **Created:** %s | **Importance:** %.2f\n\n", + obs.ID, obs.CreatedAt, obs.ImportanceScore)) + md.WriteString("---\n\n") + } + + // Wrap markdown in JSON response + response := map[string]any{ + "format": "markdown", + "count": len(observations), + "data": md.String(), + } + outputBytes, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + output = string(outputBytes) + + default: // json + response := map[string]any{ + "format": "json", + "count": len(observations), + "observations": observations, + } + outputBytes, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + output = string(outputBytes) + } + + return output, nil +} + +// escapeJSONString escapes a string for use in JSON. +func escapeJSONString(s string) string { + b, _ := json.Marshal(s) + // Remove surrounding quotes + if len(b) >= 2 { + return string(b[1 : len(b)-1]) + } + return s +} + +// handleCheckSystemHealth performs comprehensive system health checks. +func (s *Server) handleCheckSystemHealth(ctx context.Context) (string, error) { + type SubsystemHealth struct { + Status string `json:"status"` // "healthy", "degraded", "unhealthy" + Message string `json:"message,omitempty"` + Metrics map[string]any `json:"metrics,omitempty"` + Warnings []string `json:"warnings,omitempty"` + } + + type HealthReport struct { + OverallStatus string `json:"overall_status"` + HealthScore int `json:"health_score"` + Timestamp time.Time `json:"timestamp"` + Subsystems map[string]*SubsystemHealth `json:"subsystems"` + Actions []string `json:"recommended_actions,omitempty"` + } + + report := &HealthReport{ + OverallStatus: "healthy", + HealthScore: 100, + Timestamp: time.Now(), + Subsystems: make(map[string]*SubsystemHealth), + Actions: []string{}, + } + + // Check database health + dbHealth := &SubsystemHealth{ + Status: "healthy", + Metrics: make(map[string]any), + } + if s.observationStore != nil { + // Count observations + count, err := s.observationStore.GetObservationCount(ctx, "") + if err != nil { + dbHealth.Status = "unhealthy" + dbHealth.Message = "Database query failed: " + err.Error() + report.HealthScore -= 30 + } else { + dbHealth.Metrics["total_observations"] = count + dbHealth.Message = "Database operational" + } + + // Check for recent activity + recent, err := s.observationStore.GetAllRecentObservations(ctx, 1) + if err == nil && len(recent) > 0 { + dbHealth.Metrics["last_observation"] = recent[0].CreatedAt + // Check epoch for staleness warning + if recent[0].CreatedAtEpoch > 0 { + lastActivityTime := time.UnixMilli(recent[0].CreatedAtEpoch) + if time.Since(lastActivityTime) > 7*24*time.Hour { + dbHealth.Warnings = append(dbHealth.Warnings, "No observations in the last 7 days") + } + } + } + } else { + dbHealth.Status = "unhealthy" + dbHealth.Message = "Observation store not initialized" + report.HealthScore -= 50 + } + report.Subsystems["database"] = dbHealth + + // Check vector store health + vectorHealth := &SubsystemHealth{ + Status: "healthy", + Metrics: make(map[string]any), + } + if s.vectorClient != nil { + stats, err := s.vectorClient.GetHealthStats(ctx) + if err != nil { + vectorHealth.Status = "degraded" + vectorHealth.Message = "Could not get vector stats: " + err.Error() + report.HealthScore -= 15 + } else { + vectorHealth.Metrics["total_vectors"] = stats.TotalVectors + vectorHealth.Metrics["stale_vectors"] = stats.StaleVectors + vectorHealth.Metrics["current_model"] = stats.CurrentModel + vectorHealth.Metrics["needs_rebuild"] = stats.NeedsRebuild + + if stats.NeedsRebuild { + vectorHealth.Status = "degraded" + vectorHealth.Warnings = append(vectorHealth.Warnings, "Vector rebuild recommended: "+stats.RebuildReason) + report.Actions = append(report.Actions, "Run vector rebuild to update embeddings") + report.HealthScore -= 10 + } + + // Check stale ratio + if stats.TotalVectors > 0 { + staleRatio := float64(stats.StaleVectors) / float64(stats.TotalVectors) + if staleRatio > 0.2 { + vectorHealth.Warnings = append(vectorHealth.Warnings, + fmt.Sprintf("%.1f%% of vectors are stale", staleRatio*100)) + report.HealthScore -= 5 + } + } + } + + // Check cache performance + cacheStats := s.vectorClient.GetCacheStats() + vectorHealth.Metrics["cache_hit_rate"] = fmt.Sprintf("%.1f%%", cacheStats.HitRate()) + vectorHealth.Metrics["embedding_hits"] = cacheStats.EmbeddingHits + vectorHealth.Metrics["embedding_misses"] = cacheStats.EmbeddingMisses + vectorHealth.Metrics["result_hits"] = cacheStats.ResultHits + vectorHealth.Metrics["result_misses"] = cacheStats.ResultMisses + + if cacheStats.HitRate() < 20 && (cacheStats.EmbeddingHits+cacheStats.EmbeddingMisses) > 100 { + vectorHealth.Warnings = append(vectorHealth.Warnings, "Low cache hit rate - consider cache tuning") + } + } else { + vectorHealth.Status = "unhealthy" + vectorHealth.Message = "Vector client not initialized" + report.HealthScore -= 30 + } + report.Subsystems["vectors"] = vectorHealth + + // Check pattern detection health + patternHealth := &SubsystemHealth{ + Status: "healthy", + Metrics: make(map[string]any), + } + if s.patternStore != nil { + patterns, err := s.patternStore.GetActivePatterns(ctx, 100) + if err != nil { + patternHealth.Status = "degraded" + patternHealth.Message = "Could not query patterns: " + err.Error() + } else { + patternHealth.Metrics["total_patterns"] = len(patterns) + + // Count by type + typeCounts := make(map[string]int) + for _, p := range patterns { + typeCounts[string(p.Type)]++ + } + patternHealth.Metrics["patterns_by_type"] = typeCounts + } + } + report.Subsystems["patterns"] = patternHealth + + // Check session store health + sessionHealth := &SubsystemHealth{ + Status: "healthy", + Metrics: make(map[string]any), + } + if s.sessionStore != nil { + sessionsToday, err := s.sessionStore.GetSessionsToday(ctx) + if err != nil { + sessionHealth.Status = "degraded" + sessionHealth.Message = "Could not query sessions: " + err.Error() + } else { + sessionHealth.Metrics["sessions_today"] = sessionsToday + } + } + report.Subsystems["sessions"] = sessionHealth + + // Determine overall status + unhealthyCount := 0 + degradedCount := 0 + for _, sub := range report.Subsystems { + switch sub.Status { + case "unhealthy": + unhealthyCount++ + case "degraded": + degradedCount++ + } + } + + if unhealthyCount > 0 { + report.OverallStatus = "unhealthy" + } else if degradedCount > 0 { + report.OverallStatus = "degraded" + } + + // Cap health score + if report.HealthScore < 0 { + report.HealthScore = 0 + } + + // Add recommended actions based on issues + if report.HealthScore < 70 { + report.Actions = append(report.Actions, "System needs attention - check subsystem details") + } + + output, err := json.Marshal(report) + if err != nil { + return "", fmt.Errorf("marshal health report: %w", err) + } + return string(output), nil +} + +// handleAnalyzeSearchPatterns analyzes search query patterns. +func (s *Server) handleAnalyzeSearchPatterns(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Days int `json:"days"` + TopN int `json:"top_n"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid params: %w", err) + } + + if params.Days <= 0 { + params.Days = 7 + } + if params.TopN <= 0 { + params.TopN = 10 + } + + type QueryPattern struct { + Query string `json:"query"` + Count int `json:"count"` + AvgResults float64 `json:"avg_results"` + ZeroResults int `json:"zero_result_count"` + LastUsed string `json:"last_used"` + } + + type PatternAnalysis struct { + Period string `json:"period"` + TotalSearches int `json:"total_searches"` + UniqueQueries int `json:"unique_queries"` + TopQueries []QueryPattern `json:"top_queries"` + ZeroResultQueries []string `json:"zero_result_queries,omitempty"` + Insights []string `json:"insights,omitempty"` + } + + analysis := &PatternAnalysis{ + Period: fmt.Sprintf("Last %d days", params.Days), + TopQueries: []QueryPattern{}, + ZeroResultQueries: []string{}, + Insights: []string{}, + } + + // Get search stats from the search manager if available + if s.searchMgr != nil { + metrics := s.searchMgr.Metrics() + if metrics != nil { + stats := metrics.GetStats() + if totalSearches, ok := stats["total_searches"].(int); ok && totalSearches > 0 { + analysis.TotalSearches = totalSearches + analysis.Insights = append(analysis.Insights, + fmt.Sprintf("Total searches: %d", totalSearches)) + } + if avgLatency, ok := stats["avg_latency_ms"].(float64); ok { + analysis.Insights = append(analysis.Insights, + fmt.Sprintf("Average search latency: %.2fms", avgLatency)) + } + } + + // Get cache stats + cacheStats := s.searchMgr.CacheStats() + if hitRate, ok := cacheStats["hit_rate"].(float64); ok { + analysis.Insights = append(analysis.Insights, + fmt.Sprintf("Cache hit rate: %.1f%%", hitRate*100)) + } + } + + // Analyze observation patterns to suggest search improvements + if s.observationStore != nil { + // Get recent observations to understand content patterns + observations, err := s.observationStore.GetAllRecentObservations(ctx, 100) + if err == nil { + analysis.UniqueQueries = len(observations) + + // Analyze observation types + typeCounts := make(map[string]int) + for _, obs := range observations { + typeCounts[string(obs.Type)]++ + } + + // Find most common types + mostCommon := "" + maxCount := 0 + for t, c := range typeCounts { + if c > maxCount { + mostCommon = t + maxCount = c + } + } + if mostCommon != "" { + analysis.Insights = append(analysis.Insights, + fmt.Sprintf("Most common observation type: %s (%d occurrences)", mostCommon, maxCount)) + } + + // Check for concept coverage + conceptCounts := make(map[string]int) + for _, obs := range observations { + for _, c := range obs.Concepts { + conceptCounts[c]++ + } + } + if len(conceptCounts) > 0 { + analysis.Insights = append(analysis.Insights, + fmt.Sprintf("%d unique concepts across %d observations", len(conceptCounts), len(observations))) + } + } + } + + // Add general recommendations + if len(analysis.Insights) == 0 { + analysis.Insights = append(analysis.Insights, "Insufficient data for pattern analysis") + } + + output, err := json.Marshal(analysis) + if err != nil { + return "", fmt.Errorf("marshal analysis: %w", err) + } + return string(output), nil +} + +// handleGetObservationRelationships returns the relationship graph for an observation. +func (s *Server) handleGetObservationRelationships(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` + MaxDepth int `json:"max_depth"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid params: %w", err) + } + + if params.ID <= 0 { + return "", fmt.Errorf("id is required and must be positive") + } + if params.MaxDepth <= 0 { + params.MaxDepth = 2 + } + if params.MaxDepth > 5 { + params.MaxDepth = 5 + } + + if s.relationStore == nil { + return "", fmt.Errorf("relation store not available") + } + + // Get the relationship graph + graph, err := s.relationStore.GetRelationGraph(ctx, params.ID, params.MaxDepth) + if err != nil { + return "", fmt.Errorf("get relation graph: %w", err) + } + + // Build response with additional context + type RelationInfo struct { + ID int64 `json:"id"` + SourceID int64 `json:"source_id"` + TargetID int64 `json:"target_id"` + Type string `json:"type"` + Confidence float64 `json:"confidence"` + SourceTitle string `json:"source_title,omitempty"` + TargetTitle string `json:"target_title,omitempty"` + SourceType string `json:"source_type,omitempty"` + TargetType string `json:"target_type,omitempty"` + } + + type GraphResponse struct { + CenterID int64 `json:"center_id"` + MaxDepth int `json:"max_depth"` + TotalRelations int `json:"total_relations"` + Relations []RelationInfo `json:"relations"` + UniqueNodes []int64 `json:"unique_nodes"` + } + + // Collect unique node IDs + nodeSet := make(map[int64]bool) + nodeSet[params.ID] = true + + relations := make([]RelationInfo, 0, len(graph.Relations)) + for _, r := range graph.Relations { + nodeSet[r.Relation.SourceID] = true + nodeSet[r.Relation.TargetID] = true + + relations = append(relations, RelationInfo{ + ID: r.Relation.ID, + SourceID: r.Relation.SourceID, + TargetID: r.Relation.TargetID, + Type: string(r.Relation.RelationType), + Confidence: r.Relation.Confidence, + SourceTitle: r.SourceTitle, + TargetTitle: r.TargetTitle, + SourceType: string(r.SourceType), + TargetType: string(r.TargetType), + }) + } + + // Convert node set to slice + nodes := make([]int64, 0, len(nodeSet)) + for id := range nodeSet { + nodes = append(nodes, id) + } + + response := GraphResponse{ + CenterID: params.ID, + MaxDepth: params.MaxDepth, + TotalRelations: len(relations), + Relations: relations, + UniqueNodes: nodes, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + return string(output), nil +} + +// handleGetObservationScoringBreakdown returns detailed scoring breakdown for an observation. +func (s *Server) handleGetObservationScoringBreakdown(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + ID int64 `json:"id"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + if params.ID <= 0 { + return "", fmt.Errorf("id is required and must be positive") + } + + // Get the observation + obs, err := s.observationStore.GetObservationByID(ctx, params.ID) + if err != nil { + return "", fmt.Errorf("get observation: %w", err) + } + if obs == nil { + return "", fmt.Errorf("observation not found: %d", params.ID) + } + + // Calculate scoring components + if s.scoreCalculator == nil { + return "", fmt.Errorf("score calculator not initialized") + } + + components := s.scoreCalculator.CalculateComponents(obs, time.Now()) + + // Build response with observation context + response := map[string]any{ + "observation": map[string]any{ + "id": obs.ID, + "title": obs.Title.String, + "type": string(obs.Type), + "project": obs.Project, + "created_at": obs.CreatedAtEpoch, + }, + "scoring": map[string]any{ + "final_score": components.FinalScore, + "type_weight": components.TypeWeight, + "recency_decay": components.RecencyDecay, + "core_score": components.CoreScore, + "feedback_contrib": components.FeedbackContrib, + "concept_contrib": components.ConceptContrib, + "retrieval_contrib": components.RetrievalContrib, + "age_days": components.AgeDays, + }, + "explanation": map[string]any{ + "type_impact": fmt.Sprintf("Observation type '%s' has weight %.2f", obs.Type, components.TypeWeight), + "recency_impact": fmt.Sprintf("%.1f days old, decay factor %.2f", components.AgeDays, components.RecencyDecay), + "feedback_impact": fmt.Sprintf("User feedback contributes %.2f to score", components.FeedbackContrib), + "concept_impact": fmt.Sprintf("Concept tags contribute %.2f to score", components.ConceptContrib), + "retrieval_impact": fmt.Sprintf("Retrieval frequency contributes %.2f to score", components.RetrievalContrib), + }, + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + return string(output), nil +} + +// handleAnalyzeObservationImportance returns importance analysis for a project's observations. +func (s *Server) handleAnalyzeObservationImportance(ctx context.Context, args json.RawMessage) (string, error) { + var params struct { + Project string `json:"project"` + IncludeTopScored *bool `json:"include_top_scored"` + IncludeMostRetrieved *bool `json:"include_most_retrieved"` + IncludeConceptWeights *bool `json:"include_concept_weights"` + Limit int `json:"limit"` + } + if err := json.Unmarshal(args, ¶ms); err != nil { + return "", fmt.Errorf("invalid arguments: %w", err) + } + + // Set defaults + if params.Limit <= 0 { + params.Limit = 10 + } + if params.Limit > 50 { + params.Limit = 50 + } + includeTopScored := params.IncludeTopScored == nil || *params.IncludeTopScored + includeMostRetrieved := params.IncludeMostRetrieved == nil || *params.IncludeMostRetrieved + includeConceptWeights := params.IncludeConceptWeights == nil || *params.IncludeConceptWeights + + response := make(map[string]any) + response["project"] = params.Project + if params.Project == "" { + response["project"] = "(all projects)" + } + + // Get feedback statistics + stats, err := s.observationStore.GetObservationFeedbackStats(ctx, params.Project) + if err != nil { + return "", fmt.Errorf("get feedback stats: %w", err) + } + response["feedback_stats"] = stats + + // Get top-scoring observations + if includeTopScored { + topScored, err := s.observationStore.GetTopScoringObservations(ctx, params.Project, params.Limit) + if err != nil { + log.Warn().Err(err).Msg("Failed to get top-scoring observations") + } else { + topScoredSummary := make([]map[string]any, 0, len(topScored)) + for _, obs := range topScored { + topScoredSummary = append(topScoredSummary, map[string]any{ + "id": obs.ID, + "title": obs.Title.String, + "type": string(obs.Type), + "importance_score": obs.ImportanceScore, + }) + } + response["top_scoring_observations"] = topScoredSummary + } + } + + // Get most-retrieved observations + if includeMostRetrieved { + mostRetrieved, err := s.observationStore.GetMostRetrievedObservations(ctx, params.Project, params.Limit) + if err != nil { + log.Warn().Err(err).Msg("Failed to get most-retrieved observations") + } else { + mostRetrievedSummary := make([]map[string]any, 0, len(mostRetrieved)) + for _, obs := range mostRetrieved { + mostRetrievedSummary = append(mostRetrievedSummary, map[string]any{ + "id": obs.ID, + "title": obs.Title.String, + "type": string(obs.Type), + "retrieval_count": obs.RetrievalCount, + }) + } + response["most_retrieved_observations"] = mostRetrievedSummary + } + } + + // Get concept weights + if includeConceptWeights { + conceptWeights, err := s.observationStore.GetConceptWeights(ctx) + if err != nil { + log.Warn().Err(err).Msg("Failed to get concept weights") + } else if len(conceptWeights) > 0 { + response["concept_weights"] = conceptWeights + } + } + + // Generate insights + insights := []string{} + if stats != nil { + if stats.Positive > 0 { + insights = append(insights, fmt.Sprintf("%d observations marked as valuable (positive feedback)", stats.Positive)) + } + if stats.Negative > 0 { + insights = append(insights, fmt.Sprintf("%d observations marked as not helpful (negative feedback)", stats.Negative)) + } + if stats.AvgScore > 0 { + insights = append(insights, fmt.Sprintf("Average importance score: %.2f", stats.AvgScore)) + } + if stats.AvgRetrieval > 0 { + insights = append(insights, fmt.Sprintf("Average retrieval count: %.1f", stats.AvgRetrieval)) + } + } + if len(insights) > 0 { + response["insights"] = insights + } + + output, err := json.Marshal(response) + if err != nil { + return "", fmt.Errorf("marshal response: %w", err) + } + return string(output), nil +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 469dfc0..32e53a2 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -24,7 +24,7 @@ func TestServerSuite(t *testing.T) { // TestNewServer tests server creation. func (s *ServerSuite) TestNewServer() { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) s.NotNil(server) s.Nil(server.searchMgr) s.Equal("1.0.0", server.version) @@ -293,7 +293,7 @@ func TestTimelineParams(t *testing.T) { // TestHandleInitialize tests the initialize handler. func TestHandleInitialize(t *testing.T) { - server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ JSONRPC: "2.0", @@ -320,7 +320,7 @@ func TestHandleInitialize(t *testing.T) { // TestHandleToolsList tests the tools/list handler. func TestHandleToolsList(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ JSONRPC: "2.0", @@ -361,7 +361,7 @@ func TestHandleToolsList(t *testing.T) { // TestHandleRequest tests request routing. func TestHandleRequest(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() tests := []struct { @@ -423,7 +423,7 @@ func TestHandleRequest(t *testing.T) { // TestHandleToolsCall_InvalidParams tests tools/call with invalid params. func TestHandleToolsCall_InvalidParams(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() req := &Request{ @@ -442,7 +442,7 @@ func TestHandleToolsCall_InvalidParams(t *testing.T) { // TestCallTool_UnknownTool tests callTool with unknown tool name. func TestCallTool_UnknownTool(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() _, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`)) @@ -452,7 +452,7 @@ func TestCallTool_UnknownTool(t *testing.T) { // TestCallTool_InvalidArgs tests callTool with invalid arguments. func TestCallTool_InvalidArgs(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() _, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`)) @@ -574,7 +574,7 @@ func TestJSONRPCErrorCodes(t *testing.T) { // TestToolListContainsExpectedSchemas tests that tool schemas are valid. func TestToolListContainsExpectedSchemas(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ JSONRPC: "2.0", @@ -600,7 +600,7 @@ func TestToolListContainsExpectedSchemas(t *testing.T) { // TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name. func TestHandleToolsCall_UnknownTool(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() req := &Request{ @@ -620,7 +620,7 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) { func TestCallTool_ToolNameRecognition(t *testing.T) { // Note: This test verifies tool routing logic, not execution (which requires searchMgr) // All valid tool names should be in the handleToolsList response - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ JSONRPC: "2.0", @@ -782,7 +782,7 @@ func TestResponseIDTypes(t *testing.T) { // TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query. func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() // Empty query should error @@ -793,7 +793,7 @@ func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) { // TestHandleTimeline_InvalidJSON tests timeline with invalid JSON. func TestHandleTimeline_InvalidJSON(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() _, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`)) @@ -803,7 +803,7 @@ func TestHandleTimeline_InvalidJSON(t *testing.T) { // TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON. func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() _, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`)) @@ -813,7 +813,7 @@ func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) { // TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query. func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() // No anchor_id and no query should return empty result @@ -825,7 +825,7 @@ func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) { // TestHandleTimeline_WithDefaults tests timeline default values are applied. func TestHandleTimeline_WithDefaults(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() // With anchor_id but no before/after, defaults should be applied @@ -839,7 +839,7 @@ func TestHandleTimeline_WithDefaults(t *testing.T) { // TestServerFields tests Server struct fields. func TestServerFields(t *testing.T) { - server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil, nil) assert.Equal(t, "2.0.0", server.version) assert.Nil(t, server.searchMgr) @@ -891,7 +891,7 @@ func TestErrorWithNilData(t *testing.T) { // TestToolInputSchema tests that tool input schemas have required fields. func TestToolInputSchema(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) req := &Request{ JSONRPC: "2.0", @@ -960,7 +960,7 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) { // TestCallTool_UnknownToolName tests callTool with various unknown tool names. func TestCallTool_UnknownToolName(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() unknownTools := []string{ @@ -1009,7 +1009,7 @@ func TestTimelineParams_Validation(t *testing.T) { // TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error. func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() req := &Request{ @@ -1031,7 +1031,7 @@ func TestHandleToolsCall_UnknownToolNameError(t *testing.T) { // TestHandleToolsCall_EmptyParams tests tools/call with empty params. func TestHandleToolsCall_EmptyParams(t *testing.T) { - server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil) + server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil) ctx := context.Background() req := &Request{ diff --git a/internal/pattern/detector.go b/internal/pattern/detector.go index 57997c6..81a26d6 100644 --- a/internal/pattern/detector.go +++ b/internal/pattern/detector.go @@ -3,6 +3,8 @@ package pattern import ( "context" + "sort" + "strings" "sync" "time" @@ -21,6 +23,8 @@ type DetectorConfig struct { AnalysisInterval time.Duration // MaxPatternsToTrack is the maximum number of active patterns. MaxPatternsToTrack int + // MaxCandidates is the maximum number of candidates to track (LRU eviction). + MaxCandidates int } // DefaultConfig returns the default detector configuration. @@ -30,6 +34,7 @@ func DefaultConfig() DetectorConfig { MinFrequencyForPattern: 2, // At least 2 occurrences to form a pattern AnalysisInterval: 5 * time.Minute, MaxPatternsToTrack: 1000, + MaxCandidates: 500, // Prevent unbounded growth } } @@ -201,7 +206,24 @@ func (d *Detector) AnalyzeObservation(ctx context.Context, obs *models.Observati } } } else { - // Create new candidate + // Create new candidate - with immediate size check to prevent unbounded growth + // between periodic cleanups (which run every 5 minutes) + if d.config.MaxCandidates > 0 && len(d.candidates) >= d.config.MaxCandidates { + // Evict oldest candidate immediately rather than waiting for periodic cleanup + var oldestKey string + var oldestTime int64 = time.Now().UnixMilli() + for k, c := range d.candidates { + if c.lastSeenEpoch < oldestTime { + oldestTime = c.lastSeenEpoch + oldestKey = k + } + } + if oldestKey != "" { + delete(d.candidates, oldestKey) + log.Debug().Str("evicted_key", oldestKey).Msg("Evicted oldest candidate to make room") + } + } + patternType := models.DetectPatternType(obs.Concepts, obs.Title.String, obs.Narrative.String) d.candidates[candidateKey] = &candidatePattern{ signature: signature, @@ -305,17 +327,53 @@ func (d *Detector) AnalyzeRecentObservations(ctx context.Context) error { return nil } -// cleanupOldCandidates removes candidates that haven't been seen recently. +// cleanupOldCandidates removes candidates that haven't been seen recently +// and enforces the max candidates limit using LRU eviction. func (d *Detector) cleanupOldCandidates() { d.candidatesMu.Lock() defer d.candidatesMu.Unlock() threshold := time.Now().Add(-7 * 24 * time.Hour).UnixMilli() + + // First pass: remove expired candidates for key, candidate := range d.candidates { if candidate.lastSeenEpoch < threshold { delete(d.candidates, key) } } + + // Second pass: enforce max candidates limit using LRU eviction + if d.config.MaxCandidates > 0 && len(d.candidates) > d.config.MaxCandidates { + // Find oldest candidates to evict using O(n log n) sort instead of O(n²) selection sort + type keyAge struct { + key string + age int64 + } + candidates := make([]keyAge, 0, len(d.candidates)) + for k, c := range d.candidates { + candidates = append(candidates, keyAge{k, c.lastSeenEpoch}) + } + + // Sort by age ascending (oldest first) - O(n log n) + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].age < candidates[j].age + }) + + // Delete oldest entries + toEvict := len(d.candidates) - d.config.MaxCandidates + for i := 0; i < toEvict; i++ { + delete(d.candidates, candidates[i].key) + } + + log.Debug().Int("evicted", toEvict).Int("remaining", len(d.candidates)).Msg("Evicted old pattern candidates (LRU)") + } +} + +// CandidateCount returns the current number of pattern candidates. +func (d *Detector) CandidateCount() int { + d.candidatesMu.RLock() + defer d.candidatesMu.RUnlock() + return len(d.candidates) } // GetPatternInsight returns a formatted insight string for a pattern. @@ -340,11 +398,15 @@ func generateCandidateKey(signature []string) string { if len(signature) == 0 { return "" } - key := "" + // Use strings.Builder to avoid O(n²) string concatenation + var b strings.Builder + // Pre-allocate: estimate average signature element is 10 chars + separator + b.Grow(len(signature) * 11) for _, s := range signature { - key += s + "|" + b.WriteString(s) + b.WriteByte('|') } - return key + return b.String() } // generatePatternName creates a human-readable name for a pattern. diff --git a/internal/privacy/secrets.go b/internal/privacy/secrets.go new file mode 100644 index 0000000..2b52992 --- /dev/null +++ b/internal/privacy/secrets.go @@ -0,0 +1,98 @@ +// Package privacy provides utilities for protecting sensitive data. +package privacy + +import ( + "regexp" + "slices" + "strings" +) + +// secretPatterns contains compiled regular expressions for detecting secrets. +// These patterns are designed to catch common secret formats with minimal false positives. +var secretPatterns = []*regexp.Regexp{ + // API keys with common prefixes + regexp.MustCompile(`(?i)(api[_-]?key|apikey)\s*[:=]\s*['"]?[a-zA-Z0-9_-]{20,}['"]?`), + + // Passwords in configuration + regexp.MustCompile(`(?i)(password|passwd|pwd)\s*[:=]\s*['"][^'"]{8,}['"]`), + + // Secret tokens + regexp.MustCompile(`(?i)(secret[_-]?key|secret[_-]?token|auth[_-]?token)\s*[:=]\s*['"]?[a-zA-Z0-9_-]{20,}['"]?`), + + // OpenAI API keys + regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}`), + + // Anthropic API keys + regexp.MustCompile(`sk-ant-[a-zA-Z0-9-]{20,}`), + + // GitHub tokens + regexp.MustCompile(`gh[pous]_[a-zA-Z0-9]{36,}`), + regexp.MustCompile(`github_pat_[a-zA-Z0-9_]{22,}`), + + // AWS keys + regexp.MustCompile(`AKIA[0-9A-Z]{16}`), + regexp.MustCompile(`(?i)aws[_-]?secret[_-]?access[_-]?key\s*[:=]\s*['"]?[a-zA-Z0-9/+=]{40}['"]?`), + + // Private keys (PEM format indicators) + regexp.MustCompile(`-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`), + + // JWT tokens (base64.base64.base64 format) + regexp.MustCompile(`eyJ[a-zA-Z0-9_-]+\.eyJ[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+`), + + // Generic secret assignment patterns + regexp.MustCompile(`(?i)bearer\s+[a-zA-Z0-9_-]{20,}`), +} + +// ContainsSecrets checks if the given text contains any patterns that look like secrets. +// Returns true if potential secrets are detected. +func ContainsSecrets(text string) bool { + if text == "" { + return false + } + + for _, pattern := range secretPatterns { + if pattern.MatchString(text) { + return true + } + } + return false +} + +// RedactSecrets replaces detected secrets with a redaction marker. +// This allows the text to be stored while protecting sensitive data. +func RedactSecrets(text string) string { + if text == "" { + return text + } + + result := text + for _, pattern := range secretPatterns { + result = pattern.ReplaceAllStringFunc(result, func(match string) string { + // Preserve the key name, redact only the value + if idx := strings.Index(match, "="); idx != -1 { + return match[:idx+1] + "[REDACTED]" + } + if idx := strings.Index(match, ":"); idx != -1 { + return match[:idx+1] + "[REDACTED]" + } + // For standalone secrets, show just the prefix + if len(match) > 8 { + return match[:4] + "...[REDACTED]" + } + return "[REDACTED]" + }) + } + return result +} + +// SanitizeObservation checks multiple fields of an observation for secrets. +// Returns true if any secrets were found. +// This function is used as a validation gate before storing observations. +func SanitizeObservation(narrative string, facts []string) bool { + if ContainsSecrets(narrative) { + return true + } + return slices.ContainsFunc(facts, func(fact string) bool { + return ContainsSecrets(fact) + }) +} diff --git a/internal/privacy/secrets_test.go b/internal/privacy/secrets_test.go new file mode 100644 index 0000000..e3c2694 --- /dev/null +++ b/internal/privacy/secrets_test.go @@ -0,0 +1,213 @@ +package privacy + +import ( + "testing" +) + +func TestContainsSecrets(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + { + name: "empty string", + input: "", + expected: false, + }, + { + name: "normal text", + input: "This is just some regular text about a bug fix", + expected: false, + }, + { + name: "API key pattern", + input: "api_key=abc123def456ghi789jkl012mno345pqr678", + expected: true, + }, + { + name: "api-key with dash", + input: `api-key: "abc123def456ghi789jkl012mno"`, + expected: true, + }, + { + name: "password in config", + input: `password="super_secret_password_123"`, + expected: true, + }, + { + name: "OpenAI key format", + input: "sk-abc123def456ghi789jkl012mno345pqr678", + expected: true, + }, + { + name: "Anthropic key format", + input: "sk-ant-api03-abc123def456ghi789jkl012mno345", + expected: true, + }, + { + name: "GitHub PAT", + input: "ghp_1234567890abcdefghijklmnopqrstuvwxyz", + expected: true, + }, + { + name: "GitHub PAT new format", + input: "github_pat_12ABCDEFGHIJ3456789abc_defghijklmno", + expected: true, + }, + { + name: "AWS access key", + input: "AKIAIOSFODNN7EXAMPLE", + expected: true, + }, + { + name: "Private key header", + input: "-----BEGIN RSA PRIVATE KEY-----", + expected: true, + }, + { + name: "JWT token", + input: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U", + expected: true, + }, + { + name: "bearer token", + input: "Bearer abc123def456ghi789jkl012mno345", + expected: true, + }, + { + name: "secret_key in code", + input: `secret_key = "my_super_secret_token_here"`, + expected: true, + }, + { + name: "short password is not detected", + input: `password="short"`, + expected: false, // Too short to trigger + }, + { + name: "word password in sentence", + input: "The password field should be validated", + expected: false, + }, + { + name: "word api in code", + input: "The API returns JSON data", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ContainsSecrets(tt.input) + if result != tt.expected { + t.Errorf("ContainsSecrets(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestRedactSecrets(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "no secrets", + input: "This is safe text", + expected: "This is safe text", + }, + { + name: "API key gets redacted", + input: "api_key=abc123def456ghi789jkl012mno345pqr678", + expected: "api_key=[REDACTED]", + }, + { + name: "OpenAI key gets redacted", + input: "The key is sk-abc123def456ghi789jkl012mno345pqr678", + expected: "The key is sk-a...[REDACTED]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := RedactSecrets(tt.input) + if result != tt.expected { + t.Errorf("RedactSecrets(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestSanitizeObservation(t *testing.T) { + tests := []struct { + name string + narrative string + facts []string + expected bool + }{ + { + name: "clean observation", + narrative: "Fixed a bug in the login flow", + facts: []string{"Users can now log in", "Session management improved"}, + expected: false, + }, + { + name: "secret in narrative", + narrative: "Set API key api_key=abc123def456ghi789jkl012mno345", + facts: []string{"Configuration updated"}, + expected: true, + }, + { + name: "secret in facts", + narrative: "Updated configuration", + facts: []string{"Added api_key=abc123def456ghi789jkl012mno345"}, + expected: true, + }, + { + name: "empty facts", + narrative: "Clean narrative", + facts: []string{}, + expected: false, + }, + { + name: "nil facts", + narrative: "Clean narrative", + facts: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeObservation(tt.narrative, tt.facts) + if result != tt.expected { + t.Errorf("SanitizeObservation() = %v, want %v", result, tt.expected) + } + }) + } +} + +func BenchmarkContainsSecrets(b *testing.B) { + text := "This is a normal piece of text that does not contain any secrets or sensitive information" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ContainsSecrets(text) + } +} + +func BenchmarkContainsSecretsWithSecret(b *testing.B) { + text := "api_key=abc123def456ghi789jkl012mno345pqr678" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ContainsSecrets(text) + } +} diff --git a/internal/scoring/calculator.go b/internal/scoring/calculator.go index 8236ea1..b20ac67 100644 --- a/internal/scoring/calculator.go +++ b/internal/scoring/calculator.go @@ -36,6 +36,13 @@ func NewCalculator(config *models.ScoringConfig) *Calculator { // - ConceptContrib = sum(concept_weights) × concept_weight_factor // - RetrievalContrib = log2(retrieval_count + 1) × 0.1 × retrieval_weight func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 { + return c.CalculateComponents(obs, now).FinalScore +} + +// CalculateComponents returns the individual components of the importance score. +// Useful for debugging and explaining scores to users. +// This is the core calculation method - Calculate() delegates to this. +func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents { // 1. Get base type weight typeWeight := models.TypeBaseScore(obs.Type) @@ -75,42 +82,6 @@ func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 { finalScore = c.config.MinScore } - return finalScore -} - -// CalculateComponents returns the individual components of the importance score. -// Useful for debugging and explaining scores to users. -func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents { - typeWeight := models.TypeBaseScore(obs.Type) - - ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0 - if ageDays < 0 { - ageDays = 0 - } - recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays) - - coreScore := 1.0 * typeWeight * recencyDecay - feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight - - conceptBoost := 0.0 - for _, concept := range obs.Concepts { - if weight, ok := c.config.ConceptWeights[concept]; ok { - conceptBoost += weight - } - } - conceptContrib := conceptBoost * c.config.ConceptWeight - - retrievalContrib := 0.0 - if obs.RetrievalCount > 0 { - retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1 - retrievalContrib = retrievalBoost * c.config.RetrievalWeight - } - - finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib - if finalScore < c.config.MinScore { - finalScore = c.config.MinScore - } - return ScoreComponents{ TypeWeight: typeWeight, RecencyDecay: recencyDecay, diff --git a/internal/search/expansion/expander.go b/internal/search/expansion/expander.go index be1ab6e..b897319 100644 --- a/internal/search/expansion/expander.go +++ b/internal/search/expansion/expander.go @@ -3,7 +3,9 @@ package expansion import ( "context" + "math" "regexp" + "sort" "strings" "sync" @@ -281,14 +283,10 @@ func (e *Expander) expandByVocabulary(ctx context.Context, query string, minSimi return nil } - // Sort by score (descending) using bubble sort - for i := 0; i < len(similar)-1; i++ { - for j := i + 1; j < len(similar); j++ { - if similar[j].score > similar[i].score { - similar[i], similar[j] = similar[j], similar[i] - } - } - } + // Sort by score (descending) using Go's standard sort - O(n log n) + sort.Slice(similar, func(i, j int) bool { + return similar[i].score > similar[j].score + }) // Create expansion by combining top similar terms with query var expansions []ExpandedQuery @@ -427,16 +425,13 @@ func cosineSimilarity(a, b []float32) float64 { return dot / (sqrt(normA) * sqrt(normB)) } -// sqrt is a simple square root implementation. +// sqrt uses the standard math.Sqrt for better performance and accuracy. +// Returns 0 for non-positive values (original behavior for compatibility). func sqrt(x float64) float64 { if x <= 0 { return 0 } - z := x - for i := 0; i < 10; i++ { - z = (z + x/z) / 2 - } - return z + return math.Sqrt(x) } // truncate truncates a string to maxLen characters. diff --git a/internal/search/manager.go b/internal/search/manager.go index cc21626..ed0be5e 100644 --- a/internal/search/manager.go +++ b/internal/search/manager.go @@ -3,19 +3,156 @@ package search import ( "context" + "hash/fnv" + "regexp" + "sort" + "strconv" "strings" + "sync" + "sync/atomic" + "time" "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" + "golang.org/x/sync/singleflight" ) +// multiSpaceRegex matches multiple consecutive whitespace characters. +// Pre-compiled for performance in normalizeQuery. +var multiSpaceRegex = regexp.MustCompile(`\s+`) + +// Search configuration constants. +const ( + // Cache configuration + defaultCacheTTL = 30 * time.Second // Short TTL for freshness + defaultCacheMaxSize = 200 // Max cached results + cacheEvictionPercent = 10 // Evict 10% when cache is full + cacheEvictionThreshold = 80 // Start eviction scan at 80% capacity + + // Latency tracking + latencyHistogramCap = 1000 // Max latency samples for histogram + slowQueryThresholdNs = 100 * 1e6 // 100ms threshold for slow query logging + + // Query frequency tracking + maxFrequencyEntries = 1000 // Max queries to track for warming + frequencyEvictionBatch = 100 // Remove 10% when frequency map is full + staleQueryThreshold = 24 * time.Hour // Remove queries older than 24 hours + recentQueryWindow = time.Hour // Only consider queries from last hour for warming + + // Cache warming configuration + cacheWarmingInitDelay = 30 * time.Second // Delay before starting warming + cacheWarmingInterval = 20 * time.Second // Run warming cycle every 20 seconds + frequencyCleanupInterval = 5 * time.Minute // Cleanup stale entries every 5 minutes + cacheCleanupInterval = time.Minute // Cleanup expired cache every minute + warmingQueryTimeout = 5 * time.Second // Timeout for warming queries + warmingBatchSize = 5 // Warm top 5 queries per cycle + minRecencyFactor = 0.1 // Minimum recency factor for scoring + + // Default query limits + defaultQueryLimit = 20 + maxQueryLimit = 100 + defaultOrderBy = "date_desc" + + // Truncation lengths + queryLogTruncateLen = 50 // Truncate query in logs + titleTruncateLen = 100 // Truncate titles in results + warmingLogTruncateLen = 30 // Truncate query in warming logs +) + +// SearchMetrics tracks search performance statistics. +type SearchMetrics struct { + TotalSearches int64 // Total number of searches performed + VectorSearches int64 // Searches using vector search + FilterSearches int64 // Searches using filter/FTS search + TotalLatencyNs int64 // Cumulative latency in nanoseconds + VectorLatencyNs int64 // Cumulative vector search latency + FilterLatencyNs int64 // Cumulative filter search latency + CacheHits int64 // Number of result cache hits + CoalescedRequests int64 // Number of requests served via singleflight coalescing + SearchErrors int64 // Number of search errors + + // Percentile tracking (approximate using reservoir sampling) + latencyHistogram []int64 // Recent latency samples + histogramMu sync.Mutex +} + +// GetStats returns the current search statistics. +func (m *SearchMetrics) GetStats() map[string]any { + totalSearches := atomic.LoadInt64(&m.TotalSearches) + totalLatency := atomic.LoadInt64(&m.TotalLatencyNs) + vectorSearches := atomic.LoadInt64(&m.VectorSearches) + vectorLatency := atomic.LoadInt64(&m.VectorLatencyNs) + filterSearches := atomic.LoadInt64(&m.FilterSearches) + filterLatency := atomic.LoadInt64(&m.FilterLatencyNs) + + avgLatencyMs := float64(0) + if totalSearches > 0 { + avgLatencyMs = float64(totalLatency) / float64(totalSearches) / 1e6 + } + + avgVectorLatencyMs := float64(0) + if vectorSearches > 0 { + avgVectorLatencyMs = float64(vectorLatency) / float64(vectorSearches) / 1e6 + } + + avgFilterLatencyMs := float64(0) + if filterSearches > 0 { + avgFilterLatencyMs = float64(filterLatency) / float64(filterSearches) / 1e6 + } + + return map[string]any{ + "total_searches": totalSearches, + "vector_searches": vectorSearches, + "filter_searches": filterSearches, + "cache_hits": atomic.LoadInt64(&m.CacheHits), + "coalesced_requests": atomic.LoadInt64(&m.CoalescedRequests), + "search_errors": atomic.LoadInt64(&m.SearchErrors), + "avg_latency_ms": avgLatencyMs, + "avg_vector_latency_ms": avgVectorLatencyMs, + "avg_filter_latency_ms": avgFilterLatencyMs, + } +} + // Manager provides unified search across SQLite and sqlite-vec. type Manager struct { observationStore *gorm.ObservationStore summaryStore *gorm.SummaryStore promptStore *gorm.PromptStore vectorClient *sqlitevec.Client + metrics *SearchMetrics + + // Context for graceful shutdown of background goroutines + ctx context.Context + cancel context.CancelFunc + + // Request coalescing for concurrent identical queries + searchGroup singleflight.Group + + // Result cache for repeated queries (short TTL) + resultCache map[string]*cachedResult + resultCacheMu sync.RWMutex + cacheTTL time.Duration + cacheMaxSize int + + // Query frequency tracking for cache warming + queryFrequency map[string]*queryFrequencyInfo + queryFrequencyMu sync.RWMutex +} + +// queryFrequencyInfo tracks how often a query is used. +type queryFrequencyInfo struct { + params SearchParams + count int64 + lastUsed time.Time + lastCached time.Time +} + +// cachedResult stores a cached search result with expiry. +type cachedResult struct { + result *UnifiedSearchResult + expiresAt time.Time } // NewManager creates a new search manager. @@ -25,12 +162,463 @@ func NewManager( promptStore *gorm.PromptStore, vectorClient *sqlitevec.Client, ) *Manager { - return &Manager{ + ctx, cancel := context.WithCancel(context.Background()) + m := &Manager{ observationStore: observationStore, summaryStore: summaryStore, promptStore: promptStore, vectorClient: vectorClient, + metrics: &SearchMetrics{latencyHistogram: make([]int64, 0, latencyHistogramCap)}, + ctx: ctx, + cancel: cancel, + resultCache: make(map[string]*cachedResult), + cacheTTL: defaultCacheTTL, + cacheMaxSize: defaultCacheMaxSize, + queryFrequency: make(map[string]*queryFrequencyInfo), } + // Start cache cleanup goroutine + go m.cleanupCacheLoop() + // Start cache warming goroutine + go m.cacheWarmingLoop() + return m +} + +// Close stops background goroutines and cleans up resources. +func (m *Manager) Close() { + if m.cancel != nil { + m.cancel() + } +} + +// cleanupCacheLoop periodically removes expired cache entries. +func (m *Manager) cleanupCacheLoop() { + ticker := time.NewTicker(cacheCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.cleanupExpiredCache() + } + } +} + +// cleanupExpiredCache removes expired entries from the cache. +func (m *Manager) cleanupExpiredCache() { + m.resultCacheMu.Lock() + defer m.resultCacheMu.Unlock() + + now := time.Now() + for key, cached := range m.resultCache { + if now.After(cached.expiresAt) { + delete(m.resultCache, key) + } + } +} + +// cacheWarmingLoop periodically warms the cache for frequently used queries. +func (m *Manager) cacheWarmingLoop() { + // Wait a bit before starting to allow system to stabilize + select { + case <-m.ctx.Done(): + return + case <-time.After(cacheWarmingInitDelay): + } + + warmingTicker := time.NewTicker(cacheWarmingInterval) + cleanupTicker := time.NewTicker(frequencyCleanupInterval) + defer warmingTicker.Stop() + defer cleanupTicker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-warmingTicker.C: + m.warmFrequentQueries() + case <-cleanupTicker.C: + m.cleanupStaleFrequencyEntries() + } + } +} + +// cleanupStaleFrequencyEntries removes query frequency entries older than staleQueryThreshold. +// This prevents memory bloat from queries that haven't been used in a long time. +func (m *Manager) cleanupStaleFrequencyEntries() { + m.queryFrequencyMu.Lock() + now := time.Now() + var keysToDelete []string + for k, v := range m.queryFrequency { + if now.Sub(v.lastUsed) > staleQueryThreshold { + keysToDelete = append(keysToDelete, k) + } + } + for _, k := range keysToDelete { + delete(m.queryFrequency, k) + } + m.queryFrequencyMu.Unlock() + + if len(keysToDelete) > 0 { + log.Debug().Int("removed", len(keysToDelete)).Msg("Cleaned up stale query frequency entries") + } +} + +// warmFrequentQueries pre-executes frequently used queries to warm the cache. +func (m *Manager) warmFrequentQueries() { + m.queryFrequencyMu.RLock() + // Find top N most frequent queries that aren't recently cached + type queryScore struct { + key string + info *queryFrequencyInfo + score float64 + } + candidates := make([]queryScore, 0, len(m.queryFrequency)) + now := time.Now() + + for key, info := range m.queryFrequency { + // Only consider queries used recently + if now.Sub(info.lastUsed) > recentQueryWindow { + continue + } + // Only warm if not recently cached (cache about to expire or already expired) + timeSinceLastCache := now.Sub(info.lastCached) + if timeSinceLastCache < m.cacheTTL/2 { + continue + } + + // Score: frequency * recency factor + recencyFactor := 1.0 - (now.Sub(info.lastUsed).Seconds() / recentQueryWindow.Seconds()) + if recencyFactor < minRecencyFactor { + recencyFactor = minRecencyFactor + } + score := float64(info.count) * recencyFactor + + candidates = append(candidates, queryScore{key: key, info: info, score: score}) + } + m.queryFrequencyMu.RUnlock() + + // Sort by score descending using O(n log n) algorithm + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].score > candidates[j].score + }) + + // Warm top queries + warmCount := min(warmingBatchSize, len(candidates)) + for i := range warmCount { + candidate := candidates[i] + ctx, cancel := context.WithTimeout(context.Background(), warmingQueryTimeout) + result, err := m.executeSearch(ctx, candidate.info.params) + cancel() + + if err == nil && result != nil { + cacheKey := m.getCacheKey(candidate.info.params) + m.putInCache(cacheKey, result) + + // Update last cached time + m.queryFrequencyMu.Lock() + if info, ok := m.queryFrequency[candidate.key]; ok { + info.lastCached = time.Now() + } + m.queryFrequencyMu.Unlock() + + log.Debug(). + Str("query", truncate(candidate.info.params.Query, warmingLogTruncateLen)). + Float64("score", candidate.score). + Msg("Cache warmed for frequent query") + } + } +} + +// trackQueryFrequency records query usage for cache warming decisions. +func (m *Manager) trackQueryFrequency(params SearchParams) { + key := m.getCacheKey(params) + + m.queryFrequencyMu.Lock() + + if info, ok := m.queryFrequency[key]; ok { + info.count++ + info.lastUsed = time.Now() + m.queryFrequencyMu.Unlock() + return // Fast path: no eviction needed + } + + m.queryFrequency[key] = &queryFrequencyInfo{ + params: params, + count: 1, + lastUsed: time.Now(), + } + + // Limit frequency map size to prevent memory bloat + mapLen := len(m.queryFrequency) + if mapLen <= maxFrequencyEntries { + m.queryFrequencyMu.Unlock() + return // Fast path: no eviction needed + } + + // Collect keys and times for eviction (still under lock, but fast) + type entry struct { + key string + lastUsed time.Time + } + entries := make([]entry, 0, mapLen) + for k, v := range m.queryFrequency { + entries = append(entries, entry{key: k, lastUsed: v.lastUsed}) + } + m.queryFrequencyMu.Unlock() + + // Sort outside lock to reduce contention (O(n log n)) + sort.Slice(entries, func(i, j int) bool { + return entries[i].lastUsed.Before(entries[j].lastUsed) + }) + + // Collect keys to delete + deleteCount := min(frequencyEvictionBatch, len(entries)) + keysToDelete := make([]string, deleteCount) + for i := range deleteCount { + keysToDelete[i] = entries[i].key + } + + // Re-acquire lock only for deletion (brief critical section) + m.queryFrequencyMu.Lock() + for _, k := range keysToDelete { + delete(m.queryFrequency, k) + } + m.queryFrequencyMu.Unlock() +} + +// RecentQuery represents a recently executed search query. +type RecentQuery struct { + Query string `json:"query"` + Project string `json:"project,omitempty"` + Type string `json:"type,omitempty"` // observations, sessions, prompts + Count int64 `json:"count"` // Number of times executed + LastUsed time.Time `json:"last_used"` +} + +// GetRecentQueries returns the most recent search queries, sorted by last used time. +func (m *Manager) GetRecentQueries(limit int) []RecentQuery { + if limit <= 0 { + limit = defaultQueryLimit + } + if limit > maxQueryLimit { + limit = maxQueryLimit + } + + m.queryFrequencyMu.RLock() + defer m.queryFrequencyMu.RUnlock() + + // Collect all queries + queries := make([]RecentQuery, 0, len(m.queryFrequency)) + for _, info := range m.queryFrequency { + queries = append(queries, RecentQuery{ + Query: info.params.Query, + Project: info.params.Project, + Type: info.params.Type, + Count: info.count, + LastUsed: info.lastUsed, + }) + } + + // Sort by last used (most recent first) + sort.Slice(queries, func(i, j int) bool { + return queries[i].LastUsed.After(queries[j].LastUsed) + }) + + // Limit results + if len(queries) > limit { + queries = queries[:limit] + } + + return queries +} + +// GetFrequentQueries returns the most frequently used search queries. +func (m *Manager) GetFrequentQueries(limit int) []RecentQuery { + if limit <= 0 { + limit = defaultQueryLimit + } + if limit > maxQueryLimit { + limit = maxQueryLimit + } + + m.queryFrequencyMu.RLock() + defer m.queryFrequencyMu.RUnlock() + + // Only include queries used recently + now := time.Now() + queries := make([]RecentQuery, 0, len(m.queryFrequency)) + for _, info := range m.queryFrequency { + if now.Sub(info.lastUsed) > recentQueryWindow { + continue + } + queries = append(queries, RecentQuery{ + Query: info.params.Query, + Project: info.params.Project, + Type: info.params.Type, + Count: info.count, + LastUsed: info.lastUsed, + }) + } + + // Sort by count (highest first) + sort.Slice(queries, func(i, j int) bool { + return queries[i].Count > queries[j].Count + }) + + // Limit results + if len(queries) > limit { + queries = queries[:limit] + } + + return queries +} + +// normalizeQuery normalizes a search query for consistent cache keys. +// Converts to lowercase, trims whitespace, and collapses multiple spaces. +// Uses pre-compiled regex for O(n) performance instead of O(n*m) loop. +func normalizeQuery(query string) string { + // Lowercase for case-insensitive matching + query = strings.ToLower(query) + // Collapse multiple whitespace into single space using pre-compiled regex + query = multiSpaceRegex.ReplaceAllString(query, " ") + // Trim leading/trailing whitespace (after collapsing) + return strings.TrimSpace(query) +} + +// getCacheKey generates a cache key from search params. +// Uses direct string concatenation instead of JSON marshal for better performance. +// Queries are normalized for consistent cache hits across whitespace variations. +func (m *Manager) getCacheKey(params SearchParams) string { + // Normalize query for consistent cache keys + normalizedQuery := normalizeQuery(params.Query) + + // Hash directly without intermediate string allocation. + // FNV-64a is fast and collision-safe for cache keys. + h := fnv.New64a() + + // Write each field directly to the hasher with separators + h.Write([]byte(normalizedQuery)) + h.Write([]byte{'|'}) + h.Write([]byte(params.Type)) + h.Write([]byte{'|'}) + h.Write([]byte(params.Project)) + h.Write([]byte{'|'}) + h.Write([]byte(params.ObsType)) + h.Write([]byte{'|'}) + h.Write([]byte(params.Concepts)) + h.Write([]byte{'|'}) + h.Write([]byte(params.Files)) + h.Write([]byte{'|'}) + h.Write([]byte(strconv.FormatInt(params.DateStart, 10))) + h.Write([]byte{'|'}) + h.Write([]byte(strconv.FormatInt(params.DateEnd, 10))) + h.Write([]byte{'|'}) + h.Write([]byte(params.OrderBy)) + h.Write([]byte{'|'}) + h.Write([]byte(strconv.Itoa(params.Limit))) + h.Write([]byte{'|'}) + h.Write([]byte(strconv.Itoa(params.Offset))) + h.Write([]byte{'|'}) + h.Write([]byte(params.Format)) + h.Write([]byte{'|'}) + h.Write([]byte(params.Scope)) + h.Write([]byte{'|'}) + if params.IncludeGlobal { + h.Write([]byte{'1'}) + } else { + h.Write([]byte{'0'}) + } + h.Write([]byte{'|'}) + if params.ExcludeSuperseded { + h.Write([]byte{'1'}) + } else { + h.Write([]byte{'0'}) + } + + return strconv.FormatUint(h.Sum64(), 36) // Base36 for compact representation +} + +// getFromCache retrieves a result from cache if valid. +func (m *Manager) getFromCache(key string) (*UnifiedSearchResult, bool) { + m.resultCacheMu.RLock() + defer m.resultCacheMu.RUnlock() + + if cached, ok := m.resultCache[key]; ok { + if time.Now().Before(cached.expiresAt) { + atomic.AddInt64(&m.metrics.CacheHits, 1) + return cached.result, true + } + } + return nil, false +} + +// putInCache stores a result in the cache. +// Optimized to skip expensive scans when cache is below capacity threshold. +func (m *Manager) putInCache(key string, result *UnifiedSearchResult) { + m.resultCacheMu.Lock() + defer m.resultCacheMu.Unlock() + + now := time.Now() + cacheLen := len(m.resultCache) + + // Only scan for expired entries when at threshold+ capacity (amortized cleanup) + evictionThreshold := (m.cacheMaxSize * cacheEvictionThreshold) / 100 + if cacheLen >= evictionThreshold { + // Evict expired entries + for k, v := range m.resultCache { + if now.After(v.expiresAt) { + delete(m.resultCache, k) + } + } + cacheLen = len(m.resultCache) // Update after eviction + } + + // If still at capacity after removing expired, use simple FIFO-style eviction + // Go map iteration order is random, which provides good cache behavior + if cacheLen >= m.cacheMaxSize { + // Evict percentage using random-order iteration (O(n) single pass) + evictCount := max(m.cacheMaxSize*cacheEvictionPercent/100, 1) + evicted := 0 + for k := range m.resultCache { + delete(m.resultCache, k) + evicted++ + if evicted >= evictCount { + break + } + } + } + + m.resultCache[key] = &cachedResult{ + result: result, + expiresAt: now.Add(m.cacheTTL), + } +} + +// Metrics returns the search metrics for monitoring. +func (m *Manager) Metrics() *SearchMetrics { + return m.metrics +} + +// CacheStats returns current cache statistics. +func (m *Manager) CacheStats() map[string]any { + m.resultCacheMu.RLock() + cacheSize := len(m.resultCache) + m.resultCacheMu.RUnlock() + + return map[string]any{ + "size": cacheSize, + "max_size": m.cacheMaxSize, + "ttl_sec": m.cacheTTL.Seconds(), + } +} + +// ClearCache clears the result cache. Useful for testing or after data changes. +func (m *Manager) ClearCache() { + m.resultCacheMu.Lock() + m.resultCache = make(map[string]*cachedResult) + m.resultCacheMu.Unlock() } // SearchParams contains parameters for unified search. @@ -62,7 +650,7 @@ type SearchResult struct { Scope string `json:"scope,omitempty"` // "project" or "global" CreatedAt int64 `json:"created_at_epoch"` Score float64 `json:"score,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // UnifiedSearchResult contains the combined search results. @@ -73,17 +661,69 @@ type UnifiedSearchResult struct { } // UnifiedSearch performs a unified search across all document types. +// Uses caching and request coalescing for optimal performance. func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) { + start := time.Now() + defer func() { + latency := time.Since(start).Nanoseconds() + atomic.AddInt64(&m.metrics.TotalSearches, 1) + atomic.AddInt64(&m.metrics.TotalLatencyNs, latency) + + // Sample latency for histogram (reservoir sampling) + m.metrics.histogramMu.Lock() + if len(m.metrics.latencyHistogram) < latencyHistogramCap { + m.metrics.latencyHistogram = append(m.metrics.latencyHistogram, latency) + } + m.metrics.histogramMu.Unlock() + + // Log slow queries + if latency > slowQueryThresholdNs { + log.Warn(). + Str("query", truncate(params.Query, queryLogTruncateLen)). + Dur("latency", time.Duration(latency)). + Str("type", params.Type). + Msg("Slow search query") + } + }() + if params.Limit <= 0 { - params.Limit = 20 + params.Limit = defaultQueryLimit } - if params.Limit > 100 { - params.Limit = 100 + if params.Limit > maxQueryLimit { + params.Limit = maxQueryLimit } if params.OrderBy == "" { - params.OrderBy = "date_desc" + params.OrderBy = defaultOrderBy } + // Check cache first + cacheKey := m.getCacheKey(params) + if cached, ok := m.getFromCache(cacheKey); ok { + return cached, nil + } + + // Use singleflight to coalesce concurrent identical requests + result, err, _ := m.searchGroup.Do(cacheKey, func() (any, error) { + return m.executeSearch(ctx, params) + }) + + if err != nil { + return nil, err + } + + searchResult := result.(*UnifiedSearchResult) + + // Cache the result + m.putInCache(cacheKey, searchResult) + + // Track query frequency for cache warming + m.trackQueryFrequency(params) + + return searchResult, nil +} + +// executeSearch performs the actual search without caching/coalescing. +func (m *Manager) executeSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) { // If query is provided and vector client is available, use vector search if params.Query != "" && m.vectorClient != nil && m.vectorClient.IsConnected() { return m.vectorSearch(ctx, params) @@ -95,6 +735,13 @@ func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*Unif // vectorSearch performs semantic search via sqlite-vec. func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) { + start := time.Now() + defer func() { + latency := time.Since(start).Nanoseconds() + atomic.AddInt64(&m.metrics.VectorSearches, 1) + atomic.AddInt64(&m.metrics.VectorLatencyNs, latency) + }() + // Build where filter based on search type var docType sqlitevec.DocType switch params.Type { @@ -110,6 +757,7 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi // Query sqlite-vec vectorResults, err := m.vectorClient.Query(ctx, params.Query, params.Limit*2, where) if err != nil { + atomic.AddInt64(&m.metrics.SearchErrors, 1) // Fall back to filter search on error return m.filterSearch(ctx, params) } @@ -125,7 +773,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi if len(obsIDs) > 0 && (params.Type == "" || params.Type == "observations") { obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0) - if err == nil { + if err != nil { + log.Warn().Err(err).Int("count", len(obsIDs)).Msg("Failed to fetch observations by IDs in vector search") + } else { for _, o := range obs { // Skip superseded observations when requested if params.ExcludeSuperseded && o.IsSuperseded { @@ -138,7 +788,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi if len(summaryIDs) > 0 && (params.Type == "" || params.Type == "sessions") { summaries, err := m.summaryStore.GetSummariesByIDs(ctx, summaryIDs, params.OrderBy, 0) - if err == nil { + if err != nil { + log.Warn().Err(err).Int("count", len(summaryIDs)).Msg("Failed to fetch summaries by IDs in vector search") + } else { for _, s := range summaries { results = append(results, m.summaryToResult(s, params.Format)) } @@ -147,7 +799,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi if len(promptIDs) > 0 && (params.Type == "" || params.Type == "prompts") { prompts, err := m.promptStore.GetPromptsByIDs(ctx, promptIDs, params.OrderBy, 0) - if err == nil { + if err != nil { + log.Warn().Err(err).Int("count", len(promptIDs)).Msg("Failed to fetch prompts by IDs in vector search") + } else { for _, p := range prompts { results = append(results, m.promptToResult(p, params.Format)) } @@ -168,6 +822,13 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi // filterSearch performs structured filter search via SQLite. func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) { + start := time.Now() + defer func() { + latency := time.Since(start).Nanoseconds() + atomic.AddInt64(&m.metrics.FilterSearches, 1) + atomic.AddInt64(&m.metrics.FilterLatencyNs, latency) + }() + var results []SearchResult // Search observations @@ -182,7 +843,9 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi obs, err = m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit) } - if err == nil { + if err != nil { + log.Warn().Err(err).Str("project", params.Project).Msg("Failed to fetch observations in filter search") + } else { for _, o := range obs { results = append(results, m.observationToResult(o, params.Format)) } @@ -192,7 +855,9 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi // Search summaries if params.Type == "" || params.Type == "sessions" { summaries, err := m.summaryStore.GetRecentSummaries(ctx, params.Project, params.Limit) - if err == nil { + if err != nil { + log.Warn().Err(err).Str("project", params.Project).Msg("Failed to fetch summaries in filter search") + } else { for _, s := range summaries { results = append(results, m.summaryToResult(s, params.Format)) } @@ -249,7 +914,7 @@ func (m *Manager) observationToResult(obs *models.Observation, format string) Se Project: obs.Project, Scope: string(obs.Scope), CreatedAt: obs.CreatedAtEpoch, - Metadata: map[string]interface{}{ + Metadata: map[string]any{ "obs_type": string(obs.Type), "scope": string(obs.Scope), }, @@ -275,7 +940,7 @@ func (m *Manager) summaryToResult(summary *models.SessionSummary, format string) } if summary.Request.Valid { - result.Title = truncate(summary.Request.String, 100) + result.Title = truncate(summary.Request.String, titleTruncateLen) } if format == "full" && summary.Learned.Valid { @@ -293,7 +958,7 @@ func (m *Manager) promptToResult(prompt *models.UserPromptWithSession, format st CreatedAt: prompt.CreatedAtEpoch, } - result.Title = truncate(prompt.PromptText, 100) + result.Title = truncate(prompt.PromptText, titleTruncateLen) if format == "full" { result.Content = prompt.PromptText diff --git a/internal/vector/sqlitevec/client.go b/internal/vector/sqlitevec/client.go index df2e836..971c344 100644 --- a/internal/vector/sqlitevec/client.go +++ b/internal/vector/sqlitevec/client.go @@ -5,19 +5,117 @@ import ( "context" "database/sql" "fmt" + "strconv" "strings" "sync" + "sync/atomic" + "time" sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo" "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" "github.com/rs/zerolog/log" + "golang.org/x/sync/singleflight" ) +// embeddingCacheEntry stores a cached embedding with its timestamp. +type embeddingCacheEntry struct { + embedding []float32 + timestamp int64 // Unix nano +} + +// resultCacheEntry stores cached query results. +type resultCacheEntry struct { + results []QueryResult + timestamp int64 // Unix nano + queryHash string // Hash of query + filters for validation +} + // Client provides vector operations via sqlite-vec. type Client struct { db *sql.DB embedSvc *embedding.Service - mu sync.Mutex + + // Separate mutexes for read and write operations to reduce contention + writeMu sync.Mutex // Protects write operations (AddDocuments, DeleteDocuments) + readMu sync.RWMutex // Protects read operations (Query, Count) + + // Embedding cache for query deduplication + queryCache map[string]embeddingCacheEntry + queryCacheMu sync.RWMutex + cacheMaxSize int + cacheTTLNano int64 // Cache TTL in nanoseconds + + // Result cache for repeated searches + resultCache map[string]resultCacheEntry + resultCacheMu sync.RWMutex + resultCacheMaxSize int + resultCacheTTLNano int64 // Shorter TTL for results (data changes more often) + + // Cache statistics + stats CacheStats + + // Background cleanup control + stopCleanup chan struct{} + cleanupWg sync.WaitGroup + + // Singleflight to deduplicate concurrent embedding computations + embeddingGroup singleflight.Group +} + +// CacheStats tracks cache performance metrics using atomic counters for lock-free updates. +type CacheStats struct { + embeddingHits atomic.Int64 + embeddingMisses atomic.Int64 + resultHits atomic.Int64 + resultMisses atomic.Int64 + embeddingEvictions atomic.Int64 + resultEvictions atomic.Int64 +} + +// CacheStatsSnapshot is the exported version of CacheStats for JSON marshaling. +type CacheStatsSnapshot struct { + EmbeddingHits int64 `json:"embedding_hits"` + EmbeddingMisses int64 `json:"embedding_misses"` + ResultHits int64 `json:"result_hits"` + ResultMisses int64 `json:"result_misses"` + EmbeddingEvictions int64 `json:"embedding_evictions"` + ResultEvictions int64 `json:"result_evictions"` +} + +// HitRate returns the cache hit rate as a percentage. +func (s CacheStatsSnapshot) HitRate() float64 { + total := s.EmbeddingHits + s.EmbeddingMisses + s.ResultHits + s.ResultMisses + if total == 0 { + return 0 + } + hits := s.EmbeddingHits + s.ResultHits + return float64(hits) / float64(total) * 100 +} + +// HitRate returns the cache hit rate as a percentage. +func (s *CacheStats) HitRate() float64 { + embHits := s.embeddingHits.Load() + embMisses := s.embeddingMisses.Load() + resHits := s.resultHits.Load() + resMisses := s.resultMisses.Load() + total := embHits + embMisses + resHits + resMisses + if total == 0 { + return 0 + } + hits := embHits + resHits + return float64(hits) / float64(total) * 100 +} + +// Snapshot returns a copy of the current stats. +func (s *CacheStats) Snapshot() CacheStatsSnapshot { + return CacheStatsSnapshot{ + EmbeddingHits: s.embeddingHits.Load(), + EmbeddingMisses: s.embeddingMisses.Load(), + ResultHits: s.resultHits.Load(), + ResultMisses: s.resultMisses.Load(), + EmbeddingEvictions: s.embeddingEvictions.Load(), + ResultEvictions: s.resultEvictions.Load(), + } } // Config holds configuration for the client. @@ -34,10 +132,23 @@ func NewClient(cfg Config, embedSvc *embedding.Service) (*Client, error) { return nil, fmt.Errorf("embedding service required") } - return &Client{ - db: cfg.DB, - embedSvc: embedSvc, - }, nil + c := &Client{ + db: cfg.DB, + embedSvc: embedSvc, + queryCache: make(map[string]embeddingCacheEntry), + cacheMaxSize: 500, // Cache up to 500 query embeddings + cacheTTLNano: 5 * 60 * 1e9, // 5 minute TTL for embeddings + resultCache: make(map[string]resultCacheEntry), + resultCacheMaxSize: 200, // Cache up to 200 search results + resultCacheTTLNano: 60 * 1e9, // 1 minute TTL for results (shorter since data changes) + stopCleanup: make(chan struct{}), + } + + // Start background cache cleanup goroutine + c.cleanupWg.Add(1) + go c.cacheCleanupLoop() + + return c, nil } // AddDocuments adds documents with their embeddings to the vector store. @@ -46,8 +157,8 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { return nil } - c.mu.Lock() - defer c.mu.Unlock() + c.writeMu.Lock() + defer c.writeMu.Unlock() // Generate embeddings for all documents texts := make([]string, len(docs)) @@ -75,7 +186,10 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { } defer func() { if err != nil { - _ = tx.Rollback() + if rbErr := tx.Rollback(); rbErr != nil { + // Rollback failure is serious - indicates potential data corruption risk + log.Error().Err(rbErr).Err(err).Msg("Failed to rollback transaction after error - data may be inconsistent") + } } }() @@ -118,6 +232,9 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { return fmt.Errorf("commit transaction: %w", err) } + // Invalidate result cache since data changed + c.InvalidateResultCache() + log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec") return nil } @@ -128,12 +245,12 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error { return nil } - c.mu.Lock() - defer c.mu.Unlock() + c.writeMu.Lock() + defer c.writeMu.Unlock() // Build placeholder string placeholders := make([]string, len(ids)) - args := make([]interface{}, len(ids)) + args := make([]any, len(ids)) for i, id := range ids { placeholders[i] = "?" args[i] = id @@ -148,17 +265,25 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error { return fmt.Errorf("delete documents: %w", err) } + // Invalidate result cache since data changed + c.InvalidateResultCache() + log.Debug().Int("count", len(ids)).Msg("Deleted documents from sqlite-vec") return nil } // Query performs a vector similarity search. func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) { - c.mu.Lock() - defer c.mu.Unlock() + // Build cache key from query + filters + limit + cacheKey := c.buildResultCacheKey(query, limit, where) - // Generate query embedding - queryEmb, err := c.embedSvc.Embed(query) + // Check result cache first + if results, ok := c.getResultFromCache(cacheKey); ok { + return results, nil + } + + // Generate query embedding OUTSIDE the lock for better concurrency + queryEmb, err := c.getOrComputeEmbedding(query) if err != nil { return nil, fmt.Errorf("embed query: %w", err) } @@ -169,9 +294,13 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s return nil, fmt.Errorf("serialize query embedding: %w", err) } + // Now acquire read lock for the actual DB query + c.readMu.RLock() + defer c.readMu.RUnlock() + // Build query with filters // vec0 supports WHERE clauses on metadata columns - args := []interface{}{queryBlob} + args := []any{queryBlob} sqlQuery := ` SELECT @@ -232,6 +361,9 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s return nil, fmt.Errorf("iterate rows: %w", err) } + // Cache the results + c.cacheResults(cacheKey, results) + log.Debug(). Str("query", truncateString(query, 50)). Int("results", len(results)). @@ -245,11 +377,196 @@ func (c *Client) IsConnected() bool { return c.db != nil } -// Close is a no-op (db managed externally). +// Close stops the background cleanup goroutine (db managed externally). func (c *Client) Close() error { + // Signal cleanup goroutine to stop + close(c.stopCleanup) + // Wait for cleanup to finish + c.cleanupWg.Wait() return nil } +// cacheCleanupLoop periodically removes expired cache entries. +func (c *Client) cacheCleanupLoop() { + defer c.cleanupWg.Done() + + ticker := time.NewTicker(30 * time.Second) // Cleanup every 30 seconds + defer ticker.Stop() + + for { + select { + case <-c.stopCleanup: + return + case <-ticker.C: + c.cleanupExpiredCaches() + } + } +} + +// cleanupExpiredCaches removes expired entries from both caches. +func (c *Client) cleanupExpiredCaches() { + now := time.Now().UnixNano() + var embeddingExpired, resultExpired int64 + + // Cleanup embedding cache + c.queryCacheMu.Lock() + for key, entry := range c.queryCache { + if now-entry.timestamp > c.cacheTTLNano { + delete(c.queryCache, key) + embeddingExpired++ + } + } + c.queryCacheMu.Unlock() + + // Cleanup result cache + c.resultCacheMu.Lock() + for key, entry := range c.resultCache { + if now-entry.timestamp > c.resultCacheTTLNano { + delete(c.resultCache, key) + resultExpired++ + } + } + c.resultCacheMu.Unlock() + + // Update stats atomically + if embeddingExpired > 0 || resultExpired > 0 { + c.stats.embeddingEvictions.Add(embeddingExpired) + c.stats.resultEvictions.Add(resultExpired) + + log.Debug(). + Int64("embedding_expired", embeddingExpired). + Int64("result_expired", resultExpired). + Msg("Cache cleanup completed") + } +} + +// BatchQueryResult holds results from a batch query operation. +type BatchQueryResult struct { + Query string // Original query string + Results []QueryResult // Results for this query + Error error // Error if query failed +} + +// QueryBatch performs multiple vector searches concurrently. +// Returns results in the same order as input queries. +// Uses a worker pool to limit concurrent queries. +func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, where map[string]any) []BatchQueryResult { + if len(queries) == 0 { + return nil + } + + // Limit concurrency to avoid overwhelming the database + maxConcurrent := min(4, len(queries)) + + results := make([]BatchQueryResult, len(queries)) + sem := make(chan struct{}, maxConcurrent) + var wg sync.WaitGroup + + for i, query := range queries { + wg.Add(1) + go func(idx int, q string) { + defer wg.Done() + + // Acquire semaphore + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-ctx.Done(): + results[idx] = BatchQueryResult{ + Query: q, + Error: ctx.Err(), + } + return + } + + // Execute query + queryResults, err := c.Query(ctx, q, limit, where) + results[idx] = BatchQueryResult{ + Query: q, + Results: queryResults, + Error: err, + } + }(i, query) + } + + wg.Wait() + return results +} + +// QueryMultiField searches across multiple fields for a single query. +// Combines results from different field types and deduplicates by document ID. +func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) { + // Generate embedding once + queryEmb, err := c.getOrComputeEmbedding(query) + if err != nil { + return nil, fmt.Errorf("embed query: %w", err) + } + + // Serialize query embedding + queryBlob, err := sqlite_vec.SerializeFloat32(queryEmb) + if err != nil { + return nil, fmt.Errorf("serialize query embedding: %w", err) + } + + c.readMu.RLock() + defer c.readMu.RUnlock() + + // Query with field type aggregation - get best match per document + sqlQuery := ` + WITH ranked_results AS ( + SELECT + doc_id, + distance, + sqlite_id, + doc_type, + field_type, + project, + scope, + ROW_NUMBER() OVER (PARTITION BY sqlite_id ORDER BY distance ASC) as rn + FROM vectors + WHERE embedding MATCH ? + AND doc_type = ? + AND (project = ? OR scope = 'global') + ) + SELECT doc_id, distance, sqlite_id, doc_type, field_type, project, scope + FROM ranked_results + WHERE rn = 1 + ORDER BY distance + LIMIT ? + ` + + rows, err := c.db.QueryContext(ctx, sqlQuery, queryBlob, docType, project, limit) + if err != nil { + return nil, fmt.Errorf("query vectors: %w", err) + } + defer rows.Close() + + // Pre-allocate with limit to avoid repeated slice growth + results := make([]QueryResult, 0, limit) + for rows.Next() { + var r QueryResult + var sqliteID int64 + var docTypeVal, fieldType, projectVal, scope sql.NullString + + if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docTypeVal, &fieldType, &projectVal, &scope); err != nil { + return nil, fmt.Errorf("scan row: %w", err) + } + + r.Similarity = DistanceToSimilarity(r.Distance) + r.Metadata = map[string]any{ + "sqlite_id": float64(sqliteID), + "doc_type": docTypeVal.String, + "field_type": fieldType.String, + "project": projectVal.String, + "scope": scope.String, + } + + results = append(results, r) + } + + return results, rows.Err() +} + // truncateString truncates a string to maxLen characters. func truncateString(s string, maxLen int) string { if len(s) <= maxLen { @@ -260,8 +577,8 @@ func truncateString(s string, maxLen int) string { // Count returns the total number of vectors in the store. func (c *Client) Count(ctx context.Context) (int64, error) { - c.mu.Lock() - defer c.mu.Unlock() + c.readMu.RLock() + defer c.readMu.RUnlock() var count int64 err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count) @@ -281,8 +598,8 @@ func (c *Client) ModelVersion() string { // - The vectors table is empty // - Any vectors have a different model_version than the current model func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) { - c.mu.Lock() - defer c.mu.Unlock() + c.readMu.RLock() + defer c.readMu.RUnlock() currentModel := c.embedSvc.Version() @@ -329,8 +646,8 @@ type StaleVectorInfo struct { // GetStaleVectors returns doc_ids of vectors with mismatched or null model versions. // This enables granular rebuild - only re-embedding documents that need updating. func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) { - c.mu.Lock() - defer c.mu.Unlock() + c.readMu.RLock() + defer c.readMu.RUnlock() currentModel := c.embedSvc.Version() @@ -372,6 +689,134 @@ func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) return results, nil } +// VectorHealthStats contains comprehensive health information about the vector store. +type VectorHealthStats struct { + TotalVectors int64 `json:"total_vectors"` + StaleVectors int64 `json:"stale_vectors"` + CurrentModel string `json:"current_model"` + NeedsRebuild bool `json:"needs_rebuild"` + RebuildReason string `json:"rebuild_reason,omitempty"` + CoverageByType map[string]int64 `json:"coverage_by_type"` + ModelVersions map[string]int64 `json:"model_versions"` + ProjectCounts map[string]int64 `json:"project_counts"` + EmbeddingCache CacheStatsSnapshot `json:"embedding_cache"` +} + +// GetHealthStats returns comprehensive health statistics about the vector store. +func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) { + c.readMu.RLock() + defer c.readMu.RUnlock() + + stats := &VectorHealthStats{ + CurrentModel: c.embedSvc.Version(), + CoverageByType: make(map[string]int64), + ModelVersions: make(map[string]int64), + ProjectCounts: make(map[string]int64), + EmbeddingCache: c.stats.Snapshot(), + } + + // Get total count + err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&stats.TotalVectors) + if err != nil { + return nil, fmt.Errorf("count total vectors: %w", err) + } + + // Get stale count + err = c.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL", + stats.CurrentModel, + ).Scan(&stats.StaleVectors) + if err != nil { + return nil, fmt.Errorf("count stale vectors: %w", err) + } + + // Check if rebuild needed + stats.NeedsRebuild, stats.RebuildReason = c.needsRebuildUnlocked(ctx, stats.CurrentModel) + + // Get coverage by doc_type + rows, err := c.db.QueryContext(ctx, "SELECT doc_type, COUNT(*) FROM vectors GROUP BY doc_type") + if err != nil { + return nil, fmt.Errorf("query doc types: %w", err) + } + defer rows.Close() + + for rows.Next() { + var docType sql.NullString + var count int64 + if err := rows.Scan(&docType, &count); err != nil { + return nil, fmt.Errorf("scan doc type: %w", err) + } + if docType.Valid { + stats.CoverageByType[docType.String] = count + } else { + stats.CoverageByType["unknown"] = count + } + } + + // Get model version distribution + rows2, err := c.db.QueryContext(ctx, "SELECT COALESCE(model_version, 'unknown'), COUNT(*) FROM vectors GROUP BY model_version") + if err != nil { + return nil, fmt.Errorf("query model versions: %w", err) + } + defer rows2.Close() + + for rows2.Next() { + var version string + var count int64 + if err := rows2.Scan(&version, &count); err != nil { + return nil, fmt.Errorf("scan model version: %w", err) + } + stats.ModelVersions[version] = count + } + + // Get project counts (top 10) + rows3, err := c.db.QueryContext(ctx, + "SELECT COALESCE(project, 'global'), COUNT(*) FROM vectors GROUP BY project ORDER BY COUNT(*) DESC LIMIT 10") + if err != nil { + return nil, fmt.Errorf("query projects: %w", err) + } + defer rows3.Close() + + for rows3.Next() { + var project string + var count int64 + if err := rows3.Scan(&project, &count); err != nil { + return nil, fmt.Errorf("scan project: %w", err) + } + stats.ProjectCounts[project] = count + } + + return stats, nil +} + +// needsRebuildUnlocked checks if rebuild is needed without acquiring lock (caller must hold lock). +func (c *Client) needsRebuildUnlocked(ctx context.Context, currentModel string) (bool, string) { + var totalCount int64 + err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount) + if err != nil { + return false, "" + } + + if totalCount == 0 { + return true, "empty" + } + + var staleCount int64 + err = c.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL", + currentModel, + ).Scan(&staleCount) + if err != nil { + return false, "" + } + + if staleCount > 0 { + return true, fmt.Sprintf("model_mismatch:%d", staleCount) + } + + return false, "" +} + // DeleteVectorsByDocIDs removes vectors by their doc_ids. // Used for granular rebuild - delete stale vectors before re-adding. func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error { @@ -379,12 +824,12 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err return nil } - c.mu.Lock() - defer c.mu.Unlock() + c.writeMu.Lock() + defer c.writeMu.Unlock() // Build placeholder string placeholders := make([]string, len(docIDs)) - args := make([]interface{}, len(docIDs)) + args := make([]any, len(docIDs)) for i, id := range docIDs { placeholders[i] = "?" args[i] = id @@ -402,3 +847,249 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id") return nil } + +// DeleteByObservationID removes all vectors associated with an observation ID. +// Vectors are stored with doc_ids that include the observation ID, e.g., "obs_123_narrative". +func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + // Vectors have doc_ids like "obs_123_narrative", "obs_123_facts_0", etc. + pattern := fmt.Sprintf("obs_%d_%%", obsID) + + _, err := c.db.ExecContext(ctx, "DELETE FROM vectors WHERE doc_id LIKE ?", pattern) + if err != nil { + return fmt.Errorf("delete vectors for observation %d: %w", obsID, err) + } + + log.Debug().Int64("observation_id", obsID).Msg("Deleted vectors for observation") + return nil +} + +// getOrComputeEmbedding returns a cached embedding or computes a new one. +// Uses singleflight to prevent duplicate concurrent computations for the same query. +func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) { + now := time.Now().UnixNano() + + // Check cache first (read lock) + c.queryCacheMu.RLock() + if entry, ok := c.queryCache[query]; ok { + if now-entry.timestamp < c.cacheTTLNano { + c.queryCacheMu.RUnlock() + c.stats.embeddingHits.Add(1) + return entry.embedding, nil + } + } + c.queryCacheMu.RUnlock() + + // Cache miss - use singleflight to deduplicate concurrent embedding requests + result, err, _ := c.embeddingGroup.Do(query, func() (any, error) { + // Double-check cache inside singleflight (another goroutine may have just cached it) + c.queryCacheMu.RLock() + if entry, ok := c.queryCache[query]; ok { + if time.Now().UnixNano()-entry.timestamp < c.cacheTTLNano { + c.queryCacheMu.RUnlock() + return entry.embedding, nil + } + } + c.queryCacheMu.RUnlock() + + // Record cache miss + c.stats.embeddingMisses.Add(1) + + // Compute embedding + emb, err := c.embedSvc.Embed(query) + if err != nil { + return nil, err + } + + // Store in cache (write lock) + c.queryCacheMu.Lock() + nowCache := time.Now().UnixNano() + // Evict old entries if cache is full or near capacity (80% threshold) + evictionThreshold := (c.cacheMaxSize * 8) / 10 + if len(c.queryCache) >= evictionThreshold { + // Phase 1: Remove ALL expired entries first (not just 10%) + evicted := int64(0) + for k, v := range c.queryCache { + if nowCache-v.timestamp > c.cacheTTLNano { + delete(c.queryCache, k) + evicted++ + } + } + + // Phase 2: If still at capacity, evict 10% using random iteration (O(n) instead of O(n log n)) + // Go map iteration order is randomized, providing good cache behavior without sorting + if len(c.queryCache) >= c.cacheMaxSize { + evictCount := max(c.cacheMaxSize/10, 1) + for k := range c.queryCache { + delete(c.queryCache, k) + evicted++ + evictCount-- + if evictCount <= 0 { + break + } + } + } + + if evicted > 0 { + c.stats.embeddingEvictions.Add(evicted) + } + } + c.queryCache[query] = embeddingCacheEntry{ + embedding: emb, + timestamp: nowCache, + } + c.queryCacheMu.Unlock() + + return emb, nil + }) + + if err != nil { + return nil, err + } + return result.([]float32), nil +} + +// ClearCache clears the embedding cache. +func (c *Client) ClearCache() { + c.queryCacheMu.Lock() + c.queryCache = make(map[string]embeddingCacheEntry) + c.queryCacheMu.Unlock() +} + +// GetCacheStats returns comprehensive cache statistics. +func (c *Client) GetCacheStats() CacheStatsSnapshot { + return c.stats.Snapshot() +} + +// CacheStats returns basic cache size info for backward compatibility. +// Deprecated: Use GetCacheStats() for comprehensive statistics. +func (c *Client) CacheStats() (size int, maxSize int) { + c.queryCacheMu.RLock() + size = len(c.queryCache) + c.queryCacheMu.RUnlock() + return size, c.cacheMaxSize +} + +// EmbeddingCacheSize returns the current embedding cache size. +func (c *Client) EmbeddingCacheSize() int { + c.queryCacheMu.RLock() + defer c.queryCacheMu.RUnlock() + return len(c.queryCache) +} + +// ResultCacheSize returns the current result cache size. +func (c *Client) ResultCacheSize() int { + c.resultCacheMu.RLock() + defer c.resultCacheMu.RUnlock() + return len(c.resultCache) +} + +// buildResultCacheKey creates a unique key for caching query results. +// Uses strings.Builder to avoid intermediate allocations. +func (c *Client) buildResultCacheKey(query string, limit int, where map[string]any) string { + // Pre-allocate with typical key size to avoid reallocation + var b strings.Builder + b.Grow(len(query) + 32) // query + typical prefix/suffix overhead + + b.WriteString("q:") + b.WriteString(query) + b.WriteString(":l:") + b.WriteString(strconv.Itoa(limit)) + + if docType, ok := where["doc_type"].(string); ok { + b.WriteString(":dt:") + b.WriteString(docType) + } + if project, ok := where["project"].(string); ok { + b.WriteString(":p:") + b.WriteString(project) + } + return b.String() +} + +// getResultFromCache retrieves cached results if available and not expired. +func (c *Client) getResultFromCache(cacheKey string) ([]QueryResult, bool) { + now := time.Now().UnixNano() + + c.resultCacheMu.RLock() + entry, ok := c.resultCache[cacheKey] + c.resultCacheMu.RUnlock() + + if !ok { + c.stats.resultMisses.Add(1) + return nil, false + } + + // Check if entry is expired + if now-entry.timestamp > c.resultCacheTTLNano { + c.stats.resultMisses.Add(1) + return nil, false + } + + c.stats.resultHits.Add(1) + + // Return a copy to prevent mutation + results := make([]QueryResult, len(entry.results)) + copy(results, entry.results) + return results, true +} + +// cacheResults stores query results in the cache. +func (c *Client) cacheResults(cacheKey string, results []QueryResult) { + now := time.Now().UnixNano() + + c.resultCacheMu.Lock() + defer c.resultCacheMu.Unlock() + + // Evict old entries if cache is full + if len(c.resultCache) >= c.resultCacheMaxSize { + // Two-phase eviction: (1) TTL-expired entries, (2) random if still over capacity + evicted := 0 + targetSize := c.resultCacheMaxSize * 8 / 10 // Target 80% capacity + + // Phase 1: Remove all TTL-expired entries + for k, v := range c.resultCache { + if now-v.timestamp > c.resultCacheTTLNano { + delete(c.resultCache, k) + evicted++ + } + } + + // Phase 2: If still over target, remove random entries until at target + if len(c.resultCache) >= targetSize { + evictCount := len(c.resultCache) - targetSize + 1 + for k := range c.resultCache { + delete(c.resultCache, k) + evicted++ + evictCount-- + if evictCount <= 0 { + break + } + } + } + + if evicted > 0 { + c.stats.resultEvictions.Add(int64(evicted)) + } + } + + // Make a copy of results to store + resultsCopy := make([]QueryResult, len(results)) + copy(resultsCopy, results) + + c.resultCache[cacheKey] = resultCacheEntry{ + results: resultsCopy, + timestamp: now, + queryHash: cacheKey, + } +} + +// InvalidateResultCache clears the result cache. +// Should be called after write operations that modify vectors. +func (c *Client) InvalidateResultCache() { + c.resultCacheMu.Lock() + c.resultCache = make(map[string]resultCacheEntry) + c.resultCacheMu.Unlock() +} diff --git a/internal/vector/sqlitevec/sync.go b/internal/vector/sqlitevec/sync.go index 21f37c9..fe824de 100644 --- a/internal/vector/sqlitevec/sync.go +++ b/internal/vector/sqlitevec/sync.go @@ -338,3 +338,187 @@ func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error { return nil } + +// BatchSyncConfig configures batch synchronization behavior. +type BatchSyncConfig struct { + BatchSize int // Number of items per batch (default: 50) + ProgressLogFreq int // Log progress every N items (default: 100) +} + +// DefaultBatchSyncConfig returns sensible defaults for batch sync. +func DefaultBatchSyncConfig() BatchSyncConfig { + return BatchSyncConfig{ + BatchSize: 50, + ProgressLogFreq: 100, + } +} + +// BatchSyncObservations syncs multiple observations efficiently in batches. +// This reduces memory pressure during large rebuilds by processing in chunks. +func (s *Sync) BatchSyncObservations(ctx context.Context, observations []*models.Observation, cfg BatchSyncConfig) (synced int, errors int) { + if len(observations) == 0 { + return 0, 0 + } + + if cfg.BatchSize <= 0 { + cfg.BatchSize = 50 + } + if cfg.ProgressLogFreq <= 0 { + cfg.ProgressLogFreq = 100 + } + + for i := 0; i < len(observations); i += cfg.BatchSize { + // Check context cancellation + select { + case <-ctx.Done(): + log.Warn().Int("synced", synced).Int("remaining", len(observations)-i).Msg("Batch sync cancelled") + return synced, errors + default: + } + + end := min(i+cfg.BatchSize, len(observations)) + + batch := observations[i:end] + var docs []Document + + // Collect all documents for this batch + for _, obs := range batch { + docs = append(docs, s.formatObservationDocs(obs)...) + } + + // Add all documents in one call + if len(docs) > 0 { + if err := s.client.AddDocuments(ctx, docs); err != nil { + log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync observation batch") + errors += len(batch) + continue + } + } + + synced += len(batch) + + // Log progress periodically + if synced%cfg.ProgressLogFreq == 0 || synced == len(observations) { + log.Debug().Int("synced", synced).Int("total", len(observations)).Msg("Observation batch sync progress") + } + } + + return synced, errors +} + +// BatchSyncSummaries syncs multiple summaries efficiently in batches. +func (s *Sync) BatchSyncSummaries(ctx context.Context, summaries []*models.SessionSummary, cfg BatchSyncConfig) (synced int, errors int) { + if len(summaries) == 0 { + return 0, 0 + } + + if cfg.BatchSize <= 0 { + cfg.BatchSize = 50 + } + if cfg.ProgressLogFreq <= 0 { + cfg.ProgressLogFreq = 100 + } + + for i := 0; i < len(summaries); i += cfg.BatchSize { + // Check context cancellation + select { + case <-ctx.Done(): + log.Warn().Int("synced", synced).Int("remaining", len(summaries)-i).Msg("Batch sync cancelled") + return synced, errors + default: + } + + end := min(i+cfg.BatchSize, len(summaries)) + + batch := summaries[i:end] + var docs []Document + + // Collect all documents for this batch + for _, summary := range batch { + docs = append(docs, s.formatSummaryDocs(summary)...) + } + + // Add all documents in one call + if len(docs) > 0 { + if err := s.client.AddDocuments(ctx, docs); err != nil { + log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync summary batch") + errors += len(batch) + continue + } + } + + synced += len(batch) + + // Log progress periodically + if synced%cfg.ProgressLogFreq == 0 || synced == len(summaries) { + log.Debug().Int("synced", synced).Int("total", len(summaries)).Msg("Summary batch sync progress") + } + } + + return synced, errors +} + +// BatchSyncPrompts syncs multiple user prompts efficiently in batches. +func (s *Sync) BatchSyncPrompts(ctx context.Context, prompts []*models.UserPromptWithSession, cfg BatchSyncConfig) (synced int, errors int) { + if len(prompts) == 0 { + return 0, 0 + } + + if cfg.BatchSize <= 0 { + cfg.BatchSize = 50 + } + if cfg.ProgressLogFreq <= 0 { + cfg.ProgressLogFreq = 100 + } + + for i := 0; i < len(prompts); i += cfg.BatchSize { + // Check context cancellation + select { + case <-ctx.Done(): + log.Warn().Int("synced", synced).Int("remaining", len(prompts)-i).Msg("Batch sync cancelled") + return synced, errors + default: + } + + end := min(i+cfg.BatchSize, len(prompts)) + + batch := prompts[i:end] + docs := make([]Document, 0, len(batch)) + + // Collect all documents for this batch + for _, prompt := range batch { + docs = append(docs, Document{ + ID: fmt.Sprintf("prompt_%d", prompt.ID), + Content: prompt.PromptText, + Metadata: map[string]any{ + "sqlite_id": prompt.ID, + "doc_type": "user_prompt", + "sdk_session_id": prompt.SDKSessionID, + "project": prompt.Project, + "scope": "", + "created_at_epoch": prompt.CreatedAtEpoch, + "prompt_number": prompt.PromptNumber, + "field_type": "prompt", + }, + }) + } + + // Add all documents in one call + if len(docs) > 0 { + if err := s.client.AddDocuments(ctx, docs); err != nil { + log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync prompt batch") + errors += len(batch) + continue + } + } + + synced += len(batch) + + // Log progress periodically + if synced%cfg.ProgressLogFreq == 0 || synced == len(prompts) { + log.Debug().Int("synced", synced).Int("total", len(prompts)).Msg("Prompt batch sync progress") + } + } + + return synced, errors +} diff --git a/internal/worker/handlers.go b/internal/worker/handlers.go index 230a56d..056fffa 100644 --- a/internal/worker/handlers.go +++ b/internal/worker/handlers.go @@ -1,25 +1,19 @@ // Package worker provides the main worker service for claude-mnemonic. +// This file contains shared handler utilities and health/status endpoints. +// Domain-specific handlers are split into: +// - handlers_sessions.go: Session lifecycle (init, start, observation, summarize) +// - handlers_context.go: Context/search (search by prompt, file context, inject) +// - handlers_data.go: Data retrieval (observations, summaries, prompts, stats) +// - handlers_update.go: Updates and self-check (update check/apply, self-check) +// - handlers_import_export.go: Import/export/archive operations package worker import ( - "context" "encoding/json" "fmt" "net/http" - "sort" "strconv" - "time" - "github.com/go-chi/chi/v5" - "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" - "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" - "github.com/lukaszraczylo/claude-mnemonic/internal/privacy" - "github.com/lukaszraczylo/claude-mnemonic/internal/reranking" - "github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion" - "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" - "github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk" - "github.com/lukaszraczylo/claude-mnemonic/internal/worker/session" - "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/rs/zerolog/log" ) @@ -29,6 +23,16 @@ const ( DefaultObservationsLimit = 100 // DefaultSummariesLimit is the default number of summaries to return. + DefaultSummariesLimit = 50 + + // DefaultPromptsLimit is the default number of prompts to return. + DefaultPromptsLimit = 100 + + // DefaultSearchLimit is the default number of search results to return. + DefaultSearchLimit = 50 + + // DefaultContextLimit is the default number of context observations to return. + DefaultContextLimit = 50 ) // ObservationTypes is the canonical list of observation types. @@ -42,6 +46,22 @@ var ObservationTypes = []string{ "change", } +// observationTypeSet is a pre-computed map for O(1) type validation. +// Initialized at package load time. +var observationTypeSet = func() map[string]struct{} { + m := make(map[string]struct{}, len(ObservationTypes)) + for _, t := range ObservationTypes { + m[t] = struct{}{} + } + return m +}() + +// IsValidObservationType returns true if the type is valid (O(1) lookup). +func IsValidObservationType(t string) bool { + _, ok := observationTypeSet[t] + return ok +} + // ConceptTypes is the canonical list of valid concept types. // Used by both Go backend and served to frontend. var ConceptTypes = []string{ @@ -75,27 +95,52 @@ var ConceptTypes = []string{ "validation", } -const ( - DefaultSummariesLimit = 50 +// conceptTypeSet is a pre-computed map for O(1) concept validation. +var conceptTypeSet = func() map[string]struct{} { + m := make(map[string]struct{}, len(ConceptTypes)) + for _, t := range ConceptTypes { + m[t] = struct{}{} + } + return m +}() - // DefaultPromptsLimit is the default number of prompts to return. - DefaultPromptsLimit = 100 - - // DefaultSearchLimit is the default number of search results to return. - DefaultSearchLimit = 50 - - // DefaultContextLimit is the default number of context observations to return. - DefaultContextLimit = 50 -) +// IsValidConceptType returns true if the concept type is valid (O(1) lookup). +func IsValidConceptType(t string) bool { + _, ok := conceptTypeSet[t] + return ok +} // writeJSON writes a JSON response with proper error handling. -func writeJSON(w http.ResponseWriter, data interface{}) { +func writeJSON(w http.ResponseWriter, data any) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(data); err != nil { log.Error().Err(err).Msg("Failed to encode JSON response") } } +// parseIDParam parses an ID parameter from a string. +// Returns the parsed ID and true on success, or writes an error response and returns false. +// The entityName is used in error messages (e.g., "observation", "session", "pattern"). +func parseIDParam(w http.ResponseWriter, idStr, entityName string) (int64, bool) { + if idStr == "" { + http.Error(w, entityName+" id required", http.StatusBadRequest) + return 0, false + } + + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + http.Error(w, "invalid "+entityName+" id", http.StatusBadRequest) + return 0, false + } + + return id, true +} + +// formatWarning formats a warning message for use in health responses. +func formatWarning(format string, args ...any) string { + return fmt.Sprintf(format, args...) +} + // handleHealth handles health check requests. // Returns 200 OK immediately (even during init) so hooks can connect quickly. // Use /api/ready for full readiness check. @@ -106,7 +151,7 @@ func (s *Service) handleHealth(w http.ResponseWriter, r *http.Request) { } else if err := s.GetInitError(); err != nil { status = "error" } - writeJSON(w, map[string]interface{}{ + writeJSON(w, map[string]any{ "status": status, "version": s.version, }) @@ -119,6 +164,60 @@ func (s *Service) handleVersion(w http.ResponseWriter, r *http.Request) { }) } +// handleRebuildStatus returns the current status of vector rebuild operations. +// This provides visibility into long-running rebuild operations. +func (s *Service) handleRebuildStatus(w http.ResponseWriter, _ *http.Request) { + s.rebuildStatusMu.RLock() + status := s.rebuildStatus + s.rebuildStatusMu.RUnlock() + + if status == nil { + writeJSON(w, map[string]any{ + "in_progress": false, + "message": "No rebuild operation has been started", + }) + return + } + + writeJSON(w, status) +} + +// handleTriggerVectorRebuild triggers a full vector rebuild operation. +// This rebuilds all vectors from observations, summaries, and prompts. +// Returns 409 Conflict if a rebuild is already in progress. +// Returns 429 Too Many Requests if called too frequently (5 minute cooldown). +func (s *Service) handleTriggerVectorRebuild(w http.ResponseWriter, _ *http.Request) { + // Check rate limiting for expensive operations + if s.expensiveOpLimiter != nil && !s.expensiveOpLimiter.CanRebuild() { + http.Error(w, "rebuild requested too recently, please wait 5 minutes", http.StatusTooManyRequests) + return + } + + // Check if rebuild is already in progress + s.rebuildStatusMu.RLock() + if s.rebuildStatus != nil && s.rebuildStatus.InProgress { + s.rebuildStatusMu.RUnlock() + http.Error(w, "rebuild already in progress", http.StatusConflict) + return + } + s.rebuildStatusMu.RUnlock() + + // Verify we have the necessary components + if s.vectorSync == nil || s.observationStore == nil || s.summaryStore == nil || s.promptStore == nil { + http.Error(w, "vector sync not initialized", http.StatusServiceUnavailable) + return + } + + // Start rebuild in background + s.wg.Add(1) + go s.rebuildAllVectors(s.observationStore, s.summaryStore, s.promptStore, s.vectorSync) + + writeJSON(w, map[string]any{ + "status": "started", + "message": "Vector rebuild started. Check /api/rebuild-status for progress.", + }) +} + // handleReady handles readiness check requests. // Returns 200 only when fully initialized, 503 otherwise. func (s *Service) handleReady(w http.ResponseWriter, r *http.Request) { @@ -147,1168 +246,3 @@ func (s *Service) requireReady(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } - -// SessionInitRequest is the request body for session initialization. -type SessionInitRequest struct { - ClaudeSessionID string `json:"claudeSessionId"` - Project string `json:"project"` - Prompt string `json:"prompt"` - MatchedObservations int `json:"matchedObservations"` -} - -// SessionInitResponse is the response for session initialization. -type SessionInitResponse struct { - SessionDBID int64 `json:"sessionDbId"` - PromptNumber int `json:"promptNumber"` - Skipped bool `json:"skipped,omitempty"` - Reason string `json:"reason,omitempty"` -} - -// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions. -// If the same prompt text is seen within this window, it's considered a duplicate hook invocation. -const DuplicatePromptWindowSeconds = 10 - -// handleSessionInit handles session initialization from user-prompt hook. -// This handler is idempotent - duplicate requests within a short time window -// return the existing prompt data without creating duplicates. -func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) { - var req SessionInitRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - // Privacy check - if privacy.IsEntirelyPrivate(req.Prompt) { - // Create session but skip processing - sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "") - promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID) - - writeJSON(w, SessionInitResponse{ - SessionDBID: sessionID, - PromptNumber: promptNum, - Skipped: true, - Reason: "private", - }) - return - } - - // Clean prompt - cleanedPrompt := privacy.Clean(req.Prompt) - - // DUPLICATE DETECTION: Check if this exact prompt was already saved recently. - // This prevents the bug where the hook fires multiple times for the same user action, - // creating many duplicate prompts with incrementing numbers. - if existingID, existingNum, found := s.promptStore.FindRecentPromptByText(r.Context(), req.ClaudeSessionID, cleanedPrompt, DuplicatePromptWindowSeconds); found { - // Get or create session (idempotent) - sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt) - - log.Debug(). - Int64("sessionId", sessionID). - Int("promptNumber", existingNum). - Int64("promptId", existingID). - Msg("Duplicate prompt detected - returning existing") - - // Return existing prompt data without incrementing or saving again - writeJSON(w, SessionInitResponse{ - SessionDBID: sessionID, - PromptNumber: existingNum, - }) - return - } - - // Create session (idempotent) - sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Increment prompt counter - promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Save user prompt with matched observation count - promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations) - if err != nil { - log.Warn().Err(err).Msg("Failed to save user prompt") - // Non-fatal: continue with session initialization - } else if s.vectorSync != nil { - // Sync to vector DB asynchronously (non-blocking) - now := time.Now() - promptWithSession := &models.UserPromptWithSession{ - UserPrompt: models.UserPrompt{ - ID: promptID, - ClaudeSessionID: req.ClaudeSessionID, - PromptNumber: promptNum, - PromptText: cleanedPrompt, - MatchedObservations: req.MatchedObservations, - CreatedAt: now.Format(time.RFC3339), - CreatedAtEpoch: now.UnixMilli(), - }, - Project: req.Project, - SDKSessionID: req.ClaudeSessionID, - } - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := s.vectorSync.SyncUserPrompt(ctx, promptWithSession); err != nil { - log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to sqlite-vec") - } - }() - } - - log.Info(). - Int64("sessionId", sessionID). - Int("promptNumber", promptNum). - Str("project", req.Project). - Msg("Session initialized") - - // Broadcast prompt event for dashboard refresh - s.sseBroadcaster.Broadcast(map[string]interface{}{ - "type": "prompt", - "action": "created", - "project": req.Project, - }) - - writeJSON(w, SessionInitResponse{ - SessionDBID: sessionID, - PromptNumber: promptNum, - }) -} - -// SessionStartRequest is the request body for starting SDK agent. -type SessionStartRequest struct { - UserPrompt string `json:"userPrompt"` - PromptNumber int `json:"promptNumber"` -} - -// handleSessionStart handles SDK agent session start. -func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) { - idStr := chi.URLParam(r, "id") - id, err := strconv.ParseInt(idStr, 10, 64) - if err != nil { - http.Error(w, "invalid session id", http.StatusBadRequest) - return - } - - var req SessionStartRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - // Initialize session in manager - sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if sess == nil { - http.Error(w, "session not found", http.StatusNotFound) - return - } - - // Session is now registered. Observations will be processed - // asynchronously by the background queue processor (processQueue in service.go). - log.Info(). - Int64("sessionId", id). - Int("promptNumber", req.PromptNumber). - Msg("SDK agent session initialized") - - s.broadcastProcessingStatus() - w.WriteHeader(http.StatusOK) -} - -// ObservationRequest is the request body for posting observations. -type ObservationRequest struct { - ClaudeSessionID string `json:"claudeSessionId"` - Project string `json:"project"` - ToolName string `json:"tool_name"` - ToolInput interface{} `json:"tool_input"` - ToolResponse interface{} `json:"tool_response"` - CWD string `json:"cwd"` -} - -// handleObservation handles observation posting from post-tool-use hook. -func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) { - var req ObservationRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - // Find session - sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if sess == nil { - // Create session on-the-fly with project from request - id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "") - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - sess, _ = s.sessionStore.GetSessionByID(r.Context(), id) - } - - // Queue observation - if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{ - ToolName: req.ToolName, - ToolInput: req.ToolInput, - ToolResponse: req.ToolResponse, - CWD: req.CWD, - }); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - s.broadcastProcessingStatus() - w.WriteHeader(http.StatusOK) -} - -// SubagentCompleteRequest is the request body for subagent completion. -type SubagentCompleteRequest struct { - ClaudeSessionID string `json:"claudeSessionId"` - Project string `json:"project"` -} - -// handleSubagentComplete handles subagent/Task completion notifications. -// This triggers immediate processing of any queued observations from the subagent. -func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) { - var req SubagentCompleteRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - // Find session - sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID) - if err != nil || sess == nil { - // Session not found - subagent may have been in a different context - log.Debug(). - Str("claudeSessionId", req.ClaudeSessionID). - Msg("Subagent complete - no active session found") - w.WriteHeader(http.StatusOK) - return - } - - // Trigger immediate processing of queued observations - messages := s.sessionManager.DrainMessages(sess.ID) - if len(messages) > 0 && s.processor != nil { - log.Info(). - Int64("sessionId", sess.ID). - Int("messages", len(messages)). - Msg("Processing queued observations from subagent") - - for _, msg := range messages { - if msg.Type == session.MessageTypeObservation && msg.Observation != nil { - err := s.processor.ProcessObservation( - r.Context(), - sess.SDKSessionID.String, - sess.Project, - msg.Observation.ToolName, - msg.Observation.ToolInput, - msg.Observation.ToolResponse, - msg.Observation.PromptNumber, - msg.Observation.CWD, - ) - if err != nil { - log.Error().Err(err). - Str("tool", msg.Observation.ToolName). - Msg("Failed to process subagent observation") - } - } - } - } - - s.broadcastProcessingStatus() - w.WriteHeader(http.StatusOK) -} - -// handleGetSessionByClaudeID looks up a session by Claude session ID. -func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) { - claudeSessionID := r.URL.Query().Get("claudeSessionId") - if claudeSessionID == "" { - http.Error(w, "claudeSessionId required", http.StatusBadRequest) - return - } - - session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if session == nil { - http.Error(w, "session not found", http.StatusNotFound) - return - } - - writeJSON(w, session) -} - -// SummarizeRequest is the request body for summarize requests. -type SummarizeRequest struct { - LastUserMessage string `json:"lastUserMessage"` - LastAssistantMessage string `json:"lastAssistantMessage"` -} - -// handleSummarize handles summarize requests from stop hook. -func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) { - idStr := chi.URLParam(r, "id") - id, err := strconv.ParseInt(idStr, 10, 64) - if err != nil { - http.Error(w, "invalid session id", http.StatusBadRequest) - return - } - - var req SummarizeRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - // Queue summarize request - if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - s.broadcastProcessingStatus() - w.WriteHeader(http.StatusOK) -} - -// handleGetObservations returns recent observations. -// Supports optional query parameter for semantic search via sqlite-vec. -func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) { - limit := gorm.ParseLimitParam(r, DefaultObservationsLimit) - project := r.URL.Query().Get("project") - query := r.URL.Query().Get("query") - - var observations []*models.Observation - var err error - var usedVector bool - - // Use vector search if query is provided and vector client is available - if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { - where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") - vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) - if vecErr == nil && len(vectorResults) > 0 { - obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project) - if len(obsIDs) > 0 { - observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit) - if err == nil { - usedVector = true - } - } - } - } - - // Fall back to SQLite if vector search not used - if !usedVector { - if project != "" { - // Strict project filtering for dashboard - only observations from this project - observations, err = s.observationStore.GetObservationsByProjectStrict(r.Context(), project, limit) - } else { - // All projects - observations, err = s.observationStore.GetAllRecentObservations(r.Context(), limit) - } - } - - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Ensure we return empty array, not null - if observations == nil { - observations = []*models.Observation{} - } - writeJSON(w, observations) -} - -// handleGetSummaries returns recent summaries. -// Supports optional query parameter for semantic search via sqlite-vec. -func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) { - limit := gorm.ParseLimitParam(r, DefaultSummariesLimit) - project := r.URL.Query().Get("project") - query := r.URL.Query().Get("query") - - var summaries []*models.SessionSummary - var err error - var usedVector bool - - // Use vector search if query is provided and vector client is available - if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { - where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "") - vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) - if vecErr == nil && len(vectorResults) > 0 { - summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project) - if len(summaryIDs) > 0 { - summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit) - if err == nil { - usedVector = true - } - } - } - } - - // Fall back to SQLite if vector search not used - if !usedVector { - if project != "" { - summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit) - } else { - summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit) - } - } - - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Ensure we return empty array, not null - if summaries == nil { - summaries = []*models.SessionSummary{} - } - writeJSON(w, summaries) -} - -// handleGetPrompts returns recent user prompts. -// Supports optional query parameter for semantic search via sqlite-vec. -func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) { - limit := gorm.ParseLimitParam(r, DefaultPromptsLimit) - project := r.URL.Query().Get("project") - query := r.URL.Query().Get("query") - - var prompts []*models.UserPromptWithSession - var err error - var usedVector bool - - // Use vector search if query is provided and vector client is available - if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { - where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "") - vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) - if vecErr == nil && len(vectorResults) > 0 { - promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project) - if len(promptIDs) > 0 { - prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit) - if err == nil { - usedVector = true - } - } - } - } - - // Fall back to SQLite if vector search not used - if !usedVector { - if project != "" { - prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit) - } else { - prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit) - } - } - - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Ensure we return empty array, not null - if prompts == nil { - prompts = []*models.UserPromptWithSession{} - } - writeJSON(w, prompts) -} - -// handleGetProjects returns all projects. -func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) { - projects, err := s.sessionStore.GetAllProjects(r.Context()) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - writeJSON(w, projects) -} - -// handleGetTypes returns the canonical list of observation and concept types. -// This provides a single source of truth for both backend and frontend. -func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) { - writeJSON(w, map[string]interface{}{ - "observation_types": ObservationTypes, - "concept_types": ConceptTypes, - }) -} - -// handleGetModels returns available embedding models. -func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) { - models := embedding.ListModels() - defaultModel := embedding.GetDefaultModel() - - writeJSON(w, map[string]interface{}{ - "models": models, - "default": defaultModel, - "current": s.embedSvc.Version(), - }) -} - -// handleGetStats returns worker statistics. -func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - retrievalStats := s.GetRetrievalStats(project) - sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context()) - - response := map[string]interface{}{ - "uptime": time.Since(s.startTime).String(), - "activeSessions": s.sessionManager.GetActiveSessionCount(), - "queueDepth": s.sessionManager.GetTotalQueueDepth(), - "isProcessing": s.sessionManager.IsAnySessionProcessing(), - "connectedClients": s.sseBroadcaster.ClientCount(), - "sessionsToday": sessionsToday, - "retrieval": retrievalStats, - "ready": s.ready.Load(), - } - - // Add embedding model info - if s.embedSvc != nil { - response["embeddingModel"] = map[string]interface{}{ - "name": s.embedSvc.Name(), - "version": s.embedSvc.Version(), - "dimensions": s.embedSvc.Dimensions(), - } - } - - // Add vector count - if s.vectorClient != nil { - if count, err := s.vectorClient.Count(r.Context()); err == nil { - response["vectorCount"] = count - } - } - - // Include project-specific observation count if project is specified - if project != "" { - count, err := s.observationStore.GetObservationCount(r.Context(), project) - if err == nil { - response["projectObservations"] = count - response["project"] = project - } - } - - writeJSON(w, response) -} - -// handleGetRetrievalStats returns detailed retrieval statistics. -func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - stats := s.GetRetrievalStats(project) - writeJSON(w, stats) -} - -// handleContextCount returns the count of observations for a project. -func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - if project == "" { - http.Error(w, "project required", http.StatusBadRequest) - return - } - - count, err := s.observationStore.GetObservationCount(r.Context(), project) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - writeJSON(w, map[string]interface{}{ - "project": project, - "count": count, - }) -} - -// handleSearchByPrompt searches observations relevant to a user prompt. -// IMPORTANT: This is on the critical startup path - must be fast! -// No synchronous verification - just filter by staleness and return. -func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - query := r.URL.Query().Get("query") - cwd := r.URL.Query().Get("cwd") - - if project == "" || query == "" { - http.Error(w, "project and query required", http.StatusBadRequest) - return - } - - limit := gorm.ParseLimitParam(r, DefaultSearchLimit) - - var observations []*models.Observation - var err error - var usedVector bool - similarityScores := make(map[int64]float64) // Track similarity per observation - - // Get threshold settings from config - threshold := s.config.ContextRelevanceThreshold - maxResults := s.config.ContextMaxPromptResults - - // Generate expanded queries if query expander is available - var expandedQueries []expansion.ExpandedQuery - var detectedIntent string - if s.queryExpander != nil { - cfg := expansion.DefaultConfig() - cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional - expandedQueries = s.queryExpander.Expand(r.Context(), query, cfg) - if len(expandedQueries) > 0 { - detectedIntent = string(expandedQueries[0].Intent) - } - } - if len(expandedQueries) == 0 { - // Fallback to just the original query - expandedQueries = []expansion.ExpandedQuery{ - {Query: query, Weight: 1.0, Source: "original"}, - } - } - - // Try vector search first if available - if s.vectorClient != nil && s.vectorClient.IsConnected() { - where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") - - // Search with each expanded query and merge results - allVectorResults := make([]sqlitevec.QueryResult, 0) - queryWeights := make(map[string]float64) // Track weights for score merging - - for _, eq := range expandedQueries { - vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where) - if vecErr == nil && len(vectorResults) > 0 { - // Apply weight to similarity scores before merging - for i := range vectorResults { - vectorResults[i].Similarity *= eq.Weight - } - allVectorResults = append(allVectorResults, vectorResults...) - queryWeights[eq.Query] = eq.Weight - } - } - - if len(allVectorResults) > 0 { - // Filter by relevance threshold before extracting IDs - // Use a slightly lower threshold for expanded queries - effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries - filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0) - - // Build similarity map for filtered results (keeping highest weighted score per observation) - for _, vr := range filteredResults { - if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok { - id := int64(sqliteID) - // Keep the highest score for each observation - if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing { - similarityScores[id] = vr.Similarity - } - } - } - - // Extract observation IDs with project/scope filtering using shared helper - obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project) - - if len(obsIDs) > 0 { - // Fetch full observations from SQLite - observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit) - if err == nil { - usedVector = true - } - } - } - } - - // Fall back to FTS if vector search not available or returned no results - if !usedVector || len(observations) == 0 { - observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit) - if err != nil { - // FTS might fail if query has special chars, try without - log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent") - observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - } - - // Fast staleness filter - NO verification (that's too slow for interactive use) - // Just check mtimes and exclude obviously stale observations - var staleCount int - freshObservations := make([]*models.Observation, 0, len(observations)) - - for _, obs := range observations { - if len(obs.FileMtimes) > 0 && cwd != "" { - var paths []string - for path := range obs.FileMtimes { - paths = append(paths, path) - } - currentMtimes := sdk.GetFileMtimes(paths, cwd) - - if obs.CheckStaleness(currentMtimes) { - // Stale - exclude but don't verify (too slow) - // Queue for background verification instead - staleCount++ - s.queueStaleVerification(obs.ID, cwd) - continue - } - } - freshObservations = append(freshObservations, obs) - } - - // Apply cross-encoder reranking if available - var reranked bool - if s.reranker != nil && len(freshObservations) > 0 && usedVector { - // Build candidates from observations with their bi-encoder scores - candidates := make([]reranking.Candidate, len(freshObservations)) - for i, obs := range freshObservations { - content := obs.Title.String - if obs.Narrative.Valid && obs.Narrative.String != "" { - content = content + " " + obs.Narrative.String - } - candidates[i] = reranking.Candidate{ - ID: fmt.Sprintf("%d", obs.ID), - Content: content, - Score: similarityScores[obs.ID], - Metadata: map[string]any{"obs_idx": i}, - } - } - - // Rerank using cross-encoder - use pure mode or combined scores - var rerankResults []reranking.RerankResult - var rerankErr error - if s.config.RerankingPureMode { - rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults) - } else { - rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults) - } - if rerankErr != nil { - log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order") - } else if len(rerankResults) > 0 { - // Update similarity scores with reranked scores - for _, rr := range rerankResults { - if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil { - similarityScores[id] = rr.CombinedScore - } - } - - // Reorder observations based on rerank results - reorderedObs := make([]*models.Observation, 0, len(rerankResults)) - obsMap := make(map[int64]*models.Observation) - for _, obs := range freshObservations { - obsMap[obs.ID] = obs - } - for _, rr := range rerankResults { - if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil { - if obs, ok := obsMap[id]; ok { - reorderedObs = append(reorderedObs, obs) - } - } - } - freshObservations = reorderedObs - reranked = true - - log.Debug(). - Int("candidates", len(candidates)). - Int("returned", len(rerankResults)). - Msg("Cross-encoder reranking complete") - } - } - - // Cluster similar observations to remove duplicates - clusteredObservations := clusterObservations(freshObservations, 0.4) - - // Sort by similarity score (highest first) if we have scores and didn't rerank - if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked { - sort.Slice(clusteredObservations, func(i, j int) bool { - scoreI := similarityScores[clusteredObservations[i].ID] - scoreJ := similarityScores[clusteredObservations[j].ID] - return scoreI > scoreJ - }) - } - - // Apply max results cap if configured - if maxResults > 0 && len(clusteredObservations) > maxResults { - clusteredObservations = clusteredObservations[:maxResults] - } - - // Record retrieval stats (no verification done, so verified=0, deleted=0) - s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, true) - - // Increment retrieval counts for scoring (async, non-blocking) - if len(clusteredObservations) > 0 { - ids := make([]int64, len(clusteredObservations)) - for i, obs := range clusteredObservations { - ids[i] = obs.ID - } - s.incrementRetrievalCounts(ids) - } - - log.Info(). - Str("project", project). - Str("query", query). - Str("intent", detectedIntent). - Int("expansions", len(expandedQueries)). - Int("found", len(clusteredObservations)). - Int("stale_excluded", staleCount). - Float64("threshold", threshold). - Msg("Prompt-based observation search") - - // Build response with similarity scores - obsWithScores := make([]map[string]interface{}, len(clusteredObservations)) - for i, obs := range clusteredObservations { - obsMap := obs.ToMap() - if score, ok := similarityScores[obs.ID]; ok { - obsMap["similarity"] = score - } - obsWithScores[i] = obsMap - } - - // Build expansion info for response - expansionInfo := make([]map[string]interface{}, len(expandedQueries)) - for i, eq := range expandedQueries { - expansionInfo[i] = map[string]interface{}{ - "query": eq.Query, - "weight": eq.Weight, - "source": eq.Source, - } - } - - writeJSON(w, map[string]interface{}{ - "project": project, - "query": query, - "intent": detectedIntent, - "expansions": expansionInfo, - "observations": obsWithScores, - "threshold": threshold, - "max_results": maxResults, - }) -} - -// handleContextInject returns context for injection at session start. -// IMPORTANT: This is on the critical startup path - must be fast! -// No synchronous verification - just filter by staleness and return. -func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - if project == "" { - http.Error(w, "project required", http.StatusBadRequest) - return - } - - cwd := r.URL.Query().Get("cwd") - if cwd == "" { - cwd = "/" - } - - // Limit observations for fast startup (configurable, default 100) - limit := s.config.ContextObservations - if limit <= 0 { - limit = DefaultContextLimit - } - - // Full count determines how many observations get full detail (configurable, default 25) - fullCount := s.config.ContextFullCount - if fullCount <= 0 { - fullCount = 25 - } - - // Get recent observations - observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Fast staleness filter - NO verification (that's too slow for startup) - var staleCount int - freshObservations := make([]*models.Observation, 0, len(observations)) - - for _, obs := range observations { - if len(obs.FileMtimes) > 0 { - var paths []string - for path := range obs.FileMtimes { - paths = append(paths, path) - } - currentMtimes := sdk.GetFileMtimes(paths, cwd) - - if obs.CheckStaleness(currentMtimes) { - // Stale - exclude but don't verify (too slow) - // Queue for background verification instead - staleCount++ - s.queueStaleVerification(obs.ID, cwd) - continue - } - } - freshObservations = append(freshObservations, obs) - } - - // Cluster similar observations to remove duplicates - clusteredObservations := clusterObservations(freshObservations, 0.4) - duplicatesRemoved := len(freshObservations) - len(clusteredObservations) - - // Record retrieval stats (no verification done) - s.recordRetrievalStats(project, int64(len(clusteredObservations)), 0, 0, false) - - // Increment retrieval counts for scoring (async, non-blocking) - if len(clusteredObservations) > 0 { - ids := make([]int64, len(clusteredObservations)) - for i, obs := range clusteredObservations { - ids[i] = obs.ID - } - s.incrementRetrievalCounts(ids) - } - - log.Info(). - Str("project", project). - Int("total", len(observations)). - Int("fresh", len(freshObservations)). - Int("clustered", len(clusteredObservations)). - Int("duplicates", duplicatesRemoved). - Int("stale_excluded", staleCount). - Msg("Context injection with clustering") - - writeJSON(w, map[string]interface{}{ - "project": project, - "observations": clusteredObservations, - "full_count": fullCount, - "stale_excluded": staleCount, - "duplicates_removed": duplicatesRemoved, - }) -} - -// handleUpdateCheck checks for available updates. -func (s *Service) handleUpdateCheck(w http.ResponseWriter, r *http.Request) { - info, err := s.updater.CheckForUpdate(r.Context()) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - writeJSON(w, info) -} - -// handleUpdateApply downloads and applies an available update. -func (s *Service) handleUpdateApply(w http.ResponseWriter, r *http.Request) { - // First check for update - info, err := s.updater.CheckForUpdate(r.Context()) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - if !info.Available { - writeJSON(w, map[string]interface{}{ - "success": false, - "message": "No update available", - }) - return - } - - // Apply update in background - go func() { - if err := s.updater.ApplyUpdate(s.ctx, info); err != nil { - log.Error().Err(err).Msg("Update failed") - } - }() - - writeJSON(w, map[string]interface{}{ - "success": true, - "message": "Update started", - "version": info.LatestVersion, - }) -} - -// handleUpdateStatus returns the current update status. -func (s *Service) handleUpdateStatus(w http.ResponseWriter, r *http.Request) { - status := s.updater.GetStatus() - writeJSON(w, status) -} - -// ComponentHealth represents the health status of a single component. -type ComponentHealth struct { - Name string `json:"name"` - Status string `json:"status"` // "healthy", "degraded", "unhealthy" - Message string `json:"message,omitempty"` -} - -// SelfCheckResponse contains the health status of all components. -type SelfCheckResponse struct { - Overall string `json:"overall"` // "healthy", "degraded", "unhealthy" - Version string `json:"version"` - Uptime string `json:"uptime"` - Components []ComponentHealth `json:"components"` -} - -// handleSelfCheck returns the health status of all components. -func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) { - components := []ComponentHealth{} - overall := "healthy" - - // Check Worker Service - workerStatus := ComponentHealth{Name: "Worker Service", Status: "healthy"} - if !s.ready.Load() { - if err := s.GetInitError(); err != nil { - workerStatus.Status = "unhealthy" - workerStatus.Message = err.Error() - overall = "unhealthy" - } else { - workerStatus.Status = "degraded" - workerStatus.Message = "Initializing" - if overall == "healthy" { - overall = "degraded" - } - } - } - components = append(components, workerStatus) - - // Check SQLite Database - dbStatus := ComponentHealth{Name: "SQLite Database", Status: "healthy"} - if s.store == nil { - dbStatus.Status = "unhealthy" - dbStatus.Message = "Not initialized" - overall = "unhealthy" - } else if err := s.store.Ping(); err != nil { - dbStatus.Status = "unhealthy" - dbStatus.Message = err.Error() - overall = "unhealthy" - } - components = append(components, dbStatus) - - // Check Vector DB (sqlite-vec) - vectorStatus := ComponentHealth{Name: "Vector DB", Status: "healthy"} - if s.vectorClient == nil { - vectorStatus.Status = "degraded" - vectorStatus.Message = "Not configured" - if overall == "healthy" { - overall = "degraded" - } - } else if !s.vectorClient.IsConnected() { - vectorStatus.Status = "degraded" - vectorStatus.Message = "Not connected" - if overall == "healthy" { - overall = "degraded" - } - } - components = append(components, vectorStatus) - - // Check SDK Processor - sdkStatus := ComponentHealth{Name: "SDK Processor", Status: "healthy"} - if s.processor == nil { - sdkStatus.Status = "degraded" - sdkStatus.Message = "Not initialized" - if overall == "healthy" { - overall = "degraded" - } - } else if !s.processor.IsAvailable() { - sdkStatus.Status = "degraded" - sdkStatus.Message = "Claude CLI not available" - if overall == "healthy" { - overall = "degraded" - } - } - components = append(components, sdkStatus) - - // Check SSE Broadcaster - sseStatus := ComponentHealth{Name: "SSE Broadcaster", Status: "healthy"} - if s.sseBroadcaster == nil { - sseStatus.Status = "unhealthy" - sseStatus.Message = "Not initialized" - overall = "unhealthy" - } - components = append(components, sseStatus) - - // Check Cross-Encoder Reranker - rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"} - if !s.config.RerankingEnabled { - rerankerStatus.Status = "degraded" - rerankerStatus.Message = "Disabled in config" - if overall == "healthy" { - overall = "degraded" - } - } else if s.reranker == nil { - rerankerStatus.Status = "degraded" - rerankerStatus.Message = "Not initialized" - if overall == "healthy" { - overall = "degraded" - } - } else { - // Verify reranker is functional using Score - _, normalizedScore, err := s.reranker.Score("test query", "test document") - if err != nil { - rerankerStatus.Status = "unhealthy" - rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err) - if overall == "healthy" { - overall = "degraded" - } - } else { - rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore) - } - } - components = append(components, rerankerStatus) - - // Calculate uptime - uptime := time.Since(s.startTime).Round(time.Second).String() - - writeJSON(w, SelfCheckResponse{ - Overall: overall, - Version: s.version, - Uptime: uptime, - Components: components, - }) -} - -// handleUpdateRestart restarts the worker with the new binary (after update). -func (s *Service) handleUpdateRestart(w http.ResponseWriter, r *http.Request) { - status := s.updater.GetStatus() - if status.State != "done" { - http.Error(w, "no update has been applied", http.StatusBadRequest) - return - } - - // Send response before restarting - writeJSON(w, map[string]interface{}{ - "success": true, - "message": "Restarting worker...", - }) - - // Flush the response - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - - // Restart in background after response is sent - go func() { - if err := s.updater.Restart(); err != nil { - log.Error().Err(err).Msg("Failed to restart worker") - } - }() -} - -// handleRestart restarts the worker process (general restart, not tied to update). -func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) { - log.Info().Msg("Manual restart requested via API") - - // Send response before restarting - writeJSON(w, map[string]interface{}{ - "success": true, - "message": "Restarting worker...", - "version": s.version, - }) - - // Flush the response - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - - // Restart in background after response is sent - go func() { - // Small delay to ensure response is sent - time.Sleep(100 * time.Millisecond) - if err := s.updater.Restart(); err != nil { - log.Error().Err(err).Msg("Failed to restart worker") - } - }() -} diff --git a/internal/worker/handlers_context.go b/internal/worker/handlers_context.go new file mode 100644 index 0000000..c69e905 --- /dev/null +++ b/internal/worker/handlers_context.go @@ -0,0 +1,677 @@ +// Package worker provides context and search-related HTTP handlers. +package worker + +import ( + "context" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/lukaszraczylo/claude-mnemonic/internal/reranking" + "github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" +) + +// handleSearchByPrompt searches observations relevant to a user prompt. +// IMPORTANT: This is on the critical startup path - must be fast! +// No synchronous verification - just filter by staleness and return. +func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + query := r.URL.Query().Get("query") + cwd := r.URL.Query().Get("cwd") + + if project == "" || query == "" { + http.Error(w, "project and query required", http.StatusBadRequest) + return + } + + // Validate project name to prevent path traversal + if err := ValidateProjectName(project); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + limit := gorm.ParseLimitParamWithMax(r, DefaultSearchLimit, 200) + + var observations []*models.Observation + var err error + var usedVector bool + similarityScores := make(map[int64]float64) // Track similarity per observation + + // Get threshold settings from config + threshold := s.config.ContextRelevanceThreshold + maxResults := s.config.ContextMaxPromptResults + + // Generate expanded queries if query expander is available + // Use timeout context to prevent query expansion from blocking + var expandedQueries []expansion.ExpandedQuery + var detectedIntent string + if s.queryExpander != nil { + expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second) + cfg := expansion.DefaultConfig() + cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional + expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg) + expandCancel() // Cancel immediately after use (defer not needed - no panic possible between creation and here) + if len(expandedQueries) > 0 { + detectedIntent = string(expandedQueries[0].Intent) + } + } + if len(expandedQueries) == 0 { + // Fallback to just the original query + expandedQueries = []expansion.ExpandedQuery{ + {Query: query, Weight: 1.0, Source: "original"}, + } + } + + // Try vector search first if available + var vectorSearchFailed bool + if s.vectorClient != nil && s.vectorClient.IsConnected() { + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") + + // Search with each expanded query and merge results + // Pre-allocate with estimated capacity to avoid repeated reallocation + estimatedCapacity := len(expandedQueries) * limit * 2 + allVectorResults := make([]sqlitevec.QueryResult, 0, estimatedCapacity) + queryWeights := make(map[string]float64, len(expandedQueries)) + var vectorErrors int + + for _, eq := range expandedQueries { + vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where) + if vecErr != nil { + vectorErrors++ + log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed") + } else if len(vectorResults) > 0 { + // Apply weight to similarity scores before merging + for i := range vectorResults { + vectorResults[i].Similarity *= eq.Weight + } + allVectorResults = append(allVectorResults, vectorResults...) + queryWeights[eq.Query] = eq.Weight + } + } + + // Track if vector search had issues + if vectorErrors > 0 && vectorErrors == len(expandedQueries) { + vectorSearchFailed = true + log.Warn().Int("errors", vectorErrors).Str("project", project).Msg("All vector queries failed, falling back to FTS") + } + + if len(allVectorResults) > 0 { + // Filter by relevance threshold before extracting IDs + // Use a slightly lower threshold for expanded queries + effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries + filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0) + + // Build similarity map for filtered results (keeping highest weighted score per observation) + for _, vr := range filteredResults { + if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok { + id := int64(sqliteID) + // Keep the highest score for each observation + if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing { + similarityScores[id] = vr.Similarity + } + } + } + + // Extract observation IDs with project/scope filtering using shared helper + obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project) + + if len(obsIDs) > 0 { + // Fetch full observations from SQLite + observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit) + if err == nil { + usedVector = true + } + } + } + } + + // Fall back to FTS if vector search not available, failed, or returned no results + if !usedVector || len(observations) == 0 { + if vectorSearchFailed { + log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure") + } + observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit) + if err != nil { + // FTS might fail if query has special chars, try without + log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent") + observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + } + + // Fast staleness filter - NO verification (that's too slow for interactive use) + // Just check mtimes and exclude obviously stale observations + var staleCount int + freshObservations := make([]*models.Observation, 0, len(observations)) + + for _, obs := range observations { + if len(obs.FileMtimes) > 0 && cwd != "" { + var paths []string + for path := range obs.FileMtimes { + paths = append(paths, path) + } + currentMtimes := sdk.GetFileMtimes(paths, cwd) + + if obs.CheckStaleness(currentMtimes) { + // Stale - exclude but don't verify (too slow) + // Queue for background verification instead + staleCount++ + s.queueStaleVerification(obs.ID, cwd) + continue + } + } + freshObservations = append(freshObservations, obs) + } + + // Apply cross-encoder reranking if available + var reranked bool + if s.reranker != nil && len(freshObservations) > 0 && usedVector { + // Build candidates from observations with their bi-encoder scores + candidates := make([]reranking.Candidate, len(freshObservations)) + for i, obs := range freshObservations { + // Use strings.Builder for efficient concatenation + var content string + if obs.Narrative.Valid && obs.Narrative.String != "" { + var sb strings.Builder + sb.Grow(len(obs.Title.String) + 1 + len(obs.Narrative.String)) + sb.WriteString(obs.Title.String) + sb.WriteByte(' ') + sb.WriteString(obs.Narrative.String) + content = sb.String() + } else { + content = obs.Title.String + } + candidates[i] = reranking.Candidate{ + ID: strconv.FormatInt(obs.ID, 10), // Faster than fmt.Sprintf + Content: content, + Score: similarityScores[obs.ID], + Metadata: map[string]any{"obs_idx": i}, + } + } + + // Rerank using cross-encoder - use pure mode or combined scores + var rerankResults []reranking.RerankResult + var rerankErr error + if s.config.RerankingPureMode { + rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults) + } else { + rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults) + } + if rerankErr != nil { + log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order") + } else if len(rerankResults) > 0 { + // Update similarity scores with reranked scores + for _, rr := range rerankResults { + if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil { + similarityScores[id] = rr.CombinedScore + } + } + + // Reorder observations based on rerank results + reorderedObs := make([]*models.Observation, 0, len(rerankResults)) + obsMap := make(map[int64]*models.Observation) + for _, obs := range freshObservations { + obsMap[obs.ID] = obs + } + for _, rr := range rerankResults { + if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil { + if obs, ok := obsMap[id]; ok { + reorderedObs = append(reorderedObs, obs) + } + } + } + freshObservations = reorderedObs + reranked = true + + log.Debug(). + Int("candidates", len(candidates)). + Int("returned", len(rerankResults)). + Msg("Cross-encoder reranking complete") + } + } + + // Cluster similar observations to remove duplicates + clusteredObservations := clusterObservations(freshObservations, 0.4) + duplicatesRemoved := len(freshObservations) - len(clusteredObservations) + + // Sort by similarity score (highest first) if we have scores and didn't rerank + if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked { + sort.Slice(clusteredObservations, func(i, j int) bool { + scoreI := similarityScores[clusteredObservations[i].ID] + scoreJ := similarityScores[clusteredObservations[j].ID] + return scoreI > scoreJ + }) + } + + // Apply max results cap if configured + if maxResults > 0 && len(clusteredObservations) > maxResults { + clusteredObservations = clusteredObservations[:maxResults] + } + + // Record retrieval stats with staleness metrics + s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0, + int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), true) + + // Increment retrieval counts for scoring (async, non-blocking) + if len(clusteredObservations) > 0 { + ids := make([]int64, len(clusteredObservations)) + for i, obs := range clusteredObservations { + ids[i] = obs.ID + } + s.incrementRetrievalCounts(ids) + } + + log.Info(). + Str("project", project). + Str("query", query). + Str("intent", detectedIntent). + Int("expansions", len(expandedQueries)). + Int("found", len(clusteredObservations)). + Int("stale_excluded", staleCount). + Float64("threshold", threshold). + Msg("Prompt-based observation search") + + // Build response with similarity scores + obsWithScores := make([]map[string]any, len(clusteredObservations)) + for i, obs := range clusteredObservations { + obsMap := obs.ToMap() + if score, ok := similarityScores[obs.ID]; ok { + obsMap["similarity"] = score + } + obsWithScores[i] = obsMap + } + + // Build expansion info for response + expansionInfo := make([]map[string]any, len(expandedQueries)) + for i, eq := range expandedQueries { + expansionInfo[i] = map[string]any{ + "query": eq.Query, + "weight": eq.Weight, + "source": eq.Source, + } + } + + // Track this search for analytics + s.trackSearchQuery(query, project, "observations", len(clusteredObservations), usedVector) + + writeJSON(w, map[string]any{ + "project": project, + "query": query, + "intent": detectedIntent, + "expansions": expansionInfo, + "observations": obsWithScores, + "threshold": threshold, + "max_results": maxResults, + }) +} + +// handleFileContext returns observations relevant to specific files being worked on. +// Uses vector similarity search to find observations that mention or relate to the given files. +func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + if project == "" { + http.Error(w, "project required", http.StatusBadRequest) + return + } + + filesParam := r.URL.Query().Get("files") + if filesParam == "" { + http.Error(w, "files required", http.StatusBadRequest) + return + } + + // Parse comma-separated file paths + files := strings.Split(filesParam, ",") + if len(files) == 0 { + http.Error(w, "at least one file required", http.StatusBadRequest) + return + } + + // Limit to reasonable number of files + maxFiles := 20 + if len(files) > maxFiles { + files = files[:maxFiles] + } + + // Get limit parameter (default 10 per file) + limitStr := r.URL.Query().Get("limit") + limit := 10 + if limitStr != "" { + if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 && parsed <= 50 { + limit = parsed + } + } + + // Search for observations related to each file in parallel + ctx := r.Context() + + // Check if vector search is available + if s.vectorClient == nil || !s.vectorClient.IsConnected() { + writeJSON(w, map[string]any{ + "files": files, + "results": map[string]any{}, + "count": 0, + "error": "vector search not available", + }) + return + } + + // Prepare for parallel execution + type fileResult struct { + file string + results []map[string]any + obsIDs []int64 // Track observation IDs for deduplication + } + + resultsChan := make(chan fileResult, len(files)) + sem := make(chan struct{}, 5) // Limit concurrency to 5 parallel searches + var wg sync.WaitGroup + + for _, file := range files { + file = strings.TrimSpace(file) + if file == "" { + continue + } + + wg.Add(1) + go func(file string) { + defer wg.Done() + sem <- struct{}{} // Acquire semaphore + defer func() { <-sem }() // Release semaphore + + // Build search query from file path + query := buildFileQuery(file) + + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") + vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where) + if vecErr != nil { + log.Warn().Err(vecErr).Str("file", file).Msg("Vector search failed for file context") + return + } + + // Extract observation IDs from vector results + obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project) + if len(obsIDs) == 0 { + return + } + + // Fetch observations + observations, err := s.observationStore.GetObservationsByIDs(ctx, obsIDs, "score_desc", limit*2) + if err != nil { + log.Warn().Err(err).Str("file", file).Msg("Failed to fetch observations for file context") + return + } + + // Pre-build score map from vector results (O(n) instead of O(n²)) + scoreMap := make(map[int64]float64, len(vectorResults)) + var avgScore float64 + for _, vr := range vectorResults { + avgScore += vr.Similarity + // Parse observation ID from vector result ID (format: obs_{id}_{field}) + // Use index-based parsing to avoid slice allocation from strings.Split + if len(vr.ID) > 4 && vr.ID[:4] == "obs_" { + rest := vr.ID[4:] // Skip "obs_" + underscoreIdx := strings.IndexByte(rest, '_') + var idStr string + if underscoreIdx >= 0 { + idStr = rest[:underscoreIdx] + } else { + idStr = rest + } + if id, parseErr := strconv.ParseInt(idStr, 10, 64); parseErr == nil { + // Keep highest score for each observation + if existing, exists := scoreMap[id]; !exists || vr.Similarity > existing { + scoreMap[id] = vr.Similarity + } + } + } + } + if len(vectorResults) > 0 { + avgScore /= float64(len(vectorResults)) + } + + fileResults := make([]map[string]any, 0, limit) + var usedIDs []int64 + for _, obs := range observations { + // Check project scope + if obs.Scope == "project" && obs.Project != project { + continue + } + + // O(1) score lookup instead of O(n) nested loop + score, found := scoreMap[obs.ID] + if !found { + // Use average score as fallback + score = avgScore + } + + // Only include if score is above threshold + if score < 0.3 { + continue + } + + fileResults = append(fileResults, map[string]any{ + "id": obs.ID, + "title": obs.Title.String, + "type": obs.Type, + "narrative": obs.Narrative.String, + "facts": obs.Facts, + "score": score, + }) + usedIDs = append(usedIDs, obs.ID) + + if len(fileResults) >= limit { + break + } + } + + if len(fileResults) > 0 { + resultsChan <- fileResult{file: file, results: fileResults, obsIDs: usedIDs} + } + }(file) + } + + // Close channel when all goroutines complete + go func() { + wg.Wait() + close(resultsChan) + }() + + // Collect results and deduplicate + allResults := make(map[string]any) + seenObservationIDs := make(map[int64]bool) + + for res := range resultsChan { + // Filter out duplicates that were already seen in other files + dedupedResults := make([]map[string]any, 0, len(res.results)) + for i, r := range res.results { + obsID := res.obsIDs[i] + if !seenObservationIDs[obsID] { + seenObservationIDs[obsID] = true + dedupedResults = append(dedupedResults, r) + } + } + if len(dedupedResults) > 0 { + allResults[res.file] = dedupedResults + } + } + + writeJSON(w, map[string]any{ + "files": files, + "results": allResults, + "count": len(allResults), + }) +} + +// buildFileQuery extracts meaningful search terms from a file path. +func buildFileQuery(filePath string) string { + // Remove common prefixes and extensions + path := strings.TrimPrefix(filePath, "/") + + // Extract the filename and directory + parts := strings.Split(path, "/") + meaningful := make([]string, 0, len(parts)) + + for _, part := range parts { + // Skip common directory names that aren't meaningful + switch strings.ToLower(part) { + case "src", "lib", "internal", "pkg", "cmd", "api", "app", "test", "tests", "spec", "specs": + continue + default: + // Remove file extension + if idx := strings.LastIndex(part, "."); idx > 0 { + part = part[:idx] + } + // Convert camelCase/PascalCase to spaces + part = splitCamelCase(part) + // Convert snake_case to spaces + part = strings.ReplaceAll(part, "_", " ") + // Convert kebab-case to spaces + part = strings.ReplaceAll(part, "-", " ") + meaningful = append(meaningful, part) + } + } + + return strings.Join(meaningful, " ") +} + +// splitCamelCase splits camelCase or PascalCase into separate words. +func splitCamelCase(s string) string { + var result strings.Builder + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + result.WriteRune(' ') + } + result.WriteRune(r) + } + return result.String() +} + +// handleContextInject returns context for injection at session start. +// IMPORTANT: This is on the critical startup path - must be fast! +// No synchronous verification - just filter by staleness and return. +func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + if project == "" { + http.Error(w, "project required", http.StatusBadRequest) + return + } + + // Validate project name to prevent path traversal + if err := ValidateProjectName(project); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + cwd := r.URL.Query().Get("cwd") + if cwd == "" { + cwd = "/" + } + + // Limit observations for fast startup (configurable, default 100) + limit := s.config.ContextObservations + if limit <= 0 { + limit = DefaultContextLimit + } + + // Full count determines how many observations get full detail (configurable, default 25) + fullCount := s.config.ContextFullCount + if fullCount <= 0 { + fullCount = 25 + } + + // Get recent observations + observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Fast staleness filter - NO verification (that's too slow for startup) + var staleCount int + freshObservations := make([]*models.Observation, 0, len(observations)) + + for _, obs := range observations { + if len(obs.FileMtimes) > 0 { + var paths []string + for path := range obs.FileMtimes { + paths = append(paths, path) + } + currentMtimes := sdk.GetFileMtimes(paths, cwd) + + if obs.CheckStaleness(currentMtimes) { + // Stale - exclude but don't verify (too slow) + // Queue for background verification instead + staleCount++ + s.queueStaleVerification(obs.ID, cwd) + continue + } + } + freshObservations = append(freshObservations, obs) + } + + // Cluster similar observations to remove duplicates + clusteredObservations := clusterObservations(freshObservations, 0.4) + duplicatesRemoved := len(freshObservations) - len(clusteredObservations) + + // Record retrieval stats with staleness metrics + s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0, + int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), false) + + // Increment retrieval counts for scoring (async, non-blocking) + if len(clusteredObservations) > 0 { + ids := make([]int64, len(clusteredObservations)) + for i, obs := range clusteredObservations { + ids[i] = obs.ID + } + s.incrementRetrievalCounts(ids) + } + + log.Info(). + Str("project", project). + Int("total", len(observations)). + Int("fresh", len(freshObservations)). + Int("clustered", len(clusteredObservations)). + Int("duplicates", duplicatesRemoved). + Int("stale_excluded", staleCount). + Msg("Context injection with clustering") + + writeJSON(w, map[string]any{ + "project": project, + "observations": clusteredObservations, + "full_count": fullCount, + "stale_excluded": staleCount, + "duplicates_removed": duplicatesRemoved, + }) +} + +// handleContextCount returns the count of observations for a project. +func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + if project == "" { + http.Error(w, "project required", http.StatusBadRequest) + return + } + + count, err := s.getCachedObservationCount(r.Context(), project) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + writeJSON(w, map[string]any{ + "project": project, + "count": count, + }) +} diff --git a/internal/worker/handlers_data.go b/internal/worker/handlers_data.go new file mode 100644 index 0000000..a0b0863 --- /dev/null +++ b/internal/worker/handlers_data.go @@ -0,0 +1,595 @@ +// Package worker provides data retrieval HTTP handlers. +package worker + +import ( + "encoding/json" + "net/http" + "runtime" + "sort" + "strings" + "time" + + "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" + "github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" +) + +// handleGetObservations returns recent observations. +// Supports optional query parameter for semantic search via sqlite-vec. +// Supports pagination via limit and offset query parameters. +func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) { + pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit) + project := r.URL.Query().Get("project") + query := r.URL.Query().Get("query") + + // Validate project name to prevent path traversal + if err := ValidateProjectName(project); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var observations []*models.Observation + var total int64 + var err error + var usedVector bool + + // Use vector search if query is provided and vector client is available + if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") + vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where) + if vecErr == nil && len(vectorResults) > 0 { + obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project) + if len(obsIDs) > 0 { + observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit) + if err == nil { + usedVector = true + total = int64(len(observations)) // Vector search doesn't have total, use returned count + } + } + } + } + + // Fall back to SQLite if vector search not used + if !usedVector { + if project != "" { + // Strict project filtering for dashboard - only observations from this project + observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset) + } else { + // All projects + observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset) + } + } + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Ensure we return empty array, not null + if observations == nil { + observations = []*models.Observation{} + } + + // Track search if query was provided + if query != "" { + s.trackSearchQuery(query, project, "observations", len(observations), usedVector) + } + + // Return paginated response + writeJSON(w, map[string]any{ + "observations": observations, + "total": total, + "limit": pagination.Limit, + "offset": pagination.Offset, + "hasMore": int64(pagination.Offset)+int64(len(observations)) < total, + }) +} + +// handleGetSummaries returns recent summaries. +// Supports optional query parameter for semantic search via sqlite-vec. +func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) { + limit := gorm.ParseLimitParam(r, DefaultSummariesLimit) + project := r.URL.Query().Get("project") + query := r.URL.Query().Get("query") + + // Validate project name to prevent path traversal + if err := ValidateProjectName(project); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var summaries []*models.SessionSummary + var err error + var usedVector bool + + // Use vector search if query is provided and vector client is available + if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "") + vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) + if vecErr == nil && len(vectorResults) > 0 { + summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project) + if len(summaryIDs) > 0 { + summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit) + if err == nil { + usedVector = true + } + } + } + } + + // Fall back to SQLite if vector search not used + if !usedVector { + if project != "" { + summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit) + } else { + summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit) + } + } + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Ensure we return empty array, not null + if summaries == nil { + summaries = []*models.SessionSummary{} + } + writeJSON(w, summaries) +} + +// handleGetPrompts returns recent user prompts. +// Supports optional query parameter for semantic search via sqlite-vec. +func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) { + limit := gorm.ParseLimitParam(r, DefaultPromptsLimit) + project := r.URL.Query().Get("project") + query := r.URL.Query().Get("query") + + // Validate project name to prevent path traversal + if err := ValidateProjectName(project); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var prompts []*models.UserPromptWithSession + var err error + var usedVector bool + + // Use vector search if query is provided and vector client is available + if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() { + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "") + vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where) + if vecErr == nil && len(vectorResults) > 0 { + promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project) + if len(promptIDs) > 0 { + prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit) + if err == nil { + usedVector = true + } + } + } + } + + // Fall back to SQLite if vector search not used + if !usedVector { + if project != "" { + prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit) + } else { + prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit) + } + } + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Ensure we return empty array, not null + if prompts == nil { + prompts = []*models.UserPromptWithSession{} + } + writeJSON(w, prompts) +} + +// handleGetProjects returns all projects. +// Response is cacheable for 5 minutes since project list changes infrequently. +func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) { + projects, err := s.sessionStore.GetAllProjects(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Cache for 5 minutes - project list changes infrequently + w.Header().Set("Cache-Control", "public, max-age=300") + writeJSON(w, projects) +} + +// handleGetTypes returns the canonical list of observation and concept types. +// This provides a single source of truth for both backend and frontend. +// Response is cacheable as these values never change at runtime. +func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) { + // Cache for 24 hours - these values are compile-time constants + w.Header().Set("Cache-Control", "public, max-age=86400") + writeJSON(w, map[string]any{ + "observation_types": ObservationTypes, + "concept_types": ConceptTypes, + }) +} + +// handleGetModels returns available embedding models. +// Response is cacheable as model list doesn't change without restart. +func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) { + // Cache for 1 hour - model list is static during runtime + w.Header().Set("Cache-Control", "public, max-age=3600") + + models := embedding.ListModels() + defaultModel := embedding.GetDefaultModel() + + writeJSON(w, map[string]any{ + "models": models, + "default": defaultModel, + "current": s.embedSvc.Version(), + }) +} + +// handleGetStats returns worker statistics. +func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + + // Validate project name to prevent path traversal + if err := ValidateProjectName(project); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + retrievalStats := s.GetRetrievalStats(project) + sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context()) + + response := map[string]any{ + "uptime": time.Since(s.startTime).String(), + "uptimeSeconds": time.Since(s.startTime).Seconds(), + "activeSessions": s.sessionManager.GetActiveSessionCount(), + "queueDepth": s.sessionManager.GetTotalQueueDepth(), + "isProcessing": s.sessionManager.IsAnySessionProcessing(), + "connectedClients": s.sseBroadcaster.ClientCount(), + "sessionsToday": sessionsToday, + "retrieval": retrievalStats, + "ready": s.ready.Load(), + } + + // Add memory stats + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + response["memory"] = map[string]any{ + "alloc_mb": float64(memStats.Alloc) / 1024 / 1024, + "total_alloc_mb": float64(memStats.TotalAlloc) / 1024 / 1024, + "sys_mb": float64(memStats.Sys) / 1024 / 1024, + "heap_alloc_mb": float64(memStats.HeapAlloc) / 1024 / 1024, + "heap_inuse_mb": float64(memStats.HeapInuse) / 1024 / 1024, + "heap_objects": memStats.HeapObjects, + "goroutines": runtime.NumGoroutine(), + "gc_cycles": memStats.NumGC, + "gc_pause_total_ms": float64(memStats.PauseTotalNs) / 1e6, + } + + // Add database health if available + if s.store != nil { + dbHealth := s.store.HealthCheck(r.Context()) + response["database"] = map[string]any{ + "status": dbHealth.Status, + "query_latency_ms": float64(dbHealth.QueryLatency) / 1e6, + "pool": dbHealth.PoolStats, + "warning": dbHealth.Warning, + } + } + + // Add embedding model info + if s.embedSvc != nil { + response["embeddingModel"] = map[string]any{ + "name": s.embedSvc.Name(), + "version": s.embedSvc.Version(), + "dimensions": s.embedSvc.Dimensions(), + } + } + + // Add vector cache stats + if s.vectorClient != nil { + if count, err := s.vectorClient.Count(r.Context()); err == nil { + response["vectorCount"] = count + } + cacheSize, cacheMax := s.vectorClient.CacheStats() + response["vectorCache"] = map[string]any{ + "size": cacheSize, + "max_size": cacheMax, + } + } + + // Include project-specific observation count if project is specified + if project != "" { + count, err := s.getCachedObservationCount(r.Context(), project) + if err == nil { + response["projectObservations"] = count + response["project"] = project + } + } + + // Add rate limiter stats + if s.rateLimiter != nil { + response["rateLimiter"] = s.rateLimiter.Stats() + } + + // Add circuit breaker metrics + if s.processor != nil { + response["circuitBreaker"] = s.processor.CircuitBreakerMetrics() + } + + writeJSON(w, response) +} + +// handleGetRetrievalStats returns detailed retrieval statistics. +func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + stats := s.GetRetrievalStats(project) + writeJSON(w, stats) +} + +// handleGetRecentQueries returns recent search queries for analytics. +func (s *Service) handleGetRecentQueries(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + limit := gorm.ParseLimitParam(r, 20) + + queries := s.getRecentSearchQueries(project, limit) + + writeJSON(w, map[string]any{ + "queries": queries, + "count": len(queries), + "project": project, + }) +} + +// handleGetSearchAnalytics returns comprehensive search analytics and statistics. +func (s *Service) handleGetSearchAnalytics(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + + // Get all recent queries for analysis + queries := s.getRecentSearchQueries(project, maxRecentQueries) + + // Calculate analytics + totalQueries := len(queries) + vectorSearches := 0 + totalResults := 0 + zeroResultQueries := 0 + queryTypes := make(map[string]int) + topKeywords := make(map[string]int) + + for _, q := range queries { + if q.UsedVector { + vectorSearches++ + } + totalResults += q.Results + if q.Results == 0 { + zeroResultQueries++ + } + queryTypes[q.Type]++ + + // Extract keywords (simple word tokenization using iterator) + for word := range strings.FieldsSeq(strings.ToLower(q.Query)) { + if len(word) > 3 { // Skip short words + topKeywords[word]++ + } + } + } + + // Sort keywords by frequency + type keywordCount struct { + Keyword string `json:"keyword"` + Count int `json:"count"` + } + sortedKeywords := make([]keywordCount, 0, len(topKeywords)) + for kw, count := range topKeywords { + sortedKeywords = append(sortedKeywords, keywordCount{Keyword: kw, Count: count}) + } + sort.Slice(sortedKeywords, func(i, j int) bool { + return sortedKeywords[i].Count > sortedKeywords[j].Count + }) + if len(sortedKeywords) > 10 { + sortedKeywords = sortedKeywords[:10] + } + + // Calculate averages + avgResults := float64(0) + vectorSearchRate := float64(0) + zeroResultRate := float64(0) + if totalQueries > 0 { + avgResults = float64(totalResults) / float64(totalQueries) + vectorSearchRate = float64(vectorSearches) / float64(totalQueries) * 100 + zeroResultRate = float64(zeroResultQueries) / float64(totalQueries) * 100 + } + + writeJSON(w, map[string]any{ + "total_queries": totalQueries, + "vector_search_rate": vectorSearchRate, + "avg_results": avgResults, + "zero_result_rate": zeroResultRate, + "query_types": queryTypes, + "top_keywords": sortedKeywords, + "project": project, + }) +} + +// handleVectorHealth returns comprehensive health information about the vector database. +func (s *Service) handleVectorHealth(w http.ResponseWriter, r *http.Request) { + if s.vectorClient == nil { + http.Error(w, "vector client not initialized", http.StatusServiceUnavailable) + return + } + + stats, err := s.vectorClient.GetHealthStats(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Add additional computed metrics + healthScore := 100.0 + var warnings []string + + // Penalize for stale vectors + if stats.TotalVectors > 0 { + staleRatio := float64(stats.StaleVectors) / float64(stats.TotalVectors) + if staleRatio > 0 { + healthScore -= staleRatio * 50 // Up to 50 points off for stale vectors + warnings = append(warnings, formatWarning("%.1f%% vectors need rebuild", staleRatio*100)) + } + } + + // Check cache effectiveness + cacheHitRate := stats.EmbeddingCache.HitRate() + if cacheHitRate < 20 && (stats.EmbeddingCache.EmbeddingHits+stats.EmbeddingCache.EmbeddingMisses) > 100 { + healthScore -= 10 + warnings = append(warnings, formatWarning("Low cache hit rate: %.1f%%", cacheHitRate)) + } + + // Penalize if rebuild is needed + if stats.NeedsRebuild { + healthScore -= 20 + warnings = append(warnings, "Vector rebuild recommended: "+stats.RebuildReason) + } + + if healthScore < 0 { + healthScore = 0 + } + + status := "healthy" + if healthScore < 50 { + status = "unhealthy" + } else if healthScore < 80 { + status = "degraded" + } + + writeJSON(w, map[string]any{ + "status": status, + "health_score": healthScore, + "warnings": warnings, + "stats": stats, + "cache_hit_rate": cacheHitRate, + }) +} + +// UpdateObservationRequest is the request body for updating an observation. +type UpdateObservationRequest struct { + Title *string `json:"title,omitempty"` + Subtitle *string `json:"subtitle,omitempty"` + Narrative *string `json:"narrative,omitempty"` + Facts []string `json:"facts,omitempty"` + Concepts []string `json:"concepts,omitempty"` + FilesRead []string `json:"files_read,omitempty"` + FilesModified []string `json:"files_modified,omitempty"` + Scope *string `json:"scope,omitempty"` +} + +// handleUpdateObservation updates an existing observation. +// PUT /api/observations/{id} +func (s *Service) handleUpdateObservation(w http.ResponseWriter, r *http.Request) { + // Parse observation ID from URL + id, ok := parseIDParam(w, r.PathValue("id"), "observation") + if !ok { + return + } + + // Parse request body + var req UpdateObservationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) + return + } + + // Build update struct - only include fields that were provided + update := &gorm.ObservationUpdate{} + + if req.Title != nil { + update.Title = req.Title + } + if req.Subtitle != nil { + update.Subtitle = req.Subtitle + } + if req.Narrative != nil { + update.Narrative = req.Narrative + } + if req.Facts != nil { + update.Facts = &req.Facts + } + if req.Concepts != nil { + update.Concepts = &req.Concepts + } + if req.FilesRead != nil { + update.FilesRead = &req.FilesRead + } + if req.FilesModified != nil { + update.FilesModified = &req.FilesModified + } + if req.Scope != nil { + // Validate scope + if *req.Scope != "project" && *req.Scope != "global" { + http.Error(w, "scope must be 'project' or 'global'", http.StatusBadRequest) + return + } + update.Scope = req.Scope + } + + // Update the observation + updatedObs, err := s.observationStore.UpdateObservation(r.Context(), id, update) + if err != nil { + if strings.Contains(err.Error(), "not found") { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + http.Error(w, "failed to update observation: "+err.Error(), http.StatusInternalServerError) + return + } + + // Trigger vector resync for the updated observation + if s.vectorSync != nil { + s.asyncVectorSync(func() { + if err := s.vectorSync.SyncObservation(s.ctx, updatedObs); err != nil { + log.Warn().Err(err).Int64("id", id).Msg("Failed to resync observation vectors after update") + } + }) + } + + // Broadcast update event + s.sseBroadcaster.Broadcast(map[string]any{ + "type": "observation_updated", + "id": id, + }) + + writeJSON(w, map[string]any{ + "observation": updatedObs, + "message": "observation updated successfully", + }) +} + +// handleGetObservationByID returns a single observation by ID. +// GET /api/observations/{id} +func (s *Service) handleGetObservationByID(w http.ResponseWriter, r *http.Request) { + id, ok := parseIDParam(w, r.PathValue("id"), "observation") + if !ok { + return + } + + obs, err := s.observationStore.GetObservationByID(r.Context(), id) + if err != nil { + http.Error(w, "failed to get observation: "+err.Error(), http.StatusInternalServerError) + return + } + + if obs == nil { + http.Error(w, "observation not found", http.StatusNotFound) + return + } + + writeJSON(w, obs) +} diff --git a/internal/worker/handlers_import_export.go b/internal/worker/handlers_import_export.go new file mode 100644 index 0000000..7430b46 --- /dev/null +++ b/internal/worker/handlers_import_export.go @@ -0,0 +1,680 @@ +// Package worker provides import, export, and archive HTTP handlers. +package worker + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + "github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/lukaszraczylo/claude-mnemonic/pkg/similarity" + "github.com/rs/zerolog/log" +) + +// BulkImportRequest is the request body for bulk observation import. +type BulkImportRequest struct { + Project string `json:"project"` + Observations []BulkObservationInput `json:"observations"` +} + +// BulkObservationInput represents a single observation in bulk import. +type BulkObservationInput struct { + Type string `json:"type"` // bugfix, feature, refactor, etc. + Title string `json:"title"` + Subtitle string `json:"subtitle,omitempty"` + Facts []string `json:"facts,omitempty"` + Narrative string `json:"narrative,omitempty"` + Concepts []string `json:"concepts,omitempty"` + FilesRead []string `json:"files_read,omitempty"` + FilesModified []string `json:"files_modified,omitempty"` + Scope string `json:"scope,omitempty"` // project or global +} + +// BulkImportResponse contains the result of a bulk import operation. +type BulkImportResponse struct { + Imported int `json:"imported"` + Failed int `json:"failed"` + SkippedDuplicates int `json:"skipped_duplicates,omitempty"` + Errors []string `json:"errors,omitempty"` +} + +// handleBulkImport handles bulk import of observations. +// This is useful for migrating data or importing observations from external sources. +func (s *Service) handleBulkImport(w http.ResponseWriter, r *http.Request) { + // Rate limit bulk operations to prevent DoS + if s.bulkOpLimiter != nil && !s.bulkOpLimiter.CanExecute() { + remaining := s.bulkOpLimiter.CooldownRemaining() + http.Error(w, fmt.Sprintf("bulk import rate limited, retry in %d seconds", remaining), http.StatusTooManyRequests) + return + } + + var req BulkImportRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) + return + } + + if req.Project == "" { + http.Error(w, "project is required", http.StatusBadRequest) + return + } + + // Validate project name to prevent path traversal + if err := ValidateProjectName(req.Project); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if len(req.Observations) == 0 { + http.Error(w, "at least one observation is required", http.StatusBadRequest) + return + } + + // Limit batch size to prevent overwhelming the system + maxBatchSize := 100 + if len(req.Observations) > maxBatchSize { + http.Error(w, fmt.Sprintf("batch size exceeds maximum of %d", maxBatchSize), http.StatusBadRequest) + return + } + + // Create a synthetic session for bulk import + sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), fmt.Sprintf("bulk-import-%d", time.Now().UnixMilli()), req.Project, "bulk import") + if err != nil { + http.Error(w, "failed to create import session: "+err.Error(), http.StatusInternalServerError) + return + } + + var imported, failed, skippedDupes int + var errors []string + + // Track imported observations for deduplication within the batch + importedObs := make([]*models.Observation, 0, len(req.Observations)) + + // Deduplication threshold - observations more similar than this are considered duplicates + const dedupThreshold = 0.7 + + for i, obsInput := range req.Observations { + // Validate observation type using O(1) map lookup + if !IsValidObservationType(obsInput.Type) { + failed++ + errors = append(errors, fmt.Sprintf("observation %d: invalid type '%s'", i, obsInput.Type)) + continue + } + + // Build parsed observation + parsedObs := &models.ParsedObservation{ + Type: models.ObservationType(obsInput.Type), + Title: obsInput.Title, + Subtitle: obsInput.Subtitle, + Facts: obsInput.Facts, + Narrative: obsInput.Narrative, + Concepts: obsInput.Concepts, + FilesRead: obsInput.FilesRead, + FilesModified: obsInput.FilesModified, + Scope: models.ObservationScope(obsInput.Scope), + } + + // Convert to temporary observation for similarity check + tempObs := &models.Observation{ + Title: sql.NullString{String: parsedObs.Title, Valid: parsedObs.Title != ""}, + Subtitle: sql.NullString{String: parsedObs.Subtitle, Valid: parsedObs.Subtitle != ""}, + Narrative: sql.NullString{String: parsedObs.Narrative, Valid: parsedObs.Narrative != ""}, + } + + // Check for duplicates within this import batch + if similarity.IsSimilarToAny(tempObs, importedObs, dedupThreshold) { + skippedDupes++ + continue + } + + // Store observation + obsID, _, err := s.observationStore.StoreObservation( + r.Context(), + fmt.Sprintf("bulk-import-%d", sessionID), + req.Project, + parsedObs, + 0, // prompt number + 0, // discovery tokens + ) + if err != nil { + failed++ + errors = append(errors, fmt.Sprintf("observation %d: %v", i, err)) + continue + } + + // Sync to vector DB asynchronously with rate limiting + if s.vectorSync != nil { + s.asyncVectorSync(func() { + // Use service context as parent to respect shutdown signals + ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second) + defer cancel() + obs, err := s.observationStore.GetObservationByID(ctx, obsID) + if err == nil && obs != nil { + if syncErr := s.vectorSync.SyncObservation(ctx, obs); syncErr != nil { + if s.ctx.Err() == nil { // Don't log during shutdown + log.Debug().Err(syncErr).Int64("id", obsID).Msg("Failed to sync observation during bulk import") + } + } + } + }) + } + + // Track for deduplication of subsequent observations in this batch + importedObs = append(importedObs, tempObs) + imported++ + } + + log.Info(). + Str("project", req.Project). + Int("imported", imported). + Int("failed", failed). + Int("skipped_duplicates", skippedDupes). + Msg("Bulk import completed") + + // Invalidate observation count cache after import + if imported > 0 { + if req.Project != "" { + s.invalidateObsCountCache(req.Project) + } else { + s.invalidateAllObsCountCache() + } + } + + // Broadcast observation event for dashboard refresh + s.sseBroadcaster.Broadcast(map[string]any{ + "type": "observation", + "action": "bulk_import", + "project": req.Project, + "count": imported, + }) + + writeJSON(w, BulkImportResponse{ + Imported: imported, + Failed: failed, + SkippedDuplicates: skippedDupes, + Errors: errors, + }) +} + +// ArchiveRequest is the request body for archiving observations. +type ArchiveRequest struct { + IDs []int64 `json:"ids,omitempty"` // Specific IDs to archive + Project string `json:"project,omitempty"` // Archive all in project older than max_age_days + MaxAgeDays int `json:"max_age_days,omitempty"` // Only used with project + Reason string `json:"reason,omitempty"` // Optional reason for archival +} + +// handleArchiveObservations archives observations by ID or by age. +// Supports batch archival with error tracking per observation. +func (s *Service) handleArchiveObservations(w http.ResponseWriter, r *http.Request) { + var req ArchiveRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) + return + } + + var archivedIDs []int64 + var failedIDs []int64 + var errors []string + var err error + + if len(req.IDs) > 0 { + // Archive specific observations with parallel processing for large batches + if len(req.IDs) > 5 { + // Use parallel archival for batches larger than 5 + type archiveResult struct { + id int64 + err error + } + results := make(chan archiveResult, len(req.IDs)) + + // Limit concurrency to avoid overwhelming the database + sem := make(chan struct{}, 5) + var wg sync.WaitGroup + + for _, id := range req.IDs { + wg.Add(1) + go func(obsID int64) { + defer wg.Done() + sem <- struct{}{} // Acquire + defer func() { <-sem }() // Release + + archErr := s.observationStore.ArchiveObservation(r.Context(), obsID, req.Reason) + results <- archiveResult{id: obsID, err: archErr} + }(id) + } + + // Close results channel when all goroutines complete + go func() { + wg.Wait() + close(results) + }() + + // Collect results + for res := range results { + if res.err != nil { + log.Warn().Err(res.err).Int64("id", res.id).Msg("Failed to archive observation") + failedIDs = append(failedIDs, res.id) + errors = append(errors, fmt.Sprintf("id %d: %v", res.id, res.err)) + } else { + archivedIDs = append(archivedIDs, res.id) + } + } + } else { + // Sequential for small batches + for _, id := range req.IDs { + if archErr := s.observationStore.ArchiveObservation(r.Context(), id, req.Reason); archErr != nil { + log.Warn().Err(archErr).Int64("id", id).Msg("Failed to archive observation") + failedIDs = append(failedIDs, id) + errors = append(errors, fmt.Sprintf("id %d: %v", id, archErr)) + } else { + archivedIDs = append(archivedIDs, id) + } + } + } + } else if req.Project != "" || req.MaxAgeDays > 0 { + // Archive by age + archivedIDs, err = s.observationStore.ArchiveOldObservations(r.Context(), req.Project, req.MaxAgeDays, req.Reason) + if err != nil { + http.Error(w, "failed to archive: "+err.Error(), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "either 'ids' or 'project'/'max_age_days' is required", http.StatusBadRequest) + return + } + + log.Info(). + Str("project", req.Project). + Int("archived", len(archivedIDs)). + Int("failed", len(failedIDs)). + Msg("Observations archived") + + // Invalidate cache if any observations were archived + if len(archivedIDs) > 0 { + if req.Project != "" { + s.invalidateObsCountCache(req.Project) + } else { + s.invalidateAllObsCountCache() + } + } + + response := map[string]any{ + "archived_count": len(archivedIDs), + "archived_ids": archivedIDs, + } + if len(failedIDs) > 0 { + response["failed_count"] = len(failedIDs) + response["failed_ids"] = failedIDs + response["errors"] = errors + } + + writeJSON(w, response) +} + +// handleUnarchiveObservation restores an archived observation. +func (s *Service) handleUnarchiveObservation(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + http.Error(w, "invalid observation id", http.StatusBadRequest) + return + } + + if err := s.observationStore.UnarchiveObservation(r.Context(), id); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Invalidate all caches since we don't know the project + s.invalidateAllObsCountCache() + + writeJSON(w, map[string]any{ + "success": true, + "id": id, + }) +} + +// handleGetArchivedObservations returns archived observations. +func (s *Service) handleGetArchivedObservations(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + limit := gorm.ParseLimitParam(r, DefaultObservationsLimit) + + observations, err := s.observationStore.GetArchivedObservations(r.Context(), project, limit) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if observations == nil { + observations = []*models.Observation{} + } + + writeJSON(w, observations) +} + +// handleGetArchivalStats returns archival statistics. +func (s *Service) handleGetArchivalStats(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + + stats, err := s.observationStore.GetArchivalStats(r.Context(), project) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + writeJSON(w, stats) +} + +// handleExportObservations exports observations in JSON or CSV format. +// Supports query parameters: project, format (json/csv), scope, type, limit. +func (s *Service) handleExportObservations(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + format := r.URL.Query().Get("format") + if format == "" { + format = "json" + } + scope := r.URL.Query().Get("scope") // project, global, or empty for all + obsType := r.URL.Query().Get("type") // bugfix, feature, etc. + limit := gorm.ParseLimitParamWithMax(r, 1000, 5000) // Higher limit for exports, capped at 5000 + + // Validate format + if format != "json" && format != "csv" { + http.Error(w, "format must be 'json' or 'csv'", http.StatusBadRequest) + return + } + + // Get observations with filters + ctx := r.Context() + var observations []*models.Observation + var err error + + if project != "" { + observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0) + } else { + observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0) + } + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Apply additional filters + if scope != "" || obsType != "" { + filtered := make([]*models.Observation, 0, len(observations)) + for _, obs := range observations { + if scope != "" && string(obs.Scope) != scope { + continue + } + if obsType != "" && string(obs.Type) != obsType { + continue + } + filtered = append(filtered, obs) + } + observations = filtered + } + + // Generate filename + timestamp := time.Now().Format("20060102-150405") + filename := fmt.Sprintf("observations-%s.%s", timestamp, format) + if project != "" { + // Sanitize project name for filename + sanitized := strings.ReplaceAll(project, "/", "_") + sanitized = strings.ReplaceAll(sanitized, "\\", "_") + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + filename = fmt.Sprintf("observations-%s-%s.%s", sanitized, timestamp, format) + } + + switch format { + case "csv": + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename)) + s.writeObservationsCSV(w, observations) + default: // json + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename)) + writeJSON(w, map[string]any{ + "exported_at": time.Now().Format(time.RFC3339), + "project": project, + "count": len(observations), + "observations": observations, + }) + } +} + +// writeObservationsCSV writes observations in CSV format. +// Uses fmt.Fprintf directly to avoid intermediate string allocations. +func (s *Service) writeObservationsCSV(w http.ResponseWriter, observations []*models.Observation) { + // Write CSV header + _, _ = io.WriteString(w, "id,type,scope,project,title,subtitle,narrative,concepts,facts,created_at,importance_score\n") + + for _, obs := range observations { + // Write directly to avoid string allocation per row + _, _ = fmt.Fprintf(w, "%d,%s,%s,%s,%s,%s,%s,%s,%s,%s,%.2f\n", + obs.ID, + obs.Type, + obs.Scope, + escapeCsvField(obs.Project), + escapeCsvField(obs.Title.String), + escapeCsvField(obs.Subtitle.String), + escapeCsvField(obs.Narrative.String), + escapeCsvField(strings.Join(obs.Concepts, ";")), + escapeCsvField(strings.Join(obs.Facts, ";")), + obs.CreatedAt, + obs.ImportanceScore, + ) + } +} + +// escapeCsvField escapes a field for CSV output. +func escapeCsvField(s string) string { + // If field contains comma, quote, or newline, wrap in quotes and escape quotes + if strings.ContainsAny(s, ",\"\n\r") { + s = strings.ReplaceAll(s, "\"", "\"\"") + return "\"" + s + "\"" + } + return s +} + +// BulkStatusRequest represents a request to update status for multiple observations. +type BulkStatusRequest struct { + IDs []int64 `json:"ids"` + Action string `json:"action"` // "supersede", "archive", "set_feedback" + Reason string `json:"reason,omitempty"` + Feedback int `json:"feedback,omitempty"` // -1, 0, 1 for set_feedback action +} + +// handleBulkStatusUpdate updates status for multiple observations in one request. +func (s *Service) handleBulkStatusUpdate(w http.ResponseWriter, r *http.Request) { + var req BulkStatusRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest) + return + } + + if len(req.IDs) == 0 { + http.Error(w, "ids is required", http.StatusBadRequest) + return + } + + if len(req.IDs) > 500 { + http.Error(w, "maximum 500 ids per request", http.StatusBadRequest) + return + } + + ctx := r.Context() + var updated, failed int + var errors []string + + switch req.Action { + case "supersede": + for _, id := range req.IDs { + if err := s.observationStore.MarkAsSuperseded(ctx, id); err != nil { + failed++ + errors = append(errors, fmt.Sprintf("id %d: %v", id, err)) + } else { + updated++ + } + } + + case "archive": + for _, id := range req.IDs { + if err := s.observationStore.ArchiveObservation(ctx, id, req.Reason); err != nil { + failed++ + errors = append(errors, fmt.Sprintf("id %d: %v", id, err)) + } else { + updated++ + } + } + + case "set_feedback": + if req.Feedback < -1 || req.Feedback > 1 { + http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest) + return + } + for _, id := range req.IDs { + if err := s.observationStore.UpdateObservationFeedback(ctx, id, req.Feedback); err != nil { + failed++ + errors = append(errors, fmt.Sprintf("id %d: %v", id, err)) + } else { + updated++ + } + } + + default: + http.Error(w, "action must be 'supersede', 'archive', or 'set_feedback'", http.StatusBadRequest) + return + } + + // Invalidate cache for archive action (affects observation counts) + if req.Action == "archive" && updated > 0 { + // No project info available, invalidate all caches + s.invalidateAllObsCountCache() + } + + response := map[string]any{ + "action": req.Action, + "updated": updated, + "failed": failed, + } + if len(errors) > 0 { + response["errors"] = errors + } + + writeJSON(w, response) +} + +// handleFindDuplicates finds potential duplicate observations using similarity clustering. +// Returns groups of similar observations that may be candidates for merging or archival. +func (s *Service) handleFindDuplicates(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + thresholdStr := r.URL.Query().Get("threshold") + limit := gorm.ParseLimitParam(r, 100) + + // Parse threshold (default 0.6 = 60% similarity) + threshold := 0.6 + if thresholdStr != "" { + if t, err := strconv.ParseFloat(thresholdStr, 64); err == nil && t > 0 && t < 1 { + threshold = t + } + } + + // Get recent observations + ctx := r.Context() + var observations []*models.Observation + var err error + + if project != "" { + observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0) + } else { + observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0) + } + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if len(observations) < 2 { + writeJSON(w, map[string]any{ + "duplicate_groups": []any{}, + "total_checked": len(observations), + "threshold": threshold, + }) + return + } + + // Find duplicates using similarity comparison + type duplicateGroup struct { + Observations []map[string]any `json:"observations"` + Similarity float64 `json:"similarity"` + } + + groups := []duplicateGroup{} + processed := make(map[int64]bool) + + for i, obs1 := range observations { + if processed[obs1.ID] { + continue + } + + terms1 := similarity.ExtractObservationTerms(obs1) + if len(terms1) == 0 { + continue + } + + group := duplicateGroup{ + Observations: []map[string]any{obs1.ToMap()}, + Similarity: 1.0, + } + + for j := i + 1; j < len(observations); j++ { + obs2 := observations[j] + if processed[obs2.ID] { + continue + } + + terms2 := similarity.ExtractObservationTerms(obs2) + sim := similarity.JaccardSimilarity(terms1, terms2) + + if sim >= threshold { + obsMap := obs2.ToMap() + obsMap["similarity_to_first"] = sim + group.Observations = append(group.Observations, obsMap) + group.Similarity = min(group.Similarity, sim) + processed[obs2.ID] = true + } + } + + if len(group.Observations) > 1 { + processed[obs1.ID] = true + groups = append(groups, group) + } + } + + // Sort groups by size (largest first) + sort.Slice(groups, func(i, j int) bool { + return len(groups[i].Observations) > len(groups[j].Observations) + }) + + writeJSON(w, map[string]any{ + "duplicate_groups": groups, + "total_checked": len(observations), + "groups_found": len(groups), + "threshold": threshold, + "project": project, + }) +} diff --git a/internal/worker/handlers_scoring.go b/internal/worker/handlers_scoring.go index 5929e34..c7a1b82 100644 --- a/internal/worker/handlers_scoring.go +++ b/internal/worker/handlers_scoring.go @@ -10,6 +10,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" ) // FeedbackRequest represents a user feedback submission. @@ -311,8 +312,7 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ // Run recalculation in background go func() { if err := recalculator.RecalculateNow(r.Context()); err != nil { - // Log error but don't block response - _ = err // Explicitly ignore - background operation + log.Warn().Err(err).Msg("Background score recalculation failed") } }() @@ -345,14 +345,18 @@ func (s *Service) incrementRetrievalCounts(ids []int64) { } // Increment in background to not block response + // Use service context to respect shutdown signals + s.wg.Add(1) go func() { - // Create a new context with timeout for the background operation - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer s.wg.Done() + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) defer cancel() if err := store.IncrementRetrievalCount(ctx, ids); err != nil { // Log but don't fail - this is a background operation - _ = err // Explicitly ignore - background operation + if s.ctx.Err() == nil { // Don't log during shutdown + log.Debug().Err(err).Msg("Failed to increment retrieval counts") + } } }() } diff --git a/internal/worker/handlers_sessions.go b/internal/worker/handlers_sessions.go new file mode 100644 index 0000000..837f7a3 --- /dev/null +++ b/internal/worker/handlers_sessions.go @@ -0,0 +1,354 @@ +// Package worker provides session-related HTTP handlers. +package worker + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "time" + + "github.com/go-chi/chi/v5" + "github.com/lukaszraczylo/claude-mnemonic/internal/privacy" + "github.com/lukaszraczylo/claude-mnemonic/internal/worker/session" + "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" +) + +// SessionInitRequest is the request body for session initialization. +type SessionInitRequest struct { + ClaudeSessionID string `json:"claudeSessionId"` + Project string `json:"project"` + Prompt string `json:"prompt"` + MatchedObservations int `json:"matchedObservations"` +} + +// SessionInitResponse is the response for session initialization. +type SessionInitResponse struct { + SessionDBID int64 `json:"sessionDbId"` + PromptNumber int `json:"promptNumber"` + Skipped bool `json:"skipped,omitempty"` + Reason string `json:"reason,omitempty"` +} + +// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions. +// If the same prompt text is seen within this window, it's considered a duplicate hook invocation. +const DuplicatePromptWindowSeconds = 10 + +// handleSessionInit handles session initialization from user-prompt hook. +// This handler is idempotent - duplicate requests within a short time window +// return the existing prompt data without creating duplicates. +func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) { + var req SessionInitRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Privacy check + if privacy.IsEntirelyPrivate(req.Prompt) { + // Create session but skip processing + sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "") + promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID) + + writeJSON(w, SessionInitResponse{ + SessionDBID: sessionID, + PromptNumber: promptNum, + Skipped: true, + Reason: "private", + }) + return + } + + // Clean prompt + cleanedPrompt := privacy.Clean(req.Prompt) + + // DUPLICATE DETECTION: Check if this exact prompt was already saved recently. + // This prevents the bug where the hook fires multiple times for the same user action, + // creating many duplicate prompts with incrementing numbers. + if existingID, existingNum, found := s.promptStore.FindRecentPromptByText(r.Context(), req.ClaudeSessionID, cleanedPrompt, DuplicatePromptWindowSeconds); found { + // Get or create session (idempotent) + sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt) + + log.Debug(). + Int64("sessionId", sessionID). + Int("promptNumber", existingNum). + Int64("promptId", existingID). + Msg("Duplicate prompt detected - returning existing") + + // Return existing prompt data without incrementing or saving again + writeJSON(w, SessionInitResponse{ + SessionDBID: sessionID, + PromptNumber: existingNum, + }) + return + } + + // Create session (idempotent) + sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Increment prompt counter + promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Save user prompt with matched observation count + promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations) + if err != nil { + log.Warn().Err(err).Msg("Failed to save user prompt") + // Non-fatal: continue with session initialization + } else if s.vectorSync != nil { + // Sync to vector DB asynchronously (non-blocking) + now := time.Now() + promptWithSession := &models.UserPromptWithSession{ + UserPrompt: models.UserPrompt{ + ID: promptID, + ClaudeSessionID: req.ClaudeSessionID, + PromptNumber: promptNum, + PromptText: cleanedPrompt, + MatchedObservations: req.MatchedObservations, + CreatedAt: now.Format(time.RFC3339), + CreatedAtEpoch: now.UnixMilli(), + }, + Project: req.Project, + SDKSessionID: req.ClaudeSessionID, + } + s.asyncVectorSync(func() { + // Use service context as parent to respect shutdown signals + ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second) + defer cancel() + if err := s.vectorSync.SyncUserPrompt(ctx, promptWithSession); err != nil { + if s.ctx.Err() == nil { // Don't log during shutdown + log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to sqlite-vec") + } + } + }) + } + + log.Info(). + Int64("sessionId", sessionID). + Int("promptNumber", promptNum). + Str("project", req.Project). + Msg("Session initialized") + + // Broadcast prompt event for dashboard refresh + s.sseBroadcaster.Broadcast(map[string]any{ + "type": "prompt", + "action": "created", + "project": req.Project, + }) + + writeJSON(w, SessionInitResponse{ + SessionDBID: sessionID, + PromptNumber: promptNum, + }) +} + +// SessionStartRequest is the request body for starting SDK agent. +type SessionStartRequest struct { + UserPrompt string `json:"userPrompt"` + PromptNumber int `json:"promptNumber"` +} + +// handleSessionStart handles SDK agent session start. +func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + http.Error(w, "invalid session id", http.StatusBadRequest) + return + } + + var req SessionStartRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Initialize session in manager + sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if sess == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + + // Session is now registered. Observations will be processed + // asynchronously by the background queue processor (processQueue in service.go). + log.Info(). + Int64("sessionId", id). + Int("promptNumber", req.PromptNumber). + Msg("SDK agent session initialized") + + s.broadcastProcessingStatus() + w.WriteHeader(http.StatusOK) +} + +// ObservationRequest is the request body for posting observations. +type ObservationRequest struct { + ClaudeSessionID string `json:"claudeSessionId"` + Project string `json:"project"` + ToolName string `json:"tool_name"` + ToolInput any `json:"tool_input"` + ToolResponse any `json:"tool_response"` + CWD string `json:"cwd"` +} + +// handleObservation handles observation posting from post-tool-use hook. +func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) { + var req ObservationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Find session + sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if sess == nil { + // Create session on-the-fly with project from request + id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + sess, _ = s.sessionStore.GetSessionByID(r.Context(), id) + } + + // Queue observation + if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{ + ToolName: req.ToolName, + ToolInput: req.ToolInput, + ToolResponse: req.ToolResponse, + CWD: req.CWD, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.broadcastProcessingStatus() + w.WriteHeader(http.StatusOK) +} + +// SubagentCompleteRequest is the request body for subagent completion. +type SubagentCompleteRequest struct { + ClaudeSessionID string `json:"claudeSessionId"` + Project string `json:"project"` +} + +// handleSubagentComplete handles subagent/Task completion notifications. +// This triggers immediate processing of any queued observations from the subagent. +func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) { + var req SubagentCompleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Find session + sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID) + if err != nil || sess == nil { + // Session not found - subagent may have been in a different context + log.Debug(). + Str("claudeSessionId", req.ClaudeSessionID). + Msg("Subagent complete - no active session found") + w.WriteHeader(http.StatusOK) + return + } + + // Trigger immediate processing of queued observations + messages := s.sessionManager.DrainMessages(sess.ID) + if len(messages) > 0 && s.processor != nil { + log.Info(). + Int64("sessionId", sess.ID). + Int("messages", len(messages)). + Msg("Processing queued observations from subagent") + + for _, msg := range messages { + if msg.Type == session.MessageTypeObservation && msg.Observation != nil { + err := s.processor.ProcessObservation( + r.Context(), + sess.SDKSessionID.String, + sess.Project, + msg.Observation.ToolName, + msg.Observation.ToolInput, + msg.Observation.ToolResponse, + msg.Observation.PromptNumber, + msg.Observation.CWD, + ) + if err != nil { + log.Error().Err(err). + Str("tool", msg.Observation.ToolName). + Msg("Failed to process subagent observation") + } + } + } + } + + s.broadcastProcessingStatus() + w.WriteHeader(http.StatusOK) +} + +// handleGetSessionByClaudeID looks up a session by Claude session ID. +func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) { + claudeSessionID := r.URL.Query().Get("claudeSessionId") + if claudeSessionID == "" { + http.Error(w, "claudeSessionId required", http.StatusBadRequest) + return + } + + session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if session == nil { + http.Error(w, "session not found", http.StatusNotFound) + return + } + + writeJSON(w, session) +} + +// SummarizeRequest is the request body for summarize requests. +type SummarizeRequest struct { + LastUserMessage string `json:"lastUserMessage"` + LastAssistantMessage string `json:"lastAssistantMessage"` +} + +// handleSummarize handles summarize requests from stop hook. +func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + http.Error(w, "invalid session id", http.StatusBadRequest) + return + } + + var req SummarizeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Queue summarize request + if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.broadcastProcessingStatus() + w.WriteHeader(http.StatusOK) +} diff --git a/internal/worker/handlers_update.go b/internal/worker/handlers_update.go new file mode 100644 index 0000000..3f368e5 --- /dev/null +++ b/internal/worker/handlers_update.go @@ -0,0 +1,243 @@ +// Package worker provides update and restart HTTP handlers. +package worker + +import ( + "fmt" + "net/http" + "time" + + "github.com/rs/zerolog/log" +) + +// handleUpdateCheck checks for available updates. +func (s *Service) handleUpdateCheck(w http.ResponseWriter, r *http.Request) { + info, err := s.updater.CheckForUpdate(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, info) +} + +// handleUpdateApply downloads and applies an available update. +func (s *Service) handleUpdateApply(w http.ResponseWriter, r *http.Request) { + // First check for update + info, err := s.updater.CheckForUpdate(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if !info.Available { + writeJSON(w, map[string]any{ + "success": false, + "message": "No update available", + }) + return + } + + // Apply update in background with tracking for graceful shutdown + s.wg.Go(func() { + if err := s.updater.ApplyUpdate(s.ctx, info); err != nil { + log.Error().Err(err).Msg("Update failed") + } + }) + + writeJSON(w, map[string]any{ + "success": true, + "message": "Update started", + "version": info.LatestVersion, + }) +} + +// handleUpdateStatus returns the current update status. +func (s *Service) handleUpdateStatus(w http.ResponseWriter, r *http.Request) { + status := s.updater.GetStatus() + writeJSON(w, status) +} + +// ComponentHealth represents the health status of a single component. +type ComponentHealth struct { + Name string `json:"name"` + Status string `json:"status"` // "healthy", "degraded", "unhealthy" + Message string `json:"message,omitempty"` +} + +// SelfCheckResponse contains the health status of all components. +type SelfCheckResponse struct { + Overall string `json:"overall"` // "healthy", "degraded", "unhealthy" + Version string `json:"version"` + Uptime string `json:"uptime"` + Components []ComponentHealth `json:"components"` +} + +// handleSelfCheck returns the health status of all components. +func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) { + components := []ComponentHealth{} + overall := "healthy" + + // Check Worker Service + workerStatus := ComponentHealth{Name: "Worker Service", Status: "healthy"} + if !s.ready.Load() { + if err := s.GetInitError(); err != nil { + workerStatus.Status = "unhealthy" + workerStatus.Message = err.Error() + overall = "unhealthy" + } else { + workerStatus.Status = "degraded" + workerStatus.Message = "Initializing" + if overall == "healthy" { + overall = "degraded" + } + } + } + components = append(components, workerStatus) + + // Check SQLite Database + dbStatus := ComponentHealth{Name: "SQLite Database", Status: "healthy"} + if s.store == nil { + dbStatus.Status = "unhealthy" + dbStatus.Message = "Not initialized" + overall = "unhealthy" + } else if err := s.store.Ping(); err != nil { + dbStatus.Status = "unhealthy" + dbStatus.Message = err.Error() + overall = "unhealthy" + } + components = append(components, dbStatus) + + // Check Vector DB (sqlite-vec) + vectorStatus := ComponentHealth{Name: "Vector DB", Status: "healthy"} + if s.vectorClient == nil { + vectorStatus.Status = "degraded" + vectorStatus.Message = "Not configured" + if overall == "healthy" { + overall = "degraded" + } + } else if !s.vectorClient.IsConnected() { + vectorStatus.Status = "degraded" + vectorStatus.Message = "Not connected" + if overall == "healthy" { + overall = "degraded" + } + } + components = append(components, vectorStatus) + + // Check SDK Processor + sdkStatus := ComponentHealth{Name: "SDK Processor", Status: "healthy"} + if s.processor == nil { + sdkStatus.Status = "degraded" + sdkStatus.Message = "Not initialized" + if overall == "healthy" { + overall = "degraded" + } + } else if !s.processor.IsAvailable() { + sdkStatus.Status = "degraded" + sdkStatus.Message = "Claude CLI not available" + if overall == "healthy" { + overall = "degraded" + } + } + components = append(components, sdkStatus) + + // Check SSE Broadcaster + sseStatus := ComponentHealth{Name: "SSE Broadcaster", Status: "healthy"} + if s.sseBroadcaster == nil { + sseStatus.Status = "unhealthy" + sseStatus.Message = "Not initialized" + overall = "unhealthy" + } + components = append(components, sseStatus) + + // Check Cross-Encoder Reranker + rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"} + if !s.config.RerankingEnabled { + rerankerStatus.Status = "degraded" + rerankerStatus.Message = "Disabled in config" + if overall == "healthy" { + overall = "degraded" + } + } else if s.reranker == nil { + rerankerStatus.Status = "degraded" + rerankerStatus.Message = "Not initialized" + if overall == "healthy" { + overall = "degraded" + } + } else { + // Verify reranker is functional using Score + _, normalizedScore, err := s.reranker.Score("test query", "test document") + if err != nil { + rerankerStatus.Status = "unhealthy" + rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err) + if overall == "healthy" { + overall = "degraded" + } + } else { + rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore) + } + } + components = append(components, rerankerStatus) + + // Calculate uptime + uptime := time.Since(s.startTime).Round(time.Second).String() + + writeJSON(w, SelfCheckResponse{ + Overall: overall, + Version: s.version, + Uptime: uptime, + Components: components, + }) +} + +// handleUpdateRestart restarts the worker with the new binary (after update). +func (s *Service) handleUpdateRestart(w http.ResponseWriter, r *http.Request) { + status := s.updater.GetStatus() + if status.State != "done" { + http.Error(w, "no update has been applied", http.StatusBadRequest) + return + } + + // Send response before restarting + writeJSON(w, map[string]any{ + "success": true, + "message": "Restarting worker...", + }) + + // Flush the response + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + // Restart in background after response is sent + go func() { + if err := s.updater.Restart(); err != nil { + log.Error().Err(err).Msg("Failed to restart worker") + } + }() +} + +// handleRestart restarts the worker process (general restart, not tied to update). +func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) { + log.Info().Msg("Manual restart requested via API") + + // Send response before restarting + writeJSON(w, map[string]any{ + "success": true, + "message": "Restarting worker...", + "version": s.version, + }) + + // Flush the response + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + // Restart in background after response is sent + go func() { + // Small delay to ensure response is sent + time.Sleep(100 * time.Millisecond) + if err := s.updater.Restart(); err != nil { + log.Error().Err(err).Msg("Failed to restart worker") + } + }() +} diff --git a/internal/worker/middleware.go b/internal/worker/middleware.go new file mode 100644 index 0000000..cd42eb7 --- /dev/null +++ b/internal/worker/middleware.go @@ -0,0 +1,335 @@ +// Package worker provides the main worker service for claude-mnemonic. +package worker + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "regexp" + "strings" + "sync" + "time" +) + +// requestIDKey is the context key for request IDs. +type requestIDKey struct{} + +// projectNamePattern validates project names to prevent path traversal. +var projectNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_./-]+$`) + +// allowedOrigins is the whitelist of origins allowed for CORS. +// Uses exact matching to prevent bypass attacks like "evil-localhost.com". +var allowedOrigins = map[string]bool{ + "http://localhost": true, + "http://localhost:3000": true, + "http://localhost:5173": true, // Vite dev server + "http://localhost:37778": true, // Dashboard UI + "http://127.0.0.1": true, + "http://127.0.0.1:3000": true, + "http://127.0.0.1:5173": true, + "http://127.0.0.1:37778": true, +} + +// SecurityHeaders middleware adds essential security headers to all responses. +// These protect against common web vulnerabilities. +func SecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Prevent clickjacking + w.Header().Set("X-Frame-Options", "DENY") + + // Prevent MIME type sniffing + w.Header().Set("X-Content-Type-Options", "nosniff") + + // Enable XSS filter + w.Header().Set("X-XSS-Protection", "1; mode=block") + + // Restrict referrer information + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + + // Content Security Policy - restrict to self + w.Header().Set("Content-Security-Policy", "default-src 'self'") + + // Permissions Policy - disable unnecessary features + w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") + + // CORS: Use exact match whitelist to prevent bypass attacks + origin := r.Header.Get("Origin") + if allowedOrigins[origin] { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Auth-Token, Authorization, X-Request-ID") + } + + // Handle preflight requests + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) +} + +// MaxBodySize middleware limits the size of incoming request bodies. +// This prevents denial of service attacks via large payloads. +func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength > maxBytes { + http.Error(w, "request body too large", http.StatusRequestEntityTooLarge) + return + } + r.Body = http.MaxBytesReader(w, r.Body, maxBytes) + next.ServeHTTP(w, r) + }) + } +} + +// TokenAuth provides simple token-based authentication for localhost services. +// The token is generated at startup and must be provided in the X-Auth-Token header. +type TokenAuth struct { + token string + enabled bool + mu sync.RWMutex + + // ExemptPaths are paths that don't require authentication (e.g., health checks) + ExemptPaths map[string]bool +} + +// NewTokenAuth creates a new TokenAuth with a randomly generated token. +// If enabled is false, authentication is skipped (useful for development). +func NewTokenAuth(enabled bool) (*TokenAuth, error) { + ta := &TokenAuth{ + enabled: enabled, + ExemptPaths: map[string]bool{ + "/health": true, + "/api/health": true, + "/api/ready": true, + }, + } + + if enabled { + // Generate 32-byte random token + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return nil, err + } + ta.token = hex.EncodeToString(tokenBytes) + } + + return ta, nil +} + +// Token returns the authentication token. +// Returns empty string if authentication is disabled. +func (ta *TokenAuth) Token() string { + ta.mu.RLock() + defer ta.mu.RUnlock() + return ta.token +} + +// IsEnabled returns whether token authentication is enabled. +func (ta *TokenAuth) IsEnabled() bool { + ta.mu.RLock() + defer ta.mu.RUnlock() + return ta.enabled +} + +// Middleware returns HTTP middleware that enforces token authentication. +func (ta *TokenAuth) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ta.mu.RLock() + enabled := ta.enabled + token := ta.token + exempt := ta.ExemptPaths[r.URL.Path] + ta.mu.RUnlock() + + // Skip auth if disabled or path is exempt + if !enabled || exempt { + next.ServeHTTP(w, r) + return + } + + // Check for token in header + providedToken := r.Header.Get("X-Auth-Token") + if providedToken == "" { + // Also check Authorization header with Bearer scheme + auth := r.Header.Get("Authorization") + if bearer, found := strings.CutPrefix(auth, "Bearer "); found { + providedToken = bearer + } + } + + if providedToken != token { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +// ExpensiveOperationLimiter provides stricter rate limiting for expensive operations. +// It wraps the base per-client rate limiter with additional per-operation limits. +type ExpensiveOperationLimiter struct { + // Track last execution time per operation type + lastRebuild int64 // Unix timestamp + rebuildCooldown int64 // Minimum seconds between rebuilds + + mu sync.Mutex +} + +// NewExpensiveOperationLimiter creates a limiter for expensive operations. +func NewExpensiveOperationLimiter() *ExpensiveOperationLimiter { + return &ExpensiveOperationLimiter{ + rebuildCooldown: 300, // 5 minutes between rebuilds + } +} + +// CanRebuild checks if a vector rebuild operation is allowed. +// Returns false if a rebuild was triggered too recently. +func (eol *ExpensiveOperationLimiter) CanRebuild() bool { + eol.mu.Lock() + defer eol.mu.Unlock() + + now := unixNow() + if now-eol.lastRebuild < eol.rebuildCooldown { + return false + } + eol.lastRebuild = now + return true +} + +// unixNow returns current Unix timestamp. +// Separated for easier testing. +func unixNow() int64 { + return time.Now().Unix() +} + +// RequestID middleware adds a unique request ID to each request. +// The ID is added to the context and response headers for tracing. +func RequestID(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for existing request ID from client + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + // Generate new request ID + idBytes := make([]byte, 8) + if _, err := rand.Read(idBytes); err == nil { + requestID = hex.EncodeToString(idBytes) + } else { + requestID = fmt.Sprintf("%d", time.Now().UnixNano()) + } + } + + // Add to response header + w.Header().Set("X-Request-ID", requestID) + + // Add to context + ctx := context.WithValue(r.Context(), requestIDKey{}, requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetRequestID retrieves the request ID from the context. +func GetRequestID(ctx context.Context) string { + if id, ok := ctx.Value(requestIDKey{}).(string); ok { + return id + } + return "" +} + +// RequireJSONContentType middleware validates that POST/PUT/PATCH requests +// have application/json Content-Type header. +func RequireJSONContentType(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only check for methods that typically have bodies + if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" { + ct := r.Header.Get("Content-Type") + // Allow empty Content-Type for requests without body + if ct != "" && !strings.HasPrefix(ct, "application/json") { + http.Error(w, "Content-Type must be application/json", http.StatusUnsupportedMediaType) + return + } + } + next.ServeHTTP(w, r) + }) +} + +// ValidateProjectName checks if a project name is safe to use. +// Returns an error if the name contains path traversal or invalid characters. +func ValidateProjectName(project string) error { + if project == "" { + return nil // Empty is allowed (means no filter) + } + + // Check for path traversal + if strings.Contains(project, "..") { + return fmt.Errorf("invalid project name: path traversal detected") + } + + // Check for valid characters + if !projectNamePattern.MatchString(project) { + return fmt.Errorf("invalid project name: only alphanumeric, underscore, dash, dot, and slash allowed") + } + + // Max length check + if len(project) > 500 { + return fmt.Errorf("project name too long (max 500 chars)") + } + + return nil +} + +// BulkOperationLimiter provides rate limiting for bulk operations. +// Prevents DoS via repeated bulk requests. +type BulkOperationLimiter struct { + lastBulkOp int64 // Unix timestamp + cooldown int64 // Minimum seconds between operations + + mu sync.Mutex +} + +// NewBulkOperationLimiter creates a limiter for bulk operations. +func NewBulkOperationLimiter(cooldownSeconds int64) *BulkOperationLimiter { + return &BulkOperationLimiter{ + cooldown: cooldownSeconds, + } +} + +// CanExecute checks if a bulk operation is allowed. +// Returns false if a bulk operation was triggered too recently. +func (bol *BulkOperationLimiter) CanExecute() bool { + bol.mu.Lock() + defer bol.mu.Unlock() + + now := unixNow() + if now-bol.lastBulkOp < bol.cooldown { + return false + } + bol.lastBulkOp = now + return true +} + +// TimeSinceLastOp returns seconds since the last bulk operation. +func (bol *BulkOperationLimiter) TimeSinceLastOp() int64 { + bol.mu.Lock() + defer bol.mu.Unlock() + return unixNow() - bol.lastBulkOp +} + +// CooldownRemaining returns seconds remaining in the cooldown period. +// Returns 0 if no cooldown is active. +func (bol *BulkOperationLimiter) CooldownRemaining() int64 { + bol.mu.Lock() + defer bol.mu.Unlock() + + remaining := bol.cooldown - (unixNow() - bol.lastBulkOp) + if remaining < 0 { + return 0 + } + return remaining +} diff --git a/internal/worker/middleware_test.go b/internal/worker/middleware_test.go new file mode 100644 index 0000000..28477a4 --- /dev/null +++ b/internal/worker/middleware_test.go @@ -0,0 +1,515 @@ +package worker + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestSecurityHeaders(t *testing.T) { + handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + // Check all security headers are set + tests := []struct { + header string + expected string + }{ + {"X-Frame-Options", "DENY"}, + {"X-Content-Type-Options", "nosniff"}, + {"X-XSS-Protection", "1; mode=block"}, + {"Referrer-Policy", "strict-origin-when-cross-origin"}, + } + + for _, tt := range tests { + if got := rr.Header().Get(tt.header); got != tt.expected { + t.Errorf("SecurityHeaders() %s = %q, want %q", tt.header, got, tt.expected) + } + } +} + +func TestSecurityHeaders_CORS(t *testing.T) { + handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + origin string + expectCORS bool + expectedOrigin string + }{ + { + name: "localhost:37778 origin allowed", + origin: "http://localhost:37778", + expectCORS: true, + expectedOrigin: "http://localhost:37778", + }, + { + name: "127.0.0.1:5173 origin allowed", + origin: "http://127.0.0.1:5173", + expectCORS: true, + expectedOrigin: "http://127.0.0.1:5173", + }, + { + name: "localhost without port allowed", + origin: "http://localhost", + expectCORS: true, + expectedOrigin: "http://localhost", + }, + { + name: "external origin blocked", + origin: "http://evil.com", + expectCORS: false, + }, + { + name: "evil-localhost.com bypass attempt blocked", + origin: "http://evil-localhost.com", + expectCORS: false, + }, + { + name: "localhost subdomain bypass attempt blocked", + origin: "http://localhost.evil.com", + expectCORS: false, + }, + { + name: "unknown localhost port blocked", + origin: "http://localhost:9999", + expectCORS: false, + }, + { + name: "no origin header", + origin: "", + expectCORS: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + cors := rr.Header().Get("Access-Control-Allow-Origin") + if tt.expectCORS { + if cors != tt.expectedOrigin { + t.Errorf("Expected CORS origin %q, got %q", tt.expectedOrigin, cors) + } + } else { + if cors != "" { + t.Errorf("Expected no CORS header, got %q", cors) + } + } + }) + } +} + +func TestMaxBodySize(t *testing.T) { + maxSize := int64(100) + handler := MaxBodySize(maxSize)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + contentLength int64 + expectedStatus int + }{ + { + name: "within limit", + contentLength: 50, + expectedStatus: http.StatusOK, + }, + { + name: "at limit", + contentLength: 100, + expectedStatus: http.StatusOK, + }, + { + name: "exceeds limit", + contentLength: 150, + expectedStatus: http.StatusRequestEntityTooLarge, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/test", nil) + req.ContentLength = tt.contentLength + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("MaxBodySize() status = %d, want %d", rr.Code, tt.expectedStatus) + } + }) + } +} + +func TestTokenAuth(t *testing.T) { + t.Run("disabled auth allows all requests", func(t *testing.T) { + ta, err := NewTokenAuth(false) + if err != nil { + t.Fatalf("NewTokenAuth() error = %v", err) + } + + handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected OK with disabled auth, got %d", rr.Code) + } + }) + + t.Run("enabled auth requires token", func(t *testing.T) { + ta, err := NewTokenAuth(true) + if err != nil { + t.Fatalf("NewTokenAuth() error = %v", err) + } + + handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Without token + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("Expected Unauthorized without token, got %d", rr.Code) + } + + // With correct token in X-Auth-Token header + req = httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Auth-Token", ta.Token()) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected OK with correct token, got %d", rr.Code) + } + + // With correct token in Authorization header + req = httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+ta.Token()) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected OK with Bearer token, got %d", rr.Code) + } + }) + + t.Run("exempt paths skip auth", func(t *testing.T) { + ta, err := NewTokenAuth(true) + if err != nil { + t.Fatalf("NewTokenAuth() error = %v", err) + } + + handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + exemptPaths := []string{"/health", "/api/health", "/api/ready"} + for _, path := range exemptPaths { + req := httptest.NewRequest("GET", path, nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected OK for exempt path %s, got %d", path, rr.Code) + } + } + }) +} + +func TestExpensiveOperationLimiter(t *testing.T) { + limiter := NewExpensiveOperationLimiter() + + // First rebuild should be allowed + if !limiter.CanRebuild() { + t.Error("First rebuild should be allowed") + } + + // Immediate second rebuild should be blocked + if limiter.CanRebuild() { + t.Error("Immediate second rebuild should be blocked") + } +} + +func TestRequestID(t *testing.T) { + handler := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request ID is in context + id := GetRequestID(r.Context()) + if id == "" { + t.Error("Request ID should be set in context") + } + w.WriteHeader(http.StatusOK) + })) + + t.Run("generates new request ID", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Header().Get("X-Request-ID") == "" { + t.Error("X-Request-ID header should be set") + } + }) + + t.Run("uses existing request ID", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Request-ID", "test-id-12345") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Header().Get("X-Request-ID") != "test-id-12345" { + t.Errorf("Expected X-Request-ID to be test-id-12345, got %s", rr.Header().Get("X-Request-ID")) + } + }) +} + +func TestRequireJSONContentType(t *testing.T) { + handler := RequireJSONContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + method string + contentType string + expectedStatus int + }{ + { + name: "GET request without content-type", + method: "GET", + contentType: "", + expectedStatus: http.StatusOK, + }, + { + name: "POST with application/json", + method: "POST", + contentType: "application/json", + expectedStatus: http.StatusOK, + }, + { + name: "POST with application/json; charset=utf-8", + method: "POST", + contentType: "application/json; charset=utf-8", + expectedStatus: http.StatusOK, + }, + { + name: "POST without content-type (empty body)", + method: "POST", + contentType: "", + expectedStatus: http.StatusOK, + }, + { + name: "POST with text/plain rejected", + method: "POST", + contentType: "text/plain", + expectedStatus: http.StatusUnsupportedMediaType, + }, + { + name: "PUT with application/xml rejected", + method: "PUT", + contentType: "application/xml", + expectedStatus: http.StatusUnsupportedMediaType, + }, + { + name: "PATCH with form-urlencoded rejected", + method: "PATCH", + contentType: "application/x-www-form-urlencoded", + expectedStatus: http.StatusUnsupportedMediaType, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/test", nil) + if tt.contentType != "" { + req.Header.Set("Content-Type", tt.contentType) + } + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code) + } + }) + } +} + +func TestValidateProjectName(t *testing.T) { + tests := []struct { + name string + project string + wantError bool + }{ + { + name: "empty project allowed", + project: "", + wantError: false, + }, + { + name: "simple project name", + project: "my-project", + wantError: false, + }, + { + name: "project with path", + project: "org/my-project", + wantError: false, + }, + { + name: "project with underscore", + project: "my_project_v2", + wantError: false, + }, + { + name: "project with dot", + project: "my.project.name", + wantError: false, + }, + { + name: "path traversal attack", + project: "../../../etc/passwd", + wantError: true, + }, + { + name: "hidden path traversal", + project: "project/../../secret", + wantError: true, + }, + { + name: "shell injection attempt", + project: "project; rm -rf /", + wantError: true, + }, + { + name: "backtick injection", + project: "project`whoami`", + wantError: true, + }, + { + name: "special characters", + project: "project$HOME", + wantError: true, + }, + { + name: "too long project name", + project: string(make([]byte, 501)), + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateProjectName(tt.project) + if tt.wantError && err == nil { + t.Errorf("Expected error for project %q, got nil", tt.project) + } + if !tt.wantError && err != nil { + t.Errorf("Unexpected error for project %q: %v", tt.project, err) + } + }) + } +} + +func TestBulkOperationLimiter(t *testing.T) { + limiter := NewBulkOperationLimiter(1) // 1 second cooldown for testing + + // First operation should be allowed + if !limiter.CanExecute() { + t.Error("First bulk operation should be allowed") + } + + // Immediate second operation should be blocked + if limiter.CanExecute() { + t.Error("Immediate second bulk operation should be blocked") + } + + // Check cooldown remaining + remaining := limiter.CooldownRemaining() + if remaining <= 0 || remaining > 1 { + t.Errorf("Expected cooldown remaining between 0-1 seconds, got %d", remaining) + } + + // Check time since last op + since := limiter.TimeSinceLastOp() + if since < 0 || since > 1 { + t.Errorf("Expected time since last op between 0-1 seconds, got %d", since) + } +} + +func TestSecurityHeaders_CSP(t *testing.T) { + handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + // Check CSP header is set + csp := rr.Header().Get("Content-Security-Policy") + if csp == "" { + t.Error("Content-Security-Policy header should be set") + } + if csp != "default-src 'self'" { + t.Errorf("Expected CSP to be \"default-src 'self'\", got %q", csp) + } + + // Check Permissions-Policy header + pp := rr.Header().Get("Permissions-Policy") + if pp == "" { + t.Error("Permissions-Policy header should be set") + } +} + +func TestSecurityHeaders_Preflight(t *testing.T) { + handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("OPTIONS", "/test", nil) + req.Header.Set("Origin", "http://localhost:3000") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + // OPTIONS should return 204 No Content + if rr.Code != http.StatusNoContent { + t.Errorf("Expected status 204 for OPTIONS, got %d", rr.Code) + } + + // CORS headers should be set for allowed origin + if rr.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" { + t.Errorf("CORS origin should be set for allowed origin") + } + if rr.Header().Get("Access-Control-Allow-Methods") == "" { + t.Error("Access-Control-Allow-Methods should be set") + } +} diff --git a/internal/worker/ratelimit.go b/internal/worker/ratelimit.go new file mode 100644 index 0000000..6e1ec63 --- /dev/null +++ b/internal/worker/ratelimit.go @@ -0,0 +1,227 @@ +// Package worker provides the main worker service for claude-mnemonic. +package worker + +import ( + "net/http" + "sync" + "time" +) + +// RateLimiter implements a token bucket rate limiter. +type RateLimiter struct { + rate float64 // tokens per second + burst int // maximum burst size + mu sync.Mutex // protects following fields + tokens float64 // current tokens + lastUpdate time.Time // last token update time + requests int64 // total requests + rejected int64 // rejected requests +} + +// LastUpdateTime returns the last update time. +// Thread-safe - acquires the limiter's lock. +func (rl *RateLimiter) LastUpdateTime() time.Time { + rl.mu.Lock() + defer rl.mu.Unlock() + return rl.lastUpdate +} + +// lastUpdateTimeUnlocked returns the last update time without locking. +// Caller must hold rl.mu. +func (rl *RateLimiter) lastUpdateTimeUnlocked() time.Time { + return rl.lastUpdate +} + +// NewRateLimiter creates a new rate limiter. +// rate is the number of requests per second to allow. +// burst is the maximum burst of requests to allow. +func NewRateLimiter(rate float64, burst int) *RateLimiter { + return &RateLimiter{ + rate: rate, + burst: burst, + tokens: float64(burst), + lastUpdate: time.Now(), + } +} + +// Allow checks if a request should be allowed. +// Returns true if the request is allowed, false if rate limited. +func (rl *RateLimiter) Allow() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.requests++ + + // Calculate tokens added since last update + now := time.Now() + elapsed := now.Sub(rl.lastUpdate).Seconds() + rl.tokens += elapsed * rl.rate + if rl.tokens > float64(rl.burst) { + rl.tokens = float64(rl.burst) + } + rl.lastUpdate = now + + // Check if we have a token available + if rl.tokens >= 1 { + rl.tokens-- + return true + } + + rl.rejected++ + return false +} + +// Stats returns rate limiter statistics. +func (rl *RateLimiter) Stats() map[string]any { + rl.mu.Lock() + defer rl.mu.Unlock() + + return map[string]any{ + "rate": rl.rate, + "burst": rl.burst, + "current_tokens": rl.tokens, + "total_requests": rl.requests, + "rejected": rl.rejected, + "rejection_rate": float64(rl.rejected) / max(float64(rl.requests), 1), + } +} + +// RateLimitMiddleware creates middleware that applies rate limiting. +// Uses a shared rate limiter for all requests. +func RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !limiter.Allow() { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// PerClientRateLimiter implements per-client rate limiting. +type PerClientRateLimiter struct { + rate float64 + burst int + clients map[string]*RateLimiter + mu sync.Mutex + // Cleanup settings + cleanupInterval time.Duration + maxIdleTime time.Duration + lastCleanup time.Time +} + +// NewPerClientRateLimiter creates a new per-client rate limiter. +func NewPerClientRateLimiter(rate float64, burst int) *PerClientRateLimiter { + return &PerClientRateLimiter{ + rate: rate, + burst: burst, + clients: make(map[string]*RateLimiter), + cleanupInterval: 5 * time.Minute, + maxIdleTime: 10 * time.Minute, + lastCleanup: time.Now(), + } +} + +// getLimiter returns a rate limiter for the given client key. +func (pcrl *PerClientRateLimiter) getLimiter(key string) *RateLimiter { + pcrl.mu.Lock() + defer pcrl.mu.Unlock() + + // Periodic cleanup of idle clients + if time.Since(pcrl.lastCleanup) > pcrl.cleanupInterval { + pcrl.cleanupLocked() + } + + limiter, exists := pcrl.clients[key] + if !exists { + limiter = NewRateLimiter(pcrl.rate, pcrl.burst) + pcrl.clients[key] = limiter + } + + return limiter +} + +// cleanupLocked removes idle limiters. Must be called with lock held. +// Uses consistent lock ordering: always acquire limiter.mu while holding pcrl.mu. +// This is safe because the limiter.mu critical section is brief (just reading lastUpdate). +func (pcrl *PerClientRateLimiter) cleanupLocked() { + now := time.Now() + keysToDelete := make([]string, 0) + + // Check each limiter while holding pcrl.mu + // We briefly acquire limiter.mu but the critical section is minimal + for key, limiter := range pcrl.clients { + limiter.mu.Lock() + lastUpdate := limiter.lastUpdateTimeUnlocked() + limiter.mu.Unlock() + + if now.Sub(lastUpdate) > pcrl.maxIdleTime { + keysToDelete = append(keysToDelete, key) + } + } + + // Delete collected keys + for _, key := range keysToDelete { + delete(pcrl.clients, key) + } + pcrl.lastCleanup = now +} + +// Allow checks if a request from the given client should be allowed. +func (pcrl *PerClientRateLimiter) Allow(clientKey string) bool { + return pcrl.getLimiter(clientKey).Allow() +} + +// Stats returns aggregate statistics. +// Uses two-phase approach to avoid nested lock acquisition. +func (pcrl *PerClientRateLimiter) Stats() map[string]any { + // Phase 1: Collect limiters under pcrl.mu + pcrl.mu.Lock() + rate := pcrl.rate + burst := pcrl.burst + activeClients := len(pcrl.clients) + limiters := make([]*RateLimiter, 0, activeClients) + for _, limiter := range pcrl.clients { + limiters = append(limiters, limiter) + } + pcrl.mu.Unlock() + + // Phase 2: Collect stats from each limiter (only acquiring limiter.mu, not pcrl.mu) + var totalRequests, totalRejected int64 + for _, limiter := range limiters { + limiter.mu.Lock() + totalRequests += limiter.requests + totalRejected += limiter.rejected + limiter.mu.Unlock() + } + + return map[string]any{ + "rate": rate, + "burst": burst, + "active_clients": activeClients, + "total_requests": totalRequests, + "total_rejected": totalRejected, + } +} + +// PerClientRateLimitMiddleware creates middleware that applies per-client rate limiting. +// Uses X-Forwarded-For or RemoteAddr to identify clients. +func PerClientRateLimitMiddleware(limiter *PerClientRateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get client identifier (prefer X-Real-IP from RealIP middleware) + clientKey := r.RemoteAddr + if xff := r.Header.Get("X-Real-IP"); xff != "" { + clientKey = xff + } + + if !limiter.Allow(clientKey) { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/worker/sdk/processor.go b/internal/worker/sdk/processor.go index 17e51d8..fc0c78c 100644 --- a/internal/worker/sdk/processor.go +++ b/internal/worker/sdk/processor.go @@ -4,11 +4,15 @@ package sdk import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "fmt" "os" "os/exec" "path/filepath" "strings" + "sync" + "sync/atomic" "time" json "github.com/goccy/go-json" @@ -20,8 +24,178 @@ import ( "github.com/rs/zerolog/log" ) +// CircuitBreaker implements a simple circuit breaker pattern for CLI calls. +type CircuitBreaker struct { + failures int64 // Current failure count + lastFailure int64 // Unix timestamp of last failure + threshold int64 // Number of failures before opening + resetTimeout int64 // Seconds to wait before trying again + state int32 // 0=closed, 1=open, 2=half-open +} + +const ( + circuitClosed int32 = 0 + circuitOpen int32 = 1 + circuitHalfOpen int32 = 2 +) + +// NewCircuitBreaker creates a new circuit breaker. +func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker { + return &CircuitBreaker{ + threshold: threshold, + resetTimeout: resetTimeout, + } +} + +// Allow checks if a request should be allowed through. +func (cb *CircuitBreaker) Allow() bool { + state := atomic.LoadInt32(&cb.state) + if state == circuitClosed { + return true + } + + if state == circuitOpen { + // Check if reset timeout has passed + lastFail := atomic.LoadInt64(&cb.lastFailure) + if time.Now().Unix()-lastFail > cb.resetTimeout { + // Transition to half-open + atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen) + return true + } + return false + } + + // Half-open: allow one request through + return true +} + +// RecordSuccess records a successful call. +func (cb *CircuitBreaker) RecordSuccess() { + atomic.StoreInt64(&cb.failures, 0) + atomic.StoreInt32(&cb.state, circuitClosed) +} + +// RecordFailure records a failed call. +func (cb *CircuitBreaker) RecordFailure() { + failures := atomic.AddInt64(&cb.failures, 1) + atomic.StoreInt64(&cb.lastFailure, time.Now().Unix()) + + if failures >= cb.threshold { + atomic.StoreInt32(&cb.state, circuitOpen) + log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled") + } +} + +// State returns the current state as a string. +func (cb *CircuitBreaker) State() string { + switch atomic.LoadInt32(&cb.state) { + case circuitOpen: + return "open" + case circuitHalfOpen: + return "half-open" + default: + return "closed" + } +} + +// CircuitBreakerMetrics contains metrics about the circuit breaker state. +type CircuitBreakerMetrics struct { + State string `json:"state"` + Failures int64 `json:"failures"` + Threshold int64 `json:"threshold"` + ResetTimeoutSecs int64 `json:"reset_timeout_secs"` + LastFailureUnix int64 `json:"last_failure_unix,omitempty"` + SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"` +} + +// Metrics returns the current metrics of the circuit breaker. +func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics { + failures := atomic.LoadInt64(&cb.failures) + lastFail := atomic.LoadInt64(&cb.lastFailure) + state := cb.State() + + metrics := CircuitBreakerMetrics{ + State: state, + Failures: failures, + Threshold: cb.threshold, + ResetTimeoutSecs: cb.resetTimeout, + } + + if lastFail > 0 { + metrics.LastFailureUnix = lastFail + if state == "open" { + remaining := cb.resetTimeout - (time.Now().Unix() - lastFail) + if remaining > 0 { + metrics.SecondsUntilReset = remaining + } + } + } + + return metrics +} + +// RequestDeduplicator tracks recent requests to prevent duplicates. +type RequestDeduplicator struct { + seen map[string]int64 // hash -> timestamp + mu sync.RWMutex + ttlSecs int64 + maxSize int +} + +// NewRequestDeduplicator creates a new deduplicator. +func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator { + return &RequestDeduplicator{ + seen: make(map[string]int64), + ttlSecs: ttlSecs, + maxSize: maxSize, + } +} + +// IsDuplicate checks if a request hash was seen recently. +func (d *RequestDeduplicator) IsDuplicate(hash string) bool { + now := time.Now().Unix() + + d.mu.RLock() + ts, exists := d.seen[hash] + d.mu.RUnlock() + + if exists && now-ts < d.ttlSecs { + return true + } + return false +} + +// Record marks a request hash as seen. +func (d *RequestDeduplicator) Record(hash string) { + now := time.Now().Unix() + + d.mu.Lock() + defer d.mu.Unlock() + + // Evict old entries if at capacity + if len(d.seen) >= d.maxSize { + threshold := now - d.ttlSecs + for k, ts := range d.seen { + if ts < threshold { + delete(d.seen, k) + } + } + } + + d.seen[hash] = now +} + +// hashRequest creates a hash of a request for deduplication. +func hashRequest(toolName, input, output string) string { + h := sha256.New() + h.Write([]byte(toolName)) + h.Write([]byte(input)) + h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output + return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient +} + // BroadcastFunc is a callback for broadcasting events to SSE clients. -type BroadcastFunc func(event map[string]interface{}) +type BroadcastFunc func(event map[string]any) // SyncObservationFunc is a callback for syncing observations to vector DB. type SyncObservationFunc func(obs *models.Observation) @@ -29,6 +203,10 @@ type SyncObservationFunc func(obs *models.Observation) // SyncSummaryFunc is a callback for syncing summaries to vector DB. type SyncSummaryFunc func(summary *models.SessionSummary) +// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations. +// This prevents unbounded goroutine spawning during high-volume observation ingestion. +const MaxVectorSyncWorkers = 8 + // Processor handles SDK agent processing of observations and summaries using Claude Code CLI. type Processor struct { claudePath string @@ -40,6 +218,14 @@ type Processor struct { syncSummaryFunc SyncSummaryFunc // Semaphore to limit concurrent Claude CLI calls (prevents API overload) sem chan struct{} + // Circuit breaker for CLI failures + circuitBreaker *CircuitBreaker + // Request deduplicator to prevent duplicate processing + deduplicator *RequestDeduplicator + // Bounded worker pool for vector sync operations + vectorSyncChan chan *models.Observation + vectorSyncWg sync.WaitGroup + vectorSyncDone chan struct{} } // SetBroadcastFunc sets the broadcast callback for SSE events. @@ -58,7 +244,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) { } // broadcast sends an event via the broadcast callback if set. -func (p *Processor) broadcast(event map[string]interface{}) { +func (p *Processor) broadcast(event map[string]any) { if p.broadcastFunc != nil { p.broadcastFunc(event) } @@ -94,9 +280,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su observationStore: observationStore, summaryStore: summaryStore, sem: make(chan struct{}, MaxConcurrentCLICalls), + circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s + deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries + vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel + vectorSyncDone: make(chan struct{}), }, nil } +// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations. +// Call this after setting the sync function via SetSyncObservationFunc. +func (p *Processor) StartVectorSyncWorkers() { + for i := 0; i < MaxVectorSyncWorkers; i++ { + p.vectorSyncWg.Add(1) + go p.vectorSyncWorker() + } + log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started") +} + +// StopVectorSyncWorkers gracefully stops the worker pool. +func (p *Processor) StopVectorSyncWorkers() { + close(p.vectorSyncDone) + p.vectorSyncWg.Wait() + log.Info().Msg("Vector sync worker pool stopped") +} + +// vectorSyncWorker is a worker goroutine that processes vector sync requests. +func (p *Processor) vectorSyncWorker() { + defer p.vectorSyncWg.Done() + for { + select { + case <-p.vectorSyncDone: + // Drain remaining items before exiting + for { + select { + case obs := <-p.vectorSyncChan: + if p.syncObservationFunc != nil { + p.syncObservationFunc(obs) + } + default: + return + } + } + case obs := <-p.vectorSyncChan: + if p.syncObservationFunc != nil { + p.syncObservationFunc(obs) + } + } + } +} + +// CircuitBreakerState returns the current state of the circuit breaker. +func (p *Processor) CircuitBreakerState() string { + return p.circuitBreaker.State() +} + +// CircuitBreakerMetrics returns detailed metrics about the circuit breaker. +func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics { + return p.circuitBreaker.Metrics() +} + // IsAvailable checks if the Claude CLI is available for processing. func (p *Processor) IsAvailable() bool { _, err := os.Stat(p.claudePath) @@ -104,7 +346,7 @@ func (p *Processor) IsAvailable() bool { } // ProcessObservation processes a single tool observation and extracts insights. -func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error { +func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error { // Skip certain tools that aren't worth processing if shouldSkipTool(toolName) { log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)") @@ -121,11 +363,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec return nil } + // Check for duplicate request within TTL window + reqHash := hashRequest(toolName, inputStr, outputStr) + if p.deduplicator.IsDuplicate(reqHash) { + log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)") + return nil + } + + // Check circuit breaker before making CLI call + if !p.circuitBreaker.Allow() { + log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call") + return fmt.Errorf("circuit breaker open") + } + log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI") - // Note: Removed the "file already has observations" check - // Each tool execution can produce unique insights even for the same file - // Similarity-based deduplication will handle true duplicates + // Record this request to prevent duplicates + p.deduplicator.Record(reqHash) // Build the prompt exec := ToolExecution{ @@ -147,9 +401,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec // Call Claude Code CLI response, err := p.callClaudeCLI(ctx, prompt) if err != nil { + p.circuitBreaker.RecordFailure() log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation") return err } + p.circuitBreaker.RecordSuccess() // Parse observations from response observations := ParseObservations(response, sdkSessionID) @@ -200,16 +456,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec Int("trackedFiles", len(obs.FileMtimes)). Msg("Observation stored") - // Sync to vector DB if callback is set - if p.syncObservationFunc != nil { + // Sync to vector DB via bounded worker pool (non-blocking to reduce latency) + if p.syncObservationFunc != nil && p.vectorSyncChan != nil { fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0) fullObs.ID = id fullObs.CreatedAtEpoch = createdAtEpoch - p.syncObservationFunc(fullObs) + // Non-blocking send to worker pool - drops if channel is full + select { + case p.vectorSyncChan <- fullObs: + // Sent to worker pool + default: + // Channel full, fall back to direct sync in goroutine (bounded by channel buffer) + log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine") + go func(obsToSync *models.Observation) { + p.syncObservationFunc(obsToSync) + }(fullObs) + } } // Broadcast new observation event for dashboard refresh - p.broadcast(map[string]interface{}{ + p.broadcast(map[string]any{ "type": "observation", "action": "created", "id": id, @@ -311,7 +577,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe } // Broadcast new summary event for dashboard refresh - p.broadcast(map[string]interface{}{ + p.broadcast(map[string]any{ "type": "summary", "action": "created", "id": id, @@ -321,8 +587,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe return nil } +// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI. +// This prevents resource exhaustion from extremely large prompts. +const MaxPromptSize = 100 * 1024 // 100KB + +// sanitizePrompt removes null bytes and control characters from a prompt. +// Keeps newlines, tabs, and carriage returns as they're valid in prompts. +func sanitizePrompt(s string) string { + return strings.Map(func(r rune) rune { + // Keep printable ASCII, extended Unicode, and common whitespace + if r >= 32 || r == '\n' || r == '\t' || r == '\r' { + return r + } + // Remove null bytes and other control characters + return -1 + }, s) +} + // callClaudeCLI calls the Claude Code CLI with the given prompt. func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) { + // Validate and sanitize prompt + if len(prompt) > MaxPromptSize { + return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize) + } + prompt = sanitizePrompt(prompt) + // Build the full prompt with system instructions fullPrompt := systemPrompt + "\n\n" + prompt @@ -419,8 +708,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool { return true } - // Skip if output indicates an error or empty result + // Pre-compute lowercase strings once to avoid repeated allocations lowerOutput := strings.ToLower(outputStr) + lowerInput := strings.ToLower(inputStr) + + // Skip if output indicates an error or empty result trivialOutputs := []string{ "no matches found", "file not found", @@ -444,13 +736,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool { // Skip reading config files that rarely contain project-specific insights boringFiles := []string{ "package-lock.json", "yarn.lock", "pnpm-lock.yaml", - "go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock", + "go.sum", "cargo.lock", "gemfile.lock", "poetry.lock", ".gitignore", ".dockerignore", ".eslintignore", "tsconfig.json", "jsconfig.json", "vite.config", "tailwind.config", "postcss.config", } for _, boring := range boringFiles { - if strings.Contains(inputStr, boring) { + if strings.Contains(lowerInput, boring) { return true } } @@ -462,14 +754,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool { } case "Bash": - // Skip simple status commands + // Skip simple status commands (use pre-computed lowerInput) boringCommands := []string{ "git status", "git diff", "git log", "git branch", "ls ", "pwd", "echo ", "cat ", "which ", "type ", "npm list", "npm outdated", "npm audit", } for _, boring := range boringCommands { - if strings.Contains(strings.ToLower(inputStr), boring) { + if strings.Contains(lowerInput, boring) { return true } } @@ -479,7 +771,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool { } // toJSONString converts an interface to a JSON string. -func toJSONString(v interface{}) string { +func toJSONString(v any) string { if v == nil { return "" } @@ -495,36 +787,85 @@ func toJSONString(v interface{}) string { // captureFileMtimes captures current modification times for tracked files. // Returns a map of absolute file paths to their mtime in epoch milliseconds. +// For large file lists (>10 files), uses parallel stat calls for better performance. func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 { - mtimes := make(map[string]int64) + // Combine all unique file paths + allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified)) + for _, path := range filesRead { + allPaths[path] = struct{}{} + } + for _, path := range filesModified { + allPaths[path] = struct{}{} + } - // Helper to get mtime for a file path - getMtime := func(path string) (int64, bool) { - // Resolve relative paths against cwd + // For small lists, use sequential processing (goroutine overhead not worth it) + if len(allPaths) <= 10 { + return captureFileMtimesSequential(allPaths, cwd) + } + + // For larger lists, parallelize with bounded concurrency + return captureFileMtimesParallel(allPaths, cwd) +} + +// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists). +func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 { + mtimes := make(map[string]int64, len(paths)) + + for path := range paths { absPath := path if !filepath.IsAbs(path) && cwd != "" { absPath = filepath.Join(cwd, path) } info, err := os.Stat(absPath) - if err != nil { - return 0, false - } - return info.ModTime().UnixMilli(), true - } - - // Capture mtimes for all read files - for _, path := range filesRead { - if mtime, ok := getMtime(path); ok { - mtimes[path] = mtime + if err == nil { + mtimes[path] = info.ModTime().UnixMilli() } } - // Capture mtimes for all modified files - for _, path := range filesModified { - if mtime, ok := getMtime(path); ok { - mtimes[path] = mtime - } + return mtimes +} + +// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency. +func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 { + type mtimeResult struct { + path string + mtime int64 + } + + results := make(chan mtimeResult, len(paths)) + sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls + var wg sync.WaitGroup + + for path := range paths { + wg.Add(1) + go func(p string) { + defer wg.Done() + sem <- struct{}{} // Acquire + defer func() { <-sem }() // Release + + absPath := p + if !filepath.IsAbs(p) && cwd != "" { + absPath = filepath.Join(cwd, p) + } + + info, err := os.Stat(absPath) + if err == nil { + results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()} + } + }(path) + } + + // Close results channel when all goroutines complete + go func() { + wg.Wait() + close(results) + }() + + // Collect results + mtimes := make(map[string]int64, len(paths)) + for res := range results { + mtimes[res.path] = res.mtime } return mtimes diff --git a/internal/worker/sdk/processor_test.go b/internal/worker/sdk/processor_test.go index f18e6c5..a57aa37 100644 --- a/internal/worker/sdk/processor_test.go +++ b/internal/worker/sdk/processor_test.go @@ -974,3 +974,110 @@ func TestSyncSummaryFuncType(t *testing.T) { } assert.NotNil(t, fn) } + +// TestSanitizePrompt tests prompt sanitization for CLI safety. +func TestSanitizePrompt(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "normal text", + input: "Hello, world!", + expected: "Hello, world!", + }, + { + name: "text with newlines", + input: "Line 1\nLine 2\nLine 3", + expected: "Line 1\nLine 2\nLine 3", + }, + { + name: "text with tabs", + input: "Key:\tValue", + expected: "Key:\tValue", + }, + { + name: "text with carriage return", + input: "Line 1\r\nLine 2", + expected: "Line 1\r\nLine 2", + }, + { + name: "text with null bytes", + input: "Hello\x00World", + expected: "HelloWorld", + }, + { + name: "text with control characters", + input: "Hello\x01\x02\x03World", + expected: "HelloWorld", + }, + { + name: "text with bell character", + input: "Hello\x07World", + expected: "HelloWorld", + }, + { + name: "text with backspace", + input: "Hello\x08World", + expected: "HelloWorld", + }, + { + name: "text with form feed", + input: "Hello\x0cWorld", + expected: "HelloWorld", + }, + { + name: "text with escape", + input: "Hello\x1bWorld", + expected: "HelloWorld", + }, + { + name: "unicode text", + input: "Hello 世界 🌍", + expected: "Hello 世界 🌍", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "only control characters", + input: "\x00\x01\x02\x03", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sanitizePrompt(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestMaxPromptSize tests that MaxPromptSize is reasonable. +func TestMaxPromptSize(t *testing.T) { + assert.Equal(t, 100*1024, MaxPromptSize) +} + +// BenchmarkSanitizePrompt benchmarks the sanitize function. +func BenchmarkSanitizePrompt(b *testing.B) { + prompt := "Analyze the following code:\n```go\nfunc main() {\n\tfmt.Println(\"Hello, World!\")\n}\n```\n\nPlease identify any issues." + + b.ResetTimer() + for i := 0; i < b.N; i++ { + sanitizePrompt(prompt) + } +} + +// BenchmarkSanitizePromptWithControlChars benchmarks sanitization with control characters. +func BenchmarkSanitizePromptWithControlChars(b *testing.B) { + prompt := "Hello\x00World\x01Test\x02Data\x03End" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + sanitizePrompt(prompt) + } +} diff --git a/internal/worker/service.go b/internal/worker/service.go index a6180a3..ba40292 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -42,18 +42,63 @@ const ( // QueueProcessInterval is how often the background queue processor runs. QueueProcessInterval = 2 * time.Second + + // VectorSyncMaxRetries is the maximum number of retries for vector sync operations. + VectorSyncMaxRetries = 3 + + // VectorSyncInitialBackoff is the initial backoff duration for retry. + VectorSyncInitialBackoff = 100 * time.Millisecond ) +// retryWithBackoff attempts the given function up to maxRetries times with exponential backoff. +// Returns nil on success, or the last error after all retries are exhausted. +func retryWithBackoff(ctx context.Context, maxRetries int, initialBackoff time.Duration, fn func() error) error { + var lastErr error + backoff := initialBackoff + + for i := 0; i < maxRetries; i++ { + if err := fn(); err == nil { + return nil + } else { + lastErr = err + } + + // Don't wait after the last attempt + if i < maxRetries-1 { + select { + case <-time.After(backoff): + backoff *= 2 // Exponential backoff + case <-ctx.Done(): + return lastErr + } + } + } + return lastErr +} + // RetrievalStats tracks observation retrieval metrics. type RetrievalStats struct { - TotalRequests int64 // Total retrieval requests (inject + search) - ObservationsServed int64 // Observations returned to clients - VerifiedStale int64 // Stale observations that passed verification - DeletedInvalid int64 // Invalid observations deleted - SearchRequests int64 // Semantic search requests - ContextInjections int64 // Session-start context injections + TotalRequests int64 // Total retrieval requests (inject + search) + ObservationsServed int64 // Observations returned to clients + VerifiedStale int64 // Stale observations that passed verification + DeletedInvalid int64 // Invalid observations deleted + SearchRequests int64 // Semantic search requests + ContextInjections int64 // Session-start context injections + StaleExcluded int64 // Observations excluded due to staleness check + FreshCount int64 // Observations that passed staleness check + DuplicatesRemoved int64 // Observations removed by clustering + LastUpdated int64 // Unix timestamp of last update (atomic) } +// maxRetrievalStatsProjects limits the number of projects tracked to prevent unbounded memory growth. +const maxRetrievalStatsProjects = 500 + +// retrievalStatsMaxAge is the maximum age for retrieval stats before cleanup (24 hours). +const retrievalStatsMaxAge = 24 * time.Hour + +// maxRecentQueries is the maximum number of recent queries to track. +const maxRecentQueries = 100 + // Service is the main worker service orchestrator. type Service struct { // Version of the worker binary @@ -124,6 +169,48 @@ type Service struct { // Self-updater updater *update.Updater + + // Rate limiting + rateLimiter *PerClientRateLimiter + expensiveOpLimiter *ExpensiveOperationLimiter + bulkOpLimiter *BulkOperationLimiter + + // Rebuild status tracking + rebuildStatus *RebuildStatus + rebuildStatusMu sync.RWMutex + + // Recent search query tracking (circular buffer for O(1) insertion) + recentQueriesBuf [maxRecentQueries]RecentSearchQuery // fixed-size circular buffer + recentQueriesHead int // index of most recent (newest) + recentQueriesLen int // current number of items + recentQueriesMu sync.RWMutex + + // Stats caching to reduce database load + cachedObsCounts map[string]cachedCount // per-project observation counts + cachedObsCountsMu sync.RWMutex + statsCacheTTL time.Duration + + // Vector sync worker pool - limits concurrent vector sync goroutines + vectorSyncSem chan struct{} // semaphore for rate limiting +} + +// cachedCount stores a cached count value with expiration. +type cachedCount struct { + count int + timestamp time.Time +} + +// RebuildStatus tracks the progress of vector rebuild operations. +type RebuildStatus struct { + InProgress bool `json:"in_progress"` + StartTime time.Time `json:"start_time,omitempty"` + Phase string `json:"phase,omitempty"` // "observations", "summaries", "prompts", "complete" + TotalSynced int `json:"total_synced"` + TotalErrors int `json:"total_errors"` + CurrentPhase int `json:"current_phase"` // 1, 2, 3 for the three phases + TotalPhases int `json:"total_phases"` // 3 + ElapsedMs int64 `json:"elapsed_ms,omitempty"` + EstimatedPct float64 `json:"estimated_pct,omitempty"` // 0-100 } // staleVerifyRequest represents a request to verify a stale observation in background @@ -132,6 +219,161 @@ type staleVerifyRequest struct { cwd string } +// RecentSearchQuery tracks a search query for analytics. +type RecentSearchQuery struct { + Query string `json:"query"` + Project string `json:"project,omitempty"` + Type string `json:"type,omitempty"` // observations, summaries, prompts + Results int `json:"results"` + UsedVector bool `json:"used_vector"` + Timestamp time.Time `json:"timestamp"` +} + +// asyncVectorSync executes a vector sync operation with rate limiting. +// This prevents goroutine explosion during bulk operations. +// All goroutines are tracked in s.wg for graceful shutdown. +func (s *Service) asyncVectorSync(fn func()) { + s.wg.Add(1) + if s.vectorSyncSem == nil { + // Fallback if semaphore not initialized + go func() { + defer s.wg.Done() + fn() + }() + return + } + + go func() { + defer s.wg.Done() + // Acquire semaphore slot + s.vectorSyncSem <- struct{}{} + defer func() { <-s.vectorSyncSem }() + + fn() + }() +} + +// setupVectorSyncCallbacks configures all vector sync related callbacks on stores and processors. +// This is extracted to avoid duplication between initializeAsync and reinitializeDatabase. +func (s *Service) setupVectorSyncCallbacks( + patternDetector *pattern.Detector, + patternStore *gorm.PatternStore, + observationStore *gorm.ObservationStore, + promptStore *gorm.PromptStore, + processor *sdk.Processor, + sessionManager *session.Manager, + vectorSync *sqlitevec.Sync, +) { + // Set pattern sync callback if vector sync is available + if patternDetector != nil && vectorSync != nil { + patternDetector.SetSyncFunc(func(p *models.Pattern) { + err := retryWithBackoff(s.ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error { + return vectorSync.SyncPattern(s.ctx, p) + }) + if err != nil { + log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec after retries") + } + }) + } + + // Set cleanup callback for pattern deletions + if patternStore != nil && vectorSync != nil { + patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { + err := retryWithBackoff(ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error { + return vectorSync.DeletePatterns(ctx, deletedIDs) + }) + if err != nil { + log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec after retries") + } + }) + } + + // Set vector sync callbacks on processor if both are available + if processor != nil && vectorSync != nil { + processor.SetSyncObservationFunc(func(obs *models.Observation) { + err := retryWithBackoff(s.ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error { + return vectorSync.SyncObservation(s.ctx, obs) + }) + if err != nil { + log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec after retries") + } + // Trigger pattern detection for the new observation + if patternDetector != nil { + s.wg.Add(1) // Track goroutine for graceful shutdown + go func(observation *models.Observation) { + defer s.wg.Done() + detectCtx, cancel := context.WithTimeout(s.ctx, 10*time.Second) + defer cancel() + if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil { + // Don't log context canceled errors during shutdown + if s.ctx.Err() == nil { + log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed") + } + } else if result.MatchedPattern != nil { + log.Debug(). + Int64("pattern_id", result.MatchedPattern.ID). + Str("pattern_name", result.MatchedPattern.Name). + Bool("is_new", result.IsNewPattern). + Msg("Pattern matched for observation") + } + }(obs) + } + }) + processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) { + err := retryWithBackoff(s.ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error { + return vectorSync.SyncSummary(s.ctx, summary) + }) + if err != nil { + log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to sqlite-vec after retries") + } + }) + } + + // Set cleanup callback on observation store to sync deletes to vector store + if observationStore != nil && vectorSync != nil { + observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { + err := retryWithBackoff(ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error { + return vectorSync.DeleteObservations(ctx, deletedIDs) + }) + if err != nil { + log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from sqlite-vec after retries") + } + }) + } + + // Set cleanup callback on prompt store to sync deletes to vector store + if promptStore != nil && vectorSync != nil { + promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { + err := retryWithBackoff(ctx, VectorSyncMaxRetries, VectorSyncInitialBackoff, func() error { + return vectorSync.DeleteUserPrompts(ctx, deletedIDs) + }) + if err != nil { + log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from sqlite-vec") + } + }) + } + + // Set callbacks for session lifecycle events + if sessionManager != nil { + sessionManager.SetOnSessionCreated(func(id int64) { + s.broadcastProcessingStatus() + s.sseBroadcaster.Broadcast(map[string]any{ + "type": "session", + "action": "created", + "id": id, + }) + }) + sessionManager.SetOnSessionDeleted(func(id int64) { + s.broadcastProcessingStatus() + s.sseBroadcaster.Broadcast(map[string]any{ + "type": "session", + "action": "deleted", + "id": id, + }) + }) + } +} + // NewService creates a new worker service with deferred initialization. // The service starts immediately with health endpoint available, // while database and SDK initialization happens in the background. @@ -149,16 +391,26 @@ func NewService(version string) (*Service, error) { homeDir, _ := os.UserHomeDir() installDir := fmt.Sprintf("%s/.claude/plugins/marketplaces/claude-mnemonic", homeDir) + // Create rate limiter with generous limits (100 req/sec, burst of 200) + // These limits are per-client and allow for intensive CLI usage + rateLimiter := NewPerClientRateLimiter(100.0, 200) + svc := &Service{ - version: version, - config: cfg, - sseBroadcaster: sseBroadcaster, - router: router, - ctx: ctx, - cancel: cancel, - startTime: time.Now(), - updater: update.New(version, installDir), - retrievalStats: make(map[string]*RetrievalStats), + version: version, + config: cfg, + sseBroadcaster: sseBroadcaster, + router: router, + ctx: ctx, + cancel: cancel, + startTime: time.Now(), + updater: update.New(version, installDir), + retrievalStats: make(map[string]*RetrievalStats), + rateLimiter: rateLimiter, + expensiveOpLimiter: NewExpensiveOperationLimiter(), + bulkOpLimiter: NewBulkOperationLimiter(60), // 60 second cooldown for bulk operations + cachedObsCounts: make(map[string]cachedCount), + statsCacheTTL: time.Minute, // Cache stats for 1 minute + vectorSyncSem: make(chan struct{}, 10), // Limit to 10 concurrent vector syncs } // Setup middleware and routes (health endpoint works immediately) @@ -263,7 +515,7 @@ func (s *Service) initializeAsync() { } else { processor = proc // Set broadcast callback for SSE events - processor.SetBroadcastFunc(func(event map[string]interface{}) { + processor.SetBroadcastFunc(func(event map[string]any) { s.sseBroadcaster.Broadcast(event) }) log.Info().Msg("SDK processor initialized") @@ -290,91 +542,12 @@ func (s *Service) initializeAsync() { // Initialize pattern detector patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig()) - // Set pattern sync callback if vector sync is available - if vectorSync != nil { - patternDetector.SetSyncFunc(func(p *models.Pattern) { - if err := vectorSync.SyncPattern(s.ctx, p); err != nil { - log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec") - } - }) - - // Set cleanup callback for pattern deletions - patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { - if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil { - log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec") - } - }) - } - s.initMu.Lock() s.patternDetector = patternDetector s.initMu.Unlock() - // Set vector sync callbacks on processor if both are available - if processor != nil && vectorSync != nil { - processor.SetSyncObservationFunc(func(obs *models.Observation) { - if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { - log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec") - } - // Trigger pattern detection for the new observation - if patternDetector != nil { - go func(observation *models.Observation) { - detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil { - log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed") - } else if result.MatchedPattern != nil { - log.Debug(). - Int64("pattern_id", result.MatchedPattern.ID). - Str("pattern_name", result.MatchedPattern.Name). - Bool("is_new", result.IsNewPattern). - Msg("Pattern matched for observation") - } - }(obs) - } - }) - processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) { - if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { - log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to sqlite-vec") - } - }) - } - - // Set cleanup callback on observation store to sync deletes to vector store - if observationStore != nil && vectorSync != nil { - observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { - if err := vectorSync.DeleteObservations(ctx, deletedIDs); err != nil { - log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from sqlite-vec") - } - }) - } - - // Set cleanup callback on prompt store to sync deletes to vector store - if promptStore != nil && vectorSync != nil { - promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { - if err := vectorSync.DeleteUserPrompts(ctx, deletedIDs); err != nil { - log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from sqlite-vec") - } - }) - } - - // Set callbacks for session lifecycle events - sessionManager.SetOnSessionCreated(func(id int64) { - s.broadcastProcessingStatus() - s.sseBroadcaster.Broadcast(map[string]interface{}{ - "type": "session", - "action": "created", - "id": id, - }) - }) - sessionManager.SetOnSessionDeleted(func(id int64) { - s.broadcastProcessingStatus() - s.sseBroadcaster.Broadcast(map[string]interface{}{ - "type": "session", - "action": "deleted", - "id": id, - }) - }) + // Setup all vector sync callbacks using the extracted helper (avoids code duplication) + s.setupVectorSyncCallbacks(patternDetector, patternStore, observationStore, promptStore, processor, sessionManager, vectorSync) // Initialize importance scoring system scoringConfig := models.DefaultScoringConfig() @@ -598,7 +771,7 @@ func (s *Service) reinitializeDatabase() { log.Warn().Err(err).Msg("SDK processor not available after reinit") } else { processor = proc - processor.SetBroadcastFunc(func(event map[string]interface{}) { + processor.SetBroadcastFunc(func(event map[string]any) { s.sseBroadcaster.Broadcast(event) }) } @@ -611,22 +784,6 @@ func (s *Service) reinitializeDatabase() { // Create new pattern detector patternDetector := pattern.NewDetector(patternStore, observationStore, pattern.DefaultConfig()) - // Set pattern sync callback if vector sync is available - if vectorSync != nil { - patternDetector.SetSyncFunc(func(p *models.Pattern) { - if err := vectorSync.SyncPattern(s.ctx, p); err != nil { - log.Warn().Err(err).Int64("id", p.ID).Msg("Failed to sync pattern to sqlite-vec") - } - }) - - // Set cleanup callback for pattern deletions - patternStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { - if err := vectorSync.DeletePatterns(ctx, deletedIDs); err != nil { - log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete patterns from sqlite-vec") - } - }) - } - // Atomically swap all components s.initMu.Lock() s.store = store @@ -650,78 +807,15 @@ func (s *Service) reinitializeDatabase() { // Start pattern detector patternDetector.Start() - // Set vector sync callbacks on processor if both are available - if processor != nil && vectorSync != nil { - processor.SetSyncObservationFunc(func(obs *models.Observation) { - if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { - log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to sqlite-vec") - } - // Trigger pattern detection for the new observation - if patternDetector != nil { - go func(observation *models.Observation) { - detectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if result, err := patternDetector.AnalyzeObservation(detectCtx, observation); err != nil { - log.Warn().Err(err).Int64("obs_id", observation.ID).Msg("Pattern detection failed") - } else if result.MatchedPattern != nil { - log.Debug(). - Int64("pattern_id", result.MatchedPattern.ID). - Str("pattern_name", result.MatchedPattern.Name). - Bool("is_new", result.IsNewPattern). - Msg("Pattern matched for observation") - } - }(obs) - } - }) - processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) { - if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { - log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to sqlite-vec") - } - }) - } - - // Set cleanup callback on observation store to sync deletes to vector store - if observationStore != nil && vectorSync != nil { - observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { - if err := vectorSync.DeleteObservations(ctx, deletedIDs); err != nil { - log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from sqlite-vec") - } - }) - } - - // Set cleanup callback on prompt store to sync deletes to vector store - if promptStore != nil && vectorSync != nil { - promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) { - if err := vectorSync.DeleteUserPrompts(ctx, deletedIDs); err != nil { - log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from sqlite-vec") - } - }) - } - - // Set callbacks for session lifecycle events - sessionManager.SetOnSessionCreated(func(id int64) { - s.broadcastProcessingStatus() - s.sseBroadcaster.Broadcast(map[string]interface{}{ - "type": "session", - "action": "created", - "id": id, - }) - }) - sessionManager.SetOnSessionDeleted(func(id int64) { - s.broadcastProcessingStatus() - s.sseBroadcaster.Broadcast(map[string]interface{}{ - "type": "session", - "action": "deleted", - "id": id, - }) - }) + // Setup all vector sync callbacks using the extracted helper (avoids code duplication) + s.setupVectorSyncCallbacks(patternDetector, patternStore, observationStore, promptStore, processor, sessionManager, vectorSync) // Mark as ready again s.ready.Store(true) log.Info().Msg("Database reinitialization complete") // Broadcast status update - s.sseBroadcaster.Broadcast(map[string]interface{}{ + s.sseBroadcaster.Broadcast(map[string]any{ "type": "database_reinitialized", "message": "Database was recreated after deletion", }) @@ -733,7 +827,7 @@ func (s *Service) reloadConfig() { log.Info().Msg("Config changed, triggering graceful restart...") // Broadcast notification - s.sseBroadcaster.Broadcast(map[string]interface{}{ + s.sseBroadcaster.Broadcast(map[string]any{ "type": "config_changed", "message": "Configuration changed, restarting worker...", }) @@ -796,6 +890,8 @@ func (s *Service) processStaleQueue() { // rebuildAllVectors rebuilds all vectors from observations, summaries, and prompts. // Called when the vectors table is empty (e.g., after migration 20 drops all vectors). +// Uses batch processing for memory efficiency during large rebuilds. +// Progress is tracked in s.rebuildStatus for visibility via /api/rebuild-status. func (s *Service) rebuildAllVectors( observationStore *gorm.ObservationStore, summaryStore *gorm.SummaryStore, @@ -804,61 +900,91 @@ func (s *Service) rebuildAllVectors( ) { defer s.wg.Done() - log.Info().Msg("Starting full vector rebuild...") + log.Info().Msg("Starting full vector rebuild with batch processing...") start := time.Now() + // Initialize rebuild status + s.rebuildStatusMu.Lock() + s.rebuildStatus = &RebuildStatus{ + InProgress: true, + StartTime: start, + Phase: "observations", + CurrentPhase: 1, + TotalPhases: 3, + } + s.rebuildStatusMu.Unlock() + var totalSynced int var syncErrors int - // Rebuild observations + // Use batch sync config for efficient processing + cfg := sqlitevec.DefaultBatchSyncConfig() + + // Phase 1: Rebuild observations using batch sync observations, err := observationStore.GetAllObservations(s.ctx) if err != nil { log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild") } else { - for _, obs := range observations { - if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { - log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild") - syncErrors++ - } else { - totalSynced++ - } - } - log.Info().Int("count", len(observations)).Msg("Rebuilt observation vectors") + synced, errors := vectorSync.BatchSyncObservations(s.ctx, observations, cfg) + totalSynced += synced + syncErrors += errors + log.Info().Int("synced", synced).Int("errors", errors).Int("total", len(observations)).Msg("Rebuilt observation vectors") } - // Rebuild summaries + // Update status for phase 2 + s.rebuildStatusMu.Lock() + s.rebuildStatus.Phase = "summaries" + s.rebuildStatus.CurrentPhase = 2 + s.rebuildStatus.TotalSynced = totalSynced + s.rebuildStatus.TotalErrors = syncErrors + s.rebuildStatus.ElapsedMs = time.Since(start).Milliseconds() + s.rebuildStatus.EstimatedPct = 33.3 + s.rebuildStatusMu.Unlock() + + // Phase 2: Rebuild summaries using batch sync summaries, err := summaryStore.GetAllSummaries(s.ctx) if err != nil { log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild") } else { - for _, summary := range summaries { - if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { - log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild") - syncErrors++ - } else { - totalSynced++ - } - } - log.Info().Int("count", len(summaries)).Msg("Rebuilt summary vectors") + synced, errors := vectorSync.BatchSyncSummaries(s.ctx, summaries, cfg) + totalSynced += synced + syncErrors += errors + log.Info().Int("synced", synced).Int("errors", errors).Int("total", len(summaries)).Msg("Rebuilt summary vectors") } - // Rebuild user prompts + // Update status for phase 3 + s.rebuildStatusMu.Lock() + s.rebuildStatus.Phase = "prompts" + s.rebuildStatus.CurrentPhase = 3 + s.rebuildStatus.TotalSynced = totalSynced + s.rebuildStatus.TotalErrors = syncErrors + s.rebuildStatus.ElapsedMs = time.Since(start).Milliseconds() + s.rebuildStatus.EstimatedPct = 66.6 + s.rebuildStatusMu.Unlock() + + // Phase 3: Rebuild user prompts using batch sync prompts, err := promptStore.GetAllPrompts(s.ctx) if err != nil { log.Error().Err(err).Msg("Failed to fetch prompts for vector rebuild") } else { - for _, prompt := range prompts { - if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil { - log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild") - syncErrors++ - } else { - totalSynced++ - } - } - log.Info().Int("count", len(prompts)).Msg("Rebuilt prompt vectors") + synced, errors := vectorSync.BatchSyncPrompts(s.ctx, prompts, cfg) + totalSynced += synced + syncErrors += errors + log.Info().Int("synced", synced).Int("errors", errors).Int("total", len(prompts)).Msg("Rebuilt prompt vectors") } elapsed := time.Since(start) + + // Mark rebuild as complete + s.rebuildStatusMu.Lock() + s.rebuildStatus.InProgress = false + s.rebuildStatus.Phase = "complete" + s.rebuildStatus.TotalSynced = totalSynced + s.rebuildStatus.TotalErrors = syncErrors + s.rebuildStatus.ElapsedMs = elapsed.Milliseconds() + s.rebuildStatus.EstimatedPct = 100 + s.rebuildStatusMu.Unlock() + log.Info(). Int("total_synced", totalSynced). Int("errors", syncErrors). @@ -918,11 +1044,20 @@ func (s *Service) rebuildStaleVectors( return } - var totalSynced int - var syncErrors int + var totalSynced atomic.Int64 + var syncErrors atomic.Int64 + var rebuildWg sync.WaitGroup + + // Rebuild all three document types in parallel + rebuildWg.Add(3) + + // Rebuild stale observations in parallel + go func() { + defer rebuildWg.Done() + if len(staleObsIDs) == 0 { + return + } - // Rebuild stale observations - if len(staleObsIDs) > 0 { ids := make([]int64, 0, len(staleObsIDs)) for id := range staleObsIDs { ids = append(ids, id) @@ -931,21 +1066,27 @@ func (s *Service) rebuildStaleVectors( observations, err := observationStore.GetObservationsByIDs(s.ctx, ids, "date_desc", 0) if err != nil { log.Error().Err(err).Msg("Failed to fetch observations for rebuild") - } else { - for _, obs := range observations { - if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { - log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild") - syncErrors++ - } else { - totalSynced++ - } - } - log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors") + return + } + + for _, obs := range observations { + if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { + log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild") + syncErrors.Add(1) + } else { + totalSynced.Add(1) + } + } + log.Info().Int("count", len(observations)).Msg("Rebuilt stale observation vectors") + }() + + // Rebuild stale summaries in parallel + go func() { + defer rebuildWg.Done() + if len(staleSummaryIDs) == 0 { + return } - } - // Rebuild stale summaries - if len(staleSummaryIDs) > 0 { ids := make([]int64, 0, len(staleSummaryIDs)) for id := range staleSummaryIDs { ids = append(ids, id) @@ -954,21 +1095,27 @@ func (s *Service) rebuildStaleVectors( summaries, err := summaryStore.GetSummariesByIDs(s.ctx, ids, "date_desc", 0) if err != nil { log.Error().Err(err).Msg("Failed to fetch summaries for rebuild") - } else { - for _, summary := range summaries { - if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { - log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild") - syncErrors++ - } else { - totalSynced++ - } - } - log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors") + return + } + + for _, summary := range summaries { + if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { + log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild") + syncErrors.Add(1) + } else { + totalSynced.Add(1) + } + } + log.Info().Int("count", len(summaries)).Msg("Rebuilt stale summary vectors") + }() + + // Rebuild stale prompts in parallel + go func() { + defer rebuildWg.Done() + if len(stalePromptIDs) == 0 { + return } - } - // Rebuild stale prompts - if len(stalePromptIDs) > 0 { ids := make([]int64, 0, len(stalePromptIDs)) for id := range stalePromptIDs { ids = append(ids, id) @@ -977,23 +1124,27 @@ func (s *Service) rebuildStaleVectors( prompts, err := promptStore.GetPromptsByIDs(s.ctx, ids, "date_desc", 0) if err != nil { log.Error().Err(err).Msg("Failed to fetch prompts for rebuild") - } else { - for _, prompt := range prompts { - if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil { - log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild") - syncErrors++ - } else { - totalSynced++ - } - } - log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors") + return } - } + + for _, prompt := range prompts { + if err := vectorSync.SyncUserPrompt(s.ctx, prompt); err != nil { + log.Warn().Err(err).Int64("id", prompt.ID).Msg("Failed to sync prompt during rebuild") + syncErrors.Add(1) + } else { + totalSynced.Add(1) + } + } + log.Info().Int("count", len(prompts)).Msg("Rebuilt stale prompt vectors") + }() + + // Wait for all three phases to complete + rebuildWg.Wait() elapsed := time.Since(start) log.Info(). - Int("total_synced", totalSynced). - Int("errors", syncErrors). + Int64("total_synced", totalSynced.Load()). + Int64("errors", syncErrors.Load()). Dur("elapsed", elapsed). Msg("Granular vector rebuild complete") } @@ -1039,9 +1190,30 @@ func (s *Service) verifyStaleObservation(req staleVerifyRequest) { // setupMiddleware configures HTTP middleware. func (s *Service) setupMiddleware() { + // Add request ID first so all subsequent logs can include it + s.router.Use(RequestID) + s.router.Use(middleware.Logger) s.router.Use(middleware.Recoverer) s.router.Use(middleware.RealIP) + + // Add security headers (X-Frame-Options, X-Content-Type-Options, CSP, etc.) + s.router.Use(SecurityHeaders) + + // Add request body size limit (10MB) to prevent DoS via large payloads + s.router.Use(MaxBodySize(10 * 1024 * 1024)) + + // Require JSON Content-Type for POST/PUT/PATCH requests + s.router.Use(RequireJSONContentType) + + // Add gzip compression for responses >1KB (reduces bandwidth ~70% for JSON) + s.router.Use(middleware.Compress(5)) // Level 5 = good balance of speed vs compression + + // Apply per-client rate limiting (after RealIP so we get the real client IP) + if s.rateLimiter != nil { + s.router.Use(PerClientRateLimitMiddleware(s.rateLimiter)) + } + // Note: Timeout middleware is applied per-route, not globally, // to avoid killing SSE connections which need to stay open indefinitely } @@ -1061,6 +1233,13 @@ func (s *Service) setupRoutes() { // Version endpoint for hooks to check if worker needs restart s.router.Get("/api/version", s.handleVersion) + // Rebuild status endpoint for visibility into vector rebuild progress + s.router.Get("/api/rebuild-status", s.handleRebuildStatus) + + // Vector management endpoints + s.router.Post("/api/vectors/rebuild", s.handleTriggerVectorRebuild) + s.router.Get("/api/vectors/health", s.handleVectorHealth) + // Readiness check - returns 200 only when fully initialized s.router.Get("/api/ready", s.handleReady) @@ -1094,6 +1273,8 @@ func (s *Service) setupRoutes() { // Data routes r.Get("/api/observations", s.handleGetObservations) + r.Get("/api/observations/{id}", s.handleGetObservationByID) + r.Put("/api/observations/{id}", s.handleUpdateObservation) r.Get("/api/summaries", s.handleGetSummaries) r.Get("/api/prompts", s.handleGetPrompts) r.Get("/api/projects", s.handleGetProjects) @@ -1118,6 +1299,7 @@ func (s *Service) setupRoutes() { r.Get("/api/context/count", s.handleContextCount) r.Get("/api/context/inject", s.handleContextInject) r.Get("/api/context/search", s.handleSearchByPrompt) + r.Get("/api/context/files", s.handleFileContext) // Pattern routes r.Get("/api/patterns", s.handleGetPatterns) @@ -1136,14 +1318,43 @@ func (s *Service) setupRoutes() { r.Get("/api/observations/{id}/relations", s.handleGetRelations) r.Get("/api/observations/{id}/graph", s.handleGetRelationGraph) r.Get("/api/observations/{id}/related", s.handleGetRelatedObservations) + + // Bulk import, export, and archival routes + r.Post("/api/observations/bulk-import", s.handleBulkImport) + r.Get("/api/observations/export", s.handleExportObservations) + r.Post("/api/observations/archive", s.handleArchiveObservations) + r.Post("/api/observations/{id}/unarchive", s.handleUnarchiveObservation) + r.Get("/api/observations/archived", s.handleGetArchivedObservations) + r.Get("/api/observations/archival-stats", s.handleGetArchivalStats) + + // Search analytics + r.Get("/api/search/recent", s.handleGetRecentQueries) + r.Get("/api/search/analytics", s.handleGetSearchAnalytics) + + // Duplicate detection + r.Get("/api/observations/duplicates", s.handleFindDuplicates) + + // Bulk status operations + r.Post("/api/observations/bulk-status", s.handleBulkStatusUpdate) }) } // recordRetrievalStats atomically updates retrieval statistics for a project. func (s *Service) recordRetrievalStats(project string, served, verified, deleted int64, isSearch bool) { + s.recordRetrievalStatsExtended(project, served, verified, deleted, 0, 0, 0, isSearch) +} + +// recordRetrievalStatsExtended records retrieval stats including staleness metrics. +func (s *Service) recordRetrievalStatsExtended(project string, served, verified, deleted, staleExcluded, freshCount, duplicatesRemoved int64, isSearch bool) { + now := time.Now().Unix() + s.retrievalStatsMu.Lock() stats := s.retrievalStats[project] if stats == nil { + // Cleanup old entries if we're at capacity + if len(s.retrievalStats) >= maxRetrievalStatsProjects { + s.cleanupRetrievalStatsLocked() + } stats = &RetrievalStats{} s.retrievalStats[project] = stats } @@ -1153,6 +1364,10 @@ func (s *Service) recordRetrievalStats(project string, served, verified, deleted atomic.AddInt64(&stats.ObservationsServed, served) atomic.AddInt64(&stats.VerifiedStale, verified) atomic.AddInt64(&stats.DeletedInvalid, deleted) + atomic.AddInt64(&stats.StaleExcluded, staleExcluded) + atomic.AddInt64(&stats.FreshCount, freshCount) + atomic.AddInt64(&stats.DuplicatesRemoved, duplicatesRemoved) + atomic.StoreInt64(&stats.LastUpdated, now) if isSearch { atomic.AddInt64(&stats.SearchRequests, 1) } else { @@ -1160,6 +1375,17 @@ func (s *Service) recordRetrievalStats(project string, served, verified, deleted } } +// cleanupRetrievalStatsLocked removes stale entries from retrievalStats. +// Must be called with retrievalStatsMu held. +func (s *Service) cleanupRetrievalStatsLocked() { + cutoff := time.Now().Add(-retrievalStatsMaxAge).Unix() + for project, stats := range s.retrievalStats { + if atomic.LoadInt64(&stats.LastUpdated) < cutoff { + delete(s.retrievalStats, project) + } + } +} + // GetRetrievalStats returns a copy of the retrieval stats for a project. // If project is empty, returns aggregate stats across all projects. func (s *Service) GetRetrievalStats(project string) RetrievalStats { @@ -1179,6 +1405,9 @@ func (s *Service) GetRetrievalStats(project string) RetrievalStats { DeletedInvalid: atomic.LoadInt64(&stats.DeletedInvalid), SearchRequests: atomic.LoadInt64(&stats.SearchRequests), ContextInjections: atomic.LoadInt64(&stats.ContextInjections), + StaleExcluded: atomic.LoadInt64(&stats.StaleExcluded), + FreshCount: atomic.LoadInt64(&stats.FreshCount), + DuplicatesRemoved: atomic.LoadInt64(&stats.DuplicatesRemoved), } } @@ -1191,10 +1420,129 @@ func (s *Service) GetRetrievalStats(project string) RetrievalStats { result.DeletedInvalid += atomic.LoadInt64(&stats.DeletedInvalid) result.SearchRequests += atomic.LoadInt64(&stats.SearchRequests) result.ContextInjections += atomic.LoadInt64(&stats.ContextInjections) + result.StaleExcluded += atomic.LoadInt64(&stats.StaleExcluded) + result.FreshCount += atomic.LoadInt64(&stats.FreshCount) + result.DuplicatesRemoved += atomic.LoadInt64(&stats.DuplicatesRemoved) } return result } +// trackSearchQuery records a search query for analytics using a circular buffer. +// O(1) insertion - no memory allocation or copying on each insert. +func (s *Service) trackSearchQuery(query, project, queryType string, results int, usedVector bool) { + s.recentQueriesMu.Lock() + defer s.recentQueriesMu.Unlock() + + // Move head back (wrapping around) and insert at new head position + // This puts the newest item at the head + s.recentQueriesHead = (s.recentQueriesHead - 1 + maxRecentQueries) % maxRecentQueries + + s.recentQueriesBuf[s.recentQueriesHead] = RecentSearchQuery{ + Query: query, + Project: project, + Type: queryType, + Results: results, + UsedVector: usedVector, + Timestamp: time.Now(), + } + + // Increase length up to max + if s.recentQueriesLen < maxRecentQueries { + s.recentQueriesLen++ + } +} + +// getRecentSearchQueries returns recent search queries, optionally filtered by project. +// Returns newest first. +func (s *Service) getRecentSearchQueries(project string, limit int) []RecentSearchQuery { + if limit <= 0 { + limit = 20 + } + if limit > maxRecentQueries { + limit = maxRecentQueries + } + + s.recentQueriesMu.RLock() + defer s.recentQueriesMu.RUnlock() + + if s.recentQueriesLen == 0 { + return nil + } + + if project == "" { + // Return all queries up to limit (newest first from circular buffer) + count := s.recentQueriesLen + if count > limit { + count = limit + } + result := make([]RecentSearchQuery, count) + for i := 0; i < count; i++ { + idx := (s.recentQueriesHead + i) % maxRecentQueries + result[i] = s.recentQueriesBuf[idx] + } + return result + } + + // Filter by project (iterate from newest to oldest) + result := make([]RecentSearchQuery, 0, limit) + for i := 0; i < s.recentQueriesLen; i++ { + idx := (s.recentQueriesHead + i) % maxRecentQueries + q := s.recentQueriesBuf[idx] + if q.Project == project { + result = append(result, q) + if len(result) >= limit { + break + } + } + } + return result +} + +// getCachedObservationCount returns observation count for a project, using cache if available. +// Falls back to database query if cache is expired or missing. +func (s *Service) getCachedObservationCount(ctx context.Context, project string) (int, error) { + // Check cache first + s.cachedObsCountsMu.RLock() + if cached, ok := s.cachedObsCounts[project]; ok { + if time.Since(cached.timestamp) < s.statsCacheTTL { + s.cachedObsCountsMu.RUnlock() + return cached.count, nil + } + } + s.cachedObsCountsMu.RUnlock() + + // Cache miss or expired - query database + count, err := s.observationStore.GetObservationCount(ctx, project) + if err != nil { + return 0, err + } + + // Update cache + s.cachedObsCountsMu.Lock() + s.cachedObsCounts[project] = cachedCount{ + count: count, + timestamp: time.Now(), + } + s.cachedObsCountsMu.Unlock() + + return count, nil +} + +// invalidateObsCountCache invalidates the observation count cache for a project. +// Call this when observations are added, archived, or deleted. +func (s *Service) invalidateObsCountCache(project string) { + s.cachedObsCountsMu.Lock() + delete(s.cachedObsCounts, project) + s.cachedObsCountsMu.Unlock() +} + +// invalidateAllObsCountCache clears all observation count caches. +func (s *Service) invalidateAllObsCountCache() { + s.cachedObsCountsMu.Lock() + s.cachedObsCounts = make(map[string]cachedCount) + s.cachedObsCountsMu.Unlock() +} + // Start starts the worker service. // The HTTP server starts immediately; database initialization happens async. func (s *Service) Start() error { @@ -1348,9 +1696,34 @@ func (s *Service) processAllSessions() { // Shutdown gracefully shuts down the service. func (s *Service) Shutdown(ctx context.Context) error { + log.Info().Msg("Starting graceful shutdown...") + start := time.Now() + + // Cancel context to signal all background goroutines s.cancel() - // Stop file watchers + // Create error collector + var shutdownErrors []error + var mu sync.Mutex + collectError := func(name string, err error) { + if err != nil { + mu.Lock() + shutdownErrors = append(shutdownErrors, fmt.Errorf("%s: %w", name, err)) + mu.Unlock() + log.Error().Err(err).Str("component", name).Msg("Shutdown error") + } + } + + // Phase 1: Stop accepting new work (HTTP server shutdown first) + log.Debug().Msg("Phase 1: Stopping HTTP server...") + if s.server != nil { + if err := s.server.Shutdown(ctx); err != nil { + collectError("http_server", err) + } + } + + // Phase 2: Stop file watchers (prevent new DB recreation) + log.Debug().Msg("Phase 2: Stopping watchers...") if s.dbWatcher != nil { _ = s.dbWatcher.Stop() } @@ -1358,55 +1731,69 @@ func (s *Service) Shutdown(ctx context.Context) error { _ = s.configWatcher.Stop() } - // Stop background recalculator + // Phase 3: Stop background workers (drain queues) + log.Debug().Msg("Phase 3: Stopping background workers...") if s.recalculator != nil { s.recalculator.Stop() } - - // Stop pattern detector if s.patternDetector != nil { s.patternDetector.Stop() } - // Shutdown all sessions - s.sessionManager.ShutdownAll(ctx) - - // Shutdown HTTP server - if s.server != nil { - if err := s.server.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("HTTP server shutdown error") - } + // Phase 4: Shutdown sessions (flush pending work) + log.Debug().Msg("Phase 4: Shutting down sessions...") + if s.sessionManager != nil { + s.sessionManager.ShutdownAll(ctx) } - // Close reranking service + // Phase 5: Wait for goroutines with timeout + log.Debug().Msg("Phase 5: Waiting for goroutines...") + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Debug().Msg("All goroutines finished") + case <-ctx.Done(): + log.Warn().Msg("Timeout waiting for goroutines - forcing shutdown") + } + + // Phase 6: Close AI/ML services (close models) + log.Debug().Msg("Phase 6: Closing AI/ML services...") if s.reranker != nil { - if err := s.reranker.Close(); err != nil { - log.Error().Err(err).Msg("Reranking service close error") - } + collectError("reranker", s.reranker.Close()) } - - // Close embedding service if s.embedSvc != nil { - if err := s.embedSvc.Close(); err != nil { - log.Error().Err(err).Msg("Embedding service close error") - } + collectError("embedding_service", s.embedSvc.Close()) } - // Close vector client + // Phase 7: Close vector client (no external process) + log.Debug().Msg("Phase 7: Closing vector client...") if s.vectorClient != nil { - if err := s.vectorClient.Close(); err != nil { - log.Error().Err(err).Msg("Vector client close error") - } + collectError("vector_client", s.vectorClient.Close()) } - // Close database - if err := s.store.Close(); err != nil { - log.Error().Err(err).Msg("Database close error") + // Phase 8: Close database last (other components may need it) + log.Debug().Msg("Phase 8: Closing database...") + if s.store != nil { + collectError("database", s.store.Close()) } - s.wg.Wait() + elapsed := time.Since(start) + if len(shutdownErrors) > 0 { + log.Warn(). + Int("errors", len(shutdownErrors)). + Dur("elapsed", elapsed). + Msg("Worker shutdown completed with errors") + return shutdownErrors[0] + } - log.Info().Msg("Worker service shutdown complete") + log.Info(). + Dur("elapsed", elapsed). + Msg("Worker service shutdown complete") return nil } @@ -1415,7 +1802,7 @@ func (s *Service) broadcastProcessingStatus() { isProcessing := s.sessionManager.IsAnySessionProcessing() queueDepth := s.sessionManager.GetTotalQueueDepth() - s.sseBroadcaster.Broadcast(map[string]interface{}{ + s.sseBroadcaster.Broadcast(map[string]any{ "type": "processing_status", "isProcessing": isProcessing, "queueDepth": queueDepth, diff --git a/internal/worker/static.go b/internal/worker/static.go index 1598b74..a70da3a 100644 --- a/internal/worker/static.go +++ b/internal/worker/static.go @@ -5,6 +5,8 @@ import ( "io/fs" "net/http" "strings" + + "github.com/rs/zerolog/log" ) //go:embed static/* @@ -13,16 +15,22 @@ var staticFS embed.FS // staticSubFS is the static subdirectory filesystem var staticSubFS fs.FS +// staticInitErr stores any error from static filesystem initialization +var staticInitErr error + func init() { - var err error - staticSubFS, err = fs.Sub(staticFS, "static") - if err != nil { - panic("failed to create sub filesystem: " + err.Error()) + staticSubFS, staticInitErr = fs.Sub(staticFS, "static") + if staticInitErr != nil { + log.Warn().Err(staticInitErr).Msg("Static filesystem initialization failed - dashboard will be unavailable") } } // serveIndex serves the index.html file for the root path func serveIndex(w http.ResponseWriter, r *http.Request) { + if staticInitErr != nil { + http.Error(w, "Dashboard unavailable: static files not initialized", http.StatusServiceUnavailable) + return + } content, err := fs.ReadFile(staticSubFS, "index.html") if err != nil { http.Error(w, "Dashboard not found", http.StatusNotFound) @@ -38,6 +46,10 @@ func serveIndex(w http.ResponseWriter, r *http.Request) { // serveAssets serves static assets from the embedded filesystem func serveAssets(w http.ResponseWriter, r *http.Request) { + if staticInitErr != nil { + http.Error(w, "Assets unavailable: static files not initialized", http.StatusServiceUnavailable) + return + } // Strip the /assets/ prefix and serve the file path := strings.TrimPrefix(r.URL.Path, "/") diff --git a/pkg/similarity/clustering.go b/pkg/similarity/clustering.go index cd1e6cf..3970b69 100644 --- a/pkg/similarity/clustering.go +++ b/pkg/similarity/clustering.go @@ -2,6 +2,7 @@ package similarity import ( + "math/bits" "strings" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" @@ -15,6 +16,17 @@ func ClusterObservations(observations []*models.Observation, similarityThreshold return observations } + // For small sets, use the simple O(n²) algorithm + if len(observations) <= 50 { + return clusterObservationsSimple(observations, similarityThreshold) + } + + // For larger sets, use an optimized approach with early termination + return clusterObservationsOptimized(observations, similarityThreshold) +} + +// clusterObservationsSimple is the simple O(n²) algorithm for small sets. +func clusterObservationsSimple(observations []*models.Observation, similarityThreshold float64) []*models.Observation { // Extract terms for each observation termSets := make([]map[string]bool, len(observations)) for i, obs := range observations { @@ -51,6 +63,93 @@ func ClusterObservations(observations []*models.Observation, similarityThreshold return result } +// clusterObservationsOptimized uses MinHash-based approximation for large sets. +// This reduces complexity from O(n²) to approximately O(n*k) where k is the number of hash functions. +func clusterObservationsOptimized(observations []*models.Observation, similarityThreshold float64) []*models.Observation { + n := len(observations) + + // Extract terms for each observation and compute a signature + type termSetWithSig struct { + terms map[string]bool + signature uint64 // Simple hash signature for fast comparison + } + + termSets := make([]termSetWithSig, n) + for i, obs := range observations { + terms := ExtractObservationTerms(obs) + termSets[i] = termSetWithSig{ + terms: terms, + signature: computeTermSignature(terms), + } + } + + // Track which observations are already clustered + clustered := make([]bool, n) + result := make([]*models.Observation, 0, n/2) // Pre-allocate assuming ~50% are unique + + for i := 0; i < n; i++ { + if clustered[i] { + continue + } + + // This observation becomes the representative of its cluster + result = append(result, observations[i]) + clustered[i] = true + + // Use signature for fast pre-filtering + sigI := termSets[i].signature + termsI := termSets[i].terms + + // Find all similar observations and mark them as clustered + for j := i + 1; j < n; j++ { + if clustered[j] { + continue + } + + // Quick signature comparison - if signatures are very different, skip detailed comparison + sigJ := termSets[j].signature + sigDiff := sigI ^ sigJ + popCount := popCount64(sigDiff) + + // If signatures differ significantly, similarity is likely low + // Skip detailed comparison for very different signatures + if popCount > 32 { // More than half of bits differ + continue + } + + // Full Jaccard comparison for candidates + similarity := JaccardSimilarity(termsI, termSets[j].terms) + if similarity >= similarityThreshold { + clustered[j] = true + } + } + } + + return result +} + +// computeTermSignature creates a quick hash signature for term sets. +// Used for fast pre-filtering in the optimized clustering algorithm. +func computeTermSignature(terms map[string]bool) uint64 { + var sig uint64 + for term := range terms { + // Simple hash using FNV-1a inspired approach + h := uint64(14695981039346656037) + for i := 0; i < len(term); i++ { + h ^= uint64(term[i]) + h *= 1099511628211 + } + sig ^= h + } + return sig +} + +// popCount64 counts the number of set bits in a 64-bit integer. +// Uses the stdlib bits.OnesCount64 which may use CPU POPCNT instruction. +func popCount64(x uint64) int { + return bits.OnesCount64(x) +} + // IsSimilarToAny checks if a new observation is similar to any existing observation. // Returns true if similarity to any existing observation exceeds the threshold. func IsSimilarToAny(newObs *models.Observation, existing []*models.Observation, similarityThreshold float64) bool { diff --git a/ui/package-lock.json b/ui/package-lock.json index 6a20f38..9207c57 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "claude-mnemonic-dashboard", - "version": "8fe9ea5-dirty", + "version": "v0.10.5-1-g7ab4b07-dirty", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "claude-mnemonic-dashboard", - "version": "8fe9ea5-dirty", + "version": "v0.10.5-1-g7ab4b07-dirty", "dependencies": { "vis-data": "^7.1.9", "vis-network": "^9.1.9", diff --git a/ui/package.json b/ui/package.json index 09687eb..21b0206 100644 --- a/ui/package.json +++ b/ui/package.json @@ -1,6 +1,6 @@ { "name": "claude-mnemonic-dashboard", - "version": "8fe9ea5-dirty", + "version": "v0.10.5-1-g7ab4b07-dirty", "private": true, "type": "module", "scripts": { diff --git a/ui/src/components/ObservationCard.vue b/ui/src/components/ObservationCard.vue index 9e36d04..8e9f9c9 100644 --- a/ui/src/components/ObservationCard.vue +++ b/ui/src/components/ObservationCard.vue @@ -8,6 +8,7 @@ import Card from './Card.vue' import IconBox from './IconBox.vue' import Badge from './Badge.vue' import RelationGraph from './RelationGraph.vue' +import ScoreBreakdown from './ScoreBreakdown.vue' import { computed, ref, onMounted } from 'vue' const props = defineProps<{ @@ -95,6 +96,9 @@ const relationsLoading = ref(false) const relationsExpanded = ref(false) const showGraph = ref(false) +// Score breakdown state +const showScoreBreakdown = ref(false) + const hasRelations = computed(() => relations.value.length > 0) const relationCount = computed(() => relations.value.length) @@ -350,14 +354,15 @@ const splitPath = (path: string, components = 3) => { - - + {{ currentScore.toFixed(2) }} - +
@@ -124,9 +145,18 @@ function getStatusColor(status: string): string {
-
- -

Memory Contents

+
+
+ +

Memory Contents

+
+
@@ -161,9 +191,18 @@ function getStatusColor(status: string): string {
-
- -

Retrieval Stats

+
+
+ +

Retrieval Stats

+
+
@@ -261,5 +300,25 @@ function getStatusColor(status: string): string {
+ + + + + + + + + diff --git a/ui/src/components/SystemHealthDetails.vue b/ui/src/components/SystemHealthDetails.vue new file mode 100644 index 0000000..74a48df --- /dev/null +++ b/ui/src/components/SystemHealthDetails.vue @@ -0,0 +1,249 @@ + + + diff --git a/ui/src/components/TopObservations.vue b/ui/src/components/TopObservations.vue new file mode 100644 index 0000000..5aad2d2 --- /dev/null +++ b/ui/src/components/TopObservations.vue @@ -0,0 +1,248 @@ + + + diff --git a/ui/src/utils/api.ts b/ui/src/utils/api.ts index a709bfc..009d79e 100644 --- a/ui/src/utils/api.ts +++ b/ui/src/utils/api.ts @@ -164,3 +164,119 @@ export async function fetchRelatedObservations(observationId: number, minConfide export async function fetchRelationStats(signal?: AbortSignal): Promise { return fetchWithRetry(`${API_BASE}/relations/stats`, { signal }) } + +// Scoring API functions +export interface ScoreBreakdown { + observation: { + id: number + title: string + type: string + project: string + created_at: number + } + scoring: { + final_score: number + type_weight: number + recency_decay: number + core_score: number + feedback_contrib: number + concept_contrib: number + retrieval_contrib: number + age_days: number + } + explanation: { + type_impact: string + recency_impact: string + feedback_impact: string + concept_impact: string + retrieval_impact: string + } +} + +export async function fetchObservationScore(observationId: number, signal?: AbortSignal): Promise { + return fetchWithRetry(`${API_BASE}/observations/${observationId}/score`, { signal }) +} + +export interface FeedbackStats { + total: number + positive: number + negative: number + neutral: number + avg_score: number + avg_retrieval: number +} + +export interface TopObservation { + id: number + title: string + type: string + importance_score: number + retrieval_count?: number +} + +export async function fetchScoringStats(project?: string, signal?: AbortSignal): Promise { + const params = new URLSearchParams() + if (project) params.append('project', project) + const query = params.toString() + return fetchWithRetry(`${API_BASE}/scoring/stats${query ? '?' + query : ''}`, { signal }) +} + +export async function fetchTopObservations(project?: string, limit: number = 10, signal?: AbortSignal): Promise { + const params = new URLSearchParams({ limit: String(limit) }) + if (project) params.append('project', project) + return fetchWithRetry(`${API_BASE}/observations/top?${params}`, { signal }) +} + +export async function fetchMostRetrievedObservations(project?: string, limit: number = 10, signal?: AbortSignal): Promise { + const params = new URLSearchParams({ limit: String(limit) }) + if (project) params.append('project', project) + return fetchWithRetry(`${API_BASE}/observations/most-retrieved?${params}`, { signal }) +} + +// Search Analytics API functions +export interface RecentQuery { + query: string + project?: string + type?: string + count: number + last_used: string +} + +export interface SearchAnalytics { + total_searches: number + vector_searches: number + filter_searches: number + cache_hits: number + coalesced_requests: number + search_errors: number + avg_latency_ms: number + avg_vector_latency_ms: number + avg_filter_latency_ms: number +} + +export async function fetchSearchAnalytics(signal?: AbortSignal): Promise { + return fetchWithRetry(`${API_BASE}/search/analytics`, { signal }) +} + +export async function fetchRecentSearches(limit: number = 20, signal?: AbortSignal): Promise { + return fetchWithRetry(`${API_BASE}/search/recent?limit=${limit}`, { signal }) +} + +// System health API +export interface ComponentHealth { + name: string + status: 'healthy' | 'degraded' | 'unhealthy' + message?: string + latency_ms?: number +} + +export interface SystemHealth { + status: 'healthy' | 'degraded' | 'unhealthy' + version: string + components: ComponentHealth[] + warnings?: string[] +} + +export async function fetchSystemHealth(signal?: AbortSignal): Promise { + return fetchWithRetry(`${API_BASE}/selfcheck`, { signal }) +} diff --git a/ui/tsconfig.tsbuildinfo b/ui/tsconfig.tsbuildinfo index 90b5efb..4a422c9 100644 --- a/ui/tsconfig.tsbuildinfo +++ b/ui/tsconfig.tsbuildinfo @@ -1 +1 @@ -{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"} \ No newline at end of file +{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/scorebreakdown.vue","./src/components/searchanalytics.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/systemhealthdetails.vue","./src/components/timeline.vue","./src/components/topobservations.vue"],"version":"5.7.3"} \ No newline at end of file