mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
Make things 'betterer' across the board
This commit is contained in:
@@ -122,6 +122,7 @@ func main() {
|
||||
startWatchers(ctx, dbPath)
|
||||
|
||||
// Create and run MCP server with all dependencies
|
||||
// Note: maintenanceService is nil because it runs in the worker process
|
||||
server := mcp.NewServer(
|
||||
searchMgr,
|
||||
Version,
|
||||
@@ -132,6 +133,7 @@ func main() {
|
||||
vectorClient,
|
||||
scoreCalculator,
|
||||
recalculator,
|
||||
nil, // maintenanceService - handled by worker
|
||||
)
|
||||
log.Info().Str("project", *project).Str("version", Version).Msg("Starting MCP server")
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ require (
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/sugarme/tokenizer v0.3.0
|
||||
github.com/yalue/onnxruntime_go v1.25.0
|
||||
golang.org/x/sync v0.19.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
)
|
||||
|
||||
@@ -52,6 +52,8 @@ github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c h1:pwb4kNSHb4K89
|
||||
github.com/sugarme/regexpset v0.0.0-20200920021344-4d4ec8eaf93c/go.mod h1:2gwkXLWbDGUQWeL3RtpCmcY4mzCtU13kb9UsAg9xMaw=
|
||||
github.com/yalue/onnxruntime_go v1.25.0 h1:nlhVau1BpLZ/BYr+WpPZCJRD/WES0qo6dK7aKyyAs3g=
|
||||
github.com/yalue/onnxruntime_go v1.25.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -71,6 +71,12 @@ type Config struct {
|
||||
ContextObsConcepts []string `json:"context_obs_concepts"`
|
||||
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion
|
||||
ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only)
|
||||
|
||||
// Maintenance settings
|
||||
MaintenanceEnabled bool `json:"maintenance_enabled"` // Enable scheduled maintenance
|
||||
MaintenanceIntervalHours int `json:"maintenance_interval_hours"` // How often to run maintenance (default 6 hours)
|
||||
ObservationRetentionDays int `json:"observation_retention_days"` // Delete observations older than N days (0 = no age-based deletion)
|
||||
CleanupStaleObservations bool `json:"cleanup_stale_observations"` // Auto-cleanup stale observations during maintenance
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -96,8 +102,9 @@ func SettingsPath() string {
|
||||
}
|
||||
|
||||
// EnsureDataDir creates the data directory if it doesn't exist.
|
||||
// Uses 0700 permissions (owner-only) for security.
|
||||
func EnsureDataDir() error {
|
||||
return os.MkdirAll(DataDir(), 0750)
|
||||
return os.MkdirAll(DataDir(), 0700)
|
||||
}
|
||||
|
||||
// EnsureSettings creates a default settings file if it doesn't exist.
|
||||
@@ -157,8 +164,12 @@ func Default() *Config {
|
||||
ContextShowLastSummary: true,
|
||||
ContextObsTypes: DefaultObservationTypes,
|
||||
ContextObsConcepts: DefaultObservationConcepts,
|
||||
ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include
|
||||
ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only)
|
||||
ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include
|
||||
ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only)
|
||||
MaintenanceEnabled: true, // Enable scheduled maintenance
|
||||
MaintenanceIntervalHours: 6, // Run every 6 hours
|
||||
ObservationRetentionDays: 0, // 0 = no age-based deletion (keep all)
|
||||
CleanupStaleObservations: false, // Don't auto-cleanup stale observations
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+54
-17
@@ -9,27 +9,13 @@ import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// EnsureSessionExists creates a session if it doesn't exist.
|
||||
// Uses INSERT OR IGNORE pattern for atomic idempotent creation (single query instead of COUNT + INSERT).
|
||||
// This is shared between stores to avoid duplication.
|
||||
func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project string) error {
|
||||
// Check if session exists
|
||||
var count int64
|
||||
err := db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("sdk_session_id = ?", sdkSessionID).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil // Session exists
|
||||
}
|
||||
|
||||
// Auto-create session
|
||||
now := time.Now()
|
||||
session := &SDKSession{
|
||||
ClaudeSessionID: sdkSessionID,
|
||||
@@ -41,7 +27,14 @@ func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project
|
||||
PromptCounter: 0,
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).Create(session).Error
|
||||
// Use INSERT OR IGNORE - single query, atomic operation
|
||||
// If session already exists (conflict on sdk_session_id), do nothing
|
||||
return db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "sdk_session_id"}},
|
||||
DoNothing: true,
|
||||
}).
|
||||
Create(session).Error
|
||||
}
|
||||
|
||||
// sqlNullString creates a sql.NullString from a string.
|
||||
@@ -52,8 +45,13 @@ func sqlNullString(s string) sql.NullString {
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
// MaxPaginationLimit is the maximum allowed limit for pagination queries.
|
||||
// This protects against resource exhaustion from excessively large requests.
|
||||
const MaxPaginationLimit = 1000
|
||||
|
||||
// ParseLimitParam parses the "limit" query parameter from an HTTP request.
|
||||
// Returns defaultLimit if the parameter is missing or invalid.
|
||||
// Note: This does NOT enforce a maximum limit. Use ParseLimitParamWithMax for that.
|
||||
func ParseLimitParam(r *http.Request, defaultLimit int) int {
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
@@ -62,3 +60,42 @@ func ParseLimitParam(r *http.Request, defaultLimit int) int {
|
||||
}
|
||||
return defaultLimit
|
||||
}
|
||||
|
||||
// ParseLimitParamWithMax parses the "limit" query parameter with a maximum cap.
|
||||
// Returns min(parsed, maxLimit) or defaultLimit if missing/invalid.
|
||||
// If maxLimit is 0, uses MaxPaginationLimit (1000).
|
||||
func ParseLimitParamWithMax(r *http.Request, defaultLimit, maxLimit int) int {
|
||||
if maxLimit <= 0 {
|
||||
maxLimit = MaxPaginationLimit
|
||||
}
|
||||
limit := ParseLimitParam(r, defaultLimit)
|
||||
if limit > maxLimit {
|
||||
return maxLimit
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// ParseOffsetParam parses the "offset" query parameter from an HTTP request.
|
||||
// Returns 0 if the parameter is missing or invalid.
|
||||
func ParseOffsetParam(r *http.Request) int {
|
||||
if o := r.URL.Query().Get("offset"); o != "" {
|
||||
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// PaginationParams holds pagination parameters.
|
||||
type PaginationParams struct {
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ParsePaginationParams parses both limit and offset from an HTTP request.
|
||||
func ParsePaginationParams(r *http.Request, defaultLimit int) PaginationParams {
|
||||
return PaginationParams{
|
||||
Limit: ParseLimitParam(r, defaultLimit),
|
||||
Offset: ParseOffsetParam(r),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,6 +322,244 @@ func runMigrations(db *gorm.DB, sqlDB *sql.DB) error {
|
||||
return tx.Migrator().DropTable("observation_relations")
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 012: Query optimization indexes
|
||||
// Adds covering and composite indexes for common query patterns
|
||||
{
|
||||
ID: "012_query_optimization_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Composite index for observation queries by project + scope + importance
|
||||
// Covers the common pattern: WHERE project = ? OR scope = 'global' ORDER BY importance_score DESC
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_project_scope_importance
|
||||
ON observations(project, scope, importance_score DESC, created_at_epoch DESC)`,
|
||||
|
||||
// Covering index for observation retrieval (includes most used columns)
|
||||
// Allows index-only scans for listing queries
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_project_covering
|
||||
ON observations(project, scope, is_superseded, importance_score DESC)
|
||||
WHERE is_superseded = 0 OR is_superseded IS NULL`,
|
||||
|
||||
// Index for session summary lookups
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_project_importance
|
||||
ON session_summaries(project, importance_score DESC, created_at_epoch DESC)`,
|
||||
|
||||
// Index for prompt retrieval by session
|
||||
`CREATE INDEX IF NOT EXISTS idx_prompts_session_number
|
||||
ON user_prompts(claude_session_id, prompt_number)`,
|
||||
|
||||
// Index for pattern queries by frequency
|
||||
`CREATE INDEX IF NOT EXISTS idx_patterns_frequency
|
||||
ON patterns(frequency DESC, last_seen_at_epoch DESC)
|
||||
WHERE is_deprecated = 0`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist or fail for benign reasons
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_project_scope_importance",
|
||||
"DROP INDEX IF EXISTS idx_observations_project_covering",
|
||||
"DROP INDEX IF EXISTS idx_summaries_project_importance",
|
||||
"DROP INDEX IF EXISTS idx_prompts_session_number",
|
||||
"DROP INDEX IF EXISTS idx_patterns_frequency",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 013: Add archival columns to observations
|
||||
{
|
||||
ID: "013_observation_archival",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Add archival columns
|
||||
`ALTER TABLE observations ADD COLUMN is_archived INTEGER DEFAULT 0`,
|
||||
`ALTER TABLE observations ADD COLUMN archived_at_epoch INTEGER`,
|
||||
`ALTER TABLE observations ADD COLUMN archived_reason TEXT`,
|
||||
// Index for archived observations
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_archived ON observations(is_archived)`,
|
||||
// Composite index for filtering active (non-archived) observations
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_active
|
||||
ON observations(project, is_archived, is_superseded, importance_score DESC)
|
||||
WHERE (is_archived = 0 OR is_archived IS NULL)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: column may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
// SQLite doesn't support DROP COLUMN in older versions
|
||||
// but for newer versions we can try
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_active",
|
||||
"DROP INDEX IF EXISTS idx_observations_archived",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 014: Add performance-critical indexes for common query patterns
|
||||
{
|
||||
ID: "014_performance_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Index for batch ID lookups (IN queries)
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_id_covering
|
||||
ON observations(id, project, scope, importance_score)`,
|
||||
|
||||
// Index for vector search result fetching
|
||||
`CREATE INDEX IF NOT EXISTS idx_vectors_doc_type_project
|
||||
ON vectors(doc_type, project, scope)`,
|
||||
|
||||
// Index for session summaries by project
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_project_created
|
||||
ON session_summaries(project, created_at_epoch DESC)`,
|
||||
|
||||
// Index for user prompts by session
|
||||
`CREATE INDEX IF NOT EXISTS idx_prompts_session_created
|
||||
ON user_prompts(claude_session_id, created_at_epoch DESC)`,
|
||||
|
||||
// Index for patterns by type and project
|
||||
`CREATE INDEX IF NOT EXISTS idx_patterns_type_project
|
||||
ON patterns(type, project, frequency DESC)
|
||||
WHERE is_deprecated = 0`,
|
||||
|
||||
// Index for observation relations
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_source_type
|
||||
ON observation_relations(source_observation_id, relation_type)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_target_type
|
||||
ON observation_relations(target_observation_id, relation_type)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_id_covering",
|
||||
"DROP INDEX IF EXISTS idx_vectors_doc_type_project",
|
||||
"DROP INDEX IF EXISTS idx_summaries_project_created",
|
||||
"DROP INDEX IF EXISTS idx_prompts_session_created",
|
||||
"DROP INDEX IF EXISTS idx_patterns_type_project",
|
||||
"DROP INDEX IF EXISTS idx_relations_source_type",
|
||||
"DROP INDEX IF EXISTS idx_relations_target_type",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 015: Add optimized composite indexes for common query patterns
|
||||
{
|
||||
ID: "015_optimized_composite_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Composite index for GetRecentObservations with project+scope filtering
|
||||
// Covers: WHERE (project = ? OR scope = 'global') ORDER BY importance_score DESC, created_at_epoch DESC
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_project_scope_created
|
||||
ON observations(project, scope, created_at_epoch DESC, importance_score DESC)`,
|
||||
|
||||
// Index for scope='global' queries (common pattern in search)
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_global_scope
|
||||
ON observations(scope, importance_score DESC, created_at_epoch DESC)
|
||||
WHERE scope = 'global'`,
|
||||
|
||||
// Index for vector search result deduplication by observation
|
||||
`CREATE INDEX IF NOT EXISTS idx_vectors_observation_lookup
|
||||
ON vectors(doc_type, sqlite_id, project)
|
||||
WHERE doc_type = 'observation'`,
|
||||
|
||||
// Index for FTS search result ordering
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_fts_ordering
|
||||
ON observations(project, importance_score DESC)
|
||||
WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_observations_project_scope_created",
|
||||
"DROP INDEX IF EXISTS idx_observations_global_scope",
|
||||
"DROP INDEX IF EXISTS idx_vectors_observation_lookup",
|
||||
"DROP INDEX IF EXISTS idx_observations_fts_ordering",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Migration 016: Add covering indexes for relation joins and active observations
|
||||
{
|
||||
ID: "016_relation_and_active_indexes",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
// Covering index for observation relation joins (common JOIN patterns)
|
||||
// Speeds up queries like: JOIN observation_relations ON source_id = obs.id WHERE relation_type = ?
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_source_type_target
|
||||
ON observation_relations(source_observation_id, relation_type, target_observation_id)`,
|
||||
|
||||
// Covering index for reverse relation lookups
|
||||
`CREATE INDEX IF NOT EXISTS idx_relations_target_type_source
|
||||
ON observation_relations(target_observation_id, relation_type, source_observation_id)`,
|
||||
|
||||
// Partial index for active (non-archived, non-superseded) observations
|
||||
// Optimizes activeObservationFilter queries
|
||||
`CREATE INDEX IF NOT EXISTS idx_observations_active
|
||||
ON observations(project, importance_score DESC, created_at_epoch DESC)
|
||||
WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`,
|
||||
}
|
||||
for _, s := range sqls {
|
||||
if err := tx.Exec(s).Error; err != nil {
|
||||
// Non-fatal: index may already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
sqls := []string{
|
||||
"DROP INDEX IF EXISTS idx_relations_source_type_target",
|
||||
"DROP INDEX IF EXISTS idx_relations_target_type_source",
|
||||
"DROP INDEX IF EXISTS idx_observations_active",
|
||||
}
|
||||
for _, s := range sqls {
|
||||
_ = tx.Exec(s).Error
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err := m.Migrate(); err != nil {
|
||||
|
||||
@@ -48,7 +48,7 @@ func (s *SDKSession) BeforeCreate(tx *gorm.DB) error {
|
||||
type Observation struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
SDKSessionID string `gorm:"index;not null"`
|
||||
Project string `gorm:"index;not null"`
|
||||
Project string `gorm:"index:idx_observations_project;index:idx_observations_project_created,priority:1;not null"`
|
||||
Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
|
||||
Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
|
||||
|
||||
@@ -66,7 +66,7 @@ type Observation struct {
|
||||
PromptNumber sql.NullInt64
|
||||
DiscoveryTokens int64 `gorm:"default:0"`
|
||||
CreatedAt string `gorm:"not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
|
||||
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;index:idx_observations_project_created,priority:2,sort:desc;not null"`
|
||||
|
||||
// Importance scoring fields
|
||||
ImportanceScore float64 `gorm:"type:real;default:1.0;index:idx_observations_importance,priority:1,sort:desc"`
|
||||
@@ -74,7 +74,12 @@ type Observation struct {
|
||||
RetrievalCount int `gorm:"default:0"`
|
||||
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
|
||||
ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
|
||||
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded,priority:1"`
|
||||
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded;index:idx_observations_active,priority:2"`
|
||||
|
||||
// Archival fields
|
||||
IsArchived int `gorm:"default:0;index:idx_observations_archived;index:idx_observations_active,priority:1"`
|
||||
ArchivedAt sql.NullInt64 `gorm:"column:archived_at_epoch"`
|
||||
ArchivedReason sql.NullString
|
||||
}
|
||||
|
||||
func (Observation) TableName() string { return "observations" }
|
||||
|
||||
@@ -5,17 +5,36 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MaxObservationsPerProject is the maximum number of observations to keep per project.
|
||||
const MaxObservationsPerProject = 100
|
||||
|
||||
// cleanupQueueSize is the buffer size for the cleanup queue.
|
||||
const cleanupQueueSize = 100
|
||||
|
||||
// commonWords is a package-level set for O(1) lookup of stop words.
|
||||
// Created once at init time to avoid repeated map allocations.
|
||||
var commonWords = map[string]struct{}{
|
||||
"the": {}, "and": {}, "or": {}, "but": {}, "in": {},
|
||||
"on": {}, "at": {}, "to": {}, "for": {}, "of": {},
|
||||
"with": {}, "by": {}, "from": {}, "as": {}, "is": {},
|
||||
"was": {}, "are": {}, "were": {}, "be": {}, "been": {},
|
||||
"being": {}, "have": {}, "has": {}, "had": {}, "do": {},
|
||||
"does": {}, "did": {}, "will": {}, "would": {}, "should": {},
|
||||
"could": {}, "may": {}, "might": {}, "must": {}, "can": {},
|
||||
}
|
||||
|
||||
// CleanupFunc is a callback for when observations are cleaned up.
|
||||
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
|
||||
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
|
||||
@@ -25,19 +44,89 @@ type ObservationStore struct {
|
||||
db *gorm.DB
|
||||
rawDB *sql.DB
|
||||
cleanupFunc CleanupFunc
|
||||
conflictStore interface{} // Placeholder for ConflictStore (Phase 4)
|
||||
relationStore interface{} // Placeholder for RelationStore (Phase 4)
|
||||
conflictStore any // Placeholder for ConflictStore (Phase 4)
|
||||
relationStore any // Placeholder for RelationStore (Phase 4)
|
||||
|
||||
// Cleanup queue for async observation cleanup with proper error handling
|
||||
cleanupQueue chan string
|
||||
cleanupWg sync.WaitGroup
|
||||
cleanupOnce sync.Once
|
||||
cleanupStarted atomic.Bool // tracks if cleanup worker was started
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// NewObservationStore creates a new observation store.
|
||||
// The conflictStore and relationStore parameters are optional (can be nil) and will be used in Phase 4.
|
||||
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore interface{}) *ObservationStore {
|
||||
return &ObservationStore{
|
||||
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore any) *ObservationStore {
|
||||
s := &ObservationStore{
|
||||
db: store.DB,
|
||||
rawDB: store.GetRawDB(),
|
||||
cleanupFunc: cleanupFunc,
|
||||
conflictStore: conflictStore,
|
||||
relationStore: relationStore,
|
||||
cleanupQueue: make(chan string, cleanupQueueSize),
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
// Start the cleanup worker
|
||||
s.startCleanupWorker()
|
||||
return s
|
||||
}
|
||||
|
||||
// startCleanupWorker starts the background cleanup worker.
|
||||
func (s *ObservationStore) startCleanupWorker() {
|
||||
s.cleanupOnce.Do(func() {
|
||||
s.cleanupStarted.Store(true)
|
||||
s.cleanupWg.Add(1)
|
||||
go s.cleanupWorker()
|
||||
})
|
||||
}
|
||||
|
||||
// cleanupWorker processes cleanup requests from the queue.
|
||||
func (s *ObservationStore) cleanupWorker() {
|
||||
defer s.cleanupWg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCleanup:
|
||||
// Drain remaining items in queue before exiting
|
||||
for {
|
||||
select {
|
||||
case project := <-s.cleanupQueue:
|
||||
s.processCleanup(project)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case project := <-s.cleanupQueue:
|
||||
s.processCleanup(project)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processCleanup performs the actual cleanup for a project.
|
||||
func (s *ObservationStore) processCleanup(project string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
deletedIDs, err := s.CleanupOldObservations(ctx, project)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("project", project).Msg("Failed to cleanup old observations")
|
||||
return
|
||||
}
|
||||
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(ctx, deletedIDs)
|
||||
log.Debug().Str("project", project).Int("count", len(deletedIDs)).Msg("Cleaned up old observations")
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup worker and waits for it to finish.
|
||||
// Safe to call even if the worker was never started.
|
||||
func (s *ObservationStore) Close() {
|
||||
// Only stop if worker was actually started to avoid deadlock
|
||||
if s.cleanupStarted.Load() {
|
||||
close(s.stopCleanup)
|
||||
s.cleanupWg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,16 +175,15 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Cleanup old observations beyond the limit for this project (async to not block handler)
|
||||
// Queue cleanup of old observations beyond the limit for this project (async to not block handler)
|
||||
if project != "" {
|
||||
go func(proj string) {
|
||||
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
deletedIDs, _ := s.CleanupOldObservations(cleanupCtx, proj)
|
||||
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
|
||||
s.cleanupFunc(cleanupCtx, deletedIDs)
|
||||
}
|
||||
}(project)
|
||||
select {
|
||||
case s.cleanupQueue <- project:
|
||||
// Successfully queued for cleanup
|
||||
default:
|
||||
// Queue is full, log a warning instead of silently dropping
|
||||
log.Warn().Str("project", project).Msg("Cleanup queue full, skipping cleanup for this observation")
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Conflict and relation detection intentionally omitted for now
|
||||
@@ -104,6 +192,89 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
|
||||
return dbObs.ID, nowEpoch, nil
|
||||
}
|
||||
|
||||
// ObservationUpdate contains fields that can be updated on an observation.
|
||||
// Only non-nil fields will be updated.
|
||||
type ObservationUpdate struct {
|
||||
Title *string // New title
|
||||
Subtitle *string // New subtitle
|
||||
Narrative *string // New narrative
|
||||
Facts *[]string // New facts (replaces existing)
|
||||
Concepts *[]string // New concepts (replaces existing)
|
||||
FilesRead *[]string // New files read (replaces existing)
|
||||
FilesModified *[]string // New files modified (replaces existing)
|
||||
Scope *string // New scope (project or global)
|
||||
}
|
||||
|
||||
// UpdateObservation updates an existing observation with the provided fields.
|
||||
// Only non-nil fields in the update struct will be modified.
|
||||
// Returns the updated observation or an error.
|
||||
func (s *ObservationStore) UpdateObservation(ctx context.Context, id int64, update *ObservationUpdate) (*models.Observation, error) {
|
||||
if update == nil {
|
||||
return nil, fmt.Errorf("update cannot be nil")
|
||||
}
|
||||
|
||||
// First, verify the observation exists
|
||||
var dbObs Observation
|
||||
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("observation not found: %d", id)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build update map with only provided fields
|
||||
updates := make(map[string]any)
|
||||
|
||||
if update.Title != nil {
|
||||
updates["title"] = sql.NullString{String: *update.Title, Valid: true}
|
||||
}
|
||||
if update.Subtitle != nil {
|
||||
updates["subtitle"] = sql.NullString{String: *update.Subtitle, Valid: true}
|
||||
}
|
||||
if update.Narrative != nil {
|
||||
updates["narrative"] = sql.NullString{String: *update.Narrative, Valid: true}
|
||||
}
|
||||
if update.Facts != nil {
|
||||
factsJSON, _ := json.Marshal(*update.Facts)
|
||||
updates["facts"] = string(factsJSON)
|
||||
}
|
||||
if update.Concepts != nil {
|
||||
conceptsJSON, _ := json.Marshal(*update.Concepts)
|
||||
updates["concepts"] = string(conceptsJSON)
|
||||
}
|
||||
if update.FilesRead != nil {
|
||||
filesReadJSON, _ := json.Marshal(*update.FilesRead)
|
||||
updates["files_read"] = string(filesReadJSON)
|
||||
}
|
||||
if update.FilesModified != nil {
|
||||
filesModifiedJSON, _ := json.Marshal(*update.FilesModified)
|
||||
updates["files_modified"] = string(filesModifiedJSON)
|
||||
}
|
||||
if update.Scope != nil {
|
||||
updates["scope"] = sql.NullString{String: *update.Scope, Valid: true}
|
||||
}
|
||||
|
||||
// Add updated_at timestamp
|
||||
updates["updated_at_epoch"] = sql.NullInt64{Int64: time.Now().Unix(), Valid: true}
|
||||
|
||||
if len(updates) == 0 {
|
||||
// Nothing to update, just return existing observation
|
||||
return toModelObservation(&dbObs), nil
|
||||
}
|
||||
|
||||
// Perform the update
|
||||
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("id = ?", id).Updates(updates).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to update observation: %w", err)
|
||||
}
|
||||
|
||||
// Fetch the updated observation
|
||||
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservation(&dbObs), nil
|
||||
}
|
||||
|
||||
// GetObservationByID retrieves an observation by its ID.
|
||||
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
|
||||
var dbObs Observation
|
||||
@@ -134,6 +305,8 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
||||
query = query.Order("created_at_epoch DESC")
|
||||
case "importance":
|
||||
query = query.Order("importance_score DESC, created_at_epoch DESC")
|
||||
case "score_desc":
|
||||
query = query.Order("importance_score DESC, created_at_epoch DESC")
|
||||
default:
|
||||
// Default: importance first, then recency
|
||||
query = query.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC")
|
||||
@@ -152,6 +325,60 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetObservationsByIDsPreserveOrder retrieves observations by IDs, preserving the input order.
|
||||
// This is useful when the caller has already sorted/ranked the IDs (e.g., by vector similarity).
|
||||
func (s *ObservationStore) GetObservationsByIDsPreserveOrder(ctx context.Context, ids []int64) ([]*models.Observation, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Fetch all observations in a single query
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build ID -> observation map for O(1) lookups
|
||||
obsMap := make(map[int64]*Observation, len(dbObservations))
|
||||
for i := range dbObservations {
|
||||
obsMap[int64(dbObservations[i].ID)] = &dbObservations[i]
|
||||
}
|
||||
|
||||
// Reconstruct in original order
|
||||
result := make([]*models.Observation, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if obs, ok := obsMap[id]; ok {
|
||||
result = append(result, toModelObservation(obs))
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BatchGetObservationsWithScores retrieves observations with associated scores.
|
||||
// Returns a map of ID -> observation for efficient lookup.
|
||||
func (s *ObservationStore) BatchGetObservationsWithScores(ctx context.Context, ids []int64) (map[int64]*models.Observation, error) {
|
||||
if len(ids) == 0 {
|
||||
return make(map[int64]*models.Observation), nil
|
||||
}
|
||||
|
||||
// Fetch all observations in a single query
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build result map
|
||||
result := make(map[int64]*models.Observation, len(dbObservations))
|
||||
for i := range dbObservations {
|
||||
result[int64(dbObservations[i].ID)] = toModelObservation(&dbObservations[i])
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetRecentObservations retrieves recent observations for a project.
|
||||
// This includes project-scoped observations for the specified project AND global observations.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
@@ -169,13 +396,13 @@ func (s *ObservationStore) GetRecentObservations(ctx context.Context, project st
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetActiveObservations retrieves recent non-superseded observations for a project.
|
||||
// This excludes observations that have been marked as superseded by newer ones.
|
||||
// GetActiveObservations retrieves recent non-superseded, non-archived observations for a project.
|
||||
// This excludes observations that have been marked as superseded or archived.
|
||||
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
|
||||
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Scopes(projectScopeFilter(project), notSupersededFilter(), importanceOrdering()).
|
||||
Scopes(projectScopeFilter(project), activeObservationFilter(), importanceOrdering()).
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
@@ -245,7 +472,57 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetAllRecentObservationsPaginated retrieves recent observations with pagination.
|
||||
func (s *ObservationStore) GetAllRecentObservationsPaginated(ctx context.Context, limit, offset int) ([]*models.Observation, int64, error) {
|
||||
var dbObservations []Observation
|
||||
var total int64
|
||||
|
||||
// Get total count
|
||||
if err := s.db.WithContext(ctx).Model(&Observation{}).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
err := s.db.WithContext(ctx).
|
||||
Scopes(importanceOrdering()).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), total, nil
|
||||
}
|
||||
|
||||
// GetObservationsByProjectStrictPaginated retrieves observations strictly from a project with pagination.
|
||||
func (s *ObservationStore) GetObservationsByProjectStrictPaginated(ctx context.Context, project string, limit, offset int) ([]*models.Observation, int64, error) {
|
||||
var dbObservations []Observation
|
||||
var total int64
|
||||
|
||||
// Get total count for project
|
||||
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("project = ?", project).Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("project = ?", project).
|
||||
Scopes(importanceOrdering()).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), total, nil
|
||||
}
|
||||
|
||||
// GetAllObservations retrieves all observations (for vector rebuild).
|
||||
// Note: For large datasets, prefer GetAllObservationsIterator to avoid memory issues.
|
||||
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
@@ -259,6 +536,51 @@ func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Ob
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetAllObservationsIterator returns observations in batches to avoid loading all into memory.
|
||||
// The callback is called for each batch. Return false from callback to stop iteration.
|
||||
// batchSize controls how many observations are loaded at once (default 500 if <= 0).
|
||||
func (s *ObservationStore) GetAllObservationsIterator(ctx context.Context, batchSize int, callback func([]*models.Observation) bool) error {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 500
|
||||
}
|
||||
|
||||
var lastID int64 = 0
|
||||
for {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
var dbObservations []Observation
|
||||
err := s.db.WithContext(ctx).
|
||||
Where("id > ?", lastID).
|
||||
Order("id ASC").
|
||||
Limit(batchSize).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(dbObservations) == 0 {
|
||||
break // No more observations
|
||||
}
|
||||
|
||||
// Update cursor for next batch
|
||||
lastID = dbObservations[len(dbObservations)-1].ID
|
||||
|
||||
// Convert and call callback
|
||||
observations := toModelObservations(dbObservations)
|
||||
if !callback(observations) {
|
||||
break // Callback requested stop
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchObservationsFTS performs full-text search on observations using FTS5.
|
||||
// Falls back to LIKE search if FTS5 fails.
|
||||
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
|
||||
@@ -314,14 +636,26 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
|
||||
}
|
||||
|
||||
// searchObservationsLike performs fallback LIKE search on observations using GORM.
|
||||
// Limits to 2 keywords to prevent expensive OR queries that SQLite optimizes poorly.
|
||||
// This is a fallback path when FTS returns no results, so we prioritize performance.
|
||||
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
|
||||
if len(keywords) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Limit keywords to prevent excessive OR conditions that hurt query planning.
|
||||
// SQLite performs significantly better with fewer LIKE conditions.
|
||||
// Using 2 instead of 5 reduces query complexity from O(15) to O(6) conditions
|
||||
// (each keyword creates 3 LIKE conditions for title, subtitle, narrative).
|
||||
const maxKeywords = 2
|
||||
if len(keywords) > maxKeywords {
|
||||
keywords = keywords[:maxKeywords]
|
||||
}
|
||||
|
||||
// Build LIKE conditions for each keyword
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
// Pre-allocate for efficiency: maxKeywords conditions × 3 args each + 1 project arg
|
||||
conditions := make([]string, 0, len(keywords))
|
||||
args := make([]any, 0, len(keywords)*3+1)
|
||||
|
||||
for _, kw := range keywords {
|
||||
pattern := "%" + kw + "%"
|
||||
@@ -358,6 +692,217 @@ func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// DeleteObservation deletes a single observation by ID.
|
||||
func (s *ObservationStore) DeleteObservation(ctx context.Context, id int64) error {
|
||||
result := s.db.WithContext(ctx).Delete(&Observation{}, id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsSuperseded marks an observation as superseded (stale).
|
||||
func (s *ObservationStore) MarkAsSuperseded(ctx context.Context, id int64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Update("is_superseded", 1)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsSupersededBatch marks multiple observations as superseded in a single query.
|
||||
// Returns the number of observations updated and any error.
|
||||
func (s *ObservationStore) MarkAsSupersededBatch(ctx context.Context, ids []int64) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id IN ?", ids).
|
||||
Update("is_superseded", 1)
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ArchiveObservation archives a single observation with an optional reason.
|
||||
func (s *ObservationStore) ArchiveObservation(ctx context.Context, id int64, reason string) error {
|
||||
updates := map[string]any{
|
||||
"is_archived": 1,
|
||||
"archived_at_epoch": time.Now().UnixMilli(),
|
||||
}
|
||||
if reason != "" {
|
||||
updates["archived_reason"] = reason
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnarchiveObservation restores an archived observation.
|
||||
func (s *ObservationStore) UnarchiveObservation(ctx context.Context, id int64) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&Observation{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"is_archived": 0,
|
||||
"archived_at_epoch": nil,
|
||||
"archived_reason": nil,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("observation %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ArchiveOldObservations archives observations older than the specified age.
|
||||
// Returns the count of archived observations and their IDs.
|
||||
func (s *ObservationStore) ArchiveOldObservations(ctx context.Context, project string, maxAgeDays int, reason string) ([]int64, error) {
|
||||
if maxAgeDays <= 0 {
|
||||
maxAgeDays = 90 // Default: archive observations older than 90 days
|
||||
}
|
||||
|
||||
cutoffEpoch := time.Now().AddDate(0, 0, -maxAgeDays).UnixMilli()
|
||||
if reason == "" {
|
||||
reason = fmt.Sprintf("auto-archived: older than %d days", maxAgeDays)
|
||||
}
|
||||
|
||||
var idsToArchive []int64
|
||||
|
||||
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Find observations to archive (not already archived, older than cutoff)
|
||||
query := tx.Model(&Observation{}).
|
||||
Where("created_at_epoch < ?", cutoffEpoch).
|
||||
Where("COALESCE(is_archived, 0) = 0")
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ?", project)
|
||||
}
|
||||
|
||||
if err := query.Pluck("id", &idsToArchive).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(idsToArchive) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Archive the observations
|
||||
now := time.Now().UnixMilli()
|
||||
return tx.Model(&Observation{}).
|
||||
Where("id IN ?", idsToArchive).
|
||||
Updates(map[string]any{
|
||||
"is_archived": 1,
|
||||
"archived_at_epoch": now,
|
||||
"archived_reason": reason,
|
||||
}).Error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return idsToArchive, nil
|
||||
}
|
||||
|
||||
// GetArchivedObservations retrieves archived observations for a project.
|
||||
func (s *ObservationStore) GetArchivedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
|
||||
var dbObservations []Observation
|
||||
query := s.db.WithContext(ctx).
|
||||
Where("COALESCE(is_archived, 0) = 1")
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ?", project)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Order("archived_at_epoch DESC").
|
||||
Limit(limit).
|
||||
Find(&dbObservations).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toModelObservations(dbObservations), nil
|
||||
}
|
||||
|
||||
// GetArchivalStats returns statistics about archived observations.
|
||||
// Optimized to use a single query instead of 4 separate queries.
|
||||
func (s *ObservationStore) GetArchivalStats(ctx context.Context, project string) (*ArchivalStats, error) {
|
||||
// Use a single query with conditional aggregation to get all stats at once.
|
||||
// This is much faster than 4 separate queries (saves 3 round trips).
|
||||
type statsResult struct {
|
||||
TotalCount int64
|
||||
ArchivedCount int64
|
||||
OldestEpoch *int64 // Pointer to handle NULL
|
||||
NewestEpoch *int64
|
||||
}
|
||||
|
||||
var result statsResult
|
||||
|
||||
query := s.db.WithContext(ctx).Model(&Observation{}).
|
||||
Select(`
|
||||
COUNT(*) as total_count,
|
||||
SUM(CASE WHEN COALESCE(is_archived, 0) = 1 THEN 1 ELSE 0 END) as archived_count,
|
||||
MIN(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as oldest_epoch,
|
||||
MAX(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as newest_epoch
|
||||
`)
|
||||
|
||||
if project != "" {
|
||||
query = query.Where("project = ?", project)
|
||||
}
|
||||
|
||||
if err := query.Scan(&result).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := &ArchivalStats{
|
||||
TotalCount: result.TotalCount,
|
||||
ArchivedCount: result.ArchivedCount,
|
||||
ActiveCount: result.TotalCount - result.ArchivedCount,
|
||||
}
|
||||
|
||||
if result.OldestEpoch != nil {
|
||||
stats.OldestArchivedEpoch = *result.OldestEpoch
|
||||
}
|
||||
if result.NewestEpoch != nil {
|
||||
stats.NewestArchivedEpoch = *result.NewestEpoch
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// ArchivalStats contains statistics about archived observations.
|
||||
type ArchivalStats struct {
|
||||
TotalCount int64 `json:"total_count"`
|
||||
ActiveCount int64 `json:"active_count"`
|
||||
ArchivedCount int64 `json:"archived_count"`
|
||||
OldestArchivedEpoch int64 `json:"oldest_archived_epoch,omitempty"`
|
||||
NewestArchivedEpoch int64 `json:"newest_archived_epoch,omitempty"`
|
||||
}
|
||||
|
||||
// CleanupOldObservations removes observations beyond the limit for a project.
|
||||
// Returns the IDs of deleted observations.
|
||||
func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) {
|
||||
@@ -418,10 +963,12 @@ func projectScopeFilter(project string) func(*gorm.DB) *gorm.DB {
|
||||
}
|
||||
}
|
||||
|
||||
// notSupersededFilter filters out superseded observations.
|
||||
func notSupersededFilter() func(*gorm.DB) *gorm.DB {
|
||||
// activeObservationFilter filters for active (non-archived, non-superseded) observations.
|
||||
// This is more efficient than chaining notSupersededFilter + notArchivedFilter
|
||||
// as it produces a single WHERE clause for the query optimizer.
|
||||
func activeObservationFilter() func(*gorm.DB) *gorm.DB {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
return db.Where("COALESCE(is_superseded, 0) = 0")
|
||||
return db.Where("COALESCE(is_archived, 0) = 0 AND COALESCE(is_superseded, 0) = 0")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -437,23 +984,17 @@ func importanceOrdering() func(*gorm.DB) *gorm.DB {
|
||||
// ====================
|
||||
|
||||
// extractKeywords extracts keywords from a search query.
|
||||
// Uses package-level commonWords map for O(1) stop word filtering.
|
||||
func extractKeywords(query string) []string {
|
||||
words := strings.Fields(strings.ToLower(query))
|
||||
var keywords []string
|
||||
|
||||
commonWords := map[string]bool{
|
||||
"the": true, "and": true, "or": true, "but": true, "in": true,
|
||||
"on": true, "at": true, "to": true, "for": true, "of": true,
|
||||
"with": true, "by": true, "from": true, "as": true, "is": true,
|
||||
"was": true, "are": true, "were": true, "be": true, "been": true,
|
||||
"being": true, "have": true, "has": true, "had": true, "do": true,
|
||||
"does": true, "did": true, "will": true, "would": true, "should": true,
|
||||
"could": true, "may": true, "might": true, "must": true, "can": true,
|
||||
}
|
||||
keywords := make([]string, 0, len(words)) // Pre-allocate for typical case
|
||||
|
||||
for _, word := range words {
|
||||
// Skip short words and common words
|
||||
if len(word) <= 3 || commonWords[word] {
|
||||
// Skip short words and common stop words
|
||||
if len(word) <= 3 {
|
||||
continue
|
||||
}
|
||||
if _, isCommon := commonWords[word]; isCommon {
|
||||
continue
|
||||
}
|
||||
keywords = append(keywords, word)
|
||||
@@ -464,7 +1005,8 @@ func extractKeywords(query string) []string {
|
||||
|
||||
// scanObservationRows scans multiple observations from raw SQL rows.
|
||||
func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
|
||||
var observations []*models.Observation
|
||||
// Pre-allocate with reasonable initial capacity to avoid repeated slice growth
|
||||
observations := make([]*models.Observation, 0, 64)
|
||||
for rows.Next() {
|
||||
obs, err := scanObservation(rows)
|
||||
if err != nil {
|
||||
@@ -476,7 +1018,7 @@ func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
|
||||
}
|
||||
|
||||
// scanObservation scans a single observation from a row scanner.
|
||||
func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.Observation, error) {
|
||||
func scanObservation(scanner interface{ Scan(...any) error }) (*models.Observation, error) {
|
||||
var obs models.Observation
|
||||
var factsJSON, conceptsJSON, filesReadJSON, filesModifiedJSON, fileMtimesJSON []byte
|
||||
var isSuperseded int
|
||||
|
||||
@@ -3,6 +3,8 @@ package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -59,12 +61,23 @@ func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64,
|
||||
}
|
||||
|
||||
// UpdateImportanceScores bulk updates importance scores for multiple observations.
|
||||
// This is more efficient than individual updates for batch recalculation.
|
||||
// Uses a single SQL statement with CASE/WHEN for efficient batch updates.
|
||||
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
|
||||
if len(scores) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For small batches, use simple individual updates
|
||||
if len(scores) <= 5 {
|
||||
return s.updateScoresIndividually(ctx, scores)
|
||||
}
|
||||
|
||||
// For larger batches, use CASE/WHEN SQL for single-query update
|
||||
return s.updateScoresBatch(ctx, scores)
|
||||
}
|
||||
|
||||
// updateScoresIndividually updates scores one at a time (efficient for small batches).
|
||||
func (s *ObservationStore) updateScoresIndividually(ctx context.Context, scores map[int64]float64) error {
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
@@ -85,6 +98,36 @@ func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores ma
|
||||
})
|
||||
}
|
||||
|
||||
// updateScoresBatch updates multiple scores in a single SQL statement using CASE/WHEN.
|
||||
// This is much more efficient for large batches (O(1) queries instead of O(n)).
|
||||
func (s *ObservationStore) updateScoresBatch(ctx context.Context, scores map[int64]float64) error {
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
// Build CASE/WHEN clause for importance_score
|
||||
// UPDATE observations SET
|
||||
// importance_score = CASE id WHEN 1 THEN 0.5 WHEN 2 THEN 0.8 ... END,
|
||||
// score_updated_at_epoch = ?
|
||||
// WHERE id IN (1, 2, ...)
|
||||
|
||||
ids := make([]int64, 0, len(scores))
|
||||
caseBuilder := strings.Builder{}
|
||||
caseBuilder.WriteString("CASE id ")
|
||||
|
||||
for id, score := range scores {
|
||||
ids = append(ids, id)
|
||||
caseBuilder.WriteString(fmt.Sprintf("WHEN %d THEN %f ", id, score))
|
||||
}
|
||||
caseBuilder.WriteString("END")
|
||||
|
||||
// Use raw SQL for the batch update
|
||||
sql := fmt.Sprintf(
|
||||
"UPDATE observations SET importance_score = %s, score_updated_at_epoch = ? WHERE id IN ?",
|
||||
caseBuilder.String(),
|
||||
)
|
||||
|
||||
return s.db.WithContext(ctx).Exec(sql, now, ids).Error
|
||||
}
|
||||
|
||||
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
|
||||
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
|
||||
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
|
||||
|
||||
@@ -116,26 +116,44 @@ func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID st
|
||||
}
|
||||
|
||||
// IncrementPromptCounter increments the prompt counter and returns the new value.
|
||||
// Uses a single SQL query with RETURNING clause for optimal performance.
|
||||
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
|
||||
// Atomic increment using GORM expression
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("id = ?", id).
|
||||
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
|
||||
// Use raw SQL with RETURNING to get updated value in single query
|
||||
// SQLite supports RETURNING since version 3.35.0 (2021-03-12)
|
||||
var newCounter int
|
||||
err := s.db.WithContext(ctx).Raw(`
|
||||
UPDATE sdk_sessions
|
||||
SET prompt_counter = COALESCE(prompt_counter, 0) + 1
|
||||
WHERE id = ?
|
||||
RETURNING prompt_counter
|
||||
`, id).Scan(&newCounter).Error
|
||||
|
||||
if err != nil {
|
||||
// Fallback for older SQLite versions without RETURNING support
|
||||
if err.Error() == "near \"RETURNING\": syntax error" || newCounter == 0 {
|
||||
// Atomic increment
|
||||
updateErr := s.db.WithContext(ctx).
|
||||
Model(&SDKSession{}).
|
||||
Where("id = ?", id).
|
||||
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
|
||||
if updateErr != nil {
|
||||
return 0, updateErr
|
||||
}
|
||||
|
||||
// Fetch updated value
|
||||
var sess SDKSession
|
||||
fetchErr := s.db.WithContext(ctx).
|
||||
Select("prompt_counter").
|
||||
First(&sess, id).Error
|
||||
if fetchErr != nil {
|
||||
return 0, fetchErr
|
||||
}
|
||||
return sess.PromptCounter, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Fetch updated value
|
||||
var sess SDKSession
|
||||
err = s.db.WithContext(ctx).
|
||||
Select("prompt_counter").
|
||||
First(&sess, id).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return sess.PromptCounter, nil
|
||||
return newCounter, nil
|
||||
}
|
||||
|
||||
// GetPromptCounter returns the current prompt counter for a session.
|
||||
|
||||
+428
-9
@@ -2,11 +2,16 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
_ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -14,8 +19,15 @@ import (
|
||||
|
||||
// Store represents the GORM database connection with sqlite-vec support.
|
||||
type Store struct {
|
||||
DB *gorm.DB
|
||||
sqlDB *sql.DB // For FTS5 and sqlite-vec operations that require raw SQL
|
||||
DB *gorm.DB
|
||||
sqlDB *sql.DB // For FTS5 and sqlite-vec operations that require raw SQL
|
||||
metrics *PoolMetrics
|
||||
|
||||
// Health check caching to reduce database load from frequent monitoring
|
||||
healthCacheMu sync.RWMutex
|
||||
cachedHealth *HealthInfo
|
||||
healthCacheTime time.Time
|
||||
healthCacheTTL time.Duration // Default: 5 seconds
|
||||
}
|
||||
|
||||
// Config holds database configuration.
|
||||
@@ -71,8 +83,10 @@ func NewStore(cfg Config) (*Store, error) {
|
||||
}
|
||||
|
||||
store := &Store{
|
||||
DB: db,
|
||||
sqlDB: sqlDB,
|
||||
DB: db,
|
||||
sqlDB: sqlDB,
|
||||
metrics: NewPoolMetrics(100), // Track last 100 latency samples
|
||||
healthCacheTTL: 5 * time.Second, // Cache health checks for 5 seconds
|
||||
}
|
||||
|
||||
// 7. Run migrations FIRST (before PRAGMA commands)
|
||||
@@ -80,18 +94,56 @@ func NewStore(cfg Config) (*Store, error) {
|
||||
return nil, fmt.Errorf("run migrations: %w", err)
|
||||
}
|
||||
|
||||
// 8. CRITICAL: Set WAL mode and synchronous mode via raw SQL
|
||||
// 8. CRITICAL: Set WAL mode and other performance pragmas
|
||||
// Use raw sqlDB to avoid GORM transaction issues
|
||||
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
return nil, fmt.Errorf("set WAL mode: %w", err)
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA synchronous=NORMAL",
|
||||
"PRAGMA cache_size=-64000", // 64MB cache (negative = KB)
|
||||
"PRAGMA temp_store=MEMORY", // Store temp tables in memory
|
||||
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
|
||||
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
|
||||
}
|
||||
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
|
||||
return nil, fmt.Errorf("set synchronous mode: %w", err)
|
||||
for _, pragma := range pragmas {
|
||||
if _, err := sqlDB.Exec(pragma); err != nil {
|
||||
log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)")
|
||||
}
|
||||
}
|
||||
|
||||
// 9. Warm the connection pool
|
||||
store.WarmPool(maxConns)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// WarmPool pre-creates connections to avoid cold start latency.
|
||||
func (s *Store) WarmPool(numConns int) {
|
||||
if numConns <= 0 {
|
||||
numConns = 4
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numConns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := s.sqlDB.Conn(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Execute a simple query to ensure the connection is fully initialized
|
||||
_ = conn.PingContext(ctx)
|
||||
// Return connection to pool (don't close it)
|
||||
_ = conn.Close()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
log.Debug().Int("connections", numConns).Msg("Connection pool warmed")
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (s *Store) Close() error {
|
||||
return s.sqlDB.Close()
|
||||
@@ -115,3 +167,370 @@ func (s *Store) GetRawDB() *sql.DB {
|
||||
func (s *Store) GetDB() *gorm.DB {
|
||||
return s.DB
|
||||
}
|
||||
|
||||
// Stats returns database connection pool statistics.
|
||||
func (s *Store) Stats() sql.DBStats {
|
||||
return s.sqlDB.Stats()
|
||||
}
|
||||
|
||||
// Optimize runs VACUUM and ANALYZE to optimize the database.
|
||||
// Should be called periodically (e.g., daily) during low activity.
|
||||
func (s *Store) Optimize(ctx context.Context) error {
|
||||
log.Info().Msg("Starting database optimization")
|
||||
start := time.Now()
|
||||
|
||||
// ANALYZE updates statistics for query optimizer
|
||||
if _, err := s.sqlDB.ExecContext(ctx, "ANALYZE"); err != nil {
|
||||
return fmt.Errorf("analyze: %w", err)
|
||||
}
|
||||
|
||||
// PRAGMA optimize runs optimization based on query statistics
|
||||
if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA optimize"); err != nil {
|
||||
log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)")
|
||||
}
|
||||
|
||||
log.Info().Dur("duration", time.Since(start)).Msg("Database optimization complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck performs a comprehensive health check with latency measurement.
|
||||
// Returns detailed health information including connection pool stats and query latency.
|
||||
// Results are cached for healthCacheTTL (default 5 seconds) to reduce database load
|
||||
// from frequent monitoring calls.
|
||||
func (s *Store) HealthCheck(ctx context.Context) *HealthInfo {
|
||||
// Fast path: check cache with read lock
|
||||
s.healthCacheMu.RLock()
|
||||
if s.cachedHealth != nil && time.Since(s.healthCacheTime) < s.healthCacheTTL {
|
||||
cached := s.cachedHealth
|
||||
s.healthCacheMu.RUnlock()
|
||||
return cached
|
||||
}
|
||||
s.healthCacheMu.RUnlock()
|
||||
|
||||
// Slow path: perform actual health check
|
||||
info := s.performHealthCheck(ctx)
|
||||
|
||||
// Cache the result
|
||||
s.healthCacheMu.Lock()
|
||||
s.cachedHealth = info
|
||||
s.healthCacheTime = time.Now()
|
||||
s.healthCacheMu.Unlock()
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// HealthCheckForce performs a health check bypassing the cache.
|
||||
// Use this when you need real-time health data (e.g., debugging, alerting).
|
||||
func (s *Store) HealthCheckForce(ctx context.Context) *HealthInfo {
|
||||
info := s.performHealthCheck(ctx)
|
||||
|
||||
// Update the cache with fresh data
|
||||
s.healthCacheMu.Lock()
|
||||
s.cachedHealth = info
|
||||
s.healthCacheTime = time.Now()
|
||||
s.healthCacheMu.Unlock()
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// performHealthCheck does the actual health check work.
|
||||
func (s *Store) performHealthCheck(ctx context.Context) *HealthInfo {
|
||||
info := &HealthInfo{
|
||||
Status: "healthy",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// Check pool stats
|
||||
stats := s.sqlDB.Stats()
|
||||
info.PoolStats = PoolStats{
|
||||
OpenConnections: stats.OpenConnections,
|
||||
InUse: stats.InUse,
|
||||
Idle: stats.Idle,
|
||||
WaitCount: stats.WaitCount,
|
||||
WaitDuration: stats.WaitDuration,
|
||||
MaxIdleClosed: stats.MaxIdleClosed,
|
||||
MaxLifetimeClosed: stats.MaxLifetimeClosed,
|
||||
}
|
||||
|
||||
// Record pool stats for metrics tracking
|
||||
if s.metrics != nil {
|
||||
s.metrics.RecordPoolStats(stats)
|
||||
}
|
||||
|
||||
// Measure query latency with a simple SELECT
|
||||
start := time.Now()
|
||||
var dummy int
|
||||
err := s.sqlDB.QueryRowContext(ctx, "SELECT 1").Scan(&dummy)
|
||||
info.QueryLatency = time.Since(start)
|
||||
|
||||
// Record latency for historical tracking
|
||||
if s.metrics != nil {
|
||||
s.metrics.RecordLatency(info.QueryLatency)
|
||||
info.HistoricalMetrics = s.metrics.GetMetricsSummary()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
info.Status = "unhealthy"
|
||||
info.Error = err.Error()
|
||||
return info
|
||||
}
|
||||
|
||||
// Check for connection saturation (degraded if pool is heavily used)
|
||||
if stats.InUse > 0 && float64(stats.InUse)/float64(stats.OpenConnections) > 0.8 {
|
||||
info.Status = "degraded"
|
||||
info.Warning = "Connection pool heavily utilized"
|
||||
}
|
||||
|
||||
// Check for wait contention
|
||||
if stats.WaitCount > 100 && stats.WaitDuration > 100*time.Millisecond {
|
||||
info.Status = "degraded"
|
||||
info.Warning = "Connection pool contention detected"
|
||||
}
|
||||
|
||||
// Check query latency (warn if > 10ms for simple query)
|
||||
if info.QueryLatency > 10*time.Millisecond {
|
||||
if info.Status == "healthy" {
|
||||
info.Status = "degraded"
|
||||
}
|
||||
info.Warning = fmt.Sprintf("Slow query latency: %v", info.QueryLatency)
|
||||
}
|
||||
|
||||
// Check historical latency trend (degraded if P95 is high)
|
||||
if s.metrics != nil && info.HistoricalMetrics.P95Latency > 50*time.Millisecond {
|
||||
if info.Status == "healthy" {
|
||||
info.Status = "degraded"
|
||||
}
|
||||
info.Warning = fmt.Sprintf("High P95 latency: %v", info.HistoricalMetrics.P95Latency)
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// HealthInfo contains database health check results.
|
||||
type HealthInfo struct {
|
||||
Status string `json:"status"` // healthy, degraded, unhealthy
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
QueryLatency time.Duration `json:"query_latency_ns"`
|
||||
PoolStats PoolStats `json:"pool_stats"`
|
||||
HistoricalMetrics MetricsSummary `json:"historical_metrics,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Warning string `json:"warning,omitempty"`
|
||||
}
|
||||
|
||||
// PoolStats contains connection pool statistics.
|
||||
type PoolStats struct {
|
||||
OpenConnections int `json:"open_connections"`
|
||||
InUse int `json:"in_use"`
|
||||
Idle int `json:"idle"`
|
||||
WaitCount int64 `json:"wait_count"`
|
||||
WaitDuration time.Duration `json:"wait_duration_ns"`
|
||||
MaxIdleClosed int64 `json:"max_idle_closed"`
|
||||
MaxLifetimeClosed int64 `json:"max_lifetime_closed"`
|
||||
}
|
||||
|
||||
// QueryTimeout constants for different query types.
|
||||
const (
|
||||
// DefaultQueryTimeout is the default timeout for regular queries.
|
||||
DefaultQueryTimeout = 5 * time.Second
|
||||
// FastQueryTimeout is for queries that should be very fast (health checks, etc).
|
||||
FastQueryTimeout = 1 * time.Second
|
||||
// SlowQueryTimeout is for queries that may take longer (bulk operations, rebuilds).
|
||||
SlowQueryTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// PoolMetrics tracks historical connection pool metrics with a sliding window.
|
||||
type PoolMetrics struct {
|
||||
mu sync.RWMutex
|
||||
latencySamples []time.Duration // Circular buffer of latency samples
|
||||
latencyIdx int // Current index in circular buffer
|
||||
latencyCount int // Number of samples collected
|
||||
totalQueries int64 // Total queries executed
|
||||
totalWaitTime time.Duration // Cumulative wait time for connections
|
||||
peakInUse int // Peak concurrent connections in use
|
||||
peakWaitCount int64 // Peak wait count observed
|
||||
lastSampleTime time.Time // Last time a sample was recorded
|
||||
windowSize int // Size of sliding window
|
||||
}
|
||||
|
||||
// NewPoolMetrics creates a new pool metrics collector with the given window size.
|
||||
func NewPoolMetrics(windowSize int) *PoolMetrics {
|
||||
if windowSize <= 0 {
|
||||
windowSize = 100 // Default: track last 100 samples
|
||||
}
|
||||
return &PoolMetrics{
|
||||
latencySamples: make([]time.Duration, windowSize),
|
||||
windowSize: windowSize,
|
||||
lastSampleTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordLatency records a query latency sample.
|
||||
func (m *PoolMetrics) RecordLatency(latency time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.latencySamples[m.latencyIdx] = latency
|
||||
m.latencyIdx = (m.latencyIdx + 1) % m.windowSize
|
||||
if m.latencyCount < m.windowSize {
|
||||
m.latencyCount++
|
||||
}
|
||||
m.totalQueries++
|
||||
m.lastSampleTime = time.Now()
|
||||
}
|
||||
|
||||
// RecordPoolStats records pool statistics for peak tracking.
|
||||
func (m *PoolMetrics) RecordPoolStats(stats sql.DBStats) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if stats.InUse > m.peakInUse {
|
||||
m.peakInUse = stats.InUse
|
||||
}
|
||||
if stats.WaitCount > m.peakWaitCount {
|
||||
m.peakWaitCount = stats.WaitCount
|
||||
}
|
||||
m.totalWaitTime += stats.WaitDuration
|
||||
}
|
||||
|
||||
// GetMetricsSummary returns a summary of collected metrics.
|
||||
func (m *PoolMetrics) GetMetricsSummary() MetricsSummary {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
summary := MetricsSummary{
|
||||
TotalQueries: m.totalQueries,
|
||||
SampleCount: m.latencyCount,
|
||||
PeakInUse: m.peakInUse,
|
||||
PeakWaitCount: m.peakWaitCount,
|
||||
TotalWaitTime: m.totalWaitTime,
|
||||
LastSampleTime: m.lastSampleTime,
|
||||
}
|
||||
|
||||
if m.latencyCount == 0 {
|
||||
return summary
|
||||
}
|
||||
|
||||
// Calculate latency statistics
|
||||
var total time.Duration
|
||||
var min, max time.Duration = m.latencySamples[0], m.latencySamples[0]
|
||||
|
||||
for i := 0; i < m.latencyCount; i++ {
|
||||
sample := m.latencySamples[i]
|
||||
total += sample
|
||||
if sample < min {
|
||||
min = sample
|
||||
}
|
||||
if sample > max {
|
||||
max = sample
|
||||
}
|
||||
}
|
||||
|
||||
summary.AvgLatency = total / time.Duration(m.latencyCount)
|
||||
summary.MinLatency = min
|
||||
summary.MaxLatency = max
|
||||
|
||||
// Calculate P95 latency (approximate using sorted samples)
|
||||
if m.latencyCount >= 20 {
|
||||
// Copy samples for sorting
|
||||
samples := make([]time.Duration, m.latencyCount)
|
||||
copy(samples, m.latencySamples[:m.latencyCount])
|
||||
// Use slices.Sort for O(n log n) instead of O(n²) insertion sort
|
||||
slices.Sort(samples)
|
||||
p95Idx := int(float64(len(samples)) * 0.95)
|
||||
summary.P95Latency = samples[p95Idx]
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
// MetricsSummary contains aggregated pool metrics.
|
||||
type MetricsSummary struct {
|
||||
TotalQueries int64 `json:"total_queries"`
|
||||
SampleCount int `json:"sample_count"`
|
||||
AvgLatency time.Duration `json:"avg_latency_ns"`
|
||||
MinLatency time.Duration `json:"min_latency_ns"`
|
||||
MaxLatency time.Duration `json:"max_latency_ns"`
|
||||
P95Latency time.Duration `json:"p95_latency_ns,omitempty"`
|
||||
PeakInUse int `json:"peak_in_use"`
|
||||
PeakWaitCount int64 `json:"peak_wait_count"`
|
||||
TotalWaitTime time.Duration `json:"total_wait_time_ns"`
|
||||
LastSampleTime time.Time `json:"last_sample_time"`
|
||||
}
|
||||
|
||||
// GetMetrics returns the current metrics without performing a health check.
|
||||
func (s *Store) GetMetrics() MetricsSummary {
|
||||
if s.metrics == nil {
|
||||
return MetricsSummary{}
|
||||
}
|
||||
return s.metrics.GetMetricsSummary()
|
||||
}
|
||||
|
||||
// ResetMetrics resets the metrics collector (useful for testing or after major changes).
|
||||
func (s *Store) ResetMetrics() {
|
||||
if s.metrics != nil {
|
||||
s.metrics = NewPoolMetrics(s.metrics.windowSize)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout wraps a context with the given timeout and logs slow queries.
|
||||
// Returns the wrapped context and a cancel function that should be called when done.
|
||||
func (s *Store) WithTimeout(ctx context.Context, timeout time.Duration, operation string) (context.Context, context.CancelFunc) {
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
start := time.Now()
|
||||
|
||||
// Return wrapped cancel that logs if query was slow
|
||||
return timeoutCtx, func() {
|
||||
elapsed := time.Since(start)
|
||||
cancel()
|
||||
|
||||
// Log slow queries (> 100ms)
|
||||
if elapsed > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Str("operation", operation).
|
||||
Dur("elapsed", elapsed).
|
||||
Dur("timeout", timeout).
|
||||
Msg("Slow database operation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExecWithTimeout executes a raw SQL query with timeout.
|
||||
// Returns error if query takes longer than timeout.
|
||||
func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) error {
|
||||
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "exec")
|
||||
defer cancel()
|
||||
|
||||
_, err := s.sqlDB.ExecContext(timeoutCtx, query, args...)
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
return fmt.Errorf("query timeout after %v: %s", timeout, query)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryRowWithTimeout executes a row query with timeout.
|
||||
func (s *Store) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) *sql.Row {
|
||||
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "query_row")
|
||||
// Note: cancel will be called when row.Scan() completes or errors
|
||||
_ = cancel // Caller must ensure proper cleanup
|
||||
return s.sqlDB.QueryRowContext(timeoutCtx, query, args...)
|
||||
}
|
||||
|
||||
// TransactionWithTimeout wraps a transaction function with timeout handling.
|
||||
// The transaction is automatically rolled back if the context times out.
|
||||
func (s *Store) TransactionWithTimeout(ctx context.Context, timeout time.Duration, fn func(*gorm.DB) error) error {
|
||||
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "transaction")
|
||||
defer cancel()
|
||||
|
||||
return s.DB.WithContext(timeoutCtx).Transaction(func(tx *gorm.DB) error {
|
||||
// Check context before proceeding
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return timeoutCtx.Err()
|
||||
default:
|
||||
}
|
||||
return fn(tx)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
// Package maintenance provides scheduled maintenance tasks for claude-mnemonic.
|
||||
package maintenance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Service handles scheduled maintenance tasks.
|
||||
type Service struct {
|
||||
store *gorm.Store
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
promptStore *gorm.PromptStore
|
||||
vectorCleanupFn func(ctx context.Context, deletedIDs []int64)
|
||||
config *config.Config
|
||||
log zerolog.Logger
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
|
||||
// Metrics
|
||||
lastRunTime time.Time
|
||||
lastRunDuration time.Duration
|
||||
totalCleanedObs int64
|
||||
totalOptimizeRun int64
|
||||
}
|
||||
|
||||
// NewService creates a new maintenance service.
|
||||
func NewService(
|
||||
store *gorm.Store,
|
||||
observationStore *gorm.ObservationStore,
|
||||
summaryStore *gorm.SummaryStore,
|
||||
promptStore *gorm.PromptStore,
|
||||
vectorCleanupFn func(ctx context.Context, deletedIDs []int64),
|
||||
cfg *config.Config,
|
||||
log zerolog.Logger,
|
||||
) *Service {
|
||||
return &Service{
|
||||
store: store,
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
promptStore: promptStore,
|
||||
vectorCleanupFn: vectorCleanupFn,
|
||||
config: cfg,
|
||||
log: log.With().Str("component", "maintenance").Logger(),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the maintenance loop.
|
||||
func (s *Service) Start(ctx context.Context) {
|
||||
s.mu.Lock()
|
||||
if s.running {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.running = true
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
s.mu.Unlock()
|
||||
close(s.doneCh)
|
||||
}()
|
||||
|
||||
if !s.config.MaintenanceEnabled {
|
||||
s.log.Info().Msg("Maintenance disabled, not starting scheduler")
|
||||
return
|
||||
}
|
||||
|
||||
interval := max(time.Duration(s.config.MaintenanceIntervalHours)*time.Hour, time.Hour)
|
||||
|
||||
s.log.Info().
|
||||
Dur("interval", interval).
|
||||
Int("retention_days", s.config.ObservationRetentionDays).
|
||||
Bool("cleanup_stale", s.config.CleanupStaleObservations).
|
||||
Msg("Starting maintenance scheduler")
|
||||
|
||||
// Initial run after 5 minutes (allow system to stabilize)
|
||||
time.Sleep(5 * time.Minute)
|
||||
s.runMaintenance(ctx)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.log.Info().Msg("Maintenance shutting down due to context cancellation")
|
||||
return
|
||||
case <-s.stopCh:
|
||||
s.log.Info().Msg("Maintenance shutting down due to stop signal")
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.runMaintenance(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop signals the maintenance service to stop.
|
||||
func (s *Service) Stop() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running {
|
||||
return
|
||||
}
|
||||
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// Wait waits for the maintenance service to finish.
|
||||
func (s *Service) Wait() {
|
||||
<-s.doneCh
|
||||
}
|
||||
|
||||
// runMaintenance executes all maintenance tasks.
|
||||
func (s *Service) runMaintenance(ctx context.Context) {
|
||||
start := time.Now()
|
||||
s.log.Info().Msg("Starting maintenance run")
|
||||
|
||||
var totalCleaned int64
|
||||
|
||||
// Task 1: Clean up old observations by age
|
||||
if s.config.ObservationRetentionDays > 0 {
|
||||
cleaned, err := s.cleanupOldObservations(ctx)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to cleanup old observations")
|
||||
} else {
|
||||
totalCleaned += cleaned
|
||||
s.log.Info().Int64("cleaned", cleaned).Msg("Cleaned old observations by age")
|
||||
}
|
||||
}
|
||||
|
||||
// Task 2: Clean up stale observations
|
||||
if s.config.CleanupStaleObservations {
|
||||
cleaned, err := s.cleanupStaleObservations(ctx)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to cleanup stale observations")
|
||||
} else {
|
||||
totalCleaned += cleaned
|
||||
s.log.Info().Int64("cleaned", cleaned).Msg("Cleaned stale observations")
|
||||
}
|
||||
}
|
||||
|
||||
// Task 3: Optimize database
|
||||
if err := s.store.Optimize(ctx); err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to optimize database")
|
||||
} else {
|
||||
s.totalOptimizeRun++
|
||||
}
|
||||
|
||||
// Task 4: Clean up old prompts (keep last 1000 per session)
|
||||
cleanedPrompts, err := s.cleanupOldPrompts(ctx)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msg("Failed to cleanup old prompts")
|
||||
} else if cleanedPrompts > 0 {
|
||||
s.log.Info().Int64("cleaned", cleanedPrompts).Msg("Cleaned old prompts")
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
s.mu.Lock()
|
||||
s.lastRunTime = time.Now()
|
||||
s.lastRunDuration = time.Since(start)
|
||||
s.totalCleanedObs += totalCleaned
|
||||
s.mu.Unlock()
|
||||
|
||||
s.log.Info().
|
||||
Dur("duration", time.Since(start)).
|
||||
Int64("observations_cleaned", totalCleaned).
|
||||
Msg("Maintenance run completed")
|
||||
}
|
||||
|
||||
// cleanupOldObservations deletes observations older than the retention period.
|
||||
func (s *Service) cleanupOldObservations(ctx context.Context) (int64, error) {
|
||||
cutoffEpoch := time.Now().AddDate(0, 0, -s.config.ObservationRetentionDays).Unix()
|
||||
|
||||
// Get IDs of old observations
|
||||
var deletedIDs []int64
|
||||
err := s.store.GetDB().WithContext(ctx).
|
||||
Model(&gorm.Observation{}).
|
||||
Where("created_at_epoch < ?", cutoffEpoch).
|
||||
Pluck("id", &deletedIDs).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(deletedIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Delete in batches to avoid long transactions
|
||||
batchSize := 100
|
||||
for i := 0; i < len(deletedIDs); i += batchSize {
|
||||
end := min(i+batchSize, len(deletedIDs))
|
||||
batch := deletedIDs[i:end]
|
||||
|
||||
if err := s.store.GetDB().WithContext(ctx).
|
||||
Where("id IN ?", batch).
|
||||
Delete(&gorm.Observation{}).Error; err != nil {
|
||||
return int64(i), err
|
||||
}
|
||||
|
||||
// Sync vector DB deletions
|
||||
if s.vectorCleanupFn != nil {
|
||||
s.vectorCleanupFn(ctx, batch)
|
||||
}
|
||||
}
|
||||
|
||||
return int64(len(deletedIDs)), nil
|
||||
}
|
||||
|
||||
// cleanupStaleObservations deletes observations marked as stale.
|
||||
func (s *Service) cleanupStaleObservations(ctx context.Context) (int64, error) {
|
||||
// Get IDs of stale observations (is_superseded = true)
|
||||
var deletedIDs []int64
|
||||
err := s.store.GetDB().WithContext(ctx).
|
||||
Model(&gorm.Observation{}).
|
||||
Where("is_superseded = ?", true).
|
||||
Pluck("id", &deletedIDs).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(deletedIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Delete in batches
|
||||
batchSize := 100
|
||||
for i := 0; i < len(deletedIDs); i += batchSize {
|
||||
end := min(i+batchSize, len(deletedIDs))
|
||||
batch := deletedIDs[i:end]
|
||||
|
||||
if err := s.store.GetDB().WithContext(ctx).
|
||||
Where("id IN ?", batch).
|
||||
Delete(&gorm.Observation{}).Error; err != nil {
|
||||
return int64(i), err
|
||||
}
|
||||
|
||||
// Sync vector DB deletions
|
||||
if s.vectorCleanupFn != nil {
|
||||
s.vectorCleanupFn(ctx, batch)
|
||||
}
|
||||
}
|
||||
|
||||
return int64(len(deletedIDs)), nil
|
||||
}
|
||||
|
||||
// cleanupOldPrompts removes old prompts keeping only the most recent per session.
|
||||
func (s *Service) cleanupOldPrompts(ctx context.Context) (int64, error) {
|
||||
// Delete prompts older than 30 days that aren't the most recent in their session
|
||||
cutoffEpoch := time.Now().AddDate(0, 0, -30).Unix()
|
||||
|
||||
result := s.store.GetDB().WithContext(ctx).
|
||||
Where("created_at_epoch < ?", cutoffEpoch).
|
||||
Delete(&gorm.UserPrompt{})
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// Stats returns maintenance statistics.
|
||||
func (s *Service) Stats() map[string]any {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return map[string]any{
|
||||
"enabled": s.config.MaintenanceEnabled,
|
||||
"interval_hours": s.config.MaintenanceIntervalHours,
|
||||
"retention_days": s.config.ObservationRetentionDays,
|
||||
"cleanup_stale": s.config.CleanupStaleObservations,
|
||||
"last_run": s.lastRunTime,
|
||||
"last_duration_ms": s.lastRunDuration.Milliseconds(),
|
||||
"total_cleaned_obs": s.totalCleanedObs,
|
||||
"total_optimizes": s.totalOptimizeRun,
|
||||
"running": s.running,
|
||||
}
|
||||
}
|
||||
|
||||
// RunNow triggers an immediate maintenance run.
|
||||
func (s *Service) RunNow(ctx context.Context) {
|
||||
go s.runMaintenance(ctx)
|
||||
}
|
||||
+2640
-45
File diff suppressed because it is too large
Load Diff
+20
-20
@@ -24,7 +24,7 @@ func TestServerSuite(t *testing.T) {
|
||||
|
||||
// TestNewServer tests server creation.
|
||||
func (s *ServerSuite) TestNewServer() {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
s.NotNil(server)
|
||||
s.Nil(server.searchMgr)
|
||||
s.Equal("1.0.0", server.version)
|
||||
@@ -293,7 +293,7 @@ func TestTimelineParams(t *testing.T) {
|
||||
|
||||
// TestHandleInitialize tests the initialize handler.
|
||||
func TestHandleInitialize(t *testing.T) {
|
||||
server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -320,7 +320,7 @@ func TestHandleInitialize(t *testing.T) {
|
||||
|
||||
// TestHandleToolsList tests the tools/list handler.
|
||||
func TestHandleToolsList(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -361,7 +361,7 @@ func TestHandleToolsList(t *testing.T) {
|
||||
|
||||
// TestHandleRequest tests request routing.
|
||||
func TestHandleRequest(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
@@ -423,7 +423,7 @@ func TestHandleRequest(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_InvalidParams tests tools/call with invalid params.
|
||||
func TestHandleToolsCall_InvalidParams(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
@@ -442,7 +442,7 @@ func TestHandleToolsCall_InvalidParams(t *testing.T) {
|
||||
|
||||
// TestCallTool_UnknownTool tests callTool with unknown tool name.
|
||||
func TestCallTool_UnknownTool(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`))
|
||||
@@ -452,7 +452,7 @@ func TestCallTool_UnknownTool(t *testing.T) {
|
||||
|
||||
// TestCallTool_InvalidArgs tests callTool with invalid arguments.
|
||||
func TestCallTool_InvalidArgs(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`))
|
||||
@@ -574,7 +574,7 @@ func TestJSONRPCErrorCodes(t *testing.T) {
|
||||
|
||||
// TestToolListContainsExpectedSchemas tests that tool schemas are valid.
|
||||
func TestToolListContainsExpectedSchemas(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -600,7 +600,7 @@ func TestToolListContainsExpectedSchemas(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
|
||||
func TestHandleToolsCall_UnknownTool(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
@@ -620,7 +620,7 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) {
|
||||
func TestCallTool_ToolNameRecognition(t *testing.T) {
|
||||
// Note: This test verifies tool routing logic, not execution (which requires searchMgr)
|
||||
// All valid tool names should be in the handleToolsList response
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -782,7 +782,7 @@ func TestResponseIDTypes(t *testing.T) {
|
||||
|
||||
// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
|
||||
func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty query should error
|
||||
@@ -793,7 +793,7 @@ func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
|
||||
|
||||
// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
|
||||
func TestHandleTimeline_InvalidJSON(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`))
|
||||
@@ -803,7 +803,7 @@ func TestHandleTimeline_InvalidJSON(t *testing.T) {
|
||||
|
||||
// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
|
||||
func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`))
|
||||
@@ -813,7 +813,7 @@ func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
|
||||
|
||||
// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
|
||||
func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// No anchor_id and no query should return empty result
|
||||
@@ -825,7 +825,7 @@ func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
|
||||
|
||||
// TestHandleTimeline_WithDefaults tests timeline default values are applied.
|
||||
func TestHandleTimeline_WithDefaults(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
// With anchor_id but no before/after, defaults should be applied
|
||||
@@ -839,7 +839,7 @@ func TestHandleTimeline_WithDefaults(t *testing.T) {
|
||||
|
||||
// TestServerFields tests Server struct fields.
|
||||
func TestServerFields(t *testing.T) {
|
||||
server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
assert.Equal(t, "2.0.0", server.version)
|
||||
assert.Nil(t, server.searchMgr)
|
||||
@@ -891,7 +891,7 @@ func TestErrorWithNilData(t *testing.T) {
|
||||
|
||||
// TestToolInputSchema tests that tool input schemas have required fields.
|
||||
func TestToolInputSchema(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
req := &Request{
|
||||
JSONRPC: "2.0",
|
||||
@@ -960,7 +960,7 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) {
|
||||
|
||||
// TestCallTool_UnknownToolName tests callTool with various unknown tool names.
|
||||
func TestCallTool_UnknownToolName(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
unknownTools := []string{
|
||||
@@ -1009,7 +1009,7 @@ func TestTimelineParams_Validation(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error.
|
||||
func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
@@ -1031,7 +1031,7 @@ func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
|
||||
|
||||
// TestHandleToolsCall_EmptyParams tests tools/call with empty params.
|
||||
func TestHandleToolsCall_EmptyParams(t *testing.T) {
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
|
||||
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
req := &Request{
|
||||
|
||||
@@ -3,6 +3,8 @@ package pattern
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +23,8 @@ type DetectorConfig struct {
|
||||
AnalysisInterval time.Duration
|
||||
// MaxPatternsToTrack is the maximum number of active patterns.
|
||||
MaxPatternsToTrack int
|
||||
// MaxCandidates is the maximum number of candidates to track (LRU eviction).
|
||||
MaxCandidates int
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default detector configuration.
|
||||
@@ -30,6 +34,7 @@ func DefaultConfig() DetectorConfig {
|
||||
MinFrequencyForPattern: 2, // At least 2 occurrences to form a pattern
|
||||
AnalysisInterval: 5 * time.Minute,
|
||||
MaxPatternsToTrack: 1000,
|
||||
MaxCandidates: 500, // Prevent unbounded growth
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,7 +206,24 @@ func (d *Detector) AnalyzeObservation(ctx context.Context, obs *models.Observati
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Create new candidate
|
||||
// Create new candidate - with immediate size check to prevent unbounded growth
|
||||
// between periodic cleanups (which run every 5 minutes)
|
||||
if d.config.MaxCandidates > 0 && len(d.candidates) >= d.config.MaxCandidates {
|
||||
// Evict oldest candidate immediately rather than waiting for periodic cleanup
|
||||
var oldestKey string
|
||||
var oldestTime int64 = time.Now().UnixMilli()
|
||||
for k, c := range d.candidates {
|
||||
if c.lastSeenEpoch < oldestTime {
|
||||
oldestTime = c.lastSeenEpoch
|
||||
oldestKey = k
|
||||
}
|
||||
}
|
||||
if oldestKey != "" {
|
||||
delete(d.candidates, oldestKey)
|
||||
log.Debug().Str("evicted_key", oldestKey).Msg("Evicted oldest candidate to make room")
|
||||
}
|
||||
}
|
||||
|
||||
patternType := models.DetectPatternType(obs.Concepts, obs.Title.String, obs.Narrative.String)
|
||||
d.candidates[candidateKey] = &candidatePattern{
|
||||
signature: signature,
|
||||
@@ -305,17 +327,53 @@ func (d *Detector) AnalyzeRecentObservations(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOldCandidates removes candidates that haven't been seen recently.
|
||||
// cleanupOldCandidates removes candidates that haven't been seen recently
|
||||
// and enforces the max candidates limit using LRU eviction.
|
||||
func (d *Detector) cleanupOldCandidates() {
|
||||
d.candidatesMu.Lock()
|
||||
defer d.candidatesMu.Unlock()
|
||||
|
||||
threshold := time.Now().Add(-7 * 24 * time.Hour).UnixMilli()
|
||||
|
||||
// First pass: remove expired candidates
|
||||
for key, candidate := range d.candidates {
|
||||
if candidate.lastSeenEpoch < threshold {
|
||||
delete(d.candidates, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: enforce max candidates limit using LRU eviction
|
||||
if d.config.MaxCandidates > 0 && len(d.candidates) > d.config.MaxCandidates {
|
||||
// Find oldest candidates to evict using O(n log n) sort instead of O(n²) selection sort
|
||||
type keyAge struct {
|
||||
key string
|
||||
age int64
|
||||
}
|
||||
candidates := make([]keyAge, 0, len(d.candidates))
|
||||
for k, c := range d.candidates {
|
||||
candidates = append(candidates, keyAge{k, c.lastSeenEpoch})
|
||||
}
|
||||
|
||||
// Sort by age ascending (oldest first) - O(n log n)
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].age < candidates[j].age
|
||||
})
|
||||
|
||||
// Delete oldest entries
|
||||
toEvict := len(d.candidates) - d.config.MaxCandidates
|
||||
for i := 0; i < toEvict; i++ {
|
||||
delete(d.candidates, candidates[i].key)
|
||||
}
|
||||
|
||||
log.Debug().Int("evicted", toEvict).Int("remaining", len(d.candidates)).Msg("Evicted old pattern candidates (LRU)")
|
||||
}
|
||||
}
|
||||
|
||||
// CandidateCount returns the current number of pattern candidates.
|
||||
func (d *Detector) CandidateCount() int {
|
||||
d.candidatesMu.RLock()
|
||||
defer d.candidatesMu.RUnlock()
|
||||
return len(d.candidates)
|
||||
}
|
||||
|
||||
// GetPatternInsight returns a formatted insight string for a pattern.
|
||||
@@ -340,11 +398,15 @@ func generateCandidateKey(signature []string) string {
|
||||
if len(signature) == 0 {
|
||||
return ""
|
||||
}
|
||||
key := ""
|
||||
// Use strings.Builder to avoid O(n²) string concatenation
|
||||
var b strings.Builder
|
||||
// Pre-allocate: estimate average signature element is 10 chars + separator
|
||||
b.Grow(len(signature) * 11)
|
||||
for _, s := range signature {
|
||||
key += s + "|"
|
||||
b.WriteString(s)
|
||||
b.WriteByte('|')
|
||||
}
|
||||
return key
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// generatePatternName creates a human-readable name for a pattern.
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
// Package privacy provides utilities for protecting sensitive data.
|
||||
package privacy
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// secretPatterns contains compiled regular expressions for detecting secrets.
|
||||
// These patterns are designed to catch common secret formats with minimal false positives.
|
||||
var secretPatterns = []*regexp.Regexp{
|
||||
// API keys with common prefixes
|
||||
regexp.MustCompile(`(?i)(api[_-]?key|apikey)\s*[:=]\s*['"]?[a-zA-Z0-9_-]{20,}['"]?`),
|
||||
|
||||
// Passwords in configuration
|
||||
regexp.MustCompile(`(?i)(password|passwd|pwd)\s*[:=]\s*['"][^'"]{8,}['"]`),
|
||||
|
||||
// Secret tokens
|
||||
regexp.MustCompile(`(?i)(secret[_-]?key|secret[_-]?token|auth[_-]?token)\s*[:=]\s*['"]?[a-zA-Z0-9_-]{20,}['"]?`),
|
||||
|
||||
// OpenAI API keys
|
||||
regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}`),
|
||||
|
||||
// Anthropic API keys
|
||||
regexp.MustCompile(`sk-ant-[a-zA-Z0-9-]{20,}`),
|
||||
|
||||
// GitHub tokens
|
||||
regexp.MustCompile(`gh[pous]_[a-zA-Z0-9]{36,}`),
|
||||
regexp.MustCompile(`github_pat_[a-zA-Z0-9_]{22,}`),
|
||||
|
||||
// AWS keys
|
||||
regexp.MustCompile(`AKIA[0-9A-Z]{16}`),
|
||||
regexp.MustCompile(`(?i)aws[_-]?secret[_-]?access[_-]?key\s*[:=]\s*['"]?[a-zA-Z0-9/+=]{40}['"]?`),
|
||||
|
||||
// Private keys (PEM format indicators)
|
||||
regexp.MustCompile(`-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`),
|
||||
|
||||
// JWT tokens (base64.base64.base64 format)
|
||||
regexp.MustCompile(`eyJ[a-zA-Z0-9_-]+\.eyJ[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+`),
|
||||
|
||||
// Generic secret assignment patterns
|
||||
regexp.MustCompile(`(?i)bearer\s+[a-zA-Z0-9_-]{20,}`),
|
||||
}
|
||||
|
||||
// ContainsSecrets checks if the given text contains any patterns that look like secrets.
|
||||
// Returns true if potential secrets are detected.
|
||||
func ContainsSecrets(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pattern := range secretPatterns {
|
||||
if pattern.MatchString(text) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RedactSecrets replaces detected secrets with a redaction marker.
|
||||
// This allows the text to be stored while protecting sensitive data.
|
||||
func RedactSecrets(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
result := text
|
||||
for _, pattern := range secretPatterns {
|
||||
result = pattern.ReplaceAllStringFunc(result, func(match string) string {
|
||||
// Preserve the key name, redact only the value
|
||||
if idx := strings.Index(match, "="); idx != -1 {
|
||||
return match[:idx+1] + "[REDACTED]"
|
||||
}
|
||||
if idx := strings.Index(match, ":"); idx != -1 {
|
||||
return match[:idx+1] + "[REDACTED]"
|
||||
}
|
||||
// For standalone secrets, show just the prefix
|
||||
if len(match) > 8 {
|
||||
return match[:4] + "...[REDACTED]"
|
||||
}
|
||||
return "[REDACTED]"
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SanitizeObservation checks multiple fields of an observation for secrets.
|
||||
// Returns true if any secrets were found.
|
||||
// This function is used as a validation gate before storing observations.
|
||||
func SanitizeObservation(narrative string, facts []string) bool {
|
||||
if ContainsSecrets(narrative) {
|
||||
return true
|
||||
}
|
||||
return slices.ContainsFunc(facts, func(fact string) bool {
|
||||
return ContainsSecrets(fact)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
package privacy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContainsSecrets(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "normal text",
|
||||
input: "This is just some regular text about a bug fix",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "API key pattern",
|
||||
input: "api_key=abc123def456ghi789jkl012mno345pqr678",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "api-key with dash",
|
||||
input: `api-key: "abc123def456ghi789jkl012mno"`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "password in config",
|
||||
input: `password="super_secret_password_123"`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "OpenAI key format",
|
||||
input: "sk-abc123def456ghi789jkl012mno345pqr678",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Anthropic key format",
|
||||
input: "sk-ant-api03-abc123def456ghi789jkl012mno345",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GitHub PAT",
|
||||
input: "ghp_1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GitHub PAT new format",
|
||||
input: "github_pat_12ABCDEFGHIJ3456789abc_defghijklmno",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "AWS access key",
|
||||
input: "AKIAIOSFODNN7EXAMPLE",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Private key header",
|
||||
input: "-----BEGIN RSA PRIVATE KEY-----",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "JWT token",
|
||||
input: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "bearer token",
|
||||
input: "Bearer abc123def456ghi789jkl012mno345",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "secret_key in code",
|
||||
input: `secret_key = "my_super_secret_token_here"`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "short password is not detected",
|
||||
input: `password="short"`,
|
||||
expected: false, // Too short to trigger
|
||||
},
|
||||
{
|
||||
name: "word password in sentence",
|
||||
input: "The password field should be validated",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "word api in code",
|
||||
input: "The API returns JSON data",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ContainsSecrets(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ContainsSecrets(%q) = %v, want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactSecrets(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "no secrets",
|
||||
input: "This is safe text",
|
||||
expected: "This is safe text",
|
||||
},
|
||||
{
|
||||
name: "API key gets redacted",
|
||||
input: "api_key=abc123def456ghi789jkl012mno345pqr678",
|
||||
expected: "api_key=[REDACTED]",
|
||||
},
|
||||
{
|
||||
name: "OpenAI key gets redacted",
|
||||
input: "The key is sk-abc123def456ghi789jkl012mno345pqr678",
|
||||
expected: "The key is sk-a...[REDACTED]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := RedactSecrets(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("RedactSecrets(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeObservation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
narrative string
|
||||
facts []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "clean observation",
|
||||
narrative: "Fixed a bug in the login flow",
|
||||
facts: []string{"Users can now log in", "Session management improved"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "secret in narrative",
|
||||
narrative: "Set API key api_key=abc123def456ghi789jkl012mno345",
|
||||
facts: []string{"Configuration updated"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "secret in facts",
|
||||
narrative: "Updated configuration",
|
||||
facts: []string{"Added api_key=abc123def456ghi789jkl012mno345"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty facts",
|
||||
narrative: "Clean narrative",
|
||||
facts: []string{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil facts",
|
||||
narrative: "Clean narrative",
|
||||
facts: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeObservation(tt.narrative, tt.facts)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeObservation() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContainsSecrets(b *testing.B) {
|
||||
text := "This is a normal piece of text that does not contain any secrets or sensitive information"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ContainsSecrets(text)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContainsSecretsWithSecret(b *testing.B) {
|
||||
text := "api_key=abc123def456ghi789jkl012mno345pqr678"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ContainsSecrets(text)
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,13 @@ func NewCalculator(config *models.ScoringConfig) *Calculator {
|
||||
// - ConceptContrib = sum(concept_weights) × concept_weight_factor
|
||||
// - RetrievalContrib = log2(retrieval_count + 1) × 0.1 × retrieval_weight
|
||||
func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 {
|
||||
return c.CalculateComponents(obs, now).FinalScore
|
||||
}
|
||||
|
||||
// CalculateComponents returns the individual components of the importance score.
|
||||
// Useful for debugging and explaining scores to users.
|
||||
// This is the core calculation method - Calculate() delegates to this.
|
||||
func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents {
|
||||
// 1. Get base type weight
|
||||
typeWeight := models.TypeBaseScore(obs.Type)
|
||||
|
||||
@@ -75,42 +82,6 @@ func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 {
|
||||
finalScore = c.config.MinScore
|
||||
}
|
||||
|
||||
return finalScore
|
||||
}
|
||||
|
||||
// CalculateComponents returns the individual components of the importance score.
|
||||
// Useful for debugging and explaining scores to users.
|
||||
func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents {
|
||||
typeWeight := models.TypeBaseScore(obs.Type)
|
||||
|
||||
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
|
||||
if ageDays < 0 {
|
||||
ageDays = 0
|
||||
}
|
||||
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
|
||||
|
||||
coreScore := 1.0 * typeWeight * recencyDecay
|
||||
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
|
||||
|
||||
conceptBoost := 0.0
|
||||
for _, concept := range obs.Concepts {
|
||||
if weight, ok := c.config.ConceptWeights[concept]; ok {
|
||||
conceptBoost += weight
|
||||
}
|
||||
}
|
||||
conceptContrib := conceptBoost * c.config.ConceptWeight
|
||||
|
||||
retrievalContrib := 0.0
|
||||
if obs.RetrievalCount > 0 {
|
||||
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
|
||||
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
|
||||
}
|
||||
|
||||
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
|
||||
if finalScore < c.config.MinScore {
|
||||
finalScore = c.config.MinScore
|
||||
}
|
||||
|
||||
return ScoreComponents{
|
||||
TypeWeight: typeWeight,
|
||||
RecencyDecay: recencyDecay,
|
||||
|
||||
@@ -3,7 +3,9 @@ package expansion
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -281,14 +283,10 @@ func (e *Expander) expandByVocabulary(ctx context.Context, query string, minSimi
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort by score (descending) using bubble sort
|
||||
for i := 0; i < len(similar)-1; i++ {
|
||||
for j := i + 1; j < len(similar); j++ {
|
||||
if similar[j].score > similar[i].score {
|
||||
similar[i], similar[j] = similar[j], similar[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort by score (descending) using Go's standard sort - O(n log n)
|
||||
sort.Slice(similar, func(i, j int) bool {
|
||||
return similar[i].score > similar[j].score
|
||||
})
|
||||
|
||||
// Create expansion by combining top similar terms with query
|
||||
var expansions []ExpandedQuery
|
||||
@@ -427,16 +425,13 @@ func cosineSimilarity(a, b []float32) float64 {
|
||||
return dot / (sqrt(normA) * sqrt(normB))
|
||||
}
|
||||
|
||||
// sqrt is a simple square root implementation.
|
||||
// sqrt uses the standard math.Sqrt for better performance and accuracy.
|
||||
// Returns 0 for non-positive values (original behavior for compatibility).
|
||||
func sqrt(x float64) float64 {
|
||||
if x <= 0 {
|
||||
return 0
|
||||
}
|
||||
z := x
|
||||
for i := 0; i < 10; i++ {
|
||||
z = (z + x/z) / 2
|
||||
}
|
||||
return z
|
||||
return math.Sqrt(x)
|
||||
}
|
||||
|
||||
// truncate truncates a string to maxLen characters.
|
||||
|
||||
+679
-14
@@ -3,19 +3,156 @@ package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hash/fnv"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// multiSpaceRegex matches multiple consecutive whitespace characters.
|
||||
// Pre-compiled for performance in normalizeQuery.
|
||||
var multiSpaceRegex = regexp.MustCompile(`\s+`)
|
||||
|
||||
// Search configuration constants.
|
||||
const (
|
||||
// Cache configuration
|
||||
defaultCacheTTL = 30 * time.Second // Short TTL for freshness
|
||||
defaultCacheMaxSize = 200 // Max cached results
|
||||
cacheEvictionPercent = 10 // Evict 10% when cache is full
|
||||
cacheEvictionThreshold = 80 // Start eviction scan at 80% capacity
|
||||
|
||||
// Latency tracking
|
||||
latencyHistogramCap = 1000 // Max latency samples for histogram
|
||||
slowQueryThresholdNs = 100 * 1e6 // 100ms threshold for slow query logging
|
||||
|
||||
// Query frequency tracking
|
||||
maxFrequencyEntries = 1000 // Max queries to track for warming
|
||||
frequencyEvictionBatch = 100 // Remove 10% when frequency map is full
|
||||
staleQueryThreshold = 24 * time.Hour // Remove queries older than 24 hours
|
||||
recentQueryWindow = time.Hour // Only consider queries from last hour for warming
|
||||
|
||||
// Cache warming configuration
|
||||
cacheWarmingInitDelay = 30 * time.Second // Delay before starting warming
|
||||
cacheWarmingInterval = 20 * time.Second // Run warming cycle every 20 seconds
|
||||
frequencyCleanupInterval = 5 * time.Minute // Cleanup stale entries every 5 minutes
|
||||
cacheCleanupInterval = time.Minute // Cleanup expired cache every minute
|
||||
warmingQueryTimeout = 5 * time.Second // Timeout for warming queries
|
||||
warmingBatchSize = 5 // Warm top 5 queries per cycle
|
||||
minRecencyFactor = 0.1 // Minimum recency factor for scoring
|
||||
|
||||
// Default query limits
|
||||
defaultQueryLimit = 20
|
||||
maxQueryLimit = 100
|
||||
defaultOrderBy = "date_desc"
|
||||
|
||||
// Truncation lengths
|
||||
queryLogTruncateLen = 50 // Truncate query in logs
|
||||
titleTruncateLen = 100 // Truncate titles in results
|
||||
warmingLogTruncateLen = 30 // Truncate query in warming logs
|
||||
)
|
||||
|
||||
// SearchMetrics tracks search performance statistics.
|
||||
type SearchMetrics struct {
|
||||
TotalSearches int64 // Total number of searches performed
|
||||
VectorSearches int64 // Searches using vector search
|
||||
FilterSearches int64 // Searches using filter/FTS search
|
||||
TotalLatencyNs int64 // Cumulative latency in nanoseconds
|
||||
VectorLatencyNs int64 // Cumulative vector search latency
|
||||
FilterLatencyNs int64 // Cumulative filter search latency
|
||||
CacheHits int64 // Number of result cache hits
|
||||
CoalescedRequests int64 // Number of requests served via singleflight coalescing
|
||||
SearchErrors int64 // Number of search errors
|
||||
|
||||
// Percentile tracking (approximate using reservoir sampling)
|
||||
latencyHistogram []int64 // Recent latency samples
|
||||
histogramMu sync.Mutex
|
||||
}
|
||||
|
||||
// GetStats returns the current search statistics.
|
||||
func (m *SearchMetrics) GetStats() map[string]any {
|
||||
totalSearches := atomic.LoadInt64(&m.TotalSearches)
|
||||
totalLatency := atomic.LoadInt64(&m.TotalLatencyNs)
|
||||
vectorSearches := atomic.LoadInt64(&m.VectorSearches)
|
||||
vectorLatency := atomic.LoadInt64(&m.VectorLatencyNs)
|
||||
filterSearches := atomic.LoadInt64(&m.FilterSearches)
|
||||
filterLatency := atomic.LoadInt64(&m.FilterLatencyNs)
|
||||
|
||||
avgLatencyMs := float64(0)
|
||||
if totalSearches > 0 {
|
||||
avgLatencyMs = float64(totalLatency) / float64(totalSearches) / 1e6
|
||||
}
|
||||
|
||||
avgVectorLatencyMs := float64(0)
|
||||
if vectorSearches > 0 {
|
||||
avgVectorLatencyMs = float64(vectorLatency) / float64(vectorSearches) / 1e6
|
||||
}
|
||||
|
||||
avgFilterLatencyMs := float64(0)
|
||||
if filterSearches > 0 {
|
||||
avgFilterLatencyMs = float64(filterLatency) / float64(filterSearches) / 1e6
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"total_searches": totalSearches,
|
||||
"vector_searches": vectorSearches,
|
||||
"filter_searches": filterSearches,
|
||||
"cache_hits": atomic.LoadInt64(&m.CacheHits),
|
||||
"coalesced_requests": atomic.LoadInt64(&m.CoalescedRequests),
|
||||
"search_errors": atomic.LoadInt64(&m.SearchErrors),
|
||||
"avg_latency_ms": avgLatencyMs,
|
||||
"avg_vector_latency_ms": avgVectorLatencyMs,
|
||||
"avg_filter_latency_ms": avgFilterLatencyMs,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager provides unified search across SQLite and sqlite-vec.
|
||||
type Manager struct {
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
promptStore *gorm.PromptStore
|
||||
vectorClient *sqlitevec.Client
|
||||
metrics *SearchMetrics
|
||||
|
||||
// Context for graceful shutdown of background goroutines
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Request coalescing for concurrent identical queries
|
||||
searchGroup singleflight.Group
|
||||
|
||||
// Result cache for repeated queries (short TTL)
|
||||
resultCache map[string]*cachedResult
|
||||
resultCacheMu sync.RWMutex
|
||||
cacheTTL time.Duration
|
||||
cacheMaxSize int
|
||||
|
||||
// Query frequency tracking for cache warming
|
||||
queryFrequency map[string]*queryFrequencyInfo
|
||||
queryFrequencyMu sync.RWMutex
|
||||
}
|
||||
|
||||
// queryFrequencyInfo tracks how often a query is used.
|
||||
type queryFrequencyInfo struct {
|
||||
params SearchParams
|
||||
count int64
|
||||
lastUsed time.Time
|
||||
lastCached time.Time
|
||||
}
|
||||
|
||||
// cachedResult stores a cached search result with expiry.
|
||||
type cachedResult struct {
|
||||
result *UnifiedSearchResult
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new search manager.
|
||||
@@ -25,12 +162,463 @@ func NewManager(
|
||||
promptStore *gorm.PromptStore,
|
||||
vectorClient *sqlitevec.Client,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m := &Manager{
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
promptStore: promptStore,
|
||||
vectorClient: vectorClient,
|
||||
metrics: &SearchMetrics{latencyHistogram: make([]int64, 0, latencyHistogramCap)},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
resultCache: make(map[string]*cachedResult),
|
||||
cacheTTL: defaultCacheTTL,
|
||||
cacheMaxSize: defaultCacheMaxSize,
|
||||
queryFrequency: make(map[string]*queryFrequencyInfo),
|
||||
}
|
||||
// Start cache cleanup goroutine
|
||||
go m.cleanupCacheLoop()
|
||||
// Start cache warming goroutine
|
||||
go m.cacheWarmingLoop()
|
||||
return m
|
||||
}
|
||||
|
||||
// Close stops background goroutines and cleans up resources.
|
||||
func (m *Manager) Close() {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupCacheLoop periodically removes expired cache entries.
|
||||
func (m *Manager) cleanupCacheLoop() {
|
||||
ticker := time.NewTicker(cacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanupExpiredCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredCache removes expired entries from the cache.
|
||||
func (m *Manager) cleanupExpiredCache() {
|
||||
m.resultCacheMu.Lock()
|
||||
defer m.resultCacheMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, cached := range m.resultCache {
|
||||
if now.After(cached.expiresAt) {
|
||||
delete(m.resultCache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cacheWarmingLoop periodically warms the cache for frequently used queries.
|
||||
func (m *Manager) cacheWarmingLoop() {
|
||||
// Wait a bit before starting to allow system to stabilize
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-time.After(cacheWarmingInitDelay):
|
||||
}
|
||||
|
||||
warmingTicker := time.NewTicker(cacheWarmingInterval)
|
||||
cleanupTicker := time.NewTicker(frequencyCleanupInterval)
|
||||
defer warmingTicker.Stop()
|
||||
defer cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-warmingTicker.C:
|
||||
m.warmFrequentQueries()
|
||||
case <-cleanupTicker.C:
|
||||
m.cleanupStaleFrequencyEntries()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleFrequencyEntries removes query frequency entries older than staleQueryThreshold.
|
||||
// This prevents memory bloat from queries that haven't been used in a long time.
|
||||
func (m *Manager) cleanupStaleFrequencyEntries() {
|
||||
m.queryFrequencyMu.Lock()
|
||||
now := time.Now()
|
||||
var keysToDelete []string
|
||||
for k, v := range m.queryFrequency {
|
||||
if now.Sub(v.lastUsed) > staleQueryThreshold {
|
||||
keysToDelete = append(keysToDelete, k)
|
||||
}
|
||||
}
|
||||
for _, k := range keysToDelete {
|
||||
delete(m.queryFrequency, k)
|
||||
}
|
||||
m.queryFrequencyMu.Unlock()
|
||||
|
||||
if len(keysToDelete) > 0 {
|
||||
log.Debug().Int("removed", len(keysToDelete)).Msg("Cleaned up stale query frequency entries")
|
||||
}
|
||||
}
|
||||
|
||||
// warmFrequentQueries pre-executes frequently used queries to warm the cache.
|
||||
func (m *Manager) warmFrequentQueries() {
|
||||
m.queryFrequencyMu.RLock()
|
||||
// Find top N most frequent queries that aren't recently cached
|
||||
type queryScore struct {
|
||||
key string
|
||||
info *queryFrequencyInfo
|
||||
score float64
|
||||
}
|
||||
candidates := make([]queryScore, 0, len(m.queryFrequency))
|
||||
now := time.Now()
|
||||
|
||||
for key, info := range m.queryFrequency {
|
||||
// Only consider queries used recently
|
||||
if now.Sub(info.lastUsed) > recentQueryWindow {
|
||||
continue
|
||||
}
|
||||
// Only warm if not recently cached (cache about to expire or already expired)
|
||||
timeSinceLastCache := now.Sub(info.lastCached)
|
||||
if timeSinceLastCache < m.cacheTTL/2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Score: frequency * recency factor
|
||||
recencyFactor := 1.0 - (now.Sub(info.lastUsed).Seconds() / recentQueryWindow.Seconds())
|
||||
if recencyFactor < minRecencyFactor {
|
||||
recencyFactor = minRecencyFactor
|
||||
}
|
||||
score := float64(info.count) * recencyFactor
|
||||
|
||||
candidates = append(candidates, queryScore{key: key, info: info, score: score})
|
||||
}
|
||||
m.queryFrequencyMu.RUnlock()
|
||||
|
||||
// Sort by score descending using O(n log n) algorithm
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].score > candidates[j].score
|
||||
})
|
||||
|
||||
// Warm top queries
|
||||
warmCount := min(warmingBatchSize, len(candidates))
|
||||
for i := range warmCount {
|
||||
candidate := candidates[i]
|
||||
ctx, cancel := context.WithTimeout(context.Background(), warmingQueryTimeout)
|
||||
result, err := m.executeSearch(ctx, candidate.info.params)
|
||||
cancel()
|
||||
|
||||
if err == nil && result != nil {
|
||||
cacheKey := m.getCacheKey(candidate.info.params)
|
||||
m.putInCache(cacheKey, result)
|
||||
|
||||
// Update last cached time
|
||||
m.queryFrequencyMu.Lock()
|
||||
if info, ok := m.queryFrequency[candidate.key]; ok {
|
||||
info.lastCached = time.Now()
|
||||
}
|
||||
m.queryFrequencyMu.Unlock()
|
||||
|
||||
log.Debug().
|
||||
Str("query", truncate(candidate.info.params.Query, warmingLogTruncateLen)).
|
||||
Float64("score", candidate.score).
|
||||
Msg("Cache warmed for frequent query")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// trackQueryFrequency records query usage for cache warming decisions.
|
||||
func (m *Manager) trackQueryFrequency(params SearchParams) {
|
||||
key := m.getCacheKey(params)
|
||||
|
||||
m.queryFrequencyMu.Lock()
|
||||
|
||||
if info, ok := m.queryFrequency[key]; ok {
|
||||
info.count++
|
||||
info.lastUsed = time.Now()
|
||||
m.queryFrequencyMu.Unlock()
|
||||
return // Fast path: no eviction needed
|
||||
}
|
||||
|
||||
m.queryFrequency[key] = &queryFrequencyInfo{
|
||||
params: params,
|
||||
count: 1,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
|
||||
// Limit frequency map size to prevent memory bloat
|
||||
mapLen := len(m.queryFrequency)
|
||||
if mapLen <= maxFrequencyEntries {
|
||||
m.queryFrequencyMu.Unlock()
|
||||
return // Fast path: no eviction needed
|
||||
}
|
||||
|
||||
// Collect keys and times for eviction (still under lock, but fast)
|
||||
type entry struct {
|
||||
key string
|
||||
lastUsed time.Time
|
||||
}
|
||||
entries := make([]entry, 0, mapLen)
|
||||
for k, v := range m.queryFrequency {
|
||||
entries = append(entries, entry{key: k, lastUsed: v.lastUsed})
|
||||
}
|
||||
m.queryFrequencyMu.Unlock()
|
||||
|
||||
// Sort outside lock to reduce contention (O(n log n))
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].lastUsed.Before(entries[j].lastUsed)
|
||||
})
|
||||
|
||||
// Collect keys to delete
|
||||
deleteCount := min(frequencyEvictionBatch, len(entries))
|
||||
keysToDelete := make([]string, deleteCount)
|
||||
for i := range deleteCount {
|
||||
keysToDelete[i] = entries[i].key
|
||||
}
|
||||
|
||||
// Re-acquire lock only for deletion (brief critical section)
|
||||
m.queryFrequencyMu.Lock()
|
||||
for _, k := range keysToDelete {
|
||||
delete(m.queryFrequency, k)
|
||||
}
|
||||
m.queryFrequencyMu.Unlock()
|
||||
}
|
||||
|
||||
// RecentQuery represents a recently executed search query.
|
||||
type RecentQuery struct {
|
||||
Query string `json:"query"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Type string `json:"type,omitempty"` // observations, sessions, prompts
|
||||
Count int64 `json:"count"` // Number of times executed
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
}
|
||||
|
||||
// GetRecentQueries returns the most recent search queries, sorted by last used time.
|
||||
func (m *Manager) GetRecentQueries(limit int) []RecentQuery {
|
||||
if limit <= 0 {
|
||||
limit = defaultQueryLimit
|
||||
}
|
||||
if limit > maxQueryLimit {
|
||||
limit = maxQueryLimit
|
||||
}
|
||||
|
||||
m.queryFrequencyMu.RLock()
|
||||
defer m.queryFrequencyMu.RUnlock()
|
||||
|
||||
// Collect all queries
|
||||
queries := make([]RecentQuery, 0, len(m.queryFrequency))
|
||||
for _, info := range m.queryFrequency {
|
||||
queries = append(queries, RecentQuery{
|
||||
Query: info.params.Query,
|
||||
Project: info.params.Project,
|
||||
Type: info.params.Type,
|
||||
Count: info.count,
|
||||
LastUsed: info.lastUsed,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by last used (most recent first)
|
||||
sort.Slice(queries, func(i, j int) bool {
|
||||
return queries[i].LastUsed.After(queries[j].LastUsed)
|
||||
})
|
||||
|
||||
// Limit results
|
||||
if len(queries) > limit {
|
||||
queries = queries[:limit]
|
||||
}
|
||||
|
||||
return queries
|
||||
}
|
||||
|
||||
// GetFrequentQueries returns the most frequently used search queries.
|
||||
func (m *Manager) GetFrequentQueries(limit int) []RecentQuery {
|
||||
if limit <= 0 {
|
||||
limit = defaultQueryLimit
|
||||
}
|
||||
if limit > maxQueryLimit {
|
||||
limit = maxQueryLimit
|
||||
}
|
||||
|
||||
m.queryFrequencyMu.RLock()
|
||||
defer m.queryFrequencyMu.RUnlock()
|
||||
|
||||
// Only include queries used recently
|
||||
now := time.Now()
|
||||
queries := make([]RecentQuery, 0, len(m.queryFrequency))
|
||||
for _, info := range m.queryFrequency {
|
||||
if now.Sub(info.lastUsed) > recentQueryWindow {
|
||||
continue
|
||||
}
|
||||
queries = append(queries, RecentQuery{
|
||||
Query: info.params.Query,
|
||||
Project: info.params.Project,
|
||||
Type: info.params.Type,
|
||||
Count: info.count,
|
||||
LastUsed: info.lastUsed,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by count (highest first)
|
||||
sort.Slice(queries, func(i, j int) bool {
|
||||
return queries[i].Count > queries[j].Count
|
||||
})
|
||||
|
||||
// Limit results
|
||||
if len(queries) > limit {
|
||||
queries = queries[:limit]
|
||||
}
|
||||
|
||||
return queries
|
||||
}
|
||||
|
||||
// normalizeQuery normalizes a search query for consistent cache keys.
|
||||
// Converts to lowercase, trims whitespace, and collapses multiple spaces.
|
||||
// Uses pre-compiled regex for O(n) performance instead of O(n*m) loop.
|
||||
func normalizeQuery(query string) string {
|
||||
// Lowercase for case-insensitive matching
|
||||
query = strings.ToLower(query)
|
||||
// Collapse multiple whitespace into single space using pre-compiled regex
|
||||
query = multiSpaceRegex.ReplaceAllString(query, " ")
|
||||
// Trim leading/trailing whitespace (after collapsing)
|
||||
return strings.TrimSpace(query)
|
||||
}
|
||||
|
||||
// getCacheKey generates a cache key from search params.
|
||||
// Uses direct string concatenation instead of JSON marshal for better performance.
|
||||
// Queries are normalized for consistent cache hits across whitespace variations.
|
||||
func (m *Manager) getCacheKey(params SearchParams) string {
|
||||
// Normalize query for consistent cache keys
|
||||
normalizedQuery := normalizeQuery(params.Query)
|
||||
|
||||
// Hash directly without intermediate string allocation.
|
||||
// FNV-64a is fast and collision-safe for cache keys.
|
||||
h := fnv.New64a()
|
||||
|
||||
// Write each field directly to the hasher with separators
|
||||
h.Write([]byte(normalizedQuery))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.Type))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.Project))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.ObsType))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.Concepts))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.Files))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(strconv.FormatInt(params.DateStart, 10)))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(strconv.FormatInt(params.DateEnd, 10)))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.OrderBy))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(strconv.Itoa(params.Limit)))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(strconv.Itoa(params.Offset)))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.Format))
|
||||
h.Write([]byte{'|'})
|
||||
h.Write([]byte(params.Scope))
|
||||
h.Write([]byte{'|'})
|
||||
if params.IncludeGlobal {
|
||||
h.Write([]byte{'1'})
|
||||
} else {
|
||||
h.Write([]byte{'0'})
|
||||
}
|
||||
h.Write([]byte{'|'})
|
||||
if params.ExcludeSuperseded {
|
||||
h.Write([]byte{'1'})
|
||||
} else {
|
||||
h.Write([]byte{'0'})
|
||||
}
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 36) // Base36 for compact representation
|
||||
}
|
||||
|
||||
// getFromCache retrieves a result from cache if valid.
|
||||
func (m *Manager) getFromCache(key string) (*UnifiedSearchResult, bool) {
|
||||
m.resultCacheMu.RLock()
|
||||
defer m.resultCacheMu.RUnlock()
|
||||
|
||||
if cached, ok := m.resultCache[key]; ok {
|
||||
if time.Now().Before(cached.expiresAt) {
|
||||
atomic.AddInt64(&m.metrics.CacheHits, 1)
|
||||
return cached.result, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// putInCache stores a result in the cache.
|
||||
// Optimized to skip expensive scans when cache is below capacity threshold.
|
||||
func (m *Manager) putInCache(key string, result *UnifiedSearchResult) {
|
||||
m.resultCacheMu.Lock()
|
||||
defer m.resultCacheMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cacheLen := len(m.resultCache)
|
||||
|
||||
// Only scan for expired entries when at threshold+ capacity (amortized cleanup)
|
||||
evictionThreshold := (m.cacheMaxSize * cacheEvictionThreshold) / 100
|
||||
if cacheLen >= evictionThreshold {
|
||||
// Evict expired entries
|
||||
for k, v := range m.resultCache {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(m.resultCache, k)
|
||||
}
|
||||
}
|
||||
cacheLen = len(m.resultCache) // Update after eviction
|
||||
}
|
||||
|
||||
// If still at capacity after removing expired, use simple FIFO-style eviction
|
||||
// Go map iteration order is random, which provides good cache behavior
|
||||
if cacheLen >= m.cacheMaxSize {
|
||||
// Evict percentage using random-order iteration (O(n) single pass)
|
||||
evictCount := max(m.cacheMaxSize*cacheEvictionPercent/100, 1)
|
||||
evicted := 0
|
||||
for k := range m.resultCache {
|
||||
delete(m.resultCache, k)
|
||||
evicted++
|
||||
if evicted >= evictCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.resultCache[key] = &cachedResult{
|
||||
result: result,
|
||||
expiresAt: now.Add(m.cacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics returns the search metrics for monitoring.
|
||||
func (m *Manager) Metrics() *SearchMetrics {
|
||||
return m.metrics
|
||||
}
|
||||
|
||||
// CacheStats returns current cache statistics.
|
||||
func (m *Manager) CacheStats() map[string]any {
|
||||
m.resultCacheMu.RLock()
|
||||
cacheSize := len(m.resultCache)
|
||||
m.resultCacheMu.RUnlock()
|
||||
|
||||
return map[string]any{
|
||||
"size": cacheSize,
|
||||
"max_size": m.cacheMaxSize,
|
||||
"ttl_sec": m.cacheTTL.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// ClearCache clears the result cache. Useful for testing or after data changes.
|
||||
func (m *Manager) ClearCache() {
|
||||
m.resultCacheMu.Lock()
|
||||
m.resultCache = make(map[string]*cachedResult)
|
||||
m.resultCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// SearchParams contains parameters for unified search.
|
||||
@@ -62,7 +650,7 @@ type SearchResult struct {
|
||||
Scope string `json:"scope,omitempty"` // "project" or "global"
|
||||
CreatedAt int64 `json:"created_at_epoch"`
|
||||
Score float64 `json:"score,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// UnifiedSearchResult contains the combined search results.
|
||||
@@ -73,17 +661,69 @@ type UnifiedSearchResult struct {
|
||||
}
|
||||
|
||||
// UnifiedSearch performs a unified search across all document types.
|
||||
// Uses caching and request coalescing for optimal performance.
|
||||
func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
latency := time.Since(start).Nanoseconds()
|
||||
atomic.AddInt64(&m.metrics.TotalSearches, 1)
|
||||
atomic.AddInt64(&m.metrics.TotalLatencyNs, latency)
|
||||
|
||||
// Sample latency for histogram (reservoir sampling)
|
||||
m.metrics.histogramMu.Lock()
|
||||
if len(m.metrics.latencyHistogram) < latencyHistogramCap {
|
||||
m.metrics.latencyHistogram = append(m.metrics.latencyHistogram, latency)
|
||||
}
|
||||
m.metrics.histogramMu.Unlock()
|
||||
|
||||
// Log slow queries
|
||||
if latency > slowQueryThresholdNs {
|
||||
log.Warn().
|
||||
Str("query", truncate(params.Query, queryLogTruncateLen)).
|
||||
Dur("latency", time.Duration(latency)).
|
||||
Str("type", params.Type).
|
||||
Msg("Slow search query")
|
||||
}
|
||||
}()
|
||||
|
||||
if params.Limit <= 0 {
|
||||
params.Limit = 20
|
||||
params.Limit = defaultQueryLimit
|
||||
}
|
||||
if params.Limit > 100 {
|
||||
params.Limit = 100
|
||||
if params.Limit > maxQueryLimit {
|
||||
params.Limit = maxQueryLimit
|
||||
}
|
||||
if params.OrderBy == "" {
|
||||
params.OrderBy = "date_desc"
|
||||
params.OrderBy = defaultOrderBy
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
cacheKey := m.getCacheKey(params)
|
||||
if cached, ok := m.getFromCache(cacheKey); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// Use singleflight to coalesce concurrent identical requests
|
||||
result, err, _ := m.searchGroup.Do(cacheKey, func() (any, error) {
|
||||
return m.executeSearch(ctx, params)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
searchResult := result.(*UnifiedSearchResult)
|
||||
|
||||
// Cache the result
|
||||
m.putInCache(cacheKey, searchResult)
|
||||
|
||||
// Track query frequency for cache warming
|
||||
m.trackQueryFrequency(params)
|
||||
|
||||
return searchResult, nil
|
||||
}
|
||||
|
||||
// executeSearch performs the actual search without caching/coalescing.
|
||||
func (m *Manager) executeSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
|
||||
// If query is provided and vector client is available, use vector search
|
||||
if params.Query != "" && m.vectorClient != nil && m.vectorClient.IsConnected() {
|
||||
return m.vectorSearch(ctx, params)
|
||||
@@ -95,6 +735,13 @@ func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*Unif
|
||||
|
||||
// vectorSearch performs semantic search via sqlite-vec.
|
||||
func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
latency := time.Since(start).Nanoseconds()
|
||||
atomic.AddInt64(&m.metrics.VectorSearches, 1)
|
||||
atomic.AddInt64(&m.metrics.VectorLatencyNs, latency)
|
||||
}()
|
||||
|
||||
// Build where filter based on search type
|
||||
var docType sqlitevec.DocType
|
||||
switch params.Type {
|
||||
@@ -110,6 +757,7 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
// Query sqlite-vec
|
||||
vectorResults, err := m.vectorClient.Query(ctx, params.Query, params.Limit*2, where)
|
||||
if err != nil {
|
||||
atomic.AddInt64(&m.metrics.SearchErrors, 1)
|
||||
// Fall back to filter search on error
|
||||
return m.filterSearch(ctx, params)
|
||||
}
|
||||
@@ -125,7 +773,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
|
||||
if len(obsIDs) > 0 && (params.Type == "" || params.Type == "observations") {
|
||||
obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Int("count", len(obsIDs)).Msg("Failed to fetch observations by IDs in vector search")
|
||||
} else {
|
||||
for _, o := range obs {
|
||||
// Skip superseded observations when requested
|
||||
if params.ExcludeSuperseded && o.IsSuperseded {
|
||||
@@ -138,7 +788,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
|
||||
if len(summaryIDs) > 0 && (params.Type == "" || params.Type == "sessions") {
|
||||
summaries, err := m.summaryStore.GetSummariesByIDs(ctx, summaryIDs, params.OrderBy, 0)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Int("count", len(summaryIDs)).Msg("Failed to fetch summaries by IDs in vector search")
|
||||
} else {
|
||||
for _, s := range summaries {
|
||||
results = append(results, m.summaryToResult(s, params.Format))
|
||||
}
|
||||
@@ -147,7 +799,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
|
||||
if len(promptIDs) > 0 && (params.Type == "" || params.Type == "prompts") {
|
||||
prompts, err := m.promptStore.GetPromptsByIDs(ctx, promptIDs, params.OrderBy, 0)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Int("count", len(promptIDs)).Msg("Failed to fetch prompts by IDs in vector search")
|
||||
} else {
|
||||
for _, p := range prompts {
|
||||
results = append(results, m.promptToResult(p, params.Format))
|
||||
}
|
||||
@@ -168,6 +822,13 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
|
||||
// filterSearch performs structured filter search via SQLite.
|
||||
func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
latency := time.Since(start).Nanoseconds()
|
||||
atomic.AddInt64(&m.metrics.FilterSearches, 1)
|
||||
atomic.AddInt64(&m.metrics.FilterLatencyNs, latency)
|
||||
}()
|
||||
|
||||
var results []SearchResult
|
||||
|
||||
// Search observations
|
||||
@@ -182,7 +843,9 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
obs, err = m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("project", params.Project).Msg("Failed to fetch observations in filter search")
|
||||
} else {
|
||||
for _, o := range obs {
|
||||
results = append(results, m.observationToResult(o, params.Format))
|
||||
}
|
||||
@@ -192,7 +855,9 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi
|
||||
// Search summaries
|
||||
if params.Type == "" || params.Type == "sessions" {
|
||||
summaries, err := m.summaryStore.GetRecentSummaries(ctx, params.Project, params.Limit)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("project", params.Project).Msg("Failed to fetch summaries in filter search")
|
||||
} else {
|
||||
for _, s := range summaries {
|
||||
results = append(results, m.summaryToResult(s, params.Format))
|
||||
}
|
||||
@@ -249,7 +914,7 @@ func (m *Manager) observationToResult(obs *models.Observation, format string) Se
|
||||
Project: obs.Project,
|
||||
Scope: string(obs.Scope),
|
||||
CreatedAt: obs.CreatedAtEpoch,
|
||||
Metadata: map[string]interface{}{
|
||||
Metadata: map[string]any{
|
||||
"obs_type": string(obs.Type),
|
||||
"scope": string(obs.Scope),
|
||||
},
|
||||
@@ -275,7 +940,7 @@ func (m *Manager) summaryToResult(summary *models.SessionSummary, format string)
|
||||
}
|
||||
|
||||
if summary.Request.Valid {
|
||||
result.Title = truncate(summary.Request.String, 100)
|
||||
result.Title = truncate(summary.Request.String, titleTruncateLen)
|
||||
}
|
||||
|
||||
if format == "full" && summary.Learned.Valid {
|
||||
@@ -293,7 +958,7 @@ func (m *Manager) promptToResult(prompt *models.UserPromptWithSession, format st
|
||||
CreatedAt: prompt.CreatedAtEpoch,
|
||||
}
|
||||
|
||||
result.Title = truncate(prompt.PromptText, 100)
|
||||
result.Title = truncate(prompt.PromptText, titleTruncateLen)
|
||||
|
||||
if format == "full" {
|
||||
result.Content = prompt.PromptText
|
||||
|
||||
@@ -5,19 +5,117 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// embeddingCacheEntry stores a cached embedding with its timestamp.
|
||||
type embeddingCacheEntry struct {
|
||||
embedding []float32
|
||||
timestamp int64 // Unix nano
|
||||
}
|
||||
|
||||
// resultCacheEntry stores cached query results.
|
||||
type resultCacheEntry struct {
|
||||
results []QueryResult
|
||||
timestamp int64 // Unix nano
|
||||
queryHash string // Hash of query + filters for validation
|
||||
}
|
||||
|
||||
// Client provides vector operations via sqlite-vec.
|
||||
type Client struct {
|
||||
db *sql.DB
|
||||
embedSvc *embedding.Service
|
||||
mu sync.Mutex
|
||||
|
||||
// Separate mutexes for read and write operations to reduce contention
|
||||
writeMu sync.Mutex // Protects write operations (AddDocuments, DeleteDocuments)
|
||||
readMu sync.RWMutex // Protects read operations (Query, Count)
|
||||
|
||||
// Embedding cache for query deduplication
|
||||
queryCache map[string]embeddingCacheEntry
|
||||
queryCacheMu sync.RWMutex
|
||||
cacheMaxSize int
|
||||
cacheTTLNano int64 // Cache TTL in nanoseconds
|
||||
|
||||
// Result cache for repeated searches
|
||||
resultCache map[string]resultCacheEntry
|
||||
resultCacheMu sync.RWMutex
|
||||
resultCacheMaxSize int
|
||||
resultCacheTTLNano int64 // Shorter TTL for results (data changes more often)
|
||||
|
||||
// Cache statistics
|
||||
stats CacheStats
|
||||
|
||||
// Background cleanup control
|
||||
stopCleanup chan struct{}
|
||||
cleanupWg sync.WaitGroup
|
||||
|
||||
// Singleflight to deduplicate concurrent embedding computations
|
||||
embeddingGroup singleflight.Group
|
||||
}
|
||||
|
||||
// CacheStats tracks cache performance metrics using atomic counters for lock-free updates.
|
||||
type CacheStats struct {
|
||||
embeddingHits atomic.Int64
|
||||
embeddingMisses atomic.Int64
|
||||
resultHits atomic.Int64
|
||||
resultMisses atomic.Int64
|
||||
embeddingEvictions atomic.Int64
|
||||
resultEvictions atomic.Int64
|
||||
}
|
||||
|
||||
// CacheStatsSnapshot is the exported version of CacheStats for JSON marshaling.
|
||||
type CacheStatsSnapshot struct {
|
||||
EmbeddingHits int64 `json:"embedding_hits"`
|
||||
EmbeddingMisses int64 `json:"embedding_misses"`
|
||||
ResultHits int64 `json:"result_hits"`
|
||||
ResultMisses int64 `json:"result_misses"`
|
||||
EmbeddingEvictions int64 `json:"embedding_evictions"`
|
||||
ResultEvictions int64 `json:"result_evictions"`
|
||||
}
|
||||
|
||||
// HitRate returns the cache hit rate as a percentage.
|
||||
func (s CacheStatsSnapshot) HitRate() float64 {
|
||||
total := s.EmbeddingHits + s.EmbeddingMisses + s.ResultHits + s.ResultMisses
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
hits := s.EmbeddingHits + s.ResultHits
|
||||
return float64(hits) / float64(total) * 100
|
||||
}
|
||||
|
||||
// HitRate returns the cache hit rate as a percentage.
|
||||
func (s *CacheStats) HitRate() float64 {
|
||||
embHits := s.embeddingHits.Load()
|
||||
embMisses := s.embeddingMisses.Load()
|
||||
resHits := s.resultHits.Load()
|
||||
resMisses := s.resultMisses.Load()
|
||||
total := embHits + embMisses + resHits + resMisses
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
hits := embHits + resHits
|
||||
return float64(hits) / float64(total) * 100
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of the current stats.
|
||||
func (s *CacheStats) Snapshot() CacheStatsSnapshot {
|
||||
return CacheStatsSnapshot{
|
||||
EmbeddingHits: s.embeddingHits.Load(),
|
||||
EmbeddingMisses: s.embeddingMisses.Load(),
|
||||
ResultHits: s.resultHits.Load(),
|
||||
ResultMisses: s.resultMisses.Load(),
|
||||
EmbeddingEvictions: s.embeddingEvictions.Load(),
|
||||
ResultEvictions: s.resultEvictions.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// Config holds configuration for the client.
|
||||
@@ -34,10 +132,23 @@ func NewClient(cfg Config, embedSvc *embedding.Service) (*Client, error) {
|
||||
return nil, fmt.Errorf("embedding service required")
|
||||
}
|
||||
|
||||
return &Client{
|
||||
db: cfg.DB,
|
||||
embedSvc: embedSvc,
|
||||
}, nil
|
||||
c := &Client{
|
||||
db: cfg.DB,
|
||||
embedSvc: embedSvc,
|
||||
queryCache: make(map[string]embeddingCacheEntry),
|
||||
cacheMaxSize: 500, // Cache up to 500 query embeddings
|
||||
cacheTTLNano: 5 * 60 * 1e9, // 5 minute TTL for embeddings
|
||||
resultCache: make(map[string]resultCacheEntry),
|
||||
resultCacheMaxSize: 200, // Cache up to 200 search results
|
||||
resultCacheTTLNano: 60 * 1e9, // 1 minute TTL for results (shorter since data changes)
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start background cache cleanup goroutine
|
||||
c.cleanupWg.Add(1)
|
||||
go c.cacheCleanupLoop()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// AddDocuments adds documents with their embeddings to the vector store.
|
||||
@@ -46,8 +157,8 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Generate embeddings for all documents
|
||||
texts := make([]string, len(docs))
|
||||
@@ -75,7 +186,10 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
// Rollback failure is serious - indicates potential data corruption risk
|
||||
log.Error().Err(rbErr).Err(err).Msg("Failed to rollback transaction after error - data may be inconsistent")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -118,6 +232,9 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// Invalidate result cache since data changed
|
||||
c.InvalidateResultCache()
|
||||
|
||||
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
@@ -128,12 +245,12 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Build placeholder string
|
||||
placeholders := make([]string, len(ids))
|
||||
args := make([]interface{}, len(ids))
|
||||
args := make([]any, len(ids))
|
||||
for i, id := range ids {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
@@ -148,17 +265,25 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
|
||||
return fmt.Errorf("delete documents: %w", err)
|
||||
}
|
||||
|
||||
// Invalidate result cache since data changed
|
||||
c.InvalidateResultCache()
|
||||
|
||||
log.Debug().Int("count", len(ids)).Msg("Deleted documents from sqlite-vec")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query performs a vector similarity search.
|
||||
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// Build cache key from query + filters + limit
|
||||
cacheKey := c.buildResultCacheKey(query, limit, where)
|
||||
|
||||
// Generate query embedding
|
||||
queryEmb, err := c.embedSvc.Embed(query)
|
||||
// Check result cache first
|
||||
if results, ok := c.getResultFromCache(cacheKey); ok {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Generate query embedding OUTSIDE the lock for better concurrency
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
@@ -169,9 +294,13 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
// Now acquire read lock for the actual DB query
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Build query with filters
|
||||
// vec0 supports WHERE clauses on metadata columns
|
||||
args := []interface{}{queryBlob}
|
||||
args := []any{queryBlob}
|
||||
|
||||
sqlQuery := `
|
||||
SELECT
|
||||
@@ -232,6 +361,9 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
|
||||
return nil, fmt.Errorf("iterate rows: %w", err)
|
||||
}
|
||||
|
||||
// Cache the results
|
||||
c.cacheResults(cacheKey, results)
|
||||
|
||||
log.Debug().
|
||||
Str("query", truncateString(query, 50)).
|
||||
Int("results", len(results)).
|
||||
@@ -245,11 +377,196 @@ func (c *Client) IsConnected() bool {
|
||||
return c.db != nil
|
||||
}
|
||||
|
||||
// Close is a no-op (db managed externally).
|
||||
// Close stops the background cleanup goroutine (db managed externally).
|
||||
func (c *Client) Close() error {
|
||||
// Signal cleanup goroutine to stop
|
||||
close(c.stopCleanup)
|
||||
// Wait for cleanup to finish
|
||||
c.cleanupWg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheCleanupLoop periodically removes expired cache entries.
|
||||
func (c *Client) cacheCleanupLoop() {
|
||||
defer c.cleanupWg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second) // Cleanup every 30 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCleanup:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.cleanupExpiredCaches()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredCaches removes expired entries from both caches.
|
||||
func (c *Client) cleanupExpiredCaches() {
|
||||
now := time.Now().UnixNano()
|
||||
var embeddingExpired, resultExpired int64
|
||||
|
||||
// Cleanup embedding cache
|
||||
c.queryCacheMu.Lock()
|
||||
for key, entry := range c.queryCache {
|
||||
if now-entry.timestamp > c.cacheTTLNano {
|
||||
delete(c.queryCache, key)
|
||||
embeddingExpired++
|
||||
}
|
||||
}
|
||||
c.queryCacheMu.Unlock()
|
||||
|
||||
// Cleanup result cache
|
||||
c.resultCacheMu.Lock()
|
||||
for key, entry := range c.resultCache {
|
||||
if now-entry.timestamp > c.resultCacheTTLNano {
|
||||
delete(c.resultCache, key)
|
||||
resultExpired++
|
||||
}
|
||||
}
|
||||
c.resultCacheMu.Unlock()
|
||||
|
||||
// Update stats atomically
|
||||
if embeddingExpired > 0 || resultExpired > 0 {
|
||||
c.stats.embeddingEvictions.Add(embeddingExpired)
|
||||
c.stats.resultEvictions.Add(resultExpired)
|
||||
|
||||
log.Debug().
|
||||
Int64("embedding_expired", embeddingExpired).
|
||||
Int64("result_expired", resultExpired).
|
||||
Msg("Cache cleanup completed")
|
||||
}
|
||||
}
|
||||
|
||||
// BatchQueryResult holds results from a batch query operation.
|
||||
type BatchQueryResult struct {
|
||||
Query string // Original query string
|
||||
Results []QueryResult // Results for this query
|
||||
Error error // Error if query failed
|
||||
}
|
||||
|
||||
// QueryBatch performs multiple vector searches concurrently.
|
||||
// Returns results in the same order as input queries.
|
||||
// Uses a worker pool to limit concurrent queries.
|
||||
func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, where map[string]any) []BatchQueryResult {
|
||||
if len(queries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Limit concurrency to avoid overwhelming the database
|
||||
maxConcurrent := min(4, len(queries))
|
||||
|
||||
results := make([]BatchQueryResult, len(queries))
|
||||
sem := make(chan struct{}, maxConcurrent)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, query := range queries {
|
||||
wg.Add(1)
|
||||
go func(idx int, q string) {
|
||||
defer wg.Done()
|
||||
|
||||
// Acquire semaphore
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
defer func() { <-sem }()
|
||||
case <-ctx.Done():
|
||||
results[idx] = BatchQueryResult{
|
||||
Query: q,
|
||||
Error: ctx.Err(),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Execute query
|
||||
queryResults, err := c.Query(ctx, q, limit, where)
|
||||
results[idx] = BatchQueryResult{
|
||||
Query: q,
|
||||
Results: queryResults,
|
||||
Error: err,
|
||||
}
|
||||
}(i, query)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
// QueryMultiField searches across multiple fields for a single query.
|
||||
// Combines results from different field types and deduplicates by document ID.
|
||||
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
|
||||
// Generate embedding once
|
||||
queryEmb, err := c.getOrComputeEmbedding(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
|
||||
// Serialize query embedding
|
||||
queryBlob, err := sqlite_vec.SerializeFloat32(queryEmb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize query embedding: %w", err)
|
||||
}
|
||||
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
// Query with field type aggregation - get best match per document
|
||||
sqlQuery := `
|
||||
WITH ranked_results AS (
|
||||
SELECT
|
||||
doc_id,
|
||||
distance,
|
||||
sqlite_id,
|
||||
doc_type,
|
||||
field_type,
|
||||
project,
|
||||
scope,
|
||||
ROW_NUMBER() OVER (PARTITION BY sqlite_id ORDER BY distance ASC) as rn
|
||||
FROM vectors
|
||||
WHERE embedding MATCH ?
|
||||
AND doc_type = ?
|
||||
AND (project = ? OR scope = 'global')
|
||||
)
|
||||
SELECT doc_id, distance, sqlite_id, doc_type, field_type, project, scope
|
||||
FROM ranked_results
|
||||
WHERE rn = 1
|
||||
ORDER BY distance
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
rows, err := c.db.QueryContext(ctx, sqlQuery, queryBlob, docType, project, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query vectors: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Pre-allocate with limit to avoid repeated slice growth
|
||||
results := make([]QueryResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var r QueryResult
|
||||
var sqliteID int64
|
||||
var docTypeVal, fieldType, projectVal, scope sql.NullString
|
||||
|
||||
if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docTypeVal, &fieldType, &projectVal, &scope); err != nil {
|
||||
return nil, fmt.Errorf("scan row: %w", err)
|
||||
}
|
||||
|
||||
r.Similarity = DistanceToSimilarity(r.Distance)
|
||||
r.Metadata = map[string]any{
|
||||
"sqlite_id": float64(sqliteID),
|
||||
"doc_type": docTypeVal.String,
|
||||
"field_type": fieldType.String,
|
||||
"project": projectVal.String,
|
||||
"scope": scope.String,
|
||||
}
|
||||
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// truncateString truncates a string to maxLen characters.
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
@@ -260,8 +577,8 @@ func truncateString(s string, maxLen int) string {
|
||||
|
||||
// Count returns the total number of vectors in the store.
|
||||
func (c *Client) Count(ctx context.Context) (int64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
var count int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
|
||||
@@ -281,8 +598,8 @@ func (c *Client) ModelVersion() string {
|
||||
// - The vectors table is empty
|
||||
// - Any vectors have a different model_version than the current model
|
||||
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
@@ -329,8 +646,8 @@ type StaleVectorInfo struct {
|
||||
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
|
||||
// This enables granular rebuild - only re-embedding documents that need updating.
|
||||
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
currentModel := c.embedSvc.Version()
|
||||
|
||||
@@ -372,6 +689,134 @@ func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// VectorHealthStats contains comprehensive health information about the vector store.
|
||||
type VectorHealthStats struct {
|
||||
TotalVectors int64 `json:"total_vectors"`
|
||||
StaleVectors int64 `json:"stale_vectors"`
|
||||
CurrentModel string `json:"current_model"`
|
||||
NeedsRebuild bool `json:"needs_rebuild"`
|
||||
RebuildReason string `json:"rebuild_reason,omitempty"`
|
||||
CoverageByType map[string]int64 `json:"coverage_by_type"`
|
||||
ModelVersions map[string]int64 `json:"model_versions"`
|
||||
ProjectCounts map[string]int64 `json:"project_counts"`
|
||||
EmbeddingCache CacheStatsSnapshot `json:"embedding_cache"`
|
||||
}
|
||||
|
||||
// GetHealthStats returns comprehensive health statistics about the vector store.
|
||||
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
|
||||
c.readMu.RLock()
|
||||
defer c.readMu.RUnlock()
|
||||
|
||||
stats := &VectorHealthStats{
|
||||
CurrentModel: c.embedSvc.Version(),
|
||||
CoverageByType: make(map[string]int64),
|
||||
ModelVersions: make(map[string]int64),
|
||||
ProjectCounts: make(map[string]int64),
|
||||
EmbeddingCache: c.stats.Snapshot(),
|
||||
}
|
||||
|
||||
// Get total count
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&stats.TotalVectors)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count total vectors: %w", err)
|
||||
}
|
||||
|
||||
// Get stale count
|
||||
err = c.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
|
||||
stats.CurrentModel,
|
||||
).Scan(&stats.StaleVectors)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count stale vectors: %w", err)
|
||||
}
|
||||
|
||||
// Check if rebuild needed
|
||||
stats.NeedsRebuild, stats.RebuildReason = c.needsRebuildUnlocked(ctx, stats.CurrentModel)
|
||||
|
||||
// Get coverage by doc_type
|
||||
rows, err := c.db.QueryContext(ctx, "SELECT doc_type, COUNT(*) FROM vectors GROUP BY doc_type")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query doc types: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var docType sql.NullString
|
||||
var count int64
|
||||
if err := rows.Scan(&docType, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan doc type: %w", err)
|
||||
}
|
||||
if docType.Valid {
|
||||
stats.CoverageByType[docType.String] = count
|
||||
} else {
|
||||
stats.CoverageByType["unknown"] = count
|
||||
}
|
||||
}
|
||||
|
||||
// Get model version distribution
|
||||
rows2, err := c.db.QueryContext(ctx, "SELECT COALESCE(model_version, 'unknown'), COUNT(*) FROM vectors GROUP BY model_version")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query model versions: %w", err)
|
||||
}
|
||||
defer rows2.Close()
|
||||
|
||||
for rows2.Next() {
|
||||
var version string
|
||||
var count int64
|
||||
if err := rows2.Scan(&version, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan model version: %w", err)
|
||||
}
|
||||
stats.ModelVersions[version] = count
|
||||
}
|
||||
|
||||
// Get project counts (top 10)
|
||||
rows3, err := c.db.QueryContext(ctx,
|
||||
"SELECT COALESCE(project, 'global'), COUNT(*) FROM vectors GROUP BY project ORDER BY COUNT(*) DESC LIMIT 10")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query projects: %w", err)
|
||||
}
|
||||
defer rows3.Close()
|
||||
|
||||
for rows3.Next() {
|
||||
var project string
|
||||
var count int64
|
||||
if err := rows3.Scan(&project, &count); err != nil {
|
||||
return nil, fmt.Errorf("scan project: %w", err)
|
||||
}
|
||||
stats.ProjectCounts[project] = count
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// needsRebuildUnlocked checks if rebuild is needed without acquiring lock (caller must hold lock).
|
||||
func (c *Client) needsRebuildUnlocked(ctx context.Context, currentModel string) (bool, string) {
|
||||
var totalCount int64
|
||||
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if totalCount == 0 {
|
||||
return true, "empty"
|
||||
}
|
||||
|
||||
var staleCount int64
|
||||
err = c.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
|
||||
currentModel,
|
||||
).Scan(&staleCount)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if staleCount > 0 {
|
||||
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
|
||||
// Used for granular rebuild - delete stale vectors before re-adding.
|
||||
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
|
||||
@@ -379,12 +824,12 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Build placeholder string
|
||||
placeholders := make([]string, len(docIDs))
|
||||
args := make([]interface{}, len(docIDs))
|
||||
args := make([]any, len(docIDs))
|
||||
for i, id := range docIDs {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
@@ -402,3 +847,249 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
|
||||
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByObservationID removes all vectors associated with an observation ID.
|
||||
// Vectors are stored with doc_ids that include the observation ID, e.g., "obs_123_narrative".
|
||||
func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
// Vectors have doc_ids like "obs_123_narrative", "obs_123_facts_0", etc.
|
||||
pattern := fmt.Sprintf("obs_%d_%%", obsID)
|
||||
|
||||
_, err := c.db.ExecContext(ctx, "DELETE FROM vectors WHERE doc_id LIKE ?", pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete vectors for observation %d: %w", obsID, err)
|
||||
}
|
||||
|
||||
log.Debug().Int64("observation_id", obsID).Msg("Deleted vectors for observation")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrComputeEmbedding returns a cached embedding or computes a new one.
|
||||
// Uses singleflight to prevent duplicate concurrent computations for the same query.
|
||||
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
// Check cache first (read lock)
|
||||
c.queryCacheMu.RLock()
|
||||
if entry, ok := c.queryCache[query]; ok {
|
||||
if now-entry.timestamp < c.cacheTTLNano {
|
||||
c.queryCacheMu.RUnlock()
|
||||
c.stats.embeddingHits.Add(1)
|
||||
return entry.embedding, nil
|
||||
}
|
||||
}
|
||||
c.queryCacheMu.RUnlock()
|
||||
|
||||
// Cache miss - use singleflight to deduplicate concurrent embedding requests
|
||||
result, err, _ := c.embeddingGroup.Do(query, func() (any, error) {
|
||||
// Double-check cache inside singleflight (another goroutine may have just cached it)
|
||||
c.queryCacheMu.RLock()
|
||||
if entry, ok := c.queryCache[query]; ok {
|
||||
if time.Now().UnixNano()-entry.timestamp < c.cacheTTLNano {
|
||||
c.queryCacheMu.RUnlock()
|
||||
return entry.embedding, nil
|
||||
}
|
||||
}
|
||||
c.queryCacheMu.RUnlock()
|
||||
|
||||
// Record cache miss
|
||||
c.stats.embeddingMisses.Add(1)
|
||||
|
||||
// Compute embedding
|
||||
emb, err := c.embedSvc.Embed(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache (write lock)
|
||||
c.queryCacheMu.Lock()
|
||||
nowCache := time.Now().UnixNano()
|
||||
// Evict old entries if cache is full or near capacity (80% threshold)
|
||||
evictionThreshold := (c.cacheMaxSize * 8) / 10
|
||||
if len(c.queryCache) >= evictionThreshold {
|
||||
// Phase 1: Remove ALL expired entries first (not just 10%)
|
||||
evicted := int64(0)
|
||||
for k, v := range c.queryCache {
|
||||
if nowCache-v.timestamp > c.cacheTTLNano {
|
||||
delete(c.queryCache, k)
|
||||
evicted++
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: If still at capacity, evict 10% using random iteration (O(n) instead of O(n log n))
|
||||
// Go map iteration order is randomized, providing good cache behavior without sorting
|
||||
if len(c.queryCache) >= c.cacheMaxSize {
|
||||
evictCount := max(c.cacheMaxSize/10, 1)
|
||||
for k := range c.queryCache {
|
||||
delete(c.queryCache, k)
|
||||
evicted++
|
||||
evictCount--
|
||||
if evictCount <= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if evicted > 0 {
|
||||
c.stats.embeddingEvictions.Add(evicted)
|
||||
}
|
||||
}
|
||||
c.queryCache[query] = embeddingCacheEntry{
|
||||
embedding: emb,
|
||||
timestamp: nowCache,
|
||||
}
|
||||
c.queryCacheMu.Unlock()
|
||||
|
||||
return emb, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.([]float32), nil
|
||||
}
|
||||
|
||||
// ClearCache clears the embedding cache.
|
||||
func (c *Client) ClearCache() {
|
||||
c.queryCacheMu.Lock()
|
||||
c.queryCache = make(map[string]embeddingCacheEntry)
|
||||
c.queryCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// GetCacheStats returns comprehensive cache statistics.
|
||||
func (c *Client) GetCacheStats() CacheStatsSnapshot {
|
||||
return c.stats.Snapshot()
|
||||
}
|
||||
|
||||
// CacheStats returns basic cache size info for backward compatibility.
|
||||
// Deprecated: Use GetCacheStats() for comprehensive statistics.
|
||||
func (c *Client) CacheStats() (size int, maxSize int) {
|
||||
c.queryCacheMu.RLock()
|
||||
size = len(c.queryCache)
|
||||
c.queryCacheMu.RUnlock()
|
||||
return size, c.cacheMaxSize
|
||||
}
|
||||
|
||||
// EmbeddingCacheSize returns the current embedding cache size.
|
||||
func (c *Client) EmbeddingCacheSize() int {
|
||||
c.queryCacheMu.RLock()
|
||||
defer c.queryCacheMu.RUnlock()
|
||||
return len(c.queryCache)
|
||||
}
|
||||
|
||||
// ResultCacheSize returns the current result cache size.
|
||||
func (c *Client) ResultCacheSize() int {
|
||||
c.resultCacheMu.RLock()
|
||||
defer c.resultCacheMu.RUnlock()
|
||||
return len(c.resultCache)
|
||||
}
|
||||
|
||||
// buildResultCacheKey creates a unique key for caching query results.
|
||||
// Uses strings.Builder to avoid intermediate allocations.
|
||||
func (c *Client) buildResultCacheKey(query string, limit int, where map[string]any) string {
|
||||
// Pre-allocate with typical key size to avoid reallocation
|
||||
var b strings.Builder
|
||||
b.Grow(len(query) + 32) // query + typical prefix/suffix overhead
|
||||
|
||||
b.WriteString("q:")
|
||||
b.WriteString(query)
|
||||
b.WriteString(":l:")
|
||||
b.WriteString(strconv.Itoa(limit))
|
||||
|
||||
if docType, ok := where["doc_type"].(string); ok {
|
||||
b.WriteString(":dt:")
|
||||
b.WriteString(docType)
|
||||
}
|
||||
if project, ok := where["project"].(string); ok {
|
||||
b.WriteString(":p:")
|
||||
b.WriteString(project)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// getResultFromCache retrieves cached results if available and not expired.
|
||||
func (c *Client) getResultFromCache(cacheKey string) ([]QueryResult, bool) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
c.resultCacheMu.RLock()
|
||||
entry, ok := c.resultCache[cacheKey]
|
||||
c.resultCacheMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
c.stats.resultMisses.Add(1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check if entry is expired
|
||||
if now-entry.timestamp > c.resultCacheTTLNano {
|
||||
c.stats.resultMisses.Add(1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.stats.resultHits.Add(1)
|
||||
|
||||
// Return a copy to prevent mutation
|
||||
results := make([]QueryResult, len(entry.results))
|
||||
copy(results, entry.results)
|
||||
return results, true
|
||||
}
|
||||
|
||||
// cacheResults stores query results in the cache.
|
||||
func (c *Client) cacheResults(cacheKey string, results []QueryResult) {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
c.resultCacheMu.Lock()
|
||||
defer c.resultCacheMu.Unlock()
|
||||
|
||||
// Evict old entries if cache is full
|
||||
if len(c.resultCache) >= c.resultCacheMaxSize {
|
||||
// Two-phase eviction: (1) TTL-expired entries, (2) random if still over capacity
|
||||
evicted := 0
|
||||
targetSize := c.resultCacheMaxSize * 8 / 10 // Target 80% capacity
|
||||
|
||||
// Phase 1: Remove all TTL-expired entries
|
||||
for k, v := range c.resultCache {
|
||||
if now-v.timestamp > c.resultCacheTTLNano {
|
||||
delete(c.resultCache, k)
|
||||
evicted++
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: If still over target, remove random entries until at target
|
||||
if len(c.resultCache) >= targetSize {
|
||||
evictCount := len(c.resultCache) - targetSize + 1
|
||||
for k := range c.resultCache {
|
||||
delete(c.resultCache, k)
|
||||
evicted++
|
||||
evictCount--
|
||||
if evictCount <= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if evicted > 0 {
|
||||
c.stats.resultEvictions.Add(int64(evicted))
|
||||
}
|
||||
}
|
||||
|
||||
// Make a copy of results to store
|
||||
resultsCopy := make([]QueryResult, len(results))
|
||||
copy(resultsCopy, results)
|
||||
|
||||
c.resultCache[cacheKey] = resultCacheEntry{
|
||||
results: resultsCopy,
|
||||
timestamp: now,
|
||||
queryHash: cacheKey,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateResultCache clears the result cache.
|
||||
// Should be called after write operations that modify vectors.
|
||||
func (c *Client) InvalidateResultCache() {
|
||||
c.resultCacheMu.Lock()
|
||||
c.resultCache = make(map[string]resultCacheEntry)
|
||||
c.resultCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -338,3 +338,187 @@ func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchSyncConfig configures batch synchronization behavior.
|
||||
type BatchSyncConfig struct {
|
||||
BatchSize int // Number of items per batch (default: 50)
|
||||
ProgressLogFreq int // Log progress every N items (default: 100)
|
||||
}
|
||||
|
||||
// DefaultBatchSyncConfig returns sensible defaults for batch sync.
|
||||
func DefaultBatchSyncConfig() BatchSyncConfig {
|
||||
return BatchSyncConfig{
|
||||
BatchSize: 50,
|
||||
ProgressLogFreq: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// BatchSyncObservations syncs multiple observations efficiently in batches.
|
||||
// This reduces memory pressure during large rebuilds by processing in chunks.
|
||||
func (s *Sync) BatchSyncObservations(ctx context.Context, observations []*models.Observation, cfg BatchSyncConfig) (synced int, errors int) {
|
||||
if len(observations) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 50
|
||||
}
|
||||
if cfg.ProgressLogFreq <= 0 {
|
||||
cfg.ProgressLogFreq = 100
|
||||
}
|
||||
|
||||
for i := 0; i < len(observations); i += cfg.BatchSize {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().Int("synced", synced).Int("remaining", len(observations)-i).Msg("Batch sync cancelled")
|
||||
return synced, errors
|
||||
default:
|
||||
}
|
||||
|
||||
end := min(i+cfg.BatchSize, len(observations))
|
||||
|
||||
batch := observations[i:end]
|
||||
var docs []Document
|
||||
|
||||
// Collect all documents for this batch
|
||||
for _, obs := range batch {
|
||||
docs = append(docs, s.formatObservationDocs(obs)...)
|
||||
}
|
||||
|
||||
// Add all documents in one call
|
||||
if len(docs) > 0 {
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync observation batch")
|
||||
errors += len(batch)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
synced += len(batch)
|
||||
|
||||
// Log progress periodically
|
||||
if synced%cfg.ProgressLogFreq == 0 || synced == len(observations) {
|
||||
log.Debug().Int("synced", synced).Int("total", len(observations)).Msg("Observation batch sync progress")
|
||||
}
|
||||
}
|
||||
|
||||
return synced, errors
|
||||
}
|
||||
|
||||
// BatchSyncSummaries syncs multiple summaries efficiently in batches.
|
||||
func (s *Sync) BatchSyncSummaries(ctx context.Context, summaries []*models.SessionSummary, cfg BatchSyncConfig) (synced int, errors int) {
|
||||
if len(summaries) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 50
|
||||
}
|
||||
if cfg.ProgressLogFreq <= 0 {
|
||||
cfg.ProgressLogFreq = 100
|
||||
}
|
||||
|
||||
for i := 0; i < len(summaries); i += cfg.BatchSize {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().Int("synced", synced).Int("remaining", len(summaries)-i).Msg("Batch sync cancelled")
|
||||
return synced, errors
|
||||
default:
|
||||
}
|
||||
|
||||
end := min(i+cfg.BatchSize, len(summaries))
|
||||
|
||||
batch := summaries[i:end]
|
||||
var docs []Document
|
||||
|
||||
// Collect all documents for this batch
|
||||
for _, summary := range batch {
|
||||
docs = append(docs, s.formatSummaryDocs(summary)...)
|
||||
}
|
||||
|
||||
// Add all documents in one call
|
||||
if len(docs) > 0 {
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync summary batch")
|
||||
errors += len(batch)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
synced += len(batch)
|
||||
|
||||
// Log progress periodically
|
||||
if synced%cfg.ProgressLogFreq == 0 || synced == len(summaries) {
|
||||
log.Debug().Int("synced", synced).Int("total", len(summaries)).Msg("Summary batch sync progress")
|
||||
}
|
||||
}
|
||||
|
||||
return synced, errors
|
||||
}
|
||||
|
||||
// BatchSyncPrompts syncs multiple user prompts efficiently in batches.
|
||||
func (s *Sync) BatchSyncPrompts(ctx context.Context, prompts []*models.UserPromptWithSession, cfg BatchSyncConfig) (synced int, errors int) {
|
||||
if len(prompts) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 50
|
||||
}
|
||||
if cfg.ProgressLogFreq <= 0 {
|
||||
cfg.ProgressLogFreq = 100
|
||||
}
|
||||
|
||||
for i := 0; i < len(prompts); i += cfg.BatchSize {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().Int("synced", synced).Int("remaining", len(prompts)-i).Msg("Batch sync cancelled")
|
||||
return synced, errors
|
||||
default:
|
||||
}
|
||||
|
||||
end := min(i+cfg.BatchSize, len(prompts))
|
||||
|
||||
batch := prompts[i:end]
|
||||
docs := make([]Document, 0, len(batch))
|
||||
|
||||
// Collect all documents for this batch
|
||||
for _, prompt := range batch {
|
||||
docs = append(docs, Document{
|
||||
ID: fmt.Sprintf("prompt_%d", prompt.ID),
|
||||
Content: prompt.PromptText,
|
||||
Metadata: map[string]any{
|
||||
"sqlite_id": prompt.ID,
|
||||
"doc_type": "user_prompt",
|
||||
"sdk_session_id": prompt.SDKSessionID,
|
||||
"project": prompt.Project,
|
||||
"scope": "",
|
||||
"created_at_epoch": prompt.CreatedAtEpoch,
|
||||
"prompt_number": prompt.PromptNumber,
|
||||
"field_type": "prompt",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Add all documents in one call
|
||||
if len(docs) > 0 {
|
||||
if err := s.client.AddDocuments(ctx, docs); err != nil {
|
||||
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync prompt batch")
|
||||
errors += len(batch)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
synced += len(batch)
|
||||
|
||||
// Log progress periodically
|
||||
if synced%cfg.ProgressLogFreq == 0 || synced == len(prompts) {
|
||||
log.Debug().Int("synced", synced).Int("total", len(prompts)).Msg("Prompt batch sync progress")
|
||||
}
|
||||
}
|
||||
|
||||
return synced, errors
|
||||
}
|
||||
|
||||
+125
-1191
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,677 @@
|
||||
// Package worker provides context and search-related HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleSearchByPrompt searches observations relevant to a user prompt.
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
|
||||
if project == "" || query == "" {
|
||||
http.Error(w, "project and query required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limit := gorm.ParseLimitParamWithMax(r, DefaultSearchLimit, 200)
|
||||
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
var usedVector bool
|
||||
similarityScores := make(map[int64]float64) // Track similarity per observation
|
||||
|
||||
// Get threshold settings from config
|
||||
threshold := s.config.ContextRelevanceThreshold
|
||||
maxResults := s.config.ContextMaxPromptResults
|
||||
|
||||
// Generate expanded queries if query expander is available
|
||||
// Use timeout context to prevent query expansion from blocking
|
||||
var expandedQueries []expansion.ExpandedQuery
|
||||
var detectedIntent string
|
||||
if s.queryExpander != nil {
|
||||
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
cfg := expansion.DefaultConfig()
|
||||
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
||||
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
|
||||
expandCancel() // Cancel immediately after use (defer not needed - no panic possible between creation and here)
|
||||
if len(expandedQueries) > 0 {
|
||||
detectedIntent = string(expandedQueries[0].Intent)
|
||||
}
|
||||
}
|
||||
if len(expandedQueries) == 0 {
|
||||
// Fallback to just the original query
|
||||
expandedQueries = []expansion.ExpandedQuery{
|
||||
{Query: query, Weight: 1.0, Source: "original"},
|
||||
}
|
||||
}
|
||||
|
||||
// Try vector search first if available
|
||||
var vectorSearchFailed bool
|
||||
if s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
|
||||
// Search with each expanded query and merge results
|
||||
// Pre-allocate with estimated capacity to avoid repeated reallocation
|
||||
estimatedCapacity := len(expandedQueries) * limit * 2
|
||||
allVectorResults := make([]sqlitevec.QueryResult, 0, estimatedCapacity)
|
||||
queryWeights := make(map[string]float64, len(expandedQueries))
|
||||
var vectorErrors int
|
||||
|
||||
for _, eq := range expandedQueries {
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
|
||||
if vecErr != nil {
|
||||
vectorErrors++
|
||||
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
|
||||
} else if len(vectorResults) > 0 {
|
||||
// Apply weight to similarity scores before merging
|
||||
for i := range vectorResults {
|
||||
vectorResults[i].Similarity *= eq.Weight
|
||||
}
|
||||
allVectorResults = append(allVectorResults, vectorResults...)
|
||||
queryWeights[eq.Query] = eq.Weight
|
||||
}
|
||||
}
|
||||
|
||||
// Track if vector search had issues
|
||||
if vectorErrors > 0 && vectorErrors == len(expandedQueries) {
|
||||
vectorSearchFailed = true
|
||||
log.Warn().Int("errors", vectorErrors).Str("project", project).Msg("All vector queries failed, falling back to FTS")
|
||||
}
|
||||
|
||||
if len(allVectorResults) > 0 {
|
||||
// Filter by relevance threshold before extracting IDs
|
||||
// Use a slightly lower threshold for expanded queries
|
||||
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
|
||||
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
|
||||
|
||||
// Build similarity map for filtered results (keeping highest weighted score per observation)
|
||||
for _, vr := range filteredResults {
|
||||
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
|
||||
id := int64(sqliteID)
|
||||
// Keep the highest score for each observation
|
||||
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
|
||||
similarityScores[id] = vr.Similarity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract observation IDs with project/scope filtering using shared helper
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
|
||||
|
||||
if len(obsIDs) > 0 {
|
||||
// Fetch full observations from SQLite
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to FTS if vector search not available, failed, or returned no results
|
||||
if !usedVector || len(observations) == 0 {
|
||||
if vectorSearchFailed {
|
||||
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
|
||||
}
|
||||
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
|
||||
if err != nil {
|
||||
// FTS might fail if query has special chars, try without
|
||||
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
|
||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast staleness filter - NO verification (that's too slow for interactive use)
|
||||
// Just check mtimes and exclude obviously stale observations
|
||||
var staleCount int
|
||||
freshObservations := make([]*models.Observation, 0, len(observations))
|
||||
|
||||
for _, obs := range observations {
|
||||
if len(obs.FileMtimes) > 0 && cwd != "" {
|
||||
var paths []string
|
||||
for path := range obs.FileMtimes {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
currentMtimes := sdk.GetFileMtimes(paths, cwd)
|
||||
|
||||
if obs.CheckStaleness(currentMtimes) {
|
||||
// Stale - exclude but don't verify (too slow)
|
||||
// Queue for background verification instead
|
||||
staleCount++
|
||||
s.queueStaleVerification(obs.ID, cwd)
|
||||
continue
|
||||
}
|
||||
}
|
||||
freshObservations = append(freshObservations, obs)
|
||||
}
|
||||
|
||||
// Apply cross-encoder reranking if available
|
||||
var reranked bool
|
||||
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
|
||||
// Build candidates from observations with their bi-encoder scores
|
||||
candidates := make([]reranking.Candidate, len(freshObservations))
|
||||
for i, obs := range freshObservations {
|
||||
// Use strings.Builder for efficient concatenation
|
||||
var content string
|
||||
if obs.Narrative.Valid && obs.Narrative.String != "" {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(obs.Title.String) + 1 + len(obs.Narrative.String))
|
||||
sb.WriteString(obs.Title.String)
|
||||
sb.WriteByte(' ')
|
||||
sb.WriteString(obs.Narrative.String)
|
||||
content = sb.String()
|
||||
} else {
|
||||
content = obs.Title.String
|
||||
}
|
||||
candidates[i] = reranking.Candidate{
|
||||
ID: strconv.FormatInt(obs.ID, 10), // Faster than fmt.Sprintf
|
||||
Content: content,
|
||||
Score: similarityScores[obs.ID],
|
||||
Metadata: map[string]any{"obs_idx": i},
|
||||
}
|
||||
}
|
||||
|
||||
// Rerank using cross-encoder - use pure mode or combined scores
|
||||
var rerankResults []reranking.RerankResult
|
||||
var rerankErr error
|
||||
if s.config.RerankingPureMode {
|
||||
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
|
||||
} else {
|
||||
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
|
||||
}
|
||||
if rerankErr != nil {
|
||||
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
|
||||
} else if len(rerankResults) > 0 {
|
||||
// Update similarity scores with reranked scores
|
||||
for _, rr := range rerankResults {
|
||||
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||
similarityScores[id] = rr.CombinedScore
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder observations based on rerank results
|
||||
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
|
||||
obsMap := make(map[int64]*models.Observation)
|
||||
for _, obs := range freshObservations {
|
||||
obsMap[obs.ID] = obs
|
||||
}
|
||||
for _, rr := range rerankResults {
|
||||
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||
if obs, ok := obsMap[id]; ok {
|
||||
reorderedObs = append(reorderedObs, obs)
|
||||
}
|
||||
}
|
||||
}
|
||||
freshObservations = reorderedObs
|
||||
reranked = true
|
||||
|
||||
log.Debug().
|
||||
Int("candidates", len(candidates)).
|
||||
Int("returned", len(rerankResults)).
|
||||
Msg("Cross-encoder reranking complete")
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster similar observations to remove duplicates
|
||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
|
||||
|
||||
// Sort by similarity score (highest first) if we have scores and didn't rerank
|
||||
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
|
||||
sort.Slice(clusteredObservations, func(i, j int) bool {
|
||||
scoreI := similarityScores[clusteredObservations[i].ID]
|
||||
scoreJ := similarityScores[clusteredObservations[j].ID]
|
||||
return scoreI > scoreJ
|
||||
})
|
||||
}
|
||||
|
||||
// Apply max results cap if configured
|
||||
if maxResults > 0 && len(clusteredObservations) > maxResults {
|
||||
clusteredObservations = clusteredObservations[:maxResults]
|
||||
}
|
||||
|
||||
// Record retrieval stats with staleness metrics
|
||||
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
|
||||
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), true)
|
||||
|
||||
// Increment retrieval counts for scoring (async, non-blocking)
|
||||
if len(clusteredObservations) > 0 {
|
||||
ids := make([]int64, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
ids[i] = obs.ID
|
||||
}
|
||||
s.incrementRetrievalCounts(ids)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Str("query", query).
|
||||
Str("intent", detectedIntent).
|
||||
Int("expansions", len(expandedQueries)).
|
||||
Int("found", len(clusteredObservations)).
|
||||
Int("stale_excluded", staleCount).
|
||||
Float64("threshold", threshold).
|
||||
Msg("Prompt-based observation search")
|
||||
|
||||
// Build response with similarity scores
|
||||
obsWithScores := make([]map[string]any, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
obsMap := obs.ToMap()
|
||||
if score, ok := similarityScores[obs.ID]; ok {
|
||||
obsMap["similarity"] = score
|
||||
}
|
||||
obsWithScores[i] = obsMap
|
||||
}
|
||||
|
||||
// Build expansion info for response
|
||||
expansionInfo := make([]map[string]any, len(expandedQueries))
|
||||
for i, eq := range expandedQueries {
|
||||
expansionInfo[i] = map[string]any{
|
||||
"query": eq.Query,
|
||||
"weight": eq.Weight,
|
||||
"source": eq.Source,
|
||||
}
|
||||
}
|
||||
|
||||
// Track this search for analytics
|
||||
s.trackSearchQuery(query, project, "observations", len(clusteredObservations), usedVector)
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"project": project,
|
||||
"query": query,
|
||||
"intent": detectedIntent,
|
||||
"expansions": expansionInfo,
|
||||
"observations": obsWithScores,
|
||||
"threshold": threshold,
|
||||
"max_results": maxResults,
|
||||
})
|
||||
}
|
||||
|
||||
// handleFileContext returns observations relevant to specific files being worked on.
|
||||
// Uses vector similarity search to find observations that mention or relate to the given files.
|
||||
func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
filesParam := r.URL.Query().Get("files")
|
||||
if filesParam == "" {
|
||||
http.Error(w, "files required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse comma-separated file paths
|
||||
files := strings.Split(filesParam, ",")
|
||||
if len(files) == 0 {
|
||||
http.Error(w, "at least one file required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to reasonable number of files
|
||||
maxFiles := 20
|
||||
if len(files) > maxFiles {
|
||||
files = files[:maxFiles]
|
||||
}
|
||||
|
||||
// Get limit parameter (default 10 per file)
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
limit := 10
|
||||
if limitStr != "" {
|
||||
if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 && parsed <= 50 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Search for observations related to each file in parallel
|
||||
ctx := r.Context()
|
||||
|
||||
// Check if vector search is available
|
||||
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
|
||||
writeJSON(w, map[string]any{
|
||||
"files": files,
|
||||
"results": map[string]any{},
|
||||
"count": 0,
|
||||
"error": "vector search not available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare for parallel execution
|
||||
type fileResult struct {
|
||||
file string
|
||||
results []map[string]any
|
||||
obsIDs []int64 // Track observation IDs for deduplication
|
||||
}
|
||||
|
||||
resultsChan := make(chan fileResult, len(files))
|
||||
sem := make(chan struct{}, 5) // Limit concurrency to 5 parallel searches
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, file := range files {
|
||||
file = strings.TrimSpace(file)
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(file string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // Acquire semaphore
|
||||
defer func() { <-sem }() // Release semaphore
|
||||
|
||||
// Build search query from file path
|
||||
query := buildFileQuery(file)
|
||||
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||
if vecErr != nil {
|
||||
log.Warn().Err(vecErr).Str("file", file).Msg("Vector search failed for file context")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract observation IDs from vector results
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
if len(obsIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch observations
|
||||
observations, err := s.observationStore.GetObservationsByIDs(ctx, obsIDs, "score_desc", limit*2)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("file", file).Msg("Failed to fetch observations for file context")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-build score map from vector results (O(n) instead of O(n²))
|
||||
scoreMap := make(map[int64]float64, len(vectorResults))
|
||||
var avgScore float64
|
||||
for _, vr := range vectorResults {
|
||||
avgScore += vr.Similarity
|
||||
// Parse observation ID from vector result ID (format: obs_{id}_{field})
|
||||
// Use index-based parsing to avoid slice allocation from strings.Split
|
||||
if len(vr.ID) > 4 && vr.ID[:4] == "obs_" {
|
||||
rest := vr.ID[4:] // Skip "obs_"
|
||||
underscoreIdx := strings.IndexByte(rest, '_')
|
||||
var idStr string
|
||||
if underscoreIdx >= 0 {
|
||||
idStr = rest[:underscoreIdx]
|
||||
} else {
|
||||
idStr = rest
|
||||
}
|
||||
if id, parseErr := strconv.ParseInt(idStr, 10, 64); parseErr == nil {
|
||||
// Keep highest score for each observation
|
||||
if existing, exists := scoreMap[id]; !exists || vr.Similarity > existing {
|
||||
scoreMap[id] = vr.Similarity
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(vectorResults) > 0 {
|
||||
avgScore /= float64(len(vectorResults))
|
||||
}
|
||||
|
||||
fileResults := make([]map[string]any, 0, limit)
|
||||
var usedIDs []int64
|
||||
for _, obs := range observations {
|
||||
// Check project scope
|
||||
if obs.Scope == "project" && obs.Project != project {
|
||||
continue
|
||||
}
|
||||
|
||||
// O(1) score lookup instead of O(n) nested loop
|
||||
score, found := scoreMap[obs.ID]
|
||||
if !found {
|
||||
// Use average score as fallback
|
||||
score = avgScore
|
||||
}
|
||||
|
||||
// Only include if score is above threshold
|
||||
if score < 0.3 {
|
||||
continue
|
||||
}
|
||||
|
||||
fileResults = append(fileResults, map[string]any{
|
||||
"id": obs.ID,
|
||||
"title": obs.Title.String,
|
||||
"type": obs.Type,
|
||||
"narrative": obs.Narrative.String,
|
||||
"facts": obs.Facts,
|
||||
"score": score,
|
||||
})
|
||||
usedIDs = append(usedIDs, obs.ID)
|
||||
|
||||
if len(fileResults) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(fileResults) > 0 {
|
||||
resultsChan <- fileResult{file: file, results: fileResults, obsIDs: usedIDs}
|
||||
}
|
||||
}(file)
|
||||
}
|
||||
|
||||
// Close channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultsChan)
|
||||
}()
|
||||
|
||||
// Collect results and deduplicate
|
||||
allResults := make(map[string]any)
|
||||
seenObservationIDs := make(map[int64]bool)
|
||||
|
||||
for res := range resultsChan {
|
||||
// Filter out duplicates that were already seen in other files
|
||||
dedupedResults := make([]map[string]any, 0, len(res.results))
|
||||
for i, r := range res.results {
|
||||
obsID := res.obsIDs[i]
|
||||
if !seenObservationIDs[obsID] {
|
||||
seenObservationIDs[obsID] = true
|
||||
dedupedResults = append(dedupedResults, r)
|
||||
}
|
||||
}
|
||||
if len(dedupedResults) > 0 {
|
||||
allResults[res.file] = dedupedResults
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"files": files,
|
||||
"results": allResults,
|
||||
"count": len(allResults),
|
||||
})
|
||||
}
|
||||
|
||||
// buildFileQuery extracts meaningful search terms from a file path.
|
||||
func buildFileQuery(filePath string) string {
|
||||
// Remove common prefixes and extensions
|
||||
path := strings.TrimPrefix(filePath, "/")
|
||||
|
||||
// Extract the filename and directory
|
||||
parts := strings.Split(path, "/")
|
||||
meaningful := make([]string, 0, len(parts))
|
||||
|
||||
for _, part := range parts {
|
||||
// Skip common directory names that aren't meaningful
|
||||
switch strings.ToLower(part) {
|
||||
case "src", "lib", "internal", "pkg", "cmd", "api", "app", "test", "tests", "spec", "specs":
|
||||
continue
|
||||
default:
|
||||
// Remove file extension
|
||||
if idx := strings.LastIndex(part, "."); idx > 0 {
|
||||
part = part[:idx]
|
||||
}
|
||||
// Convert camelCase/PascalCase to spaces
|
||||
part = splitCamelCase(part)
|
||||
// Convert snake_case to spaces
|
||||
part = strings.ReplaceAll(part, "_", " ")
|
||||
// Convert kebab-case to spaces
|
||||
part = strings.ReplaceAll(part, "-", " ")
|
||||
meaningful = append(meaningful, part)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(meaningful, " ")
|
||||
}
|
||||
|
||||
// splitCamelCase splits camelCase or PascalCase into separate words.
|
||||
func splitCamelCase(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
result.WriteRune(' ')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// handleContextInject returns context for injection at session start.
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
if cwd == "" {
|
||||
cwd = "/"
|
||||
}
|
||||
|
||||
// Limit observations for fast startup (configurable, default 100)
|
||||
limit := s.config.ContextObservations
|
||||
if limit <= 0 {
|
||||
limit = DefaultContextLimit
|
||||
}
|
||||
|
||||
// Full count determines how many observations get full detail (configurable, default 25)
|
||||
fullCount := s.config.ContextFullCount
|
||||
if fullCount <= 0 {
|
||||
fullCount = 25
|
||||
}
|
||||
|
||||
// Get recent observations
|
||||
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Fast staleness filter - NO verification (that's too slow for startup)
|
||||
var staleCount int
|
||||
freshObservations := make([]*models.Observation, 0, len(observations))
|
||||
|
||||
for _, obs := range observations {
|
||||
if len(obs.FileMtimes) > 0 {
|
||||
var paths []string
|
||||
for path := range obs.FileMtimes {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
currentMtimes := sdk.GetFileMtimes(paths, cwd)
|
||||
|
||||
if obs.CheckStaleness(currentMtimes) {
|
||||
// Stale - exclude but don't verify (too slow)
|
||||
// Queue for background verification instead
|
||||
staleCount++
|
||||
s.queueStaleVerification(obs.ID, cwd)
|
||||
continue
|
||||
}
|
||||
}
|
||||
freshObservations = append(freshObservations, obs)
|
||||
}
|
||||
|
||||
// Cluster similar observations to remove duplicates
|
||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
|
||||
|
||||
// Record retrieval stats with staleness metrics
|
||||
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
|
||||
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), false)
|
||||
|
||||
// Increment retrieval counts for scoring (async, non-blocking)
|
||||
if len(clusteredObservations) > 0 {
|
||||
ids := make([]int64, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
ids[i] = obs.ID
|
||||
}
|
||||
s.incrementRetrievalCounts(ids)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Int("total", len(observations)).
|
||||
Int("fresh", len(freshObservations)).
|
||||
Int("clustered", len(clusteredObservations)).
|
||||
Int("duplicates", duplicatesRemoved).
|
||||
Int("stale_excluded", staleCount).
|
||||
Msg("Context injection with clustering")
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"project": project,
|
||||
"observations": clusteredObservations,
|
||||
"full_count": fullCount,
|
||||
"stale_excluded": staleCount,
|
||||
"duplicates_removed": duplicatesRemoved,
|
||||
})
|
||||
}
|
||||
|
||||
// handleContextCount returns the count of observations for a project.
|
||||
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := s.getCachedObservationCount(r.Context(), project)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"project": project,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,595 @@
|
||||
// Package worker provides data retrieval HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleGetObservations returns recent observations.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
// Supports pagination via limit and offset query parameters.
|
||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var observations []*models.Observation
|
||||
var total int64
|
||||
var err error
|
||||
var usedVector bool
|
||||
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
if len(obsIDs) > 0 {
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
total = int64(len(observations)) // Vector search doesn't have total, use returned count
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
// Strict project filtering for dashboard - only observations from this project
|
||||
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
|
||||
} else {
|
||||
// All projects
|
||||
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
|
||||
// Track search if query was provided
|
||||
if query != "" {
|
||||
s.trackSearchQuery(query, project, "observations", len(observations), usedVector)
|
||||
}
|
||||
|
||||
// Return paginated response
|
||||
writeJSON(w, map[string]any{
|
||||
"observations": observations,
|
||||
"total": total,
|
||||
"limit": pagination.Limit,
|
||||
"offset": pagination.Offset,
|
||||
"hasMore": int64(pagination.Offset)+int64(len(observations)) < total,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetSummaries returns recent summaries.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var summaries []*models.SessionSummary
|
||||
var err error
|
||||
var usedVector bool
|
||||
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
|
||||
if len(summaryIDs) > 0 {
|
||||
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
|
||||
} else {
|
||||
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if summaries == nil {
|
||||
summaries = []*models.SessionSummary{}
|
||||
}
|
||||
writeJSON(w, summaries)
|
||||
}
|
||||
|
||||
// handleGetPrompts returns recent user prompts.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var prompts []*models.UserPromptWithSession
|
||||
var err error
|
||||
var usedVector bool
|
||||
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
|
||||
if len(promptIDs) > 0 {
|
||||
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
|
||||
} else {
|
||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if prompts == nil {
|
||||
prompts = []*models.UserPromptWithSession{}
|
||||
}
|
||||
writeJSON(w, prompts)
|
||||
}
|
||||
|
||||
// handleGetProjects returns all projects.
|
||||
// Response is cacheable for 5 minutes since project list changes infrequently.
|
||||
func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) {
|
||||
projects, err := s.sessionStore.GetAllProjects(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Cache for 5 minutes - project list changes infrequently
|
||||
w.Header().Set("Cache-Control", "public, max-age=300")
|
||||
writeJSON(w, projects)
|
||||
}
|
||||
|
||||
// handleGetTypes returns the canonical list of observation and concept types.
|
||||
// This provides a single source of truth for both backend and frontend.
|
||||
// Response is cacheable as these values never change at runtime.
|
||||
func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
|
||||
// Cache for 24 hours - these values are compile-time constants
|
||||
w.Header().Set("Cache-Control", "public, max-age=86400")
|
||||
writeJSON(w, map[string]any{
|
||||
"observation_types": ObservationTypes,
|
||||
"concept_types": ConceptTypes,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetModels returns available embedding models.
|
||||
// Response is cacheable as model list doesn't change without restart.
|
||||
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
|
||||
// Cache for 1 hour - model list is static during runtime
|
||||
w.Header().Set("Cache-Control", "public, max-age=3600")
|
||||
|
||||
models := embedding.ListModels()
|
||||
defaultModel := embedding.GetDefaultModel()
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"models": models,
|
||||
"default": defaultModel,
|
||||
"current": s.embedSvc.Version(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetStats returns worker statistics.
|
||||
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
retrievalStats := s.GetRetrievalStats(project)
|
||||
sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context())
|
||||
|
||||
response := map[string]any{
|
||||
"uptime": time.Since(s.startTime).String(),
|
||||
"uptimeSeconds": time.Since(s.startTime).Seconds(),
|
||||
"activeSessions": s.sessionManager.GetActiveSessionCount(),
|
||||
"queueDepth": s.sessionManager.GetTotalQueueDepth(),
|
||||
"isProcessing": s.sessionManager.IsAnySessionProcessing(),
|
||||
"connectedClients": s.sseBroadcaster.ClientCount(),
|
||||
"sessionsToday": sessionsToday,
|
||||
"retrieval": retrievalStats,
|
||||
"ready": s.ready.Load(),
|
||||
}
|
||||
|
||||
// Add memory stats
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
response["memory"] = map[string]any{
|
||||
"alloc_mb": float64(memStats.Alloc) / 1024 / 1024,
|
||||
"total_alloc_mb": float64(memStats.TotalAlloc) / 1024 / 1024,
|
||||
"sys_mb": float64(memStats.Sys) / 1024 / 1024,
|
||||
"heap_alloc_mb": float64(memStats.HeapAlloc) / 1024 / 1024,
|
||||
"heap_inuse_mb": float64(memStats.HeapInuse) / 1024 / 1024,
|
||||
"heap_objects": memStats.HeapObjects,
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"gc_cycles": memStats.NumGC,
|
||||
"gc_pause_total_ms": float64(memStats.PauseTotalNs) / 1e6,
|
||||
}
|
||||
|
||||
// Add database health if available
|
||||
if s.store != nil {
|
||||
dbHealth := s.store.HealthCheck(r.Context())
|
||||
response["database"] = map[string]any{
|
||||
"status": dbHealth.Status,
|
||||
"query_latency_ms": float64(dbHealth.QueryLatency) / 1e6,
|
||||
"pool": dbHealth.PoolStats,
|
||||
"warning": dbHealth.Warning,
|
||||
}
|
||||
}
|
||||
|
||||
// Add embedding model info
|
||||
if s.embedSvc != nil {
|
||||
response["embeddingModel"] = map[string]any{
|
||||
"name": s.embedSvc.Name(),
|
||||
"version": s.embedSvc.Version(),
|
||||
"dimensions": s.embedSvc.Dimensions(),
|
||||
}
|
||||
}
|
||||
|
||||
// Add vector cache stats
|
||||
if s.vectorClient != nil {
|
||||
if count, err := s.vectorClient.Count(r.Context()); err == nil {
|
||||
response["vectorCount"] = count
|
||||
}
|
||||
cacheSize, cacheMax := s.vectorClient.CacheStats()
|
||||
response["vectorCache"] = map[string]any{
|
||||
"size": cacheSize,
|
||||
"max_size": cacheMax,
|
||||
}
|
||||
}
|
||||
|
||||
// Include project-specific observation count if project is specified
|
||||
if project != "" {
|
||||
count, err := s.getCachedObservationCount(r.Context(), project)
|
||||
if err == nil {
|
||||
response["projectObservations"] = count
|
||||
response["project"] = project
|
||||
}
|
||||
}
|
||||
|
||||
// Add rate limiter stats
|
||||
if s.rateLimiter != nil {
|
||||
response["rateLimiter"] = s.rateLimiter.Stats()
|
||||
}
|
||||
|
||||
// Add circuit breaker metrics
|
||||
if s.processor != nil {
|
||||
response["circuitBreaker"] = s.processor.CircuitBreakerMetrics()
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleGetRetrievalStats returns detailed retrieval statistics.
|
||||
func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
stats := s.GetRetrievalStats(project)
|
||||
writeJSON(w, stats)
|
||||
}
|
||||
|
||||
// handleGetRecentQueries returns recent search queries for analytics.
|
||||
func (s *Service) handleGetRecentQueries(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
limit := gorm.ParseLimitParam(r, 20)
|
||||
|
||||
queries := s.getRecentSearchQueries(project, limit)
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"queries": queries,
|
||||
"count": len(queries),
|
||||
"project": project,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetSearchAnalytics returns comprehensive search analytics and statistics.
|
||||
func (s *Service) handleGetSearchAnalytics(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
// Get all recent queries for analysis
|
||||
queries := s.getRecentSearchQueries(project, maxRecentQueries)
|
||||
|
||||
// Calculate analytics
|
||||
totalQueries := len(queries)
|
||||
vectorSearches := 0
|
||||
totalResults := 0
|
||||
zeroResultQueries := 0
|
||||
queryTypes := make(map[string]int)
|
||||
topKeywords := make(map[string]int)
|
||||
|
||||
for _, q := range queries {
|
||||
if q.UsedVector {
|
||||
vectorSearches++
|
||||
}
|
||||
totalResults += q.Results
|
||||
if q.Results == 0 {
|
||||
zeroResultQueries++
|
||||
}
|
||||
queryTypes[q.Type]++
|
||||
|
||||
// Extract keywords (simple word tokenization using iterator)
|
||||
for word := range strings.FieldsSeq(strings.ToLower(q.Query)) {
|
||||
if len(word) > 3 { // Skip short words
|
||||
topKeywords[word]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort keywords by frequency
|
||||
type keywordCount struct {
|
||||
Keyword string `json:"keyword"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
sortedKeywords := make([]keywordCount, 0, len(topKeywords))
|
||||
for kw, count := range topKeywords {
|
||||
sortedKeywords = append(sortedKeywords, keywordCount{Keyword: kw, Count: count})
|
||||
}
|
||||
sort.Slice(sortedKeywords, func(i, j int) bool {
|
||||
return sortedKeywords[i].Count > sortedKeywords[j].Count
|
||||
})
|
||||
if len(sortedKeywords) > 10 {
|
||||
sortedKeywords = sortedKeywords[:10]
|
||||
}
|
||||
|
||||
// Calculate averages
|
||||
avgResults := float64(0)
|
||||
vectorSearchRate := float64(0)
|
||||
zeroResultRate := float64(0)
|
||||
if totalQueries > 0 {
|
||||
avgResults = float64(totalResults) / float64(totalQueries)
|
||||
vectorSearchRate = float64(vectorSearches) / float64(totalQueries) * 100
|
||||
zeroResultRate = float64(zeroResultQueries) / float64(totalQueries) * 100
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"total_queries": totalQueries,
|
||||
"vector_search_rate": vectorSearchRate,
|
||||
"avg_results": avgResults,
|
||||
"zero_result_rate": zeroResultRate,
|
||||
"query_types": queryTypes,
|
||||
"top_keywords": sortedKeywords,
|
||||
"project": project,
|
||||
})
|
||||
}
|
||||
|
||||
// handleVectorHealth returns comprehensive health information about the vector database.
|
||||
func (s *Service) handleVectorHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if s.vectorClient == nil {
|
||||
http.Error(w, "vector client not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := s.vectorClient.GetHealthStats(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Add additional computed metrics
|
||||
healthScore := 100.0
|
||||
var warnings []string
|
||||
|
||||
// Penalize for stale vectors
|
||||
if stats.TotalVectors > 0 {
|
||||
staleRatio := float64(stats.StaleVectors) / float64(stats.TotalVectors)
|
||||
if staleRatio > 0 {
|
||||
healthScore -= staleRatio * 50 // Up to 50 points off for stale vectors
|
||||
warnings = append(warnings, formatWarning("%.1f%% vectors need rebuild", staleRatio*100))
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache effectiveness
|
||||
cacheHitRate := stats.EmbeddingCache.HitRate()
|
||||
if cacheHitRate < 20 && (stats.EmbeddingCache.EmbeddingHits+stats.EmbeddingCache.EmbeddingMisses) > 100 {
|
||||
healthScore -= 10
|
||||
warnings = append(warnings, formatWarning("Low cache hit rate: %.1f%%", cacheHitRate))
|
||||
}
|
||||
|
||||
// Penalize if rebuild is needed
|
||||
if stats.NeedsRebuild {
|
||||
healthScore -= 20
|
||||
warnings = append(warnings, "Vector rebuild recommended: "+stats.RebuildReason)
|
||||
}
|
||||
|
||||
if healthScore < 0 {
|
||||
healthScore = 0
|
||||
}
|
||||
|
||||
status := "healthy"
|
||||
if healthScore < 50 {
|
||||
status = "unhealthy"
|
||||
} else if healthScore < 80 {
|
||||
status = "degraded"
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"status": status,
|
||||
"health_score": healthScore,
|
||||
"warnings": warnings,
|
||||
"stats": stats,
|
||||
"cache_hit_rate": cacheHitRate,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateObservationRequest is the request body for updating an observation.
|
||||
type UpdateObservationRequest struct {
|
||||
Title *string `json:"title,omitempty"`
|
||||
Subtitle *string `json:"subtitle,omitempty"`
|
||||
Narrative *string `json:"narrative,omitempty"`
|
||||
Facts []string `json:"facts,omitempty"`
|
||||
Concepts []string `json:"concepts,omitempty"`
|
||||
FilesRead []string `json:"files_read,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
Scope *string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// handleUpdateObservation updates an existing observation.
|
||||
// PUT /api/observations/{id}
|
||||
func (s *Service) handleUpdateObservation(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse observation ID from URL
|
||||
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req UpdateObservationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Build update struct - only include fields that were provided
|
||||
update := &gorm.ObservationUpdate{}
|
||||
|
||||
if req.Title != nil {
|
||||
update.Title = req.Title
|
||||
}
|
||||
if req.Subtitle != nil {
|
||||
update.Subtitle = req.Subtitle
|
||||
}
|
||||
if req.Narrative != nil {
|
||||
update.Narrative = req.Narrative
|
||||
}
|
||||
if req.Facts != nil {
|
||||
update.Facts = &req.Facts
|
||||
}
|
||||
if req.Concepts != nil {
|
||||
update.Concepts = &req.Concepts
|
||||
}
|
||||
if req.FilesRead != nil {
|
||||
update.FilesRead = &req.FilesRead
|
||||
}
|
||||
if req.FilesModified != nil {
|
||||
update.FilesModified = &req.FilesModified
|
||||
}
|
||||
if req.Scope != nil {
|
||||
// Validate scope
|
||||
if *req.Scope != "project" && *req.Scope != "global" {
|
||||
http.Error(w, "scope must be 'project' or 'global'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
update.Scope = req.Scope
|
||||
}
|
||||
|
||||
// Update the observation
|
||||
updatedObs, err := s.observationStore.UpdateObservation(r.Context(), id, update)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "failed to update observation: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger vector resync for the updated observation
|
||||
if s.vectorSync != nil {
|
||||
s.asyncVectorSync(func() {
|
||||
if err := s.vectorSync.SyncObservation(s.ctx, updatedObs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", id).Msg("Failed to resync observation vectors after update")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Broadcast update event
|
||||
s.sseBroadcaster.Broadcast(map[string]any{
|
||||
"type": "observation_updated",
|
||||
"id": id,
|
||||
})
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"observation": updatedObs,
|
||||
"message": "observation updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetObservationByID returns a single observation by ID.
|
||||
// GET /api/observations/{id}
|
||||
func (s *Service) handleGetObservationByID(w http.ResponseWriter, r *http.Request) {
|
||||
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
obs, err := s.observationStore.GetObservationByID(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get observation: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if obs == nil {
|
||||
http.Error(w, "observation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, obs)
|
||||
}
|
||||
@@ -0,0 +1,680 @@
|
||||
// Package worker provides import, export, and archive HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// BulkImportRequest is the request body for bulk observation import.
|
||||
type BulkImportRequest struct {
|
||||
Project string `json:"project"`
|
||||
Observations []BulkObservationInput `json:"observations"`
|
||||
}
|
||||
|
||||
// BulkObservationInput represents a single observation in bulk import.
|
||||
type BulkObservationInput struct {
|
||||
Type string `json:"type"` // bugfix, feature, refactor, etc.
|
||||
Title string `json:"title"`
|
||||
Subtitle string `json:"subtitle,omitempty"`
|
||||
Facts []string `json:"facts,omitempty"`
|
||||
Narrative string `json:"narrative,omitempty"`
|
||||
Concepts []string `json:"concepts,omitempty"`
|
||||
FilesRead []string `json:"files_read,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
Scope string `json:"scope,omitempty"` // project or global
|
||||
}
|
||||
|
||||
// BulkImportResponse contains the result of a bulk import operation.
|
||||
type BulkImportResponse struct {
|
||||
Imported int `json:"imported"`
|
||||
Failed int `json:"failed"`
|
||||
SkippedDuplicates int `json:"skipped_duplicates,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// handleBulkImport handles bulk import of observations.
|
||||
// This is useful for migrating data or importing observations from external sources.
|
||||
func (s *Service) handleBulkImport(w http.ResponseWriter, r *http.Request) {
|
||||
// Rate limit bulk operations to prevent DoS
|
||||
if s.bulkOpLimiter != nil && !s.bulkOpLimiter.CanExecute() {
|
||||
remaining := s.bulkOpLimiter.CooldownRemaining()
|
||||
http.Error(w, fmt.Sprintf("bulk import rate limited, retry in %d seconds", remaining), http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
var req BulkImportRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Project == "" {
|
||||
http.Error(w, "project is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(req.Project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Observations) == 0 {
|
||||
http.Error(w, "at least one observation is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Limit batch size to prevent overwhelming the system
|
||||
maxBatchSize := 100
|
||||
if len(req.Observations) > maxBatchSize {
|
||||
http.Error(w, fmt.Sprintf("batch size exceeds maximum of %d", maxBatchSize), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a synthetic session for bulk import
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), fmt.Sprintf("bulk-import-%d", time.Now().UnixMilli()), req.Project, "bulk import")
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create import session: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var imported, failed, skippedDupes int
|
||||
var errors []string
|
||||
|
||||
// Track imported observations for deduplication within the batch
|
||||
importedObs := make([]*models.Observation, 0, len(req.Observations))
|
||||
|
||||
// Deduplication threshold - observations more similar than this are considered duplicates
|
||||
const dedupThreshold = 0.7
|
||||
|
||||
for i, obsInput := range req.Observations {
|
||||
// Validate observation type using O(1) map lookup
|
||||
if !IsValidObservationType(obsInput.Type) {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("observation %d: invalid type '%s'", i, obsInput.Type))
|
||||
continue
|
||||
}
|
||||
|
||||
// Build parsed observation
|
||||
parsedObs := &models.ParsedObservation{
|
||||
Type: models.ObservationType(obsInput.Type),
|
||||
Title: obsInput.Title,
|
||||
Subtitle: obsInput.Subtitle,
|
||||
Facts: obsInput.Facts,
|
||||
Narrative: obsInput.Narrative,
|
||||
Concepts: obsInput.Concepts,
|
||||
FilesRead: obsInput.FilesRead,
|
||||
FilesModified: obsInput.FilesModified,
|
||||
Scope: models.ObservationScope(obsInput.Scope),
|
||||
}
|
||||
|
||||
// Convert to temporary observation for similarity check
|
||||
tempObs := &models.Observation{
|
||||
Title: sql.NullString{String: parsedObs.Title, Valid: parsedObs.Title != ""},
|
||||
Subtitle: sql.NullString{String: parsedObs.Subtitle, Valid: parsedObs.Subtitle != ""},
|
||||
Narrative: sql.NullString{String: parsedObs.Narrative, Valid: parsedObs.Narrative != ""},
|
||||
}
|
||||
|
||||
// Check for duplicates within this import batch
|
||||
if similarity.IsSimilarToAny(tempObs, importedObs, dedupThreshold) {
|
||||
skippedDupes++
|
||||
continue
|
||||
}
|
||||
|
||||
// Store observation
|
||||
obsID, _, err := s.observationStore.StoreObservation(
|
||||
r.Context(),
|
||||
fmt.Sprintf("bulk-import-%d", sessionID),
|
||||
req.Project,
|
||||
parsedObs,
|
||||
0, // prompt number
|
||||
0, // discovery tokens
|
||||
)
|
||||
if err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("observation %d: %v", i, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Sync to vector DB asynchronously with rate limiting
|
||||
if s.vectorSync != nil {
|
||||
s.asyncVectorSync(func() {
|
||||
// Use service context as parent to respect shutdown signals
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
obs, err := s.observationStore.GetObservationByID(ctx, obsID)
|
||||
if err == nil && obs != nil {
|
||||
if syncErr := s.vectorSync.SyncObservation(ctx, obs); syncErr != nil {
|
||||
if s.ctx.Err() == nil { // Don't log during shutdown
|
||||
log.Debug().Err(syncErr).Int64("id", obsID).Msg("Failed to sync observation during bulk import")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Track for deduplication of subsequent observations in this batch
|
||||
importedObs = append(importedObs, tempObs)
|
||||
imported++
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", req.Project).
|
||||
Int("imported", imported).
|
||||
Int("failed", failed).
|
||||
Int("skipped_duplicates", skippedDupes).
|
||||
Msg("Bulk import completed")
|
||||
|
||||
// Invalidate observation count cache after import
|
||||
if imported > 0 {
|
||||
if req.Project != "" {
|
||||
s.invalidateObsCountCache(req.Project)
|
||||
} else {
|
||||
s.invalidateAllObsCountCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast observation event for dashboard refresh
|
||||
s.sseBroadcaster.Broadcast(map[string]any{
|
||||
"type": "observation",
|
||||
"action": "bulk_import",
|
||||
"project": req.Project,
|
||||
"count": imported,
|
||||
})
|
||||
|
||||
writeJSON(w, BulkImportResponse{
|
||||
Imported: imported,
|
||||
Failed: failed,
|
||||
SkippedDuplicates: skippedDupes,
|
||||
Errors: errors,
|
||||
})
|
||||
}
|
||||
|
||||
// ArchiveRequest is the request body for archiving observations.
|
||||
type ArchiveRequest struct {
|
||||
IDs []int64 `json:"ids,omitempty"` // Specific IDs to archive
|
||||
Project string `json:"project,omitempty"` // Archive all in project older than max_age_days
|
||||
MaxAgeDays int `json:"max_age_days,omitempty"` // Only used with project
|
||||
Reason string `json:"reason,omitempty"` // Optional reason for archival
|
||||
}
|
||||
|
||||
// handleArchiveObservations archives observations by ID or by age.
|
||||
// Supports batch archival with error tracking per observation.
|
||||
func (s *Service) handleArchiveObservations(w http.ResponseWriter, r *http.Request) {
|
||||
var req ArchiveRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var archivedIDs []int64
|
||||
var failedIDs []int64
|
||||
var errors []string
|
||||
var err error
|
||||
|
||||
if len(req.IDs) > 0 {
|
||||
// Archive specific observations with parallel processing for large batches
|
||||
if len(req.IDs) > 5 {
|
||||
// Use parallel archival for batches larger than 5
|
||||
type archiveResult struct {
|
||||
id int64
|
||||
err error
|
||||
}
|
||||
results := make(chan archiveResult, len(req.IDs))
|
||||
|
||||
// Limit concurrency to avoid overwhelming the database
|
||||
sem := make(chan struct{}, 5)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, id := range req.IDs {
|
||||
wg.Add(1)
|
||||
go func(obsID int64) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // Acquire
|
||||
defer func() { <-sem }() // Release
|
||||
|
||||
archErr := s.observationStore.ArchiveObservation(r.Context(), obsID, req.Reason)
|
||||
results <- archiveResult{id: obsID, err: archErr}
|
||||
}(id)
|
||||
}
|
||||
|
||||
// Close results channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
for res := range results {
|
||||
if res.err != nil {
|
||||
log.Warn().Err(res.err).Int64("id", res.id).Msg("Failed to archive observation")
|
||||
failedIDs = append(failedIDs, res.id)
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", res.id, res.err))
|
||||
} else {
|
||||
archivedIDs = append(archivedIDs, res.id)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sequential for small batches
|
||||
for _, id := range req.IDs {
|
||||
if archErr := s.observationStore.ArchiveObservation(r.Context(), id, req.Reason); archErr != nil {
|
||||
log.Warn().Err(archErr).Int64("id", id).Msg("Failed to archive observation")
|
||||
failedIDs = append(failedIDs, id)
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, archErr))
|
||||
} else {
|
||||
archivedIDs = append(archivedIDs, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if req.Project != "" || req.MaxAgeDays > 0 {
|
||||
// Archive by age
|
||||
archivedIDs, err = s.observationStore.ArchiveOldObservations(r.Context(), req.Project, req.MaxAgeDays, req.Reason)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to archive: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "either 'ids' or 'project'/'max_age_days' is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", req.Project).
|
||||
Int("archived", len(archivedIDs)).
|
||||
Int("failed", len(failedIDs)).
|
||||
Msg("Observations archived")
|
||||
|
||||
// Invalidate cache if any observations were archived
|
||||
if len(archivedIDs) > 0 {
|
||||
if req.Project != "" {
|
||||
s.invalidateObsCountCache(req.Project)
|
||||
} else {
|
||||
s.invalidateAllObsCountCache()
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"archived_count": len(archivedIDs),
|
||||
"archived_ids": archivedIDs,
|
||||
}
|
||||
if len(failedIDs) > 0 {
|
||||
response["failed_count"] = len(failedIDs)
|
||||
response["failed_ids"] = failedIDs
|
||||
response["errors"] = errors
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleUnarchiveObservation restores an archived observation.
|
||||
func (s *Service) handleUnarchiveObservation(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.observationStore.UnarchiveObservation(r.Context(), id); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate all caches since we don't know the project
|
||||
s.invalidateAllObsCountCache()
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetArchivedObservations returns archived observations.
|
||||
func (s *Service) handleGetArchivedObservations(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
limit := gorm.ParseLimitParam(r, DefaultObservationsLimit)
|
||||
|
||||
observations, err := s.observationStore.GetArchivedObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
|
||||
writeJSON(w, observations)
|
||||
}
|
||||
|
||||
// handleGetArchivalStats returns archival statistics.
|
||||
func (s *Service) handleGetArchivalStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
stats, err := s.observationStore.GetArchivalStats(r.Context(), project)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, stats)
|
||||
}
|
||||
|
||||
// handleExportObservations exports observations in JSON or CSV format.
|
||||
// Supports query parameters: project, format (json/csv), scope, type, limit.
|
||||
func (s *Service) handleExportObservations(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
format := r.URL.Query().Get("format")
|
||||
if format == "" {
|
||||
format = "json"
|
||||
}
|
||||
scope := r.URL.Query().Get("scope") // project, global, or empty for all
|
||||
obsType := r.URL.Query().Get("type") // bugfix, feature, etc.
|
||||
limit := gorm.ParseLimitParamWithMax(r, 1000, 5000) // Higher limit for exports, capped at 5000
|
||||
|
||||
// Validate format
|
||||
if format != "json" && format != "csv" {
|
||||
http.Error(w, "format must be 'json' or 'csv'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get observations with filters
|
||||
ctx := r.Context()
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
|
||||
if project != "" {
|
||||
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
|
||||
} else {
|
||||
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply additional filters
|
||||
if scope != "" || obsType != "" {
|
||||
filtered := make([]*models.Observation, 0, len(observations))
|
||||
for _, obs := range observations {
|
||||
if scope != "" && string(obs.Scope) != scope {
|
||||
continue
|
||||
}
|
||||
if obsType != "" && string(obs.Type) != obsType {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, obs)
|
||||
}
|
||||
observations = filtered
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("observations-%s.%s", timestamp, format)
|
||||
if project != "" {
|
||||
// Sanitize project name for filename
|
||||
sanitized := strings.ReplaceAll(project, "/", "_")
|
||||
sanitized = strings.ReplaceAll(sanitized, "\\", "_")
|
||||
if len(sanitized) > 50 {
|
||||
sanitized = sanitized[:50]
|
||||
}
|
||||
filename = fmt.Sprintf("observations-%s-%s.%s", sanitized, timestamp, format)
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv":
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
|
||||
s.writeObservationsCSV(w, observations)
|
||||
default: // json
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
|
||||
writeJSON(w, map[string]any{
|
||||
"exported_at": time.Now().Format(time.RFC3339),
|
||||
"project": project,
|
||||
"count": len(observations),
|
||||
"observations": observations,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// writeObservationsCSV writes observations in CSV format.
|
||||
// Uses fmt.Fprintf directly to avoid intermediate string allocations.
|
||||
func (s *Service) writeObservationsCSV(w http.ResponseWriter, observations []*models.Observation) {
|
||||
// Write CSV header
|
||||
_, _ = io.WriteString(w, "id,type,scope,project,title,subtitle,narrative,concepts,facts,created_at,importance_score\n")
|
||||
|
||||
for _, obs := range observations {
|
||||
// Write directly to avoid string allocation per row
|
||||
_, _ = fmt.Fprintf(w, "%d,%s,%s,%s,%s,%s,%s,%s,%s,%s,%.2f\n",
|
||||
obs.ID,
|
||||
obs.Type,
|
||||
obs.Scope,
|
||||
escapeCsvField(obs.Project),
|
||||
escapeCsvField(obs.Title.String),
|
||||
escapeCsvField(obs.Subtitle.String),
|
||||
escapeCsvField(obs.Narrative.String),
|
||||
escapeCsvField(strings.Join(obs.Concepts, ";")),
|
||||
escapeCsvField(strings.Join(obs.Facts, ";")),
|
||||
obs.CreatedAt,
|
||||
obs.ImportanceScore,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeCsvField escapes a field for CSV output.
|
||||
func escapeCsvField(s string) string {
|
||||
// If field contains comma, quote, or newline, wrap in quotes and escape quotes
|
||||
if strings.ContainsAny(s, ",\"\n\r") {
|
||||
s = strings.ReplaceAll(s, "\"", "\"\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// BulkStatusRequest represents a request to update status for multiple observations.
|
||||
type BulkStatusRequest struct {
|
||||
IDs []int64 `json:"ids"`
|
||||
Action string `json:"action"` // "supersede", "archive", "set_feedback"
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Feedback int `json:"feedback,omitempty"` // -1, 0, 1 for set_feedback action
|
||||
}
|
||||
|
||||
// handleBulkStatusUpdate updates status for multiple observations in one request.
|
||||
func (s *Service) handleBulkStatusUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
var req BulkStatusRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.IDs) == 0 {
|
||||
http.Error(w, "ids is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.IDs) > 500 {
|
||||
http.Error(w, "maximum 500 ids per request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var updated, failed int
|
||||
var errors []string
|
||||
|
||||
switch req.Action {
|
||||
case "supersede":
|
||||
for _, id := range req.IDs {
|
||||
if err := s.observationStore.MarkAsSuperseded(ctx, id); err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
|
||||
} else {
|
||||
updated++
|
||||
}
|
||||
}
|
||||
|
||||
case "archive":
|
||||
for _, id := range req.IDs {
|
||||
if err := s.observationStore.ArchiveObservation(ctx, id, req.Reason); err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
|
||||
} else {
|
||||
updated++
|
||||
}
|
||||
}
|
||||
|
||||
case "set_feedback":
|
||||
if req.Feedback < -1 || req.Feedback > 1 {
|
||||
http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
for _, id := range req.IDs {
|
||||
if err := s.observationStore.UpdateObservationFeedback(ctx, id, req.Feedback); err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
|
||||
} else {
|
||||
updated++
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
http.Error(w, "action must be 'supersede', 'archive', or 'set_feedback'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate cache for archive action (affects observation counts)
|
||||
if req.Action == "archive" && updated > 0 {
|
||||
// No project info available, invalidate all caches
|
||||
s.invalidateAllObsCountCache()
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"action": req.Action,
|
||||
"updated": updated,
|
||||
"failed": failed,
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
response["errors"] = errors
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleFindDuplicates finds potential duplicate observations using similarity clustering.
|
||||
// Returns groups of similar observations that may be candidates for merging or archival.
|
||||
func (s *Service) handleFindDuplicates(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
thresholdStr := r.URL.Query().Get("threshold")
|
||||
limit := gorm.ParseLimitParam(r, 100)
|
||||
|
||||
// Parse threshold (default 0.6 = 60% similarity)
|
||||
threshold := 0.6
|
||||
if thresholdStr != "" {
|
||||
if t, err := strconv.ParseFloat(thresholdStr, 64); err == nil && t > 0 && t < 1 {
|
||||
threshold = t
|
||||
}
|
||||
}
|
||||
|
||||
// Get recent observations
|
||||
ctx := r.Context()
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
|
||||
if project != "" {
|
||||
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
|
||||
} else {
|
||||
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if len(observations) < 2 {
|
||||
writeJSON(w, map[string]any{
|
||||
"duplicate_groups": []any{},
|
||||
"total_checked": len(observations),
|
||||
"threshold": threshold,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Find duplicates using similarity comparison
|
||||
type duplicateGroup struct {
|
||||
Observations []map[string]any `json:"observations"`
|
||||
Similarity float64 `json:"similarity"`
|
||||
}
|
||||
|
||||
groups := []duplicateGroup{}
|
||||
processed := make(map[int64]bool)
|
||||
|
||||
for i, obs1 := range observations {
|
||||
if processed[obs1.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
terms1 := similarity.ExtractObservationTerms(obs1)
|
||||
if len(terms1) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
group := duplicateGroup{
|
||||
Observations: []map[string]any{obs1.ToMap()},
|
||||
Similarity: 1.0,
|
||||
}
|
||||
|
||||
for j := i + 1; j < len(observations); j++ {
|
||||
obs2 := observations[j]
|
||||
if processed[obs2.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
terms2 := similarity.ExtractObservationTerms(obs2)
|
||||
sim := similarity.JaccardSimilarity(terms1, terms2)
|
||||
|
||||
if sim >= threshold {
|
||||
obsMap := obs2.ToMap()
|
||||
obsMap["similarity_to_first"] = sim
|
||||
group.Observations = append(group.Observations, obsMap)
|
||||
group.Similarity = min(group.Similarity, sim)
|
||||
processed[obs2.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(group.Observations) > 1 {
|
||||
processed[obs1.ID] = true
|
||||
groups = append(groups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort groups by size (largest first)
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
return len(groups[i].Observations) > len(groups[j].Observations)
|
||||
})
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"duplicate_groups": groups,
|
||||
"total_checked": len(observations),
|
||||
"groups_found": len(groups),
|
||||
"threshold": threshold,
|
||||
"project": project,
|
||||
})
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FeedbackRequest represents a user feedback submission.
|
||||
@@ -311,8 +312,7 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ
|
||||
// Run recalculation in background
|
||||
go func() {
|
||||
if err := recalculator.RecalculateNow(r.Context()); err != nil {
|
||||
// Log error but don't block response
|
||||
_ = err // Explicitly ignore - background operation
|
||||
log.Warn().Err(err).Msg("Background score recalculation failed")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -345,14 +345,18 @@ func (s *Service) incrementRetrievalCounts(ids []int64) {
|
||||
}
|
||||
|
||||
// Increment in background to not block response
|
||||
// Use service context to respect shutdown signals
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
// Create a new context with timeout for the background operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer s.wg.Done()
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
|
||||
// Log but don't fail - this is a background operation
|
||||
_ = err // Explicitly ignore - background operation
|
||||
if s.ctx.Err() == nil { // Don't log during shutdown
|
||||
log.Debug().Err(err).Msg("Failed to increment retrieval counts")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
// Package worker provides session-related HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// SessionInitRequest is the request body for session initialization.
|
||||
type SessionInitRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
Prompt string `json:"prompt"`
|
||||
MatchedObservations int `json:"matchedObservations"`
|
||||
}
|
||||
|
||||
// SessionInitResponse is the response for session initialization.
|
||||
type SessionInitResponse struct {
|
||||
SessionDBID int64 `json:"sessionDbId"`
|
||||
PromptNumber int `json:"promptNumber"`
|
||||
Skipped bool `json:"skipped,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions.
|
||||
// If the same prompt text is seen within this window, it's considered a duplicate hook invocation.
|
||||
const DuplicatePromptWindowSeconds = 10
|
||||
|
||||
// handleSessionInit handles session initialization from user-prompt hook.
|
||||
// This handler is idempotent - duplicate requests within a short time window
|
||||
// return the existing prompt data without creating duplicates.
|
||||
func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
|
||||
var req SessionInitRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Privacy check
|
||||
if privacy.IsEntirelyPrivate(req.Prompt) {
|
||||
// Create session but skip processing
|
||||
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
|
||||
promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
|
||||
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: promptNum,
|
||||
Skipped: true,
|
||||
Reason: "private",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Clean prompt
|
||||
cleanedPrompt := privacy.Clean(req.Prompt)
|
||||
|
||||
// DUPLICATE DETECTION: Check if this exact prompt was already saved recently.
|
||||
// This prevents the bug where the hook fires multiple times for the same user action,
|
||||
// creating many duplicate prompts with incrementing numbers.
|
||||
if existingID, existingNum, found := s.promptStore.FindRecentPromptByText(r.Context(), req.ClaudeSessionID, cleanedPrompt, DuplicatePromptWindowSeconds); found {
|
||||
// Get or create session (idempotent)
|
||||
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
|
||||
|
||||
log.Debug().
|
||||
Int64("sessionId", sessionID).
|
||||
Int("promptNumber", existingNum).
|
||||
Int64("promptId", existingID).
|
||||
Msg("Duplicate prompt detected - returning existing")
|
||||
|
||||
// Return existing prompt data without incrementing or saving again
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: existingNum,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create session (idempotent)
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment prompt counter
|
||||
promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Save user prompt with matched observation count
|
||||
promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to save user prompt")
|
||||
// Non-fatal: continue with session initialization
|
||||
} else if s.vectorSync != nil {
|
||||
// Sync to vector DB asynchronously (non-blocking)
|
||||
now := time.Now()
|
||||
promptWithSession := &models.UserPromptWithSession{
|
||||
UserPrompt: models.UserPrompt{
|
||||
ID: promptID,
|
||||
ClaudeSessionID: req.ClaudeSessionID,
|
||||
PromptNumber: promptNum,
|
||||
PromptText: cleanedPrompt,
|
||||
MatchedObservations: req.MatchedObservations,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
Project: req.Project,
|
||||
SDKSessionID: req.ClaudeSessionID,
|
||||
}
|
||||
s.asyncVectorSync(func() {
|
||||
// Use service context as parent to respect shutdown signals
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.vectorSync.SyncUserPrompt(ctx, promptWithSession); err != nil {
|
||||
if s.ctx.Err() == nil { // Don't log during shutdown
|
||||
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to sqlite-vec")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionID).
|
||||
Int("promptNumber", promptNum).
|
||||
Str("project", req.Project).
|
||||
Msg("Session initialized")
|
||||
|
||||
// Broadcast prompt event for dashboard refresh
|
||||
s.sseBroadcaster.Broadcast(map[string]any{
|
||||
"type": "prompt",
|
||||
"action": "created",
|
||||
"project": req.Project,
|
||||
})
|
||||
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: promptNum,
|
||||
})
|
||||
}
|
||||
|
||||
// SessionStartRequest is the request body for starting SDK agent.
|
||||
type SessionStartRequest struct {
|
||||
UserPrompt string `json:"userPrompt"`
|
||||
PromptNumber int `json:"promptNumber"`
|
||||
}
|
||||
|
||||
// handleSessionStart handles SDK agent session start.
|
||||
func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req SessionStartRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize session in manager
|
||||
sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if sess == nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Session is now registered. Observations will be processed
|
||||
// asynchronously by the background queue processor (processQueue in service.go).
|
||||
log.Info().
|
||||
Int64("sessionId", id).
|
||||
Int("promptNumber", req.PromptNumber).
|
||||
Msg("SDK agent session initialized")
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// ObservationRequest is the request body for posting observations.
|
||||
type ObservationRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolInput any `json:"tool_input"`
|
||||
ToolResponse any `json:"tool_response"`
|
||||
CWD string `json:"cwd"`
|
||||
}
|
||||
|
||||
// handleObservation handles observation posting from post-tool-use hook.
|
||||
func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) {
|
||||
var req ObservationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Find session
|
||||
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if sess == nil {
|
||||
// Create session on-the-fly with project from request
|
||||
id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sess, _ = s.sessionStore.GetSessionByID(r.Context(), id)
|
||||
}
|
||||
|
||||
// Queue observation
|
||||
if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{
|
||||
ToolName: req.ToolName,
|
||||
ToolInput: req.ToolInput,
|
||||
ToolResponse: req.ToolResponse,
|
||||
CWD: req.CWD,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// SubagentCompleteRequest is the request body for subagent completion.
|
||||
type SubagentCompleteRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
}
|
||||
|
||||
// handleSubagentComplete handles subagent/Task completion notifications.
|
||||
// This triggers immediate processing of any queued observations from the subagent.
|
||||
func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) {
|
||||
var req SubagentCompleteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Find session
|
||||
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
|
||||
if err != nil || sess == nil {
|
||||
// Session not found - subagent may have been in a different context
|
||||
log.Debug().
|
||||
Str("claudeSessionId", req.ClaudeSessionID).
|
||||
Msg("Subagent complete - no active session found")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger immediate processing of queued observations
|
||||
messages := s.sessionManager.DrainMessages(sess.ID)
|
||||
if len(messages) > 0 && s.processor != nil {
|
||||
log.Info().
|
||||
Int64("sessionId", sess.ID).
|
||||
Int("messages", len(messages)).
|
||||
Msg("Processing queued observations from subagent")
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Type == session.MessageTypeObservation && msg.Observation != nil {
|
||||
err := s.processor.ProcessObservation(
|
||||
r.Context(),
|
||||
sess.SDKSessionID.String,
|
||||
sess.Project,
|
||||
msg.Observation.ToolName,
|
||||
msg.Observation.ToolInput,
|
||||
msg.Observation.ToolResponse,
|
||||
msg.Observation.PromptNumber,
|
||||
msg.Observation.CWD,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("tool", msg.Observation.ToolName).
|
||||
Msg("Failed to process subagent observation")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// handleGetSessionByClaudeID looks up a session by Claude session ID.
|
||||
func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) {
|
||||
claudeSessionID := r.URL.Query().Get("claudeSessionId")
|
||||
if claudeSessionID == "" {
|
||||
http.Error(w, "claudeSessionId required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, session)
|
||||
}
|
||||
|
||||
// SummarizeRequest is the request body for summarize requests.
|
||||
type SummarizeRequest struct {
|
||||
LastUserMessage string `json:"lastUserMessage"`
|
||||
LastAssistantMessage string `json:"lastAssistantMessage"`
|
||||
}
|
||||
|
||||
// handleSummarize handles summarize requests from stop hook.
|
||||
func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req SummarizeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Queue summarize request
|
||||
if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
// Package worker provides update and restart HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleUpdateCheck checks for available updates.
|
||||
func (s *Service) handleUpdateCheck(w http.ResponseWriter, r *http.Request) {
|
||||
info, err := s.updater.CheckForUpdate(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
writeJSON(w, info)
|
||||
}
|
||||
|
||||
// handleUpdateApply downloads and applies an available update.
|
||||
func (s *Service) handleUpdateApply(w http.ResponseWriter, r *http.Request) {
|
||||
// First check for update
|
||||
info, err := s.updater.CheckForUpdate(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !info.Available {
|
||||
writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"message": "No update available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply update in background with tracking for graceful shutdown
|
||||
s.wg.Go(func() {
|
||||
if err := s.updater.ApplyUpdate(s.ctx, info); err != nil {
|
||||
log.Error().Err(err).Msg("Update failed")
|
||||
}
|
||||
})
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "Update started",
|
||||
"version": info.LatestVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// handleUpdateStatus returns the current update status.
|
||||
func (s *Service) handleUpdateStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status := s.updater.GetStatus()
|
||||
writeJSON(w, status)
|
||||
}
|
||||
|
||||
// ComponentHealth represents the health status of a single component.
|
||||
type ComponentHealth struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"` // "healthy", "degraded", "unhealthy"
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// SelfCheckResponse contains the health status of all components.
|
||||
type SelfCheckResponse struct {
|
||||
Overall string `json:"overall"` // "healthy", "degraded", "unhealthy"
|
||||
Version string `json:"version"`
|
||||
Uptime string `json:"uptime"`
|
||||
Components []ComponentHealth `json:"components"`
|
||||
}
|
||||
|
||||
// handleSelfCheck returns the health status of all components.
|
||||
func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
|
||||
components := []ComponentHealth{}
|
||||
overall := "healthy"
|
||||
|
||||
// Check Worker Service
|
||||
workerStatus := ComponentHealth{Name: "Worker Service", Status: "healthy"}
|
||||
if !s.ready.Load() {
|
||||
if err := s.GetInitError(); err != nil {
|
||||
workerStatus.Status = "unhealthy"
|
||||
workerStatus.Message = err.Error()
|
||||
overall = "unhealthy"
|
||||
} else {
|
||||
workerStatus.Status = "degraded"
|
||||
workerStatus.Message = "Initializing"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
}
|
||||
}
|
||||
components = append(components, workerStatus)
|
||||
|
||||
// Check SQLite Database
|
||||
dbStatus := ComponentHealth{Name: "SQLite Database", Status: "healthy"}
|
||||
if s.store == nil {
|
||||
dbStatus.Status = "unhealthy"
|
||||
dbStatus.Message = "Not initialized"
|
||||
overall = "unhealthy"
|
||||
} else if err := s.store.Ping(); err != nil {
|
||||
dbStatus.Status = "unhealthy"
|
||||
dbStatus.Message = err.Error()
|
||||
overall = "unhealthy"
|
||||
}
|
||||
components = append(components, dbStatus)
|
||||
|
||||
// Check Vector DB (sqlite-vec)
|
||||
vectorStatus := ComponentHealth{Name: "Vector DB", Status: "healthy"}
|
||||
if s.vectorClient == nil {
|
||||
vectorStatus.Status = "degraded"
|
||||
vectorStatus.Message = "Not configured"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if !s.vectorClient.IsConnected() {
|
||||
vectorStatus.Status = "degraded"
|
||||
vectorStatus.Message = "Not connected"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
}
|
||||
components = append(components, vectorStatus)
|
||||
|
||||
// Check SDK Processor
|
||||
sdkStatus := ComponentHealth{Name: "SDK Processor", Status: "healthy"}
|
||||
if s.processor == nil {
|
||||
sdkStatus.Status = "degraded"
|
||||
sdkStatus.Message = "Not initialized"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if !s.processor.IsAvailable() {
|
||||
sdkStatus.Status = "degraded"
|
||||
sdkStatus.Message = "Claude CLI not available"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
}
|
||||
components = append(components, sdkStatus)
|
||||
|
||||
// Check SSE Broadcaster
|
||||
sseStatus := ComponentHealth{Name: "SSE Broadcaster", Status: "healthy"}
|
||||
if s.sseBroadcaster == nil {
|
||||
sseStatus.Status = "unhealthy"
|
||||
sseStatus.Message = "Not initialized"
|
||||
overall = "unhealthy"
|
||||
}
|
||||
components = append(components, sseStatus)
|
||||
|
||||
// Check Cross-Encoder Reranker
|
||||
rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"}
|
||||
if !s.config.RerankingEnabled {
|
||||
rerankerStatus.Status = "degraded"
|
||||
rerankerStatus.Message = "Disabled in config"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if s.reranker == nil {
|
||||
rerankerStatus.Status = "degraded"
|
||||
rerankerStatus.Message = "Not initialized"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else {
|
||||
// Verify reranker is functional using Score
|
||||
_, normalizedScore, err := s.reranker.Score("test query", "test document")
|
||||
if err != nil {
|
||||
rerankerStatus.Status = "unhealthy"
|
||||
rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err)
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else {
|
||||
rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore)
|
||||
}
|
||||
}
|
||||
components = append(components, rerankerStatus)
|
||||
|
||||
// Calculate uptime
|
||||
uptime := time.Since(s.startTime).Round(time.Second).String()
|
||||
|
||||
writeJSON(w, SelfCheckResponse{
|
||||
Overall: overall,
|
||||
Version: s.version,
|
||||
Uptime: uptime,
|
||||
Components: components,
|
||||
})
|
||||
}
|
||||
|
||||
// handleUpdateRestart restarts the worker with the new binary (after update).
|
||||
func (s *Service) handleUpdateRestart(w http.ResponseWriter, r *http.Request) {
|
||||
status := s.updater.GetStatus()
|
||||
if status.State != "done" {
|
||||
http.Error(w, "no update has been applied", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send response before restarting
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "Restarting worker...",
|
||||
})
|
||||
|
||||
// Flush the response
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// Restart in background after response is sent
|
||||
go func() {
|
||||
if err := s.updater.Restart(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to restart worker")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleRestart restarts the worker process (general restart, not tied to update).
|
||||
func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) {
|
||||
log.Info().Msg("Manual restart requested via API")
|
||||
|
||||
// Send response before restarting
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "Restarting worker...",
|
||||
"version": s.version,
|
||||
})
|
||||
|
||||
// Flush the response
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// Restart in background after response is sent
|
||||
go func() {
|
||||
// Small delay to ensure response is sent
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := s.updater.Restart(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to restart worker")
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,335 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// requestIDKey is the context key for request IDs.
|
||||
type requestIDKey struct{}
|
||||
|
||||
// projectNamePattern validates project names to prevent path traversal.
|
||||
var projectNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_./-]+$`)
|
||||
|
||||
// allowedOrigins is the whitelist of origins allowed for CORS.
|
||||
// Uses exact matching to prevent bypass attacks like "evil-localhost.com".
|
||||
var allowedOrigins = map[string]bool{
|
||||
"http://localhost": true,
|
||||
"http://localhost:3000": true,
|
||||
"http://localhost:5173": true, // Vite dev server
|
||||
"http://localhost:37778": true, // Dashboard UI
|
||||
"http://127.0.0.1": true,
|
||||
"http://127.0.0.1:3000": true,
|
||||
"http://127.0.0.1:5173": true,
|
||||
"http://127.0.0.1:37778": true,
|
||||
}
|
||||
|
||||
// SecurityHeaders middleware adds essential security headers to all responses.
|
||||
// These protect against common web vulnerabilities.
|
||||
func SecurityHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent clickjacking
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
|
||||
// Prevent MIME type sniffing
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// Enable XSS filter
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Restrict referrer information
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Content Security Policy - restrict to self
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
|
||||
// Permissions Policy - disable unnecessary features
|
||||
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
|
||||
|
||||
// CORS: Use exact match whitelist to prevent bypass attacks
|
||||
origin := r.Header.Get("Origin")
|
||||
if allowedOrigins[origin] {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Auth-Token, Authorization, X-Request-ID")
|
||||
}
|
||||
|
||||
// Handle preflight requests
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// MaxBodySize middleware limits the size of incoming request bodies.
|
||||
// This prevents denial of service attacks via large payloads.
|
||||
func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength > maxBytes {
|
||||
http.Error(w, "request body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TokenAuth provides simple token-based authentication for localhost services.
|
||||
// The token is generated at startup and must be provided in the X-Auth-Token header.
|
||||
type TokenAuth struct {
|
||||
token string
|
||||
enabled bool
|
||||
mu sync.RWMutex
|
||||
|
||||
// ExemptPaths are paths that don't require authentication (e.g., health checks)
|
||||
ExemptPaths map[string]bool
|
||||
}
|
||||
|
||||
// NewTokenAuth creates a new TokenAuth with a randomly generated token.
|
||||
// If enabled is false, authentication is skipped (useful for development).
|
||||
func NewTokenAuth(enabled bool) (*TokenAuth, error) {
|
||||
ta := &TokenAuth{
|
||||
enabled: enabled,
|
||||
ExemptPaths: map[string]bool{
|
||||
"/health": true,
|
||||
"/api/health": true,
|
||||
"/api/ready": true,
|
||||
},
|
||||
}
|
||||
|
||||
if enabled {
|
||||
// Generate 32-byte random token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ta.token = hex.EncodeToString(tokenBytes)
|
||||
}
|
||||
|
||||
return ta, nil
|
||||
}
|
||||
|
||||
// Token returns the authentication token.
|
||||
// Returns empty string if authentication is disabled.
|
||||
func (ta *TokenAuth) Token() string {
|
||||
ta.mu.RLock()
|
||||
defer ta.mu.RUnlock()
|
||||
return ta.token
|
||||
}
|
||||
|
||||
// IsEnabled returns whether token authentication is enabled.
|
||||
func (ta *TokenAuth) IsEnabled() bool {
|
||||
ta.mu.RLock()
|
||||
defer ta.mu.RUnlock()
|
||||
return ta.enabled
|
||||
}
|
||||
|
||||
// Middleware returns HTTP middleware that enforces token authentication.
|
||||
func (ta *TokenAuth) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ta.mu.RLock()
|
||||
enabled := ta.enabled
|
||||
token := ta.token
|
||||
exempt := ta.ExemptPaths[r.URL.Path]
|
||||
ta.mu.RUnlock()
|
||||
|
||||
// Skip auth if disabled or path is exempt
|
||||
if !enabled || exempt {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for token in header
|
||||
providedToken := r.Header.Get("X-Auth-Token")
|
||||
if providedToken == "" {
|
||||
// Also check Authorization header with Bearer scheme
|
||||
auth := r.Header.Get("Authorization")
|
||||
if bearer, found := strings.CutPrefix(auth, "Bearer "); found {
|
||||
providedToken = bearer
|
||||
}
|
||||
}
|
||||
|
||||
if providedToken != token {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ExpensiveOperationLimiter provides stricter rate limiting for expensive operations.
|
||||
// It wraps the base per-client rate limiter with additional per-operation limits.
|
||||
type ExpensiveOperationLimiter struct {
|
||||
// Track last execution time per operation type
|
||||
lastRebuild int64 // Unix timestamp
|
||||
rebuildCooldown int64 // Minimum seconds between rebuilds
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewExpensiveOperationLimiter creates a limiter for expensive operations.
|
||||
func NewExpensiveOperationLimiter() *ExpensiveOperationLimiter {
|
||||
return &ExpensiveOperationLimiter{
|
||||
rebuildCooldown: 300, // 5 minutes between rebuilds
|
||||
}
|
||||
}
|
||||
|
||||
// CanRebuild checks if a vector rebuild operation is allowed.
|
||||
// Returns false if a rebuild was triggered too recently.
|
||||
func (eol *ExpensiveOperationLimiter) CanRebuild() bool {
|
||||
eol.mu.Lock()
|
||||
defer eol.mu.Unlock()
|
||||
|
||||
now := unixNow()
|
||||
if now-eol.lastRebuild < eol.rebuildCooldown {
|
||||
return false
|
||||
}
|
||||
eol.lastRebuild = now
|
||||
return true
|
||||
}
|
||||
|
||||
// unixNow returns current Unix timestamp.
|
||||
// Separated for easier testing.
|
||||
func unixNow() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// RequestID middleware adds a unique request ID to each request.
|
||||
// The ID is added to the context and response headers for tracing.
|
||||
func RequestID(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check for existing request ID from client
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
// Generate new request ID
|
||||
idBytes := make([]byte, 8)
|
||||
if _, err := rand.Read(idBytes); err == nil {
|
||||
requestID = hex.EncodeToString(idBytes)
|
||||
} else {
|
||||
requestID = fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
}
|
||||
|
||||
// Add to response header
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
// Add to context
|
||||
ctx := context.WithValue(r.Context(), requestIDKey{}, requestID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the context.
|
||||
func GetRequestID(ctx context.Context) string {
|
||||
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// RequireJSONContentType middleware validates that POST/PUT/PATCH requests
|
||||
// have application/json Content-Type header.
|
||||
func RequireJSONContentType(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only check for methods that typically have bodies
|
||||
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
|
||||
ct := r.Header.Get("Content-Type")
|
||||
// Allow empty Content-Type for requests without body
|
||||
if ct != "" && !strings.HasPrefix(ct, "application/json") {
|
||||
http.Error(w, "Content-Type must be application/json", http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateProjectName checks if a project name is safe to use.
|
||||
// Returns an error if the name contains path traversal or invalid characters.
|
||||
func ValidateProjectName(project string) error {
|
||||
if project == "" {
|
||||
return nil // Empty is allowed (means no filter)
|
||||
}
|
||||
|
||||
// Check for path traversal
|
||||
if strings.Contains(project, "..") {
|
||||
return fmt.Errorf("invalid project name: path traversal detected")
|
||||
}
|
||||
|
||||
// Check for valid characters
|
||||
if !projectNamePattern.MatchString(project) {
|
||||
return fmt.Errorf("invalid project name: only alphanumeric, underscore, dash, dot, and slash allowed")
|
||||
}
|
||||
|
||||
// Max length check
|
||||
if len(project) > 500 {
|
||||
return fmt.Errorf("project name too long (max 500 chars)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BulkOperationLimiter provides rate limiting for bulk operations.
|
||||
// Prevents DoS via repeated bulk requests.
|
||||
type BulkOperationLimiter struct {
|
||||
lastBulkOp int64 // Unix timestamp
|
||||
cooldown int64 // Minimum seconds between operations
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewBulkOperationLimiter creates a limiter for bulk operations.
|
||||
func NewBulkOperationLimiter(cooldownSeconds int64) *BulkOperationLimiter {
|
||||
return &BulkOperationLimiter{
|
||||
cooldown: cooldownSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
// CanExecute checks if a bulk operation is allowed.
|
||||
// Returns false if a bulk operation was triggered too recently.
|
||||
func (bol *BulkOperationLimiter) CanExecute() bool {
|
||||
bol.mu.Lock()
|
||||
defer bol.mu.Unlock()
|
||||
|
||||
now := unixNow()
|
||||
if now-bol.lastBulkOp < bol.cooldown {
|
||||
return false
|
||||
}
|
||||
bol.lastBulkOp = now
|
||||
return true
|
||||
}
|
||||
|
||||
// TimeSinceLastOp returns seconds since the last bulk operation.
|
||||
func (bol *BulkOperationLimiter) TimeSinceLastOp() int64 {
|
||||
bol.mu.Lock()
|
||||
defer bol.mu.Unlock()
|
||||
return unixNow() - bol.lastBulkOp
|
||||
}
|
||||
|
||||
// CooldownRemaining returns seconds remaining in the cooldown period.
|
||||
// Returns 0 if no cooldown is active.
|
||||
func (bol *BulkOperationLimiter) CooldownRemaining() int64 {
|
||||
bol.mu.Lock()
|
||||
defer bol.mu.Unlock()
|
||||
|
||||
remaining := bol.cooldown - (unixNow() - bol.lastBulkOp)
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
@@ -0,0 +1,515 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Check all security headers are set
|
||||
tests := []struct {
|
||||
header string
|
||||
expected string
|
||||
}{
|
||||
{"X-Frame-Options", "DENY"},
|
||||
{"X-Content-Type-Options", "nosniff"},
|
||||
{"X-XSS-Protection", "1; mode=block"},
|
||||
{"Referrer-Policy", "strict-origin-when-cross-origin"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := rr.Header().Get(tt.header); got != tt.expected {
|
||||
t.Errorf("SecurityHeaders() %s = %q, want %q", tt.header, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CORS(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
expectCORS bool
|
||||
expectedOrigin string
|
||||
}{
|
||||
{
|
||||
name: "localhost:37778 origin allowed",
|
||||
origin: "http://localhost:37778",
|
||||
expectCORS: true,
|
||||
expectedOrigin: "http://localhost:37778",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1:5173 origin allowed",
|
||||
origin: "http://127.0.0.1:5173",
|
||||
expectCORS: true,
|
||||
expectedOrigin: "http://127.0.0.1:5173",
|
||||
},
|
||||
{
|
||||
name: "localhost without port allowed",
|
||||
origin: "http://localhost",
|
||||
expectCORS: true,
|
||||
expectedOrigin: "http://localhost",
|
||||
},
|
||||
{
|
||||
name: "external origin blocked",
|
||||
origin: "http://evil.com",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "evil-localhost.com bypass attempt blocked",
|
||||
origin: "http://evil-localhost.com",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "localhost subdomain bypass attempt blocked",
|
||||
origin: "http://localhost.evil.com",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "unknown localhost port blocked",
|
||||
origin: "http://localhost:9999",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "no origin header",
|
||||
origin: "",
|
||||
expectCORS: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tt.origin != "" {
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
cors := rr.Header().Get("Access-Control-Allow-Origin")
|
||||
if tt.expectCORS {
|
||||
if cors != tt.expectedOrigin {
|
||||
t.Errorf("Expected CORS origin %q, got %q", tt.expectedOrigin, cors)
|
||||
}
|
||||
} else {
|
||||
if cors != "" {
|
||||
t.Errorf("Expected no CORS header, got %q", cors)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBodySize(t *testing.T) {
|
||||
maxSize := int64(100)
|
||||
handler := MaxBodySize(maxSize)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contentLength int64
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "within limit",
|
||||
contentLength: 50,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "at limit",
|
||||
contentLength: 100,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "exceeds limit",
|
||||
contentLength: 150,
|
||||
expectedStatus: http.StatusRequestEntityTooLarge,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
req.ContentLength = tt.contentLength
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("MaxBodySize() status = %d, want %d", rr.Code, tt.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenAuth(t *testing.T) {
|
||||
t.Run("disabled auth allows all requests", func(t *testing.T) {
|
||||
ta, err := NewTokenAuth(false)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenAuth() error = %v", err)
|
||||
}
|
||||
|
||||
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK with disabled auth, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("enabled auth requires token", func(t *testing.T) {
|
||||
ta, err := NewTokenAuth(true)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenAuth() error = %v", err)
|
||||
}
|
||||
|
||||
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Without token
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected Unauthorized without token, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// With correct token in X-Auth-Token header
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Auth-Token", ta.Token())
|
||||
rr = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK with correct token, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// With correct token in Authorization header
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+ta.Token())
|
||||
rr = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK with Bearer token, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exempt paths skip auth", func(t *testing.T) {
|
||||
ta, err := NewTokenAuth(true)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenAuth() error = %v", err)
|
||||
}
|
||||
|
||||
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
exemptPaths := []string{"/health", "/api/health", "/api/ready"}
|
||||
for _, path := range exemptPaths {
|
||||
req := httptest.NewRequest("GET", path, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK for exempt path %s, got %d", path, rr.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpensiveOperationLimiter(t *testing.T) {
|
||||
limiter := NewExpensiveOperationLimiter()
|
||||
|
||||
// First rebuild should be allowed
|
||||
if !limiter.CanRebuild() {
|
||||
t.Error("First rebuild should be allowed")
|
||||
}
|
||||
|
||||
// Immediate second rebuild should be blocked
|
||||
if limiter.CanRebuild() {
|
||||
t.Error("Immediate second rebuild should be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
handler := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request ID is in context
|
||||
id := GetRequestID(r.Context())
|
||||
if id == "" {
|
||||
t.Error("Request ID should be set in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
t.Run("generates new request ID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-ID") == "" {
|
||||
t.Error("X-Request-ID header should be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses existing request ID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Request-ID", "test-id-12345")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-ID") != "test-id-12345" {
|
||||
t.Errorf("Expected X-Request-ID to be test-id-12345, got %s", rr.Header().Get("X-Request-ID"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequireJSONContentType(t *testing.T) {
|
||||
handler := RequireJSONContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
contentType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "GET request without content-type",
|
||||
method: "GET",
|
||||
contentType: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST with application/json",
|
||||
method: "POST",
|
||||
contentType: "application/json",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST with application/json; charset=utf-8",
|
||||
method: "POST",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST without content-type (empty body)",
|
||||
method: "POST",
|
||||
contentType: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST with text/plain rejected",
|
||||
method: "POST",
|
||||
contentType: "text/plain",
|
||||
expectedStatus: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
{
|
||||
name: "PUT with application/xml rejected",
|
||||
method: "PUT",
|
||||
contentType: "application/xml",
|
||||
expectedStatus: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
{
|
||||
name: "PATCH with form-urlencoded rejected",
|
||||
method: "PATCH",
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
expectedStatus: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||
if tt.contentType != "" {
|
||||
req.Header.Set("Content-Type", tt.contentType)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProjectName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
project string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "empty project allowed",
|
||||
project: "",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "simple project name",
|
||||
project: "my-project",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "project with path",
|
||||
project: "org/my-project",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "project with underscore",
|
||||
project: "my_project_v2",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "project with dot",
|
||||
project: "my.project.name",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal attack",
|
||||
project: "../../../etc/passwd",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "hidden path traversal",
|
||||
project: "project/../../secret",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "shell injection attempt",
|
||||
project: "project; rm -rf /",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "backtick injection",
|
||||
project: "project`whoami`",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
project: "project$HOME",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "too long project name",
|
||||
project: string(make([]byte, 501)),
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateProjectName(tt.project)
|
||||
if tt.wantError && err == nil {
|
||||
t.Errorf("Expected error for project %q, got nil", tt.project)
|
||||
}
|
||||
if !tt.wantError && err != nil {
|
||||
t.Errorf("Unexpected error for project %q: %v", tt.project, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkOperationLimiter(t *testing.T) {
|
||||
limiter := NewBulkOperationLimiter(1) // 1 second cooldown for testing
|
||||
|
||||
// First operation should be allowed
|
||||
if !limiter.CanExecute() {
|
||||
t.Error("First bulk operation should be allowed")
|
||||
}
|
||||
|
||||
// Immediate second operation should be blocked
|
||||
if limiter.CanExecute() {
|
||||
t.Error("Immediate second bulk operation should be blocked")
|
||||
}
|
||||
|
||||
// Check cooldown remaining
|
||||
remaining := limiter.CooldownRemaining()
|
||||
if remaining <= 0 || remaining > 1 {
|
||||
t.Errorf("Expected cooldown remaining between 0-1 seconds, got %d", remaining)
|
||||
}
|
||||
|
||||
// Check time since last op
|
||||
since := limiter.TimeSinceLastOp()
|
||||
if since < 0 || since > 1 {
|
||||
t.Errorf("Expected time since last op between 0-1 seconds, got %d", since)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CSP(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Check CSP header is set
|
||||
csp := rr.Header().Get("Content-Security-Policy")
|
||||
if csp == "" {
|
||||
t.Error("Content-Security-Policy header should be set")
|
||||
}
|
||||
if csp != "default-src 'self'" {
|
||||
t.Errorf("Expected CSP to be \"default-src 'self'\", got %q", csp)
|
||||
}
|
||||
|
||||
// Check Permissions-Policy header
|
||||
pp := rr.Header().Get("Permissions-Policy")
|
||||
if pp == "" {
|
||||
t.Error("Permissions-Policy header should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_Preflight(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("OPTIONS", "/test", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// OPTIONS should return 204 No Content
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Errorf("Expected status 204 for OPTIONS, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// CORS headers should be set for allowed origin
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" {
|
||||
t.Errorf("CORS origin should be set for allowed origin")
|
||||
}
|
||||
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||
t.Error("Access-Control-Allow-Methods should be set")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token bucket rate limiter.
|
||||
type RateLimiter struct {
|
||||
rate float64 // tokens per second
|
||||
burst int // maximum burst size
|
||||
mu sync.Mutex // protects following fields
|
||||
tokens float64 // current tokens
|
||||
lastUpdate time.Time // last token update time
|
||||
requests int64 // total requests
|
||||
rejected int64 // rejected requests
|
||||
}
|
||||
|
||||
// LastUpdateTime returns the last update time.
|
||||
// Thread-safe - acquires the limiter's lock.
|
||||
func (rl *RateLimiter) LastUpdateTime() time.Time {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
return rl.lastUpdate
|
||||
}
|
||||
|
||||
// lastUpdateTimeUnlocked returns the last update time without locking.
|
||||
// Caller must hold rl.mu.
|
||||
func (rl *RateLimiter) lastUpdateTimeUnlocked() time.Time {
|
||||
return rl.lastUpdate
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter.
|
||||
// rate is the number of requests per second to allow.
|
||||
// burst is the maximum burst of requests to allow.
|
||||
func NewRateLimiter(rate float64, burst int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
tokens: float64(burst),
|
||||
lastUpdate: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed.
|
||||
// Returns true if the request is allowed, false if rate limited.
|
||||
func (rl *RateLimiter) Allow() bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.requests++
|
||||
|
||||
// Calculate tokens added since last update
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastUpdate).Seconds()
|
||||
rl.tokens += elapsed * rl.rate
|
||||
if rl.tokens > float64(rl.burst) {
|
||||
rl.tokens = float64(rl.burst)
|
||||
}
|
||||
rl.lastUpdate = now
|
||||
|
||||
// Check if we have a token available
|
||||
if rl.tokens >= 1 {
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
rl.rejected++
|
||||
return false
|
||||
}
|
||||
|
||||
// Stats returns rate limiter statistics.
|
||||
func (rl *RateLimiter) Stats() map[string]any {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
return map[string]any{
|
||||
"rate": rl.rate,
|
||||
"burst": rl.burst,
|
||||
"current_tokens": rl.tokens,
|
||||
"total_requests": rl.requests,
|
||||
"rejected": rl.rejected,
|
||||
"rejection_rate": float64(rl.rejected) / max(float64(rl.requests), 1),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware creates middleware that applies rate limiting.
|
||||
// Uses a shared rate limiter for all requests.
|
||||
func RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PerClientRateLimiter implements per-client rate limiting.
|
||||
type PerClientRateLimiter struct {
|
||||
rate float64
|
||||
burst int
|
||||
clients map[string]*RateLimiter
|
||||
mu sync.Mutex
|
||||
// Cleanup settings
|
||||
cleanupInterval time.Duration
|
||||
maxIdleTime time.Duration
|
||||
lastCleanup time.Time
|
||||
}
|
||||
|
||||
// NewPerClientRateLimiter creates a new per-client rate limiter.
|
||||
func NewPerClientRateLimiter(rate float64, burst int) *PerClientRateLimiter {
|
||||
return &PerClientRateLimiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
clients: make(map[string]*RateLimiter),
|
||||
cleanupInterval: 5 * time.Minute,
|
||||
maxIdleTime: 10 * time.Minute,
|
||||
lastCleanup: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// getLimiter returns a rate limiter for the given client key.
|
||||
func (pcrl *PerClientRateLimiter) getLimiter(key string) *RateLimiter {
|
||||
pcrl.mu.Lock()
|
||||
defer pcrl.mu.Unlock()
|
||||
|
||||
// Periodic cleanup of idle clients
|
||||
if time.Since(pcrl.lastCleanup) > pcrl.cleanupInterval {
|
||||
pcrl.cleanupLocked()
|
||||
}
|
||||
|
||||
limiter, exists := pcrl.clients[key]
|
||||
if !exists {
|
||||
limiter = NewRateLimiter(pcrl.rate, pcrl.burst)
|
||||
pcrl.clients[key] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupLocked removes idle limiters. Must be called with lock held.
|
||||
// Uses consistent lock ordering: always acquire limiter.mu while holding pcrl.mu.
|
||||
// This is safe because the limiter.mu critical section is brief (just reading lastUpdate).
|
||||
func (pcrl *PerClientRateLimiter) cleanupLocked() {
|
||||
now := time.Now()
|
||||
keysToDelete := make([]string, 0)
|
||||
|
||||
// Check each limiter while holding pcrl.mu
|
||||
// We briefly acquire limiter.mu but the critical section is minimal
|
||||
for key, limiter := range pcrl.clients {
|
||||
limiter.mu.Lock()
|
||||
lastUpdate := limiter.lastUpdateTimeUnlocked()
|
||||
limiter.mu.Unlock()
|
||||
|
||||
if now.Sub(lastUpdate) > pcrl.maxIdleTime {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete collected keys
|
||||
for _, key := range keysToDelete {
|
||||
delete(pcrl.clients, key)
|
||||
}
|
||||
pcrl.lastCleanup = now
|
||||
}
|
||||
|
||||
// Allow checks if a request from the given client should be allowed.
|
||||
func (pcrl *PerClientRateLimiter) Allow(clientKey string) bool {
|
||||
return pcrl.getLimiter(clientKey).Allow()
|
||||
}
|
||||
|
||||
// Stats returns aggregate statistics.
|
||||
// Uses two-phase approach to avoid nested lock acquisition.
|
||||
func (pcrl *PerClientRateLimiter) Stats() map[string]any {
|
||||
// Phase 1: Collect limiters under pcrl.mu
|
||||
pcrl.mu.Lock()
|
||||
rate := pcrl.rate
|
||||
burst := pcrl.burst
|
||||
activeClients := len(pcrl.clients)
|
||||
limiters := make([]*RateLimiter, 0, activeClients)
|
||||
for _, limiter := range pcrl.clients {
|
||||
limiters = append(limiters, limiter)
|
||||
}
|
||||
pcrl.mu.Unlock()
|
||||
|
||||
// Phase 2: Collect stats from each limiter (only acquiring limiter.mu, not pcrl.mu)
|
||||
var totalRequests, totalRejected int64
|
||||
for _, limiter := range limiters {
|
||||
limiter.mu.Lock()
|
||||
totalRequests += limiter.requests
|
||||
totalRejected += limiter.rejected
|
||||
limiter.mu.Unlock()
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"rate": rate,
|
||||
"burst": burst,
|
||||
"active_clients": activeClients,
|
||||
"total_requests": totalRequests,
|
||||
"total_rejected": totalRejected,
|
||||
}
|
||||
}
|
||||
|
||||
// PerClientRateLimitMiddleware creates middleware that applies per-client rate limiting.
|
||||
// Uses X-Forwarded-For or RemoteAddr to identify clients.
|
||||
func PerClientRateLimitMiddleware(limiter *PerClientRateLimiter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get client identifier (prefer X-Real-IP from RealIP middleware)
|
||||
clientKey := r.RemoteAddr
|
||||
if xff := r.Header.Get("X-Real-IP"); xff != "" {
|
||||
clientKey = xff
|
||||
}
|
||||
|
||||
if !limiter.Allow(clientKey) {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,15 @@ package sdk
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
@@ -20,8 +24,178 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// CircuitBreaker implements a simple circuit breaker pattern for CLI calls.
|
||||
type CircuitBreaker struct {
|
||||
failures int64 // Current failure count
|
||||
lastFailure int64 // Unix timestamp of last failure
|
||||
threshold int64 // Number of failures before opening
|
||||
resetTimeout int64 // Seconds to wait before trying again
|
||||
state int32 // 0=closed, 1=open, 2=half-open
|
||||
}
|
||||
|
||||
const (
|
||||
circuitClosed int32 = 0
|
||||
circuitOpen int32 = 1
|
||||
circuitHalfOpen int32 = 2
|
||||
)
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker.
|
||||
func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
threshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed through.
|
||||
func (cb *CircuitBreaker) Allow() bool {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == circuitClosed {
|
||||
return true
|
||||
}
|
||||
|
||||
if state == circuitOpen {
|
||||
// Check if reset timeout has passed
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailure)
|
||||
if time.Now().Unix()-lastFail > cb.resetTimeout {
|
||||
// Transition to half-open
|
||||
atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Half-open: allow one request through
|
||||
return true
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful call.
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
atomic.StoreInt32(&cb.state, circuitClosed)
|
||||
}
|
||||
|
||||
// RecordFailure records a failed call.
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
failures := atomic.AddInt64(&cb.failures, 1)
|
||||
atomic.StoreInt64(&cb.lastFailure, time.Now().Unix())
|
||||
|
||||
if failures >= cb.threshold {
|
||||
atomic.StoreInt32(&cb.state, circuitOpen)
|
||||
log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current state as a string.
|
||||
func (cb *CircuitBreaker) State() string {
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case circuitOpen:
|
||||
return "open"
|
||||
case circuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "closed"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerMetrics contains metrics about the circuit breaker state.
|
||||
type CircuitBreakerMetrics struct {
|
||||
State string `json:"state"`
|
||||
Failures int64 `json:"failures"`
|
||||
Threshold int64 `json:"threshold"`
|
||||
ResetTimeoutSecs int64 `json:"reset_timeout_secs"`
|
||||
LastFailureUnix int64 `json:"last_failure_unix,omitempty"`
|
||||
SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"`
|
||||
}
|
||||
|
||||
// Metrics returns the current metrics of the circuit breaker.
|
||||
func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics {
|
||||
failures := atomic.LoadInt64(&cb.failures)
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailure)
|
||||
state := cb.State()
|
||||
|
||||
metrics := CircuitBreakerMetrics{
|
||||
State: state,
|
||||
Failures: failures,
|
||||
Threshold: cb.threshold,
|
||||
ResetTimeoutSecs: cb.resetTimeout,
|
||||
}
|
||||
|
||||
if lastFail > 0 {
|
||||
metrics.LastFailureUnix = lastFail
|
||||
if state == "open" {
|
||||
remaining := cb.resetTimeout - (time.Now().Unix() - lastFail)
|
||||
if remaining > 0 {
|
||||
metrics.SecondsUntilReset = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// RequestDeduplicator tracks recent requests to prevent duplicates.
|
||||
type RequestDeduplicator struct {
|
||||
seen map[string]int64 // hash -> timestamp
|
||||
mu sync.RWMutex
|
||||
ttlSecs int64
|
||||
maxSize int
|
||||
}
|
||||
|
||||
// NewRequestDeduplicator creates a new deduplicator.
|
||||
func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator {
|
||||
return &RequestDeduplicator{
|
||||
seen: make(map[string]int64),
|
||||
ttlSecs: ttlSecs,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// IsDuplicate checks if a request hash was seen recently.
|
||||
func (d *RequestDeduplicator) IsDuplicate(hash string) bool {
|
||||
now := time.Now().Unix()
|
||||
|
||||
d.mu.RLock()
|
||||
ts, exists := d.seen[hash]
|
||||
d.mu.RUnlock()
|
||||
|
||||
if exists && now-ts < d.ttlSecs {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Record marks a request hash as seen.
|
||||
func (d *RequestDeduplicator) Record(hash string) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// Evict old entries if at capacity
|
||||
if len(d.seen) >= d.maxSize {
|
||||
threshold := now - d.ttlSecs
|
||||
for k, ts := range d.seen {
|
||||
if ts < threshold {
|
||||
delete(d.seen, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.seen[hash] = now
|
||||
}
|
||||
|
||||
// hashRequest creates a hash of a request for deduplication.
|
||||
func hashRequest(toolName, input, output string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(toolName))
|
||||
h.Write([]byte(input))
|
||||
h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output
|
||||
return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient
|
||||
}
|
||||
|
||||
// BroadcastFunc is a callback for broadcasting events to SSE clients.
|
||||
type BroadcastFunc func(event map[string]interface{})
|
||||
type BroadcastFunc func(event map[string]any)
|
||||
|
||||
// SyncObservationFunc is a callback for syncing observations to vector DB.
|
||||
type SyncObservationFunc func(obs *models.Observation)
|
||||
@@ -29,6 +203,10 @@ type SyncObservationFunc func(obs *models.Observation)
|
||||
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
|
||||
type SyncSummaryFunc func(summary *models.SessionSummary)
|
||||
|
||||
// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations.
|
||||
// This prevents unbounded goroutine spawning during high-volume observation ingestion.
|
||||
const MaxVectorSyncWorkers = 8
|
||||
|
||||
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
|
||||
type Processor struct {
|
||||
claudePath string
|
||||
@@ -40,6 +218,14 @@ type Processor struct {
|
||||
syncSummaryFunc SyncSummaryFunc
|
||||
// Semaphore to limit concurrent Claude CLI calls (prevents API overload)
|
||||
sem chan struct{}
|
||||
// Circuit breaker for CLI failures
|
||||
circuitBreaker *CircuitBreaker
|
||||
// Request deduplicator to prevent duplicate processing
|
||||
deduplicator *RequestDeduplicator
|
||||
// Bounded worker pool for vector sync operations
|
||||
vectorSyncChan chan *models.Observation
|
||||
vectorSyncWg sync.WaitGroup
|
||||
vectorSyncDone chan struct{}
|
||||
}
|
||||
|
||||
// SetBroadcastFunc sets the broadcast callback for SSE events.
|
||||
@@ -58,7 +244,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
|
||||
}
|
||||
|
||||
// broadcast sends an event via the broadcast callback if set.
|
||||
func (p *Processor) broadcast(event map[string]interface{}) {
|
||||
func (p *Processor) broadcast(event map[string]any) {
|
||||
if p.broadcastFunc != nil {
|
||||
p.broadcastFunc(event)
|
||||
}
|
||||
@@ -94,9 +280,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
sem: make(chan struct{}, MaxConcurrentCLICalls),
|
||||
circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s
|
||||
deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries
|
||||
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel
|
||||
vectorSyncDone: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations.
|
||||
// Call this after setting the sync function via SetSyncObservationFunc.
|
||||
func (p *Processor) StartVectorSyncWorkers() {
|
||||
for i := 0; i < MaxVectorSyncWorkers; i++ {
|
||||
p.vectorSyncWg.Add(1)
|
||||
go p.vectorSyncWorker()
|
||||
}
|
||||
log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started")
|
||||
}
|
||||
|
||||
// StopVectorSyncWorkers gracefully stops the worker pool.
|
||||
func (p *Processor) StopVectorSyncWorkers() {
|
||||
close(p.vectorSyncDone)
|
||||
p.vectorSyncWg.Wait()
|
||||
log.Info().Msg("Vector sync worker pool stopped")
|
||||
}
|
||||
|
||||
// vectorSyncWorker is a worker goroutine that processes vector sync requests.
|
||||
func (p *Processor) vectorSyncWorker() {
|
||||
defer p.vectorSyncWg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-p.vectorSyncDone:
|
||||
// Drain remaining items before exiting
|
||||
for {
|
||||
select {
|
||||
case obs := <-p.vectorSyncChan:
|
||||
if p.syncObservationFunc != nil {
|
||||
p.syncObservationFunc(obs)
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case obs := <-p.vectorSyncChan:
|
||||
if p.syncObservationFunc != nil {
|
||||
p.syncObservationFunc(obs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerState returns the current state of the circuit breaker.
|
||||
func (p *Processor) CircuitBreakerState() string {
|
||||
return p.circuitBreaker.State()
|
||||
}
|
||||
|
||||
// CircuitBreakerMetrics returns detailed metrics about the circuit breaker.
|
||||
func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics {
|
||||
return p.circuitBreaker.Metrics()
|
||||
}
|
||||
|
||||
// IsAvailable checks if the Claude CLI is available for processing.
|
||||
func (p *Processor) IsAvailable() bool {
|
||||
_, err := os.Stat(p.claudePath)
|
||||
@@ -104,7 +346,7 @@ func (p *Processor) IsAvailable() bool {
|
||||
}
|
||||
|
||||
// ProcessObservation processes a single tool observation and extracts insights.
|
||||
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
|
||||
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error {
|
||||
// Skip certain tools that aren't worth processing
|
||||
if shouldSkipTool(toolName) {
|
||||
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
|
||||
@@ -121,11 +363,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for duplicate request within TTL window
|
||||
reqHash := hashRequest(toolName, inputStr, outputStr)
|
||||
if p.deduplicator.IsDuplicate(reqHash) {
|
||||
log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check circuit breaker before making CLI call
|
||||
if !p.circuitBreaker.Allow() {
|
||||
log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call")
|
||||
return fmt.Errorf("circuit breaker open")
|
||||
}
|
||||
|
||||
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
|
||||
|
||||
// Note: Removed the "file already has observations" check
|
||||
// Each tool execution can produce unique insights even for the same file
|
||||
// Similarity-based deduplication will handle true duplicates
|
||||
// Record this request to prevent duplicates
|
||||
p.deduplicator.Record(reqHash)
|
||||
|
||||
// Build the prompt
|
||||
exec := ToolExecution{
|
||||
@@ -147,9 +401,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
// Call Claude Code CLI
|
||||
response, err := p.callClaudeCLI(ctx, prompt)
|
||||
if err != nil {
|
||||
p.circuitBreaker.RecordFailure()
|
||||
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
|
||||
return err
|
||||
}
|
||||
p.circuitBreaker.RecordSuccess()
|
||||
|
||||
// Parse observations from response
|
||||
observations := ParseObservations(response, sdkSessionID)
|
||||
@@ -200,16 +456,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
Int("trackedFiles", len(obs.FileMtimes)).
|
||||
Msg("Observation stored")
|
||||
|
||||
// Sync to vector DB if callback is set
|
||||
if p.syncObservationFunc != nil {
|
||||
// Sync to vector DB via bounded worker pool (non-blocking to reduce latency)
|
||||
if p.syncObservationFunc != nil && p.vectorSyncChan != nil {
|
||||
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
|
||||
fullObs.ID = id
|
||||
fullObs.CreatedAtEpoch = createdAtEpoch
|
||||
p.syncObservationFunc(fullObs)
|
||||
// Non-blocking send to worker pool - drops if channel is full
|
||||
select {
|
||||
case p.vectorSyncChan <- fullObs:
|
||||
// Sent to worker pool
|
||||
default:
|
||||
// Channel full, fall back to direct sync in goroutine (bounded by channel buffer)
|
||||
log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine")
|
||||
go func(obsToSync *models.Observation) {
|
||||
p.syncObservationFunc(obsToSync)
|
||||
}(fullObs)
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast new observation event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
p.broadcast(map[string]any{
|
||||
"type": "observation",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
@@ -311,7 +577,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
|
||||
}
|
||||
|
||||
// Broadcast new summary event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
p.broadcast(map[string]any{
|
||||
"type": "summary",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
@@ -321,8 +587,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI.
|
||||
// This prevents resource exhaustion from extremely large prompts.
|
||||
const MaxPromptSize = 100 * 1024 // 100KB
|
||||
|
||||
// sanitizePrompt removes null bytes and control characters from a prompt.
|
||||
// Keeps newlines, tabs, and carriage returns as they're valid in prompts.
|
||||
func sanitizePrompt(s string) string {
|
||||
return strings.Map(func(r rune) rune {
|
||||
// Keep printable ASCII, extended Unicode, and common whitespace
|
||||
if r >= 32 || r == '\n' || r == '\t' || r == '\r' {
|
||||
return r
|
||||
}
|
||||
// Remove null bytes and other control characters
|
||||
return -1
|
||||
}, s)
|
||||
}
|
||||
|
||||
// callClaudeCLI calls the Claude Code CLI with the given prompt.
|
||||
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
|
||||
// Validate and sanitize prompt
|
||||
if len(prompt) > MaxPromptSize {
|
||||
return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize)
|
||||
}
|
||||
prompt = sanitizePrompt(prompt)
|
||||
|
||||
// Build the full prompt with system instructions
|
||||
fullPrompt := systemPrompt + "\n\n" + prompt
|
||||
|
||||
@@ -419,8 +708,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip if output indicates an error or empty result
|
||||
// Pre-compute lowercase strings once to avoid repeated allocations
|
||||
lowerOutput := strings.ToLower(outputStr)
|
||||
lowerInput := strings.ToLower(inputStr)
|
||||
|
||||
// Skip if output indicates an error or empty result
|
||||
trivialOutputs := []string{
|
||||
"no matches found",
|
||||
"file not found",
|
||||
@@ -444,13 +736,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
// Skip reading config files that rarely contain project-specific insights
|
||||
boringFiles := []string{
|
||||
"package-lock.json", "yarn.lock", "pnpm-lock.yaml",
|
||||
"go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock",
|
||||
"go.sum", "cargo.lock", "gemfile.lock", "poetry.lock",
|
||||
".gitignore", ".dockerignore", ".eslintignore",
|
||||
"tsconfig.json", "jsconfig.json", "vite.config",
|
||||
"tailwind.config", "postcss.config",
|
||||
}
|
||||
for _, boring := range boringFiles {
|
||||
if strings.Contains(inputStr, boring) {
|
||||
if strings.Contains(lowerInput, boring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -462,14 +754,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
}
|
||||
|
||||
case "Bash":
|
||||
// Skip simple status commands
|
||||
// Skip simple status commands (use pre-computed lowerInput)
|
||||
boringCommands := []string{
|
||||
"git status", "git diff", "git log", "git branch",
|
||||
"ls ", "pwd", "echo ", "cat ", "which ", "type ",
|
||||
"npm list", "npm outdated", "npm audit",
|
||||
}
|
||||
for _, boring := range boringCommands {
|
||||
if strings.Contains(strings.ToLower(inputStr), boring) {
|
||||
if strings.Contains(lowerInput, boring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -479,7 +771,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
}
|
||||
|
||||
// toJSONString converts an interface to a JSON string.
|
||||
func toJSONString(v interface{}) string {
|
||||
func toJSONString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -495,36 +787,85 @@ func toJSONString(v interface{}) string {
|
||||
|
||||
// captureFileMtimes captures current modification times for tracked files.
|
||||
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
|
||||
// For large file lists (>10 files), uses parallel stat calls for better performance.
|
||||
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
|
||||
mtimes := make(map[string]int64)
|
||||
// Combine all unique file paths
|
||||
allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified))
|
||||
for _, path := range filesRead {
|
||||
allPaths[path] = struct{}{}
|
||||
}
|
||||
for _, path := range filesModified {
|
||||
allPaths[path] = struct{}{}
|
||||
}
|
||||
|
||||
// Helper to get mtime for a file path
|
||||
getMtime := func(path string) (int64, bool) {
|
||||
// Resolve relative paths against cwd
|
||||
// For small lists, use sequential processing (goroutine overhead not worth it)
|
||||
if len(allPaths) <= 10 {
|
||||
return captureFileMtimesSequential(allPaths, cwd)
|
||||
}
|
||||
|
||||
// For larger lists, parallelize with bounded concurrency
|
||||
return captureFileMtimesParallel(allPaths, cwd)
|
||||
}
|
||||
|
||||
// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists).
|
||||
func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 {
|
||||
mtimes := make(map[string]int64, len(paths))
|
||||
|
||||
for path := range paths {
|
||||
absPath := path
|
||||
if !filepath.IsAbs(path) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, path)
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return info.ModTime().UnixMilli(), true
|
||||
}
|
||||
|
||||
// Capture mtimes for all read files
|
||||
for _, path := range filesRead {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
if err == nil {
|
||||
mtimes[path] = info.ModTime().UnixMilli()
|
||||
}
|
||||
}
|
||||
|
||||
// Capture mtimes for all modified files
|
||||
for _, path := range filesModified {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
}
|
||||
return mtimes
|
||||
}
|
||||
|
||||
// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency.
|
||||
func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 {
|
||||
type mtimeResult struct {
|
||||
path string
|
||||
mtime int64
|
||||
}
|
||||
|
||||
results := make(chan mtimeResult, len(paths))
|
||||
sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for path := range paths {
|
||||
wg.Add(1)
|
||||
go func(p string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // Acquire
|
||||
defer func() { <-sem }() // Release
|
||||
|
||||
absPath := p
|
||||
if !filepath.IsAbs(p) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, p)
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err == nil {
|
||||
results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()}
|
||||
}
|
||||
}(path)
|
||||
}
|
||||
|
||||
// Close results channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
mtimes := make(map[string]int64, len(paths))
|
||||
for res := range results {
|
||||
mtimes[res.path] = res.mtime
|
||||
}
|
||||
|
||||
return mtimes
|
||||
|
||||
@@ -974,3 +974,110 @@ func TestSyncSummaryFuncType(t *testing.T) {
|
||||
}
|
||||
assert.NotNil(t, fn)
|
||||
}
|
||||
|
||||
// TestSanitizePrompt tests prompt sanitization for CLI safety.
|
||||
func TestSanitizePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal text",
|
||||
input: "Hello, world!",
|
||||
expected: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "text with newlines",
|
||||
input: "Line 1\nLine 2\nLine 3",
|
||||
expected: "Line 1\nLine 2\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "text with tabs",
|
||||
input: "Key:\tValue",
|
||||
expected: "Key:\tValue",
|
||||
},
|
||||
{
|
||||
name: "text with carriage return",
|
||||
input: "Line 1\r\nLine 2",
|
||||
expected: "Line 1\r\nLine 2",
|
||||
},
|
||||
{
|
||||
name: "text with null bytes",
|
||||
input: "Hello\x00World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with control characters",
|
||||
input: "Hello\x01\x02\x03World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with bell character",
|
||||
input: "Hello\x07World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with backspace",
|
||||
input: "Hello\x08World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with form feed",
|
||||
input: "Hello\x0cWorld",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with escape",
|
||||
input: "Hello\x1bWorld",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "unicode text",
|
||||
input: "Hello 世界 🌍",
|
||||
expected: "Hello 世界 🌍",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "only control characters",
|
||||
input: "\x00\x01\x02\x03",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizePrompt(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxPromptSize tests that MaxPromptSize is reasonable.
|
||||
func TestMaxPromptSize(t *testing.T) {
|
||||
assert.Equal(t, 100*1024, MaxPromptSize)
|
||||
}
|
||||
|
||||
// BenchmarkSanitizePrompt benchmarks the sanitize function.
|
||||
func BenchmarkSanitizePrompt(b *testing.B) {
|
||||
prompt := "Analyze the following code:\n```go\nfunc main() {\n\tfmt.Println(\"Hello, World!\")\n}\n```\n\nPlease identify any issues."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sanitizePrompt(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSanitizePromptWithControlChars benchmarks sanitization with control characters.
|
||||
func BenchmarkSanitizePromptWithControlChars(b *testing.B) {
|
||||
prompt := "Hello\x00World\x01Test\x02Data\x03End"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sanitizePrompt(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
+673
-286
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,8 @@ import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
@@ -13,16 +15,22 @@ var staticFS embed.FS
|
||||
// staticSubFS is the static subdirectory filesystem
|
||||
var staticSubFS fs.FS
|
||||
|
||||
// staticInitErr stores any error from static filesystem initialization
|
||||
var staticInitErr error
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
staticSubFS, err = fs.Sub(staticFS, "static")
|
||||
if err != nil {
|
||||
panic("failed to create sub filesystem: " + err.Error())
|
||||
staticSubFS, staticInitErr = fs.Sub(staticFS, "static")
|
||||
if staticInitErr != nil {
|
||||
log.Warn().Err(staticInitErr).Msg("Static filesystem initialization failed - dashboard will be unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
// serveIndex serves the index.html file for the root path
|
||||
func serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||
if staticInitErr != nil {
|
||||
http.Error(w, "Dashboard unavailable: static files not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
content, err := fs.ReadFile(staticSubFS, "index.html")
|
||||
if err != nil {
|
||||
http.Error(w, "Dashboard not found", http.StatusNotFound)
|
||||
@@ -38,6 +46,10 @@ func serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// serveAssets serves static assets from the embedded filesystem
|
||||
func serveAssets(w http.ResponseWriter, r *http.Request) {
|
||||
if staticInitErr != nil {
|
||||
http.Error(w, "Assets unavailable: static files not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
// Strip the /assets/ prefix and serve the file
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package similarity
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
@@ -15,6 +16,17 @@ func ClusterObservations(observations []*models.Observation, similarityThreshold
|
||||
return observations
|
||||
}
|
||||
|
||||
// For small sets, use the simple O(n²) algorithm
|
||||
if len(observations) <= 50 {
|
||||
return clusterObservationsSimple(observations, similarityThreshold)
|
||||
}
|
||||
|
||||
// For larger sets, use an optimized approach with early termination
|
||||
return clusterObservationsOptimized(observations, similarityThreshold)
|
||||
}
|
||||
|
||||
// clusterObservationsSimple is the simple O(n²) algorithm for small sets.
|
||||
func clusterObservationsSimple(observations []*models.Observation, similarityThreshold float64) []*models.Observation {
|
||||
// Extract terms for each observation
|
||||
termSets := make([]map[string]bool, len(observations))
|
||||
for i, obs := range observations {
|
||||
@@ -51,6 +63,93 @@ func ClusterObservations(observations []*models.Observation, similarityThreshold
|
||||
return result
|
||||
}
|
||||
|
||||
// clusterObservationsOptimized uses MinHash-based approximation for large sets.
|
||||
// This reduces complexity from O(n²) to approximately O(n*k) where k is the number of hash functions.
|
||||
func clusterObservationsOptimized(observations []*models.Observation, similarityThreshold float64) []*models.Observation {
|
||||
n := len(observations)
|
||||
|
||||
// Extract terms for each observation and compute a signature
|
||||
type termSetWithSig struct {
|
||||
terms map[string]bool
|
||||
signature uint64 // Simple hash signature for fast comparison
|
||||
}
|
||||
|
||||
termSets := make([]termSetWithSig, n)
|
||||
for i, obs := range observations {
|
||||
terms := ExtractObservationTerms(obs)
|
||||
termSets[i] = termSetWithSig{
|
||||
terms: terms,
|
||||
signature: computeTermSignature(terms),
|
||||
}
|
||||
}
|
||||
|
||||
// Track which observations are already clustered
|
||||
clustered := make([]bool, n)
|
||||
result := make([]*models.Observation, 0, n/2) // Pre-allocate assuming ~50% are unique
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if clustered[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
// This observation becomes the representative of its cluster
|
||||
result = append(result, observations[i])
|
||||
clustered[i] = true
|
||||
|
||||
// Use signature for fast pre-filtering
|
||||
sigI := termSets[i].signature
|
||||
termsI := termSets[i].terms
|
||||
|
||||
// Find all similar observations and mark them as clustered
|
||||
for j := i + 1; j < n; j++ {
|
||||
if clustered[j] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Quick signature comparison - if signatures are very different, skip detailed comparison
|
||||
sigJ := termSets[j].signature
|
||||
sigDiff := sigI ^ sigJ
|
||||
popCount := popCount64(sigDiff)
|
||||
|
||||
// If signatures differ significantly, similarity is likely low
|
||||
// Skip detailed comparison for very different signatures
|
||||
if popCount > 32 { // More than half of bits differ
|
||||
continue
|
||||
}
|
||||
|
||||
// Full Jaccard comparison for candidates
|
||||
similarity := JaccardSimilarity(termsI, termSets[j].terms)
|
||||
if similarity >= similarityThreshold {
|
||||
clustered[j] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// computeTermSignature creates a quick hash signature for term sets.
|
||||
// Used for fast pre-filtering in the optimized clustering algorithm.
|
||||
func computeTermSignature(terms map[string]bool) uint64 {
|
||||
var sig uint64
|
||||
for term := range terms {
|
||||
// Simple hash using FNV-1a inspired approach
|
||||
h := uint64(14695981039346656037)
|
||||
for i := 0; i < len(term); i++ {
|
||||
h ^= uint64(term[i])
|
||||
h *= 1099511628211
|
||||
}
|
||||
sig ^= h
|
||||
}
|
||||
return sig
|
||||
}
|
||||
|
||||
// popCount64 counts the number of set bits in a 64-bit integer.
|
||||
// Uses the stdlib bits.OnesCount64 which may use CPU POPCNT instruction.
|
||||
func popCount64(x uint64) int {
|
||||
return bits.OnesCount64(x)
|
||||
}
|
||||
|
||||
// IsSimilarToAny checks if a new observation is similar to any existing observation.
|
||||
// Returns true if similarity to any existing observation exceeds the threshold.
|
||||
func IsSimilarToAny(newObs *models.Observation, existing []*models.Observation, similarityThreshold float64) bool {
|
||||
|
||||
Generated
+2
-2
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"version": "v0.10.5-1-g7ab4b07-dirty",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"version": "v0.10.5-1-g7ab4b07-dirty",
|
||||
"dependencies": {
|
||||
"vis-data": "^7.1.9",
|
||||
"vis-network": "^9.1.9",
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "claude-mnemonic-dashboard",
|
||||
"version": "8fe9ea5-dirty",
|
||||
"version": "v0.10.5-1-g7ab4b07-dirty",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
||||
@@ -8,6 +8,7 @@ import Card from './Card.vue'
|
||||
import IconBox from './IconBox.vue'
|
||||
import Badge from './Badge.vue'
|
||||
import RelationGraph from './RelationGraph.vue'
|
||||
import ScoreBreakdown from './ScoreBreakdown.vue'
|
||||
import { computed, ref, onMounted } from 'vue'
|
||||
|
||||
const props = defineProps<{
|
||||
@@ -95,6 +96,9 @@ const relationsLoading = ref(false)
|
||||
const relationsExpanded = ref(false)
|
||||
const showGraph = ref(false)
|
||||
|
||||
// Score breakdown state
|
||||
const showScoreBreakdown = ref(false)
|
||||
|
||||
const hasRelations = computed(() => relations.value.length > 0)
|
||||
const relationCount = computed(() => relations.value.length)
|
||||
|
||||
@@ -350,14 +354,15 @@ const splitPath = (path: string, components = 3) => {
|
||||
<i class="fas fa-thumbs-up text-sm" />
|
||||
</button>
|
||||
|
||||
<span
|
||||
class="text-[10px] font-mono px-1.5 py-0.5 rounded bg-slate-800/50 text-slate-400 flex items-center gap-1 transition-all duration-300"
|
||||
<button
|
||||
@click="showScoreBreakdown = true"
|
||||
class="text-[10px] font-mono px-1.5 py-0.5 rounded bg-slate-800/50 text-slate-400 flex items-center gap-1 transition-all duration-300 hover:bg-purple-500/20 hover:text-purple-300 cursor-pointer"
|
||||
:class="{ 'text-green-400': localScore !== null && localScore > (observation.importance_score || 1), 'text-red-400': localScore !== null && localScore < (observation.importance_score || 1) }"
|
||||
:title="`Importance Score: ${currentScore.toFixed(3)}\nRetrieval Count: ${observation.retrieval_count || 0}`"
|
||||
:title="`Importance Score: ${currentScore.toFixed(3)}\nRetrieval Count: ${observation.retrieval_count || 0}\nClick for details`"
|
||||
>
|
||||
<i class="fas fa-scale-balanced text-amber-500/60" />
|
||||
<i class="fas fa-chart-bar text-purple-500/60" />
|
||||
{{ currentScore.toFixed(2) }}
|
||||
</span>
|
||||
</button>
|
||||
|
||||
<button
|
||||
@click="submitFeedback(-1)"
|
||||
@@ -383,5 +388,12 @@ const splitPath = (path: string, components = 3) => {
|
||||
@close="showGraph = false"
|
||||
@navigate-to="handleNavigateTo"
|
||||
/>
|
||||
|
||||
<!-- Score Breakdown Modal -->
|
||||
<ScoreBreakdown
|
||||
:observation-id="observation.id"
|
||||
:show="showScoreBreakdown"
|
||||
@close="showScoreBreakdown = false"
|
||||
/>
|
||||
</Card>
|
||||
</template>
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, watch } from 'vue'
|
||||
import { fetchObservationScore, type ScoreBreakdown } from '@/utils/api'
|
||||
import Card from './Card.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
observationId: number
|
||||
show: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
}>()
|
||||
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const data = ref<ScoreBreakdown | null>(null)
|
||||
|
||||
const loadScore = async () => {
|
||||
if (!props.observationId) return
|
||||
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
data.value = await fetchObservationScore(props.observationId)
|
||||
} catch (err) {
|
||||
error.value = err instanceof Error ? err.message : 'Failed to load score breakdown'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Load on mount and when ID changes
|
||||
onMounted(() => {
|
||||
if (props.show) loadScore()
|
||||
})
|
||||
|
||||
watch(() => props.show, (newVal) => {
|
||||
if (newVal) loadScore()
|
||||
})
|
||||
|
||||
watch(() => props.observationId, () => {
|
||||
if (props.show) loadScore()
|
||||
})
|
||||
|
||||
// Score bar helper
|
||||
const getScoreBarWidth = (value: number, max: number = 2) => {
|
||||
return `${Math.min(100, Math.max(0, (value / max) * 100))}%`
|
||||
}
|
||||
|
||||
// Score color helper
|
||||
const getScoreColor = (value: number) => {
|
||||
if (value >= 1.5) return 'bg-green-500'
|
||||
if (value >= 1) return 'bg-amber-500'
|
||||
if (value >= 0.5) return 'bg-orange-500'
|
||||
return 'bg-red-500'
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<!-- Modal Backdrop -->
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="show"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center"
|
||||
>
|
||||
<!-- Backdrop -->
|
||||
<div
|
||||
class="absolute inset-0 bg-black/60 backdrop-blur-sm"
|
||||
@click="emit('close')"
|
||||
/>
|
||||
|
||||
<!-- Modal Content -->
|
||||
<div class="relative w-full max-w-lg mx-4 max-h-[90vh] overflow-y-auto">
|
||||
<Card
|
||||
gradient="bg-gradient-to-br from-purple-500/10 to-indigo-500/5"
|
||||
border-class="border-purple-500/30"
|
||||
>
|
||||
<!-- Header -->
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-chart-bar text-purple-400" />
|
||||
<h3 class="text-lg font-semibold text-purple-100">Score Breakdown</h3>
|
||||
</div>
|
||||
<button
|
||||
@click="emit('close')"
|
||||
class="p-1.5 text-slate-400 hover:text-slate-200 hover:bg-slate-700/50 rounded-lg transition-colors"
|
||||
>
|
||||
<i class="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Loading State -->
|
||||
<div v-if="loading" class="flex items-center justify-center py-8">
|
||||
<i class="fas fa-circle-notch fa-spin text-2xl text-purple-400" />
|
||||
</div>
|
||||
|
||||
<!-- Error State -->
|
||||
<div v-else-if="error" class="text-center py-8">
|
||||
<i class="fas fa-exclamation-triangle text-2xl text-red-400 mb-2" />
|
||||
<p class="text-red-300">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Content -->
|
||||
<div v-else-if="data" class="space-y-4">
|
||||
<!-- Observation Info -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg">
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide mb-1">Observation</div>
|
||||
<div class="text-amber-100 font-medium">{{ data.observation.title || 'Untitled' }}</div>
|
||||
<div class="flex items-center gap-2 mt-1 text-xs text-slate-400">
|
||||
<span class="px-1.5 py-0.5 bg-slate-700/50 rounded">{{ data.observation.type }}</span>
|
||||
<span>{{ data.scoring.age_days.toFixed(1) }} days old</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Final Score -->
|
||||
<div class="p-4 bg-gradient-to-r from-purple-500/20 to-indigo-500/10 rounded-lg border border-purple-500/30">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-sm text-slate-300">Final Score</span>
|
||||
<span class="text-2xl font-bold text-purple-300">{{ data.scoring.final_score.toFixed(3) }}</span>
|
||||
</div>
|
||||
<div class="mt-2 h-2 bg-slate-700 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full transition-all duration-500"
|
||||
:class="getScoreColor(data.scoring.final_score)"
|
||||
:style="{ width: getScoreBarWidth(data.scoring.final_score) }"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Score Components -->
|
||||
<div class="space-y-3">
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Score Components</div>
|
||||
|
||||
<!-- Type Weight -->
|
||||
<div class="flex items-center justify-between text-sm">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-tag text-blue-400 w-4" />
|
||||
<span class="text-slate-300">Type Weight</span>
|
||||
</div>
|
||||
<span class="font-mono text-blue-300">{{ data.scoring.type_weight.toFixed(2) }}</span>
|
||||
</div>
|
||||
<p class="text-xs text-slate-500 ml-6 -mt-2">{{ data.explanation.type_impact }}</p>
|
||||
|
||||
<!-- Recency Decay -->
|
||||
<div class="flex items-center justify-between text-sm">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-clock text-cyan-400 w-4" />
|
||||
<span class="text-slate-300">Recency Decay</span>
|
||||
</div>
|
||||
<span class="font-mono text-cyan-300">{{ data.scoring.recency_decay.toFixed(2) }}</span>
|
||||
</div>
|
||||
<p class="text-xs text-slate-500 ml-6 -mt-2">{{ data.explanation.recency_impact }}</p>
|
||||
|
||||
<!-- Core Score -->
|
||||
<div class="flex items-center justify-between text-sm">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-star text-amber-400 w-4" />
|
||||
<span class="text-slate-300">Core Score</span>
|
||||
</div>
|
||||
<span class="font-mono text-amber-300">{{ data.scoring.core_score.toFixed(3) }}</span>
|
||||
</div>
|
||||
|
||||
<!-- Feedback Contribution -->
|
||||
<div class="flex items-center justify-between text-sm">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-thumbs-up text-green-400 w-4" />
|
||||
<span class="text-slate-300">Feedback</span>
|
||||
</div>
|
||||
<span class="font-mono" :class="data.scoring.feedback_contrib >= 0 ? 'text-green-300' : 'text-red-300'">
|
||||
{{ data.scoring.feedback_contrib >= 0 ? '+' : '' }}{{ data.scoring.feedback_contrib.toFixed(3) }}
|
||||
</span>
|
||||
</div>
|
||||
<p class="text-xs text-slate-500 ml-6 -mt-2">{{ data.explanation.feedback_impact }}</p>
|
||||
|
||||
<!-- Concept Contribution -->
|
||||
<div class="flex items-center justify-between text-sm">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-tags text-purple-400 w-4" />
|
||||
<span class="text-slate-300">Concepts</span>
|
||||
</div>
|
||||
<span class="font-mono text-purple-300">+{{ data.scoring.concept_contrib.toFixed(3) }}</span>
|
||||
</div>
|
||||
<p class="text-xs text-slate-500 ml-6 -mt-2">{{ data.explanation.concept_impact }}</p>
|
||||
|
||||
<!-- Retrieval Contribution -->
|
||||
<div class="flex items-center justify-between text-sm">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-search text-indigo-400 w-4" />
|
||||
<span class="text-slate-300">Retrieval</span>
|
||||
</div>
|
||||
<span class="font-mono text-indigo-300">+{{ data.scoring.retrieval_contrib.toFixed(3) }}</span>
|
||||
</div>
|
||||
<p class="text-xs text-slate-500 ml-6 -mt-2">{{ data.explanation.retrieval_impact }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
@@ -0,0 +1,271 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, watch, computed } from 'vue'
|
||||
import { fetchSearchAnalytics, fetchRecentSearches, type SearchAnalytics, type RecentQuery } from '@/utils/api'
|
||||
import Card from './Card.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
show: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
}>()
|
||||
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const analytics = ref<SearchAnalytics | null>(null)
|
||||
const recentSearches = ref<RecentQuery[]>([])
|
||||
|
||||
const loadData = async () => {
|
||||
if (!props.show) return
|
||||
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
const [analyticsData, searchesData] = await Promise.all([
|
||||
fetchSearchAnalytics(),
|
||||
fetchRecentSearches(20)
|
||||
])
|
||||
analytics.value = analyticsData
|
||||
recentSearches.value = searchesData
|
||||
} catch (err) {
|
||||
error.value = err instanceof Error ? err.message : 'Failed to load search analytics'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Load on mount and when show changes
|
||||
onMounted(() => {
|
||||
if (props.show) loadData()
|
||||
})
|
||||
|
||||
watch(() => props.show, (newVal) => {
|
||||
if (newVal) loadData()
|
||||
})
|
||||
|
||||
// Computed stats
|
||||
const cacheHitRate = computed(() => {
|
||||
if (!analytics.value || analytics.value.total_searches === 0) return 0
|
||||
return (analytics.value.cache_hits / analytics.value.total_searches) * 100
|
||||
})
|
||||
|
||||
const coalescedRate = computed(() => {
|
||||
if (!analytics.value || analytics.value.total_searches === 0) return 0
|
||||
return (analytics.value.coalesced_requests / analytics.value.total_searches) * 100
|
||||
})
|
||||
|
||||
const errorRate = computed(() => {
|
||||
if (!analytics.value || analytics.value.total_searches === 0) return 0
|
||||
return (analytics.value.search_errors / analytics.value.total_searches) * 100
|
||||
})
|
||||
|
||||
// Helper for latency color
|
||||
const getLatencyColor = (ms: number) => {
|
||||
if (ms < 10) return 'text-green-400'
|
||||
if (ms < 50) return 'text-amber-400'
|
||||
return 'text-red-400'
|
||||
}
|
||||
|
||||
// Helper for formatting time ago
|
||||
const formatTimeAgo = (isoDate: string) => {
|
||||
const date = new Date(isoDate)
|
||||
const now = new Date()
|
||||
const diffMs = now.getTime() - date.getTime()
|
||||
const diffMins = Math.floor(diffMs / 60000)
|
||||
const diffHours = Math.floor(diffMs / 3600000)
|
||||
const diffDays = Math.floor(diffMs / 86400000)
|
||||
|
||||
if (diffMins < 1) return 'just now'
|
||||
if (diffMins < 60) return `${diffMins}m ago`
|
||||
if (diffHours < 24) return `${diffHours}h ago`
|
||||
return `${diffDays}d ago`
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<!-- Modal Backdrop -->
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="show"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center"
|
||||
>
|
||||
<!-- Backdrop -->
|
||||
<div
|
||||
class="absolute inset-0 bg-black/60 backdrop-blur-sm"
|
||||
@click="emit('close')"
|
||||
/>
|
||||
|
||||
<!-- Modal Content -->
|
||||
<div class="relative w-full max-w-2xl mx-4 max-h-[90vh] overflow-y-auto">
|
||||
<Card
|
||||
gradient="bg-gradient-to-br from-cyan-500/10 to-blue-500/5"
|
||||
border-class="border-cyan-500/30"
|
||||
>
|
||||
<!-- Header -->
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-chart-line text-cyan-400" />
|
||||
<h3 class="text-lg font-semibold text-cyan-100">Search Analytics</h3>
|
||||
</div>
|
||||
<button
|
||||
@click="emit('close')"
|
||||
class="p-1.5 text-slate-400 hover:text-slate-200 hover:bg-slate-700/50 rounded-lg transition-colors"
|
||||
>
|
||||
<i class="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Loading State -->
|
||||
<div v-if="loading" class="flex items-center justify-center py-8">
|
||||
<i class="fas fa-circle-notch fa-spin text-2xl text-cyan-400" />
|
||||
</div>
|
||||
|
||||
<!-- Error State -->
|
||||
<div v-else-if="error" class="text-center py-8">
|
||||
<i class="fas fa-exclamation-triangle text-2xl text-red-400 mb-2" />
|
||||
<p class="text-red-300">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Content -->
|
||||
<div v-else-if="analytics" class="space-y-6">
|
||||
<!-- Overview Stats Grid -->
|
||||
<div class="grid grid-cols-2 md:grid-cols-4 gap-3">
|
||||
<!-- Total Searches -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg text-center">
|
||||
<div class="text-2xl font-bold text-cyan-300">{{ analytics.total_searches.toLocaleString() }}</div>
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Total Searches</div>
|
||||
</div>
|
||||
|
||||
<!-- Vector Searches -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg text-center">
|
||||
<div class="text-2xl font-bold text-purple-300">{{ analytics.vector_searches.toLocaleString() }}</div>
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Vector Searches</div>
|
||||
</div>
|
||||
|
||||
<!-- Filter Searches -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg text-center">
|
||||
<div class="text-2xl font-bold text-blue-300">{{ analytics.filter_searches.toLocaleString() }}</div>
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Filter Searches</div>
|
||||
</div>
|
||||
|
||||
<!-- Cache Hits -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg text-center">
|
||||
<div class="text-2xl font-bold text-green-300">{{ analytics.cache_hits.toLocaleString() }}</div>
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Cache Hits</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Performance Metrics -->
|
||||
<div class="space-y-3">
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Performance Metrics</div>
|
||||
|
||||
<!-- Cache Hit Rate -->
|
||||
<div class="flex items-center justify-between p-3 bg-slate-800/30 rounded-lg">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-database text-green-400 w-5" />
|
||||
<span class="text-slate-300">Cache Hit Rate</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-24 h-2 bg-slate-700 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-green-500 transition-all"
|
||||
:style="{ width: `${cacheHitRate}%` }"
|
||||
/>
|
||||
</div>
|
||||
<span class="font-mono text-green-300 w-16 text-right">{{ cacheHitRate.toFixed(1) }}%</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Coalesced Rate -->
|
||||
<div class="flex items-center justify-between p-3 bg-slate-800/30 rounded-lg">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-compress-arrows-alt text-amber-400 w-5" />
|
||||
<span class="text-slate-300">Coalesced Requests</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-24 h-2 bg-slate-700 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-amber-500 transition-all"
|
||||
:style="{ width: `${coalescedRate}%` }"
|
||||
/>
|
||||
</div>
|
||||
<span class="font-mono text-amber-300 w-16 text-right">{{ coalescedRate.toFixed(1) }}%</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Error Rate -->
|
||||
<div class="flex items-center justify-between p-3 bg-slate-800/30 rounded-lg">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-exclamation-circle text-red-400 w-5" />
|
||||
<span class="text-slate-300">Error Rate</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-24 h-2 bg-slate-700 rounded-full overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-red-500 transition-all"
|
||||
:style="{ width: `${Math.min(100, errorRate)}%` }"
|
||||
/>
|
||||
</div>
|
||||
<span class="font-mono text-red-300 w-16 text-right">{{ errorRate.toFixed(2) }}%</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Latency Stats -->
|
||||
<div class="space-y-3">
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Latency</div>
|
||||
|
||||
<div class="grid grid-cols-3 gap-3">
|
||||
<!-- Average Latency -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg text-center">
|
||||
<div class="text-xl font-bold font-mono" :class="getLatencyColor(analytics.avg_latency_ms)">
|
||||
{{ analytics.avg_latency_ms.toFixed(1) }}ms
|
||||
</div>
|
||||
<div class="text-xs text-slate-500">Average</div>
|
||||
</div>
|
||||
|
||||
<!-- Vector Latency -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg text-center">
|
||||
<div class="text-xl font-bold font-mono" :class="getLatencyColor(analytics.avg_vector_latency_ms)">
|
||||
{{ analytics.avg_vector_latency_ms.toFixed(1) }}ms
|
||||
</div>
|
||||
<div class="text-xs text-slate-500">Vector</div>
|
||||
</div>
|
||||
|
||||
<!-- Filter Latency -->
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg text-center">
|
||||
<div class="text-xl font-bold font-mono" :class="getLatencyColor(analytics.avg_filter_latency_ms)">
|
||||
{{ analytics.avg_filter_latency_ms.toFixed(1) }}ms
|
||||
</div>
|
||||
<div class="text-xs text-slate-500">Filter</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Recent Searches -->
|
||||
<div v-if="recentSearches.length > 0" class="space-y-3">
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Recent Searches</div>
|
||||
|
||||
<div class="space-y-2 max-h-48 overflow-y-auto">
|
||||
<div
|
||||
v-for="(search, index) in recentSearches"
|
||||
:key="index"
|
||||
class="flex items-center gap-3 p-2 bg-slate-800/30 rounded-lg text-sm"
|
||||
>
|
||||
<i class="fas fa-search text-slate-500 text-xs" />
|
||||
<span class="flex-1 text-slate-300 truncate" :title="search.query">{{ search.query }}</span>
|
||||
<span v-if="search.project" class="text-xs text-amber-600/80 font-mono">{{ search.project.split('/').pop() }}</span>
|
||||
<span v-if="search.type" class="text-xs text-cyan-500 bg-cyan-500/10 px-1.5 py-0.5 rounded">{{ search.type }}</span>
|
||||
<span class="text-xs text-slate-500 font-mono">×{{ search.count }}</span>
|
||||
<span class="text-xs text-slate-600">{{ formatTimeAgo(search.last_used) }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
@@ -2,6 +2,9 @@
|
||||
import { ref, computed } from 'vue'
|
||||
import type { Stats, SelfCheckResponse } from '@/types'
|
||||
import ProjectFilter from './ProjectFilter.vue'
|
||||
import SearchAnalytics from './SearchAnalytics.vue'
|
||||
import SystemHealthDetails from './SystemHealthDetails.vue'
|
||||
import TopObservations from './TopObservations.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
stats: Stats | null
|
||||
@@ -19,6 +22,15 @@ defineEmits<{
|
||||
// Collapse state - persisted in localStorage
|
||||
const isCollapsed = ref(localStorage.getItem('sidebar-collapsed') === 'true')
|
||||
|
||||
// Search Analytics modal state
|
||||
const showSearchAnalytics = ref(false)
|
||||
|
||||
// System Health Details modal state
|
||||
const showHealthDetails = ref(false)
|
||||
|
||||
// Top Observations modal state
|
||||
const showTopObservations = ref(false)
|
||||
|
||||
function toggleCollapse() {
|
||||
isCollapsed.value = !isCollapsed.value
|
||||
localStorage.setItem('sidebar-collapsed', String(isCollapsed.value))
|
||||
@@ -96,9 +108,18 @@ function getStatusColor(status: string): string {
|
||||
|
||||
<!-- Component Health -->
|
||||
<div class="bg-slate-800/50 rounded-lg p-4 border border-slate-700/50">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<i :class="['fas', overallHealthIcon, overallHealthColor]" />
|
||||
<h3 class="text-sm font-semibold text-white">System Health</h3>
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<i :class="['fas', overallHealthIcon, overallHealthColor]" />
|
||||
<h3 class="text-sm font-semibold text-white">System Health</h3>
|
||||
</div>
|
||||
<button
|
||||
@click="showHealthDetails = true"
|
||||
class="text-xs text-emerald-400 hover:text-emerald-300 transition-colors"
|
||||
title="View detailed health status"
|
||||
>
|
||||
<i class="fas fa-expand" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="health" class="space-y-2">
|
||||
@@ -124,9 +145,18 @@ function getStatusColor(status: string): string {
|
||||
|
||||
<!-- Memory Stats -->
|
||||
<div class="bg-slate-800/50 rounded-lg p-4 border border-slate-700/50">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<i class="fas fa-brain text-purple-400" />
|
||||
<h3 class="text-sm font-semibold text-white">Memory Contents</h3>
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-brain text-purple-400" />
|
||||
<h3 class="text-sm font-semibold text-white">Memory Contents</h3>
|
||||
</div>
|
||||
<button
|
||||
@click="showTopObservations = true"
|
||||
class="text-xs text-amber-400 hover:text-amber-300 transition-colors"
|
||||
title="View top observations"
|
||||
>
|
||||
<i class="fas fa-trophy" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="space-y-3">
|
||||
@@ -161,9 +191,18 @@ function getStatusColor(status: string): string {
|
||||
|
||||
<!-- Retrieval Stats -->
|
||||
<div v-if="stats?.retrieval" class="bg-slate-800/50 rounded-lg p-4 border border-slate-700/50">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<i class="fas fa-search text-cyan-400" />
|
||||
<h3 class="text-sm font-semibold text-white">Retrieval Stats</h3>
|
||||
<div class="flex items-center justify-between mb-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-search text-cyan-400" />
|
||||
<h3 class="text-sm font-semibold text-white">Retrieval Stats</h3>
|
||||
</div>
|
||||
<button
|
||||
@click="showSearchAnalytics = true"
|
||||
class="text-xs text-cyan-400 hover:text-cyan-300 transition-colors"
|
||||
title="View detailed analytics"
|
||||
>
|
||||
<i class="fas fa-chart-line" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="space-y-3">
|
||||
@@ -261,5 +300,25 @@ function getStatusColor(status: string): string {
|
||||
<i class="fas fa-search text-cyan-400" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Search Analytics Modal -->
|
||||
<SearchAnalytics
|
||||
:show="showSearchAnalytics"
|
||||
@close="showSearchAnalytics = false"
|
||||
/>
|
||||
|
||||
<!-- System Health Details Modal -->
|
||||
<SystemHealthDetails
|
||||
:show="showHealthDetails"
|
||||
@close="showHealthDetails = false"
|
||||
/>
|
||||
|
||||
<!-- Top Observations Modal -->
|
||||
<TopObservations
|
||||
:show="showTopObservations"
|
||||
:current-project="currentProject"
|
||||
@close="showTopObservations = false"
|
||||
@navigate-to-observation="$emit('update:project', null)"
|
||||
/>
|
||||
</aside>
|
||||
</template>
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, watch, computed } from 'vue'
|
||||
import { fetchSystemHealth, type SystemHealth } from '@/utils/api'
|
||||
import Card from './Card.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
show: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
}>()
|
||||
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const health = ref<SystemHealth | null>(null)
|
||||
|
||||
const loadData = async () => {
|
||||
if (!props.show) return
|
||||
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
health.value = await fetchSystemHealth()
|
||||
} catch (err) {
|
||||
error.value = err instanceof Error ? err.message : 'Failed to load system health'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Load on mount and when show changes
|
||||
onMounted(() => {
|
||||
if (props.show) loadData()
|
||||
})
|
||||
|
||||
watch(() => props.show, (newVal) => {
|
||||
if (newVal) loadData()
|
||||
})
|
||||
|
||||
// Status helpers
|
||||
const getStatusIcon = (status: string) => {
|
||||
switch (status) {
|
||||
case 'healthy': return 'fa-circle-check'
|
||||
case 'degraded': return 'fa-triangle-exclamation'
|
||||
case 'unhealthy': return 'fa-circle-xmark'
|
||||
default: return 'fa-circle-question'
|
||||
}
|
||||
}
|
||||
|
||||
const getStatusColor = (status: string) => {
|
||||
switch (status) {
|
||||
case 'healthy': return 'text-green-400'
|
||||
case 'degraded': return 'text-amber-400'
|
||||
case 'unhealthy': return 'text-red-400'
|
||||
default: return 'text-slate-400'
|
||||
}
|
||||
}
|
||||
|
||||
const getStatusBgColor = (status: string) => {
|
||||
switch (status) {
|
||||
case 'healthy': return 'bg-green-500/20 border-green-500/30'
|
||||
case 'degraded': return 'bg-amber-500/20 border-amber-500/30'
|
||||
case 'unhealthy': return 'bg-red-500/20 border-red-500/30'
|
||||
default: return 'bg-slate-500/20 border-slate-500/30'
|
||||
}
|
||||
}
|
||||
|
||||
const getLatencyColor = (ms: number | undefined) => {
|
||||
if (!ms) return 'text-slate-400'
|
||||
if (ms < 10) return 'text-green-400'
|
||||
if (ms < 50) return 'text-amber-400'
|
||||
return 'text-red-400'
|
||||
}
|
||||
|
||||
// Count healthy/degraded/unhealthy components
|
||||
const componentCounts = computed(() => {
|
||||
if (!health.value) return { healthy: 0, degraded: 0, unhealthy: 0 }
|
||||
const counts = { healthy: 0, degraded: 0, unhealthy: 0 }
|
||||
for (const c of health.value.components) {
|
||||
if (c.status === 'healthy') counts.healthy++
|
||||
else if (c.status === 'degraded') counts.degraded++
|
||||
else if (c.status === 'unhealthy') counts.unhealthy++
|
||||
}
|
||||
return counts
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<!-- Modal Backdrop -->
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="show"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center"
|
||||
>
|
||||
<!-- Backdrop -->
|
||||
<div
|
||||
class="absolute inset-0 bg-black/60 backdrop-blur-sm"
|
||||
@click="emit('close')"
|
||||
/>
|
||||
|
||||
<!-- Modal Content -->
|
||||
<div class="relative w-full max-w-xl mx-4 max-h-[90vh] overflow-y-auto">
|
||||
<Card
|
||||
gradient="bg-gradient-to-br from-emerald-500/10 to-green-500/5"
|
||||
border-class="border-emerald-500/30"
|
||||
>
|
||||
<!-- Header -->
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-heartbeat text-emerald-400" />
|
||||
<h3 class="text-lg font-semibold text-emerald-100">System Health</h3>
|
||||
</div>
|
||||
<button
|
||||
@click="emit('close')"
|
||||
class="p-1.5 text-slate-400 hover:text-slate-200 hover:bg-slate-700/50 rounded-lg transition-colors"
|
||||
>
|
||||
<i class="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Loading State -->
|
||||
<div v-if="loading" class="flex items-center justify-center py-8">
|
||||
<i class="fas fa-circle-notch fa-spin text-2xl text-emerald-400" />
|
||||
</div>
|
||||
|
||||
<!-- Error State -->
|
||||
<div v-else-if="error" class="text-center py-8">
|
||||
<i class="fas fa-exclamation-triangle text-2xl text-red-400 mb-2" />
|
||||
<p class="text-red-300">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Content -->
|
||||
<div v-else-if="health" class="space-y-5">
|
||||
<!-- Overall Status -->
|
||||
<div
|
||||
class="p-4 rounded-lg border"
|
||||
:class="getStatusBgColor(health.status)"
|
||||
>
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center gap-3">
|
||||
<i
|
||||
class="fas text-3xl"
|
||||
:class="[getStatusIcon(health.status), getStatusColor(health.status)]"
|
||||
/>
|
||||
<div>
|
||||
<div class="text-lg font-semibold capitalize" :class="getStatusColor(health.status)">
|
||||
{{ health.status }}
|
||||
</div>
|
||||
<div class="text-xs text-slate-400">Overall System Status</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="text-right">
|
||||
<div class="text-sm text-slate-300 font-mono">{{ health.version }}</div>
|
||||
<div class="text-xs text-slate-500">Version</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Component Status Summary -->
|
||||
<div class="grid grid-cols-3 gap-3 text-center">
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg">
|
||||
<div class="text-xl font-bold text-green-400">{{ componentCounts.healthy }}</div>
|
||||
<div class="text-xs text-slate-500">Healthy</div>
|
||||
</div>
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg">
|
||||
<div class="text-xl font-bold text-amber-400">{{ componentCounts.degraded }}</div>
|
||||
<div class="text-xs text-slate-500">Degraded</div>
|
||||
</div>
|
||||
<div class="p-3 bg-slate-800/50 rounded-lg">
|
||||
<div class="text-xl font-bold text-red-400">{{ componentCounts.unhealthy }}</div>
|
||||
<div class="text-xs text-slate-500">Unhealthy</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Components List -->
|
||||
<div class="space-y-3">
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Components</div>
|
||||
|
||||
<div class="space-y-2">
|
||||
<div
|
||||
v-for="component in health.components"
|
||||
:key="component.name"
|
||||
class="flex items-center gap-3 p-3 bg-slate-800/30 rounded-lg"
|
||||
>
|
||||
<!-- Status Icon -->
|
||||
<i
|
||||
class="fas w-5 text-center"
|
||||
:class="[getStatusIcon(component.status), getStatusColor(component.status)]"
|
||||
/>
|
||||
|
||||
<!-- Name & Message -->
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="text-sm font-medium text-slate-200">{{ component.name }}</div>
|
||||
<div v-if="component.message" class="text-xs text-slate-500 truncate" :title="component.message">
|
||||
{{ component.message }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Latency -->
|
||||
<div v-if="component.latency_ms !== undefined" class="text-right">
|
||||
<span class="font-mono text-sm" :class="getLatencyColor(component.latency_ms)">
|
||||
{{ component.latency_ms.toFixed(1) }}ms
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Status Badge -->
|
||||
<span
|
||||
class="text-xs font-medium capitalize px-2 py-0.5 rounded"
|
||||
:class="[getStatusColor(component.status), getStatusBgColor(component.status)]"
|
||||
>
|
||||
{{ component.status }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Warnings -->
|
||||
<div v-if="health.warnings && health.warnings.length > 0" class="space-y-3">
|
||||
<div class="text-xs text-slate-500 uppercase tracking-wide">Warnings</div>
|
||||
|
||||
<div class="space-y-2">
|
||||
<div
|
||||
v-for="(warning, index) in health.warnings"
|
||||
:key="index"
|
||||
class="flex items-start gap-2 p-3 bg-amber-500/10 border border-amber-500/30 rounded-lg text-sm"
|
||||
>
|
||||
<i class="fas fa-exclamation-triangle text-amber-400 mt-0.5" />
|
||||
<span class="text-amber-200">{{ warning }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Refresh Button -->
|
||||
<button
|
||||
@click="loadData"
|
||||
:disabled="loading"
|
||||
class="w-full py-2 text-sm text-emerald-400 hover:text-emerald-300 bg-emerald-500/10 hover:bg-emerald-500/20 rounded-lg transition-colors flex items-center justify-center gap-2"
|
||||
>
|
||||
<i class="fas fa-sync-alt" :class="{ 'fa-spin': loading }" />
|
||||
Refresh Health Status
|
||||
</button>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
@@ -0,0 +1,248 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, watch } from 'vue'
|
||||
import { fetchTopObservations, fetchMostRetrievedObservations } from '@/utils/api'
|
||||
import type { Observation } from '@/types'
|
||||
import Card from './Card.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
show: boolean
|
||||
currentProject: string | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
navigateToObservation: [id: number]
|
||||
}>()
|
||||
|
||||
const loading = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const topObservations = ref<Observation[]>([])
|
||||
const mostRetrieved = ref<Observation[]>([])
|
||||
const activeTab = ref<'top' | 'retrieved'>('top')
|
||||
|
||||
const loadData = async () => {
|
||||
if (!props.show) return
|
||||
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
const project = props.currentProject || undefined
|
||||
const [topData, retrievedData] = await Promise.all([
|
||||
fetchTopObservations(project, 15),
|
||||
fetchMostRetrievedObservations(project, 15)
|
||||
])
|
||||
topObservations.value = topData
|
||||
mostRetrieved.value = retrievedData
|
||||
} catch (err) {
|
||||
error.value = err instanceof Error ? err.message : 'Failed to load observations'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// Load on mount and when show changes
|
||||
onMounted(() => {
|
||||
if (props.show) loadData()
|
||||
})
|
||||
|
||||
watch(() => props.show, (newVal) => {
|
||||
if (newVal) loadData()
|
||||
})
|
||||
|
||||
// Also reload when project changes
|
||||
watch(() => props.currentProject, () => {
|
||||
if (props.show) loadData()
|
||||
})
|
||||
|
||||
// Type config for styling
|
||||
const typeConfig: Record<string, { icon: string; colorClass: string; bgClass: string }> = {
|
||||
discovery: { icon: 'fa-lightbulb', colorClass: 'text-amber-400', bgClass: 'bg-amber-500/20' },
|
||||
bugfix: { icon: 'fa-bug', colorClass: 'text-red-400', bgClass: 'bg-red-500/20' },
|
||||
change: { icon: 'fa-code-branch', colorClass: 'text-blue-400', bgClass: 'bg-blue-500/20' },
|
||||
refactor: { icon: 'fa-wrench', colorClass: 'text-purple-400', bgClass: 'bg-purple-500/20' },
|
||||
feature: { icon: 'fa-star', colorClass: 'text-green-400', bgClass: 'bg-green-500/20' },
|
||||
pattern: { icon: 'fa-puzzle-piece', colorClass: 'text-cyan-400', bgClass: 'bg-cyan-500/20' },
|
||||
architecture: { icon: 'fa-sitemap', colorClass: 'text-indigo-400', bgClass: 'bg-indigo-500/20' },
|
||||
preference: { icon: 'fa-heart', colorClass: 'text-pink-400', bgClass: 'bg-pink-500/20' }
|
||||
}
|
||||
|
||||
const getTypeConfig = (type: string) => {
|
||||
return typeConfig[type] || { icon: 'fa-circle', colorClass: 'text-slate-400', bgClass: 'bg-slate-500/20' }
|
||||
}
|
||||
|
||||
// Format score for display
|
||||
const formatScore = (score: number) => {
|
||||
return score.toFixed(2)
|
||||
}
|
||||
|
||||
// Current observations based on active tab
|
||||
const currentObservations = () => {
|
||||
return activeTab.value === 'top' ? topObservations.value : mostRetrieved.value
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<!-- Modal Backdrop -->
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="show"
|
||||
class="fixed inset-0 z-50 flex items-center justify-center"
|
||||
>
|
||||
<!-- Backdrop -->
|
||||
<div
|
||||
class="absolute inset-0 bg-black/60 backdrop-blur-sm"
|
||||
@click="emit('close')"
|
||||
/>
|
||||
|
||||
<!-- Modal Content -->
|
||||
<div class="relative w-full max-w-2xl mx-4 max-h-[90vh] overflow-y-auto">
|
||||
<Card
|
||||
gradient="bg-gradient-to-br from-amber-500/10 to-orange-500/5"
|
||||
border-class="border-amber-500/30"
|
||||
>
|
||||
<!-- Header -->
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<div class="flex items-center gap-2">
|
||||
<i class="fas fa-trophy text-amber-400" />
|
||||
<h3 class="text-lg font-semibold text-amber-100">Top Observations</h3>
|
||||
</div>
|
||||
<button
|
||||
@click="emit('close')"
|
||||
class="p-1.5 text-slate-400 hover:text-slate-200 hover:bg-slate-700/50 rounded-lg transition-colors"
|
||||
>
|
||||
<i class="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Tabs -->
|
||||
<div class="flex gap-2 mb-4">
|
||||
<button
|
||||
@click="activeTab = 'top'"
|
||||
:class="[
|
||||
'flex-1 py-2 px-4 text-sm font-medium rounded-lg transition-colors',
|
||||
activeTab === 'top'
|
||||
? 'bg-amber-500/20 text-amber-300 border border-amber-500/30'
|
||||
: 'text-slate-400 hover:text-slate-300 hover:bg-slate-700/50'
|
||||
]"
|
||||
>
|
||||
<i class="fas fa-star mr-2" />
|
||||
Highest Scored
|
||||
</button>
|
||||
<button
|
||||
@click="activeTab = 'retrieved'"
|
||||
:class="[
|
||||
'flex-1 py-2 px-4 text-sm font-medium rounded-lg transition-colors',
|
||||
activeTab === 'retrieved'
|
||||
? 'bg-cyan-500/20 text-cyan-300 border border-cyan-500/30'
|
||||
: 'text-slate-400 hover:text-slate-300 hover:bg-slate-700/50'
|
||||
]"
|
||||
>
|
||||
<i class="fas fa-search mr-2" />
|
||||
Most Retrieved
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Project Filter Indicator -->
|
||||
<div v-if="currentProject" class="flex items-center gap-2 mb-4 text-xs text-slate-500">
|
||||
<i class="fas fa-filter" />
|
||||
<span>Filtered by:</span>
|
||||
<span class="text-amber-600/80 font-mono">{{ currentProject.split('/').pop() }}</span>
|
||||
</div>
|
||||
|
||||
<!-- Loading State -->
|
||||
<div v-if="loading" class="flex items-center justify-center py-8">
|
||||
<i class="fas fa-circle-notch fa-spin text-2xl text-amber-400" />
|
||||
</div>
|
||||
|
||||
<!-- Error State -->
|
||||
<div v-else-if="error" class="text-center py-8">
|
||||
<i class="fas fa-exclamation-triangle text-2xl text-red-400 mb-2" />
|
||||
<p class="text-red-300">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Empty State -->
|
||||
<div v-else-if="currentObservations().length === 0" class="text-center py-8">
|
||||
<i class="fas fa-inbox text-2xl text-slate-500 mb-2" />
|
||||
<p class="text-slate-400">No observations found</p>
|
||||
</div>
|
||||
|
||||
<!-- Content -->
|
||||
<div v-else class="space-y-2">
|
||||
<div
|
||||
v-for="(obs, index) in currentObservations()"
|
||||
:key="obs.id"
|
||||
@click="emit('navigateToObservation', obs.id); emit('close')"
|
||||
class="flex items-center gap-3 p-3 bg-slate-800/30 hover:bg-slate-800/50 rounded-lg cursor-pointer transition-colors group"
|
||||
>
|
||||
<!-- Rank -->
|
||||
<div
|
||||
class="w-7 h-7 rounded-full flex items-center justify-center text-xs font-bold"
|
||||
:class="[
|
||||
index < 3 ? 'bg-amber-500/30 text-amber-300' : 'bg-slate-700/50 text-slate-400'
|
||||
]"
|
||||
>
|
||||
{{ index + 1 }}
|
||||
</div>
|
||||
|
||||
<!-- Type Icon -->
|
||||
<div
|
||||
class="w-8 h-8 rounded-lg flex items-center justify-center"
|
||||
:class="getTypeConfig(obs.type).bgClass"
|
||||
>
|
||||
<i
|
||||
class="fas text-sm"
|
||||
:class="[getTypeConfig(obs.type).icon, getTypeConfig(obs.type).colorClass]"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Title & Meta -->
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="text-sm font-medium text-slate-200 truncate group-hover:text-amber-200 transition-colors">
|
||||
{{ obs.title || 'Untitled' }}
|
||||
</div>
|
||||
<div class="flex items-center gap-2 text-xs text-slate-500">
|
||||
<span class="capitalize">{{ obs.type }}</span>
|
||||
<span v-if="obs.project" class="text-amber-600/70 font-mono">{{ obs.project.split('/').pop() }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Score / Retrieval Count -->
|
||||
<div class="text-right flex-shrink-0">
|
||||
<div
|
||||
v-if="activeTab === 'top'"
|
||||
class="text-sm font-mono font-bold"
|
||||
:class="obs.importance_score && obs.importance_score >= 1.5 ? 'text-green-400' : obs.importance_score && obs.importance_score >= 1 ? 'text-amber-400' : 'text-slate-400'"
|
||||
>
|
||||
{{ formatScore(obs.importance_score || 1) }}
|
||||
</div>
|
||||
<div
|
||||
v-else
|
||||
class="text-sm font-mono font-bold text-cyan-400"
|
||||
>
|
||||
{{ obs.retrieval_count || 0 }}×
|
||||
</div>
|
||||
<div class="text-xs text-slate-500">
|
||||
{{ activeTab === 'top' ? 'score' : 'retrieved' }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Arrow -->
|
||||
<i class="fas fa-chevron-right text-slate-600 group-hover:text-slate-400 transition-colors" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Refresh Button -->
|
||||
<button
|
||||
@click="loadData"
|
||||
:disabled="loading"
|
||||
class="w-full mt-4 py-2 text-sm text-amber-400 hover:text-amber-300 bg-amber-500/10 hover:bg-amber-500/20 rounded-lg transition-colors flex items-center justify-center gap-2"
|
||||
>
|
||||
<i class="fas fa-sync-alt" :class="{ 'fa-spin': loading }" />
|
||||
Refresh
|
||||
</button>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
@@ -164,3 +164,119 @@ export async function fetchRelatedObservations(observationId: number, minConfide
|
||||
export async function fetchRelationStats(signal?: AbortSignal): Promise<RelationStats> {
|
||||
return fetchWithRetry<RelationStats>(`${API_BASE}/relations/stats`, { signal })
|
||||
}
|
||||
|
||||
// Scoring API functions
|
||||
export interface ScoreBreakdown {
|
||||
observation: {
|
||||
id: number
|
||||
title: string
|
||||
type: string
|
||||
project: string
|
||||
created_at: number
|
||||
}
|
||||
scoring: {
|
||||
final_score: number
|
||||
type_weight: number
|
||||
recency_decay: number
|
||||
core_score: number
|
||||
feedback_contrib: number
|
||||
concept_contrib: number
|
||||
retrieval_contrib: number
|
||||
age_days: number
|
||||
}
|
||||
explanation: {
|
||||
type_impact: string
|
||||
recency_impact: string
|
||||
feedback_impact: string
|
||||
concept_impact: string
|
||||
retrieval_impact: string
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchObservationScore(observationId: number, signal?: AbortSignal): Promise<ScoreBreakdown> {
|
||||
return fetchWithRetry<ScoreBreakdown>(`${API_BASE}/observations/${observationId}/score`, { signal })
|
||||
}
|
||||
|
||||
export interface FeedbackStats {
|
||||
total: number
|
||||
positive: number
|
||||
negative: number
|
||||
neutral: number
|
||||
avg_score: number
|
||||
avg_retrieval: number
|
||||
}
|
||||
|
||||
export interface TopObservation {
|
||||
id: number
|
||||
title: string
|
||||
type: string
|
||||
importance_score: number
|
||||
retrieval_count?: number
|
||||
}
|
||||
|
||||
export async function fetchScoringStats(project?: string, signal?: AbortSignal): Promise<FeedbackStats> {
|
||||
const params = new URLSearchParams()
|
||||
if (project) params.append('project', project)
|
||||
const query = params.toString()
|
||||
return fetchWithRetry<FeedbackStats>(`${API_BASE}/scoring/stats${query ? '?' + query : ''}`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchTopObservations(project?: string, limit: number = 10, signal?: AbortSignal): Promise<Observation[]> {
|
||||
const params = new URLSearchParams({ limit: String(limit) })
|
||||
if (project) params.append('project', project)
|
||||
return fetchWithRetry<Observation[]>(`${API_BASE}/observations/top?${params}`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchMostRetrievedObservations(project?: string, limit: number = 10, signal?: AbortSignal): Promise<Observation[]> {
|
||||
const params = new URLSearchParams({ limit: String(limit) })
|
||||
if (project) params.append('project', project)
|
||||
return fetchWithRetry<Observation[]>(`${API_BASE}/observations/most-retrieved?${params}`, { signal })
|
||||
}
|
||||
|
||||
// Search Analytics API functions
|
||||
export interface RecentQuery {
|
||||
query: string
|
||||
project?: string
|
||||
type?: string
|
||||
count: number
|
||||
last_used: string
|
||||
}
|
||||
|
||||
export interface SearchAnalytics {
|
||||
total_searches: number
|
||||
vector_searches: number
|
||||
filter_searches: number
|
||||
cache_hits: number
|
||||
coalesced_requests: number
|
||||
search_errors: number
|
||||
avg_latency_ms: number
|
||||
avg_vector_latency_ms: number
|
||||
avg_filter_latency_ms: number
|
||||
}
|
||||
|
||||
export async function fetchSearchAnalytics(signal?: AbortSignal): Promise<SearchAnalytics> {
|
||||
return fetchWithRetry<SearchAnalytics>(`${API_BASE}/search/analytics`, { signal })
|
||||
}
|
||||
|
||||
export async function fetchRecentSearches(limit: number = 20, signal?: AbortSignal): Promise<RecentQuery[]> {
|
||||
return fetchWithRetry<RecentQuery[]>(`${API_BASE}/search/recent?limit=${limit}`, { signal })
|
||||
}
|
||||
|
||||
// System health API
|
||||
export interface ComponentHealth {
|
||||
name: string
|
||||
status: 'healthy' | 'degraded' | 'unhealthy'
|
||||
message?: string
|
||||
latency_ms?: number
|
||||
}
|
||||
|
||||
export interface SystemHealth {
|
||||
status: 'healthy' | 'degraded' | 'unhealthy'
|
||||
version: string
|
||||
components: ComponentHealth[]
|
||||
warnings?: string[]
|
||||
}
|
||||
|
||||
export async function fetchSystemHealth(signal?: AbortSignal): Promise<SystemHealth> {
|
||||
return fetchWithRetry<SystemHealth>(`${API_BASE}/selfcheck`, { signal })
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/timeline.vue"],"version":"5.7.3"}
|
||||
{"root":["./src/main.ts","./src/vite-env.d.ts","./src/components/index.ts","./src/composables/index.ts","./src/composables/usehealth.ts","./src/composables/usesse.ts","./src/composables/usestats.ts","./src/composables/usetimeline.ts","./src/composables/usetypes.ts","./src/composables/useupdate.ts","./src/types/api.ts","./src/types/index.ts","./src/types/observation.ts","./src/types/prompt.ts","./src/types/relation.ts","./src/types/summary.ts","./src/utils/api.ts","./src/utils/formatters.ts","./src/app.vue","./src/components/badge.vue","./src/components/card.vue","./src/components/filtertabs.vue","./src/components/header.vue","./src/components/iconbox.vue","./src/components/observationcard.vue","./src/components/projectfilter.vue","./src/components/promptcard.vue","./src/components/relationgraph.vue","./src/components/scorebreakdown.vue","./src/components/searchanalytics.vue","./src/components/sidebar.vue","./src/components/statscards.vue","./src/components/summarycard.vue","./src/components/systemhealthdetails.vue","./src/components/timeline.vue","./src/components/topobservations.vue"],"version":"5.7.3"}
|
||||
Reference in New Issue
Block a user