Make things 'betterer' across the board (#23)

* Make things 'betterer' across the board

* fix: reorganize struct fields and config parameters for consistency

- [x] Reorder Config struct fields alphabetically and by related functionality
- [x] Reorganize Observation model fields with archival fields grouped together
- [x] Reorder ObservationStore fields to group related members
- [x] Reorder Store struct fields with health check caching grouped
- [x] Reorganize HealthInfo and PoolMetrics struct field order
- [x] Reorder maintenance Service struct fields logically
- [x] Reorganize MCP server handler parameter structs alphabetically
- [x] Reorder pattern detector candidate tracking fields
- [x] Reorganize search Manager struct fields by functionality
- [x] Reorder vector Client struct fields with mutex protections grouped
- [x] Reorganize handler request/response struct fields
- [x] Update handlers_test.go to expect wrapped response format
- [x] Reorder middleware TokenAuth and rate limiter fields
- [x] Reorganize Service struct fields with grouped functionality
- [x] Fix RateLimiter field ordering for clarity
- [x] Reorder CircuitBreaker metrics fields

* fix(security): improve JSON output safety and path traversal protection

- [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler
- [x] Remove escapeJSONString helper function in favor of standard JSON marshaling
- [x] Add safeResolvePath function to validate paths and prevent directory traversal
- [x] Apply path traversal validation in captureFileMtimes operations
- [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation

* fix(sdk): improve path traversal protection and allocation safety

- [x] Enhance safeResolvePath with stricter validation using filepath.Rel
- [x] Reject paths containing ".." after cleaning to prevent traversal
- [x] Validate absolute paths are within cwd when cwd is specified
- [x] Apply safeResolvePath validation to GetFileContent for consistency
- [x] Add comprehensive test coverage for path traversal protection
- [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
2026-01-11 01:51:20 +00:00
committed by GitHub
parent 3107eddeb2
commit d04b60517a
46 changed files with 12710 additions and 2068 deletions
+18 -9
View File
@@ -46,28 +46,32 @@ type Config struct {
VectorStorageStrategy string `json:"vector_storage_strategy"`
ContextObsConcepts []string `json:"context_obs_concepts"`
ContextObsTypes []string `json:"context_obs_types"`
ContextMaxPromptResults int `json:"context_max_prompt_results"`
RerankingResults int `json:"reranking_results"`
ContextFullCount int `json:"context_full_count"`
GraphBranchFactor int `json:"graph_branch_factor"`
GraphEdgeWeight float64 `json:"graph_edge_weight"`
ContextRelevanceThreshold float64 `json:"context_relevance_threshold"`
RerankingCandidates int `json:"reranking_candidates"`
WorkerPort int `json:"worker_port"`
RerankingMinImprovement float64 `json:"reranking_min_improvement"`
ContextObservations int `json:"context_observations"`
ContextFullCount int `json:"context_full_count"`
ContextMaxPromptResults int `json:"context_max_prompt_results"`
ContextSessionCount int `json:"context_session_count"`
MaxConns int `json:"max_conns"`
RerankingAlpha float64 `json:"reranking_alpha"`
GraphMaxHops int `json:"graph_max_hops"`
GraphBranchFactor int `json:"graph_branch_factor"`
RerankingResults int `json:"reranking_results"`
GraphRebuildIntervalMin int `json:"graph_rebuild_interval_min"`
HubThreshold int `json:"hub_threshold"`
ContextShowLastSummary bool `json:"context_show_last_summary"`
RerankingEnabled bool `json:"reranking_enabled"`
ObservationRetentionDays int `json:"observation_retention_days"`
MaintenanceIntervalHours int `json:"maintenance_interval_hours"`
ContextShowWorkTokens bool `json:"context_show_work_tokens"`
ContextShowReadTokens bool `json:"context_show_read_tokens"`
RerankingPureMode bool `json:"reranking_pure_mode"`
GraphEnabled bool `json:"graph_enabled"`
MaintenanceEnabled bool `json:"maintenance_enabled"`
RerankingEnabled bool `json:"reranking_enabled"`
ContextShowLastSummary bool `json:"context_show_last_summary"`
CleanupStaleObservations bool `json:"cleanup_stale_observations"`
}
var (
@@ -93,8 +97,9 @@ func SettingsPath() string {
}
// EnsureDataDir creates the data directory if it doesn't exist.
// Uses 0700 permissions (owner-only) for security.
func EnsureDataDir() error {
return os.MkdirAll(DataDir(), 0750)
return os.MkdirAll(DataDir(), 0700)
}
// EnsureSettings creates a default settings file if it doesn't exist.
@@ -161,8 +166,12 @@ func Default() *Config {
ContextShowLastSummary: true,
ContextObsTypes: DefaultObservationTypes,
ContextObsConcepts: DefaultObservationConcepts,
ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include
ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only)
ContextRelevanceThreshold: 0.3, // Minimum 30% similarity to include
ContextMaxPromptResults: 10, // Cap at 10 results max (0 = no cap, threshold only)
MaintenanceEnabled: true, // Enable scheduled maintenance
MaintenanceIntervalHours: 6, // Run every 6 hours
ObservationRetentionDays: 0, // 0 = no age-based deletion (keep all)
CleanupStaleObservations: false, // Don't auto-cleanup stale observations
}
}
+54 -17
View File
@@ -9,27 +9,13 @@ import (
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// EnsureSessionExists creates a session if it doesn't exist.
// Uses INSERT OR IGNORE pattern for atomic idempotent creation (single query instead of COUNT + INSERT).
// This is shared between stores to avoid duplication.
func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project string) error {
// Check if session exists
var count int64
err := db.WithContext(ctx).
Model(&SDKSession{}).
Where("sdk_session_id = ?", sdkSessionID).
Count(&count).Error
if err != nil {
return err
}
if count > 0 {
return nil // Session exists
}
// Auto-create session
now := time.Now()
session := &SDKSession{
ClaudeSessionID: sdkSessionID,
@@ -41,7 +27,14 @@ func EnsureSessionExists(ctx context.Context, db *gorm.DB, sdkSessionID, project
PromptCounter: 0,
}
return db.WithContext(ctx).Create(session).Error
// Use INSERT OR IGNORE - single query, atomic operation
// If session already exists (conflict on sdk_session_id), do nothing
return db.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "sdk_session_id"}},
DoNothing: true,
}).
Create(session).Error
}
// sqlNullString creates a sql.NullString from a string.
@@ -52,8 +45,13 @@ func sqlNullString(s string) sql.NullString {
return sql.NullString{String: s, Valid: true}
}
// MaxPaginationLimit is the maximum allowed limit for pagination queries.
// This protects against resource exhaustion from excessively large requests.
const MaxPaginationLimit = 1000
// ParseLimitParam parses the "limit" query parameter from an HTTP request.
// Returns defaultLimit if the parameter is missing or invalid.
// Note: This does NOT enforce a maximum limit. Use ParseLimitParamWithMax for that.
func ParseLimitParam(r *http.Request, defaultLimit int) int {
if l := r.URL.Query().Get("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
@@ -62,3 +60,42 @@ func ParseLimitParam(r *http.Request, defaultLimit int) int {
}
return defaultLimit
}
// ParseLimitParamWithMax parses the "limit" query parameter with a maximum cap.
// Returns min(parsed, maxLimit) or defaultLimit if missing/invalid.
// If maxLimit is 0, uses MaxPaginationLimit (1000).
func ParseLimitParamWithMax(r *http.Request, defaultLimit, maxLimit int) int {
if maxLimit <= 0 {
maxLimit = MaxPaginationLimit
}
limit := ParseLimitParam(r, defaultLimit)
if limit > maxLimit {
return maxLimit
}
return limit
}
// ParseOffsetParam parses the "offset" query parameter from an HTTP request.
// Returns 0 if the parameter is missing or invalid.
func ParseOffsetParam(r *http.Request) int {
if o := r.URL.Query().Get("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
return parsed
}
}
return 0
}
// PaginationParams holds pagination parameters.
type PaginationParams struct {
Limit int
Offset int
}
// ParsePaginationParams parses both limit and offset from an HTTP request.
func ParsePaginationParams(r *http.Request, defaultLimit int) PaginationParams {
return PaginationParams{
Limit: ParseLimitParam(r, defaultLimit),
Offset: ParseOffsetParam(r),
}
}
+238
View File
@@ -322,6 +322,244 @@ func runMigrations(db *gorm.DB, sqlDB *sql.DB) error {
return tx.Migrator().DropTable("observation_relations")
},
},
// Migration 012: Query optimization indexes
// Adds covering and composite indexes for common query patterns
{
ID: "012_query_optimization_indexes",
Migrate: func(tx *gorm.DB) error {
sqls := []string{
// Composite index for observation queries by project + scope + importance
// Covers the common pattern: WHERE project = ? OR scope = 'global' ORDER BY importance_score DESC
`CREATE INDEX IF NOT EXISTS idx_observations_project_scope_importance
ON observations(project, scope, importance_score DESC, created_at_epoch DESC)`,
// Covering index for observation retrieval (includes most used columns)
// Allows index-only scans for listing queries
`CREATE INDEX IF NOT EXISTS idx_observations_project_covering
ON observations(project, scope, is_superseded, importance_score DESC)
WHERE is_superseded = 0 OR is_superseded IS NULL`,
// Index for session summary lookups
`CREATE INDEX IF NOT EXISTS idx_summaries_project_importance
ON session_summaries(project, importance_score DESC, created_at_epoch DESC)`,
// Index for prompt retrieval by session
`CREATE INDEX IF NOT EXISTS idx_prompts_session_number
ON user_prompts(claude_session_id, prompt_number)`,
// Index for pattern queries by frequency
`CREATE INDEX IF NOT EXISTS idx_patterns_frequency
ON patterns(frequency DESC, last_seen_at_epoch DESC)
WHERE is_deprecated = 0`,
}
for _, s := range sqls {
if err := tx.Exec(s).Error; err != nil {
// Non-fatal: index may already exist or fail for benign reasons
continue
}
}
return nil
},
Rollback: func(tx *gorm.DB) error {
sqls := []string{
"DROP INDEX IF EXISTS idx_observations_project_scope_importance",
"DROP INDEX IF EXISTS idx_observations_project_covering",
"DROP INDEX IF EXISTS idx_summaries_project_importance",
"DROP INDEX IF EXISTS idx_prompts_session_number",
"DROP INDEX IF EXISTS idx_patterns_frequency",
}
for _, s := range sqls {
if err := tx.Exec(s).Error; err != nil {
continue
}
}
return nil
},
},
// Migration 013: Add archival columns to observations
{
ID: "013_observation_archival",
Migrate: func(tx *gorm.DB) error {
sqls := []string{
// Add archival columns
`ALTER TABLE observations ADD COLUMN is_archived INTEGER DEFAULT 0`,
`ALTER TABLE observations ADD COLUMN archived_at_epoch INTEGER`,
`ALTER TABLE observations ADD COLUMN archived_reason TEXT`,
// Index for archived observations
`CREATE INDEX IF NOT EXISTS idx_observations_archived ON observations(is_archived)`,
// Composite index for filtering active (non-archived) observations
`CREATE INDEX IF NOT EXISTS idx_observations_active
ON observations(project, is_archived, is_superseded, importance_score DESC)
WHERE (is_archived = 0 OR is_archived IS NULL)`,
}
for _, s := range sqls {
if err := tx.Exec(s).Error; err != nil {
// Non-fatal: column may already exist
continue
}
}
return nil
},
Rollback: func(tx *gorm.DB) error {
// SQLite doesn't support DROP COLUMN in older versions
// but for newer versions we can try
sqls := []string{
"DROP INDEX IF EXISTS idx_observations_active",
"DROP INDEX IF EXISTS idx_observations_archived",
}
for _, s := range sqls {
_ = tx.Exec(s).Error
}
return nil
},
},
// Migration 014: Add performance-critical indexes for common query patterns
{
ID: "014_performance_indexes",
Migrate: func(tx *gorm.DB) error {
sqls := []string{
// Index for batch ID lookups (IN queries)
`CREATE INDEX IF NOT EXISTS idx_observations_id_covering
ON observations(id, project, scope, importance_score)`,
// Index for vector search result fetching
`CREATE INDEX IF NOT EXISTS idx_vectors_doc_type_project
ON vectors(doc_type, project, scope)`,
// Index for session summaries by project
`CREATE INDEX IF NOT EXISTS idx_summaries_project_created
ON session_summaries(project, created_at_epoch DESC)`,
// Index for user prompts by session
`CREATE INDEX IF NOT EXISTS idx_prompts_session_created
ON user_prompts(claude_session_id, created_at_epoch DESC)`,
// Index for patterns by type and project
`CREATE INDEX IF NOT EXISTS idx_patterns_type_project
ON patterns(type, project, frequency DESC)
WHERE is_deprecated = 0`,
// Index for observation relations
`CREATE INDEX IF NOT EXISTS idx_relations_source_type
ON observation_relations(source_observation_id, relation_type)`,
`CREATE INDEX IF NOT EXISTS idx_relations_target_type
ON observation_relations(target_observation_id, relation_type)`,
}
for _, s := range sqls {
if err := tx.Exec(s).Error; err != nil {
// Non-fatal: index may already exist
continue
}
}
return nil
},
Rollback: func(tx *gorm.DB) error {
sqls := []string{
"DROP INDEX IF EXISTS idx_observations_id_covering",
"DROP INDEX IF EXISTS idx_vectors_doc_type_project",
"DROP INDEX IF EXISTS idx_summaries_project_created",
"DROP INDEX IF EXISTS idx_prompts_session_created",
"DROP INDEX IF EXISTS idx_patterns_type_project",
"DROP INDEX IF EXISTS idx_relations_source_type",
"DROP INDEX IF EXISTS idx_relations_target_type",
}
for _, s := range sqls {
_ = tx.Exec(s).Error
}
return nil
},
},
// Migration 015: Add optimized composite indexes for common query patterns
{
ID: "015_optimized_composite_indexes",
Migrate: func(tx *gorm.DB) error {
sqls := []string{
// Composite index for GetRecentObservations with project+scope filtering
// Covers: WHERE (project = ? OR scope = 'global') ORDER BY importance_score DESC, created_at_epoch DESC
`CREATE INDEX IF NOT EXISTS idx_observations_project_scope_created
ON observations(project, scope, created_at_epoch DESC, importance_score DESC)`,
// Index for scope='global' queries (common pattern in search)
`CREATE INDEX IF NOT EXISTS idx_observations_global_scope
ON observations(scope, importance_score DESC, created_at_epoch DESC)
WHERE scope = 'global'`,
// Index for vector search result deduplication by observation
`CREATE INDEX IF NOT EXISTS idx_vectors_observation_lookup
ON vectors(doc_type, sqlite_id, project)
WHERE doc_type = 'observation'`,
// Index for FTS search result ordering
`CREATE INDEX IF NOT EXISTS idx_observations_fts_ordering
ON observations(project, importance_score DESC)
WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`,
}
for _, s := range sqls {
if err := tx.Exec(s).Error; err != nil {
// Non-fatal: index may already exist
continue
}
}
return nil
},
Rollback: func(tx *gorm.DB) error {
sqls := []string{
"DROP INDEX IF EXISTS idx_observations_project_scope_created",
"DROP INDEX IF EXISTS idx_observations_global_scope",
"DROP INDEX IF EXISTS idx_vectors_observation_lookup",
"DROP INDEX IF EXISTS idx_observations_fts_ordering",
}
for _, s := range sqls {
_ = tx.Exec(s).Error
}
return nil
},
},
// Migration 016: Add covering indexes for relation joins and active observations
{
ID: "016_relation_and_active_indexes",
Migrate: func(tx *gorm.DB) error {
sqls := []string{
// Covering index for observation relation joins (common JOIN patterns)
// Speeds up queries like: JOIN observation_relations ON source_id = obs.id WHERE relation_type = ?
`CREATE INDEX IF NOT EXISTS idx_relations_source_type_target
ON observation_relations(source_observation_id, relation_type, target_observation_id)`,
// Covering index for reverse relation lookups
`CREATE INDEX IF NOT EXISTS idx_relations_target_type_source
ON observation_relations(target_observation_id, relation_type, source_observation_id)`,
// Partial index for active (non-archived, non-superseded) observations
// Optimizes activeObservationFilter queries
`CREATE INDEX IF NOT EXISTS idx_observations_active
ON observations(project, importance_score DESC, created_at_epoch DESC)
WHERE (is_archived = 0 OR is_archived IS NULL) AND (is_superseded = 0 OR is_superseded IS NULL)`,
}
for _, s := range sqls {
if err := tx.Exec(s).Error; err != nil {
// Non-fatal: index may already exist
continue
}
}
return nil
},
Rollback: func(tx *gorm.DB) error {
sqls := []string{
"DROP INDEX IF EXISTS idx_relations_source_type_target",
"DROP INDEX IF EXISTS idx_relations_target_type_source",
"DROP INDEX IF EXISTS idx_observations_active",
}
for _, s := range sqls {
_ = tx.Exec(s).Error
}
return nil
},
},
})
if err := m.Migrate(); err != nil {
+11 -7
View File
@@ -45,30 +45,34 @@ func (s *SDKSession) BeforeCreate(tx *gorm.DB) error {
}
// Observation represents a stored observation (learning).
// Field order optimized for memory alignment (fieldalignment).
type Observation struct {
FileMtimes models.JSONInt64Map `gorm:"type:text"`
SDKSessionID string `gorm:"index;not null"`
Project string `gorm:"index;not null"`
Project string `gorm:"index:idx_observations_project;index:idx_observations_project_created,priority:1;not null"`
Scope models.ObservationScope `gorm:"type:text;default:'project';check:scope IN ('project', 'global');index:idx_observations_scope;index:idx_observations_project_scope,priority:2"`
Type models.ObservationType `gorm:"type:text;check:type IN ('decision', 'bugfix', 'feature', 'refactor', 'discovery', 'change');index;not null"`
CreatedAt string `gorm:"not null"`
Title sql.NullString `gorm:"type:text"`
Facts models.JSONStringArray `gorm:"type:text"`
Narrative sql.NullString `gorm:"type:text"`
Concepts models.JSONStringArray `gorm:"type:text"`
FilesRead models.JSONStringArray `gorm:"type:text"`
FilesModified models.JSONStringArray `gorm:"type:text"`
Subtitle sql.NullString `gorm:"type:text"`
Facts models.JSONStringArray `gorm:"type:text"`
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
PromptNumber sql.NullInt64
Title sql.NullString `gorm:"type:text"`
ArchivedReason sql.NullString
ScoreUpdatedAt sql.NullInt64 `gorm:"column:score_updated_at_epoch;index:idx_observations_score_updated"`
PromptNumber sql.NullInt64
ArchivedAt sql.NullInt64 `gorm:"column:archived_at_epoch"`
LastRetrievedAt sql.NullInt64 `gorm:"column:last_retrieved_at_epoch"`
ID int64 `gorm:"primaryKey;autoIncrement"`
ImportanceScore float64 `gorm:"type:real;default:1.0;index:idx_observations_importance,priority:1,sort:desc"`
UserFeedback int `gorm:"default:0"`
RetrievalCount int `gorm:"default:0"`
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;not null"`
CreatedAtEpoch int64 `gorm:"index:idx_observations_created,sort:desc;index:idx_observations_project_created,priority:2,sort:desc;not null"`
DiscoveryTokens int64 `gorm:"default:0"`
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded,priority:1"`
IsSuperseded int `gorm:"default:0;index:idx_observations_superseded;index:idx_observations_active,priority:2"`
IsArchived int `gorm:"default:0;index:idx_observations_archived;index:idx_observations_active,priority:1"`
}
func (Observation) TableName() string { return "observations" }
+579 -39
View File
@@ -5,39 +5,126 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"gorm.io/gorm"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// MaxObservationsPerProject is the maximum number of observations to keep per project.
const MaxObservationsPerProject = 100
// cleanupQueueSize is the buffer size for the cleanup queue.
const cleanupQueueSize = 100
// commonWords is a package-level set for O(1) lookup of stop words.
// Created once at init time to avoid repeated map allocations.
var commonWords = map[string]struct{}{
"the": {}, "and": {}, "or": {}, "but": {}, "in": {},
"on": {}, "at": {}, "to": {}, "for": {}, "of": {},
"with": {}, "by": {}, "from": {}, "as": {}, "is": {},
"was": {}, "are": {}, "were": {}, "be": {}, "been": {},
"being": {}, "have": {}, "has": {}, "had": {}, "do": {},
"does": {}, "did": {}, "will": {}, "would": {}, "should": {},
"could": {}, "may": {}, "might": {}, "must": {}, "can": {},
}
// CleanupFunc is a callback for when observations are cleaned up.
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
// ObservationStore provides observation-related database operations using GORM.
type ObservationStore struct {
db *gorm.DB
rawDB *sql.DB
cleanupFunc CleanupFunc
conflictStore interface{} // Placeholder for ConflictStore (Phase 4)
relationStore interface{} // Placeholder for RelationStore (Phase 4)
conflictStore any
relationStore any
db *gorm.DB
rawDB *sql.DB
cleanupFunc CleanupFunc
cleanupQueue chan string
stopCleanup chan struct{}
cleanupWg sync.WaitGroup
cleanupOnce sync.Once
cleanupStarted atomic.Bool
}
// NewObservationStore creates a new observation store.
// The conflictStore and relationStore parameters are optional (can be nil) and will be used in Phase 4.
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore interface{}) *ObservationStore {
return &ObservationStore{
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore any) *ObservationStore {
s := &ObservationStore{
db: store.DB,
rawDB: store.GetRawDB(),
cleanupFunc: cleanupFunc,
conflictStore: conflictStore,
relationStore: relationStore,
cleanupQueue: make(chan string, cleanupQueueSize),
stopCleanup: make(chan struct{}),
}
// Start the cleanup worker
s.startCleanupWorker()
return s
}
// startCleanupWorker starts the background cleanup worker.
func (s *ObservationStore) startCleanupWorker() {
s.cleanupOnce.Do(func() {
s.cleanupStarted.Store(true)
s.cleanupWg.Add(1)
go s.cleanupWorker()
})
}
// cleanupWorker processes cleanup requests from the queue.
func (s *ObservationStore) cleanupWorker() {
defer s.cleanupWg.Done()
for {
select {
case <-s.stopCleanup:
// Drain remaining items in queue before exiting
for {
select {
case project := <-s.cleanupQueue:
s.processCleanup(project)
default:
return
}
}
case project := <-s.cleanupQueue:
s.processCleanup(project)
}
}
}
// processCleanup performs the actual cleanup for a project.
func (s *ObservationStore) processCleanup(project string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
deletedIDs, err := s.CleanupOldObservations(ctx, project)
if err != nil {
log.Warn().Err(err).Str("project", project).Msg("Failed to cleanup old observations")
return
}
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
s.cleanupFunc(ctx, deletedIDs)
log.Debug().Str("project", project).Int("count", len(deletedIDs)).Msg("Cleaned up old observations")
}
}
// Close stops the cleanup worker and waits for it to finish.
// Safe to call even if the worker was never started.
func (s *ObservationStore) Close() {
// Only stop if worker was actually started to avoid deadlock
if s.cleanupStarted.Load() {
close(s.stopCleanup)
s.cleanupWg.Wait()
}
}
@@ -86,16 +173,15 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
return 0, 0, err
}
// Cleanup old observations beyond the limit for this project (async to not block handler)
// Queue cleanup of old observations beyond the limit for this project (async to not block handler)
if project != "" {
go func(proj string) {
cleanupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
deletedIDs, _ := s.CleanupOldObservations(cleanupCtx, proj)
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
s.cleanupFunc(cleanupCtx, deletedIDs)
}
}(project)
select {
case s.cleanupQueue <- project:
// Successfully queued for cleanup
default:
// Queue is full, log a warning instead of silently dropping
log.Warn().Str("project", project).Msg("Cleanup queue full, skipping cleanup for this observation")
}
}
// Note: Conflict and relation detection intentionally omitted for now
@@ -104,6 +190,89 @@ func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, p
return dbObs.ID, nowEpoch, nil
}
// ObservationUpdate contains fields that can be updated on an observation.
// Only non-nil fields will be updated.
type ObservationUpdate struct {
Title *string // New title
Subtitle *string // New subtitle
Narrative *string // New narrative
Facts *[]string // New facts (replaces existing)
Concepts *[]string // New concepts (replaces existing)
FilesRead *[]string // New files read (replaces existing)
FilesModified *[]string // New files modified (replaces existing)
Scope *string // New scope (project or global)
}
// UpdateObservation updates an existing observation with the provided fields.
// Only non-nil fields in the update struct will be modified.
// Returns the updated observation or an error.
func (s *ObservationStore) UpdateObservation(ctx context.Context, id int64, update *ObservationUpdate) (*models.Observation, error) {
if update == nil {
return nil, fmt.Errorf("update cannot be nil")
}
// First, verify the observation exists
var dbObs Observation
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("observation not found: %d", id)
}
return nil, err
}
// Build update map with only provided fields
updates := make(map[string]any)
if update.Title != nil {
updates["title"] = sql.NullString{String: *update.Title, Valid: true}
}
if update.Subtitle != nil {
updates["subtitle"] = sql.NullString{String: *update.Subtitle, Valid: true}
}
if update.Narrative != nil {
updates["narrative"] = sql.NullString{String: *update.Narrative, Valid: true}
}
if update.Facts != nil {
factsJSON, _ := json.Marshal(*update.Facts)
updates["facts"] = string(factsJSON)
}
if update.Concepts != nil {
conceptsJSON, _ := json.Marshal(*update.Concepts)
updates["concepts"] = string(conceptsJSON)
}
if update.FilesRead != nil {
filesReadJSON, _ := json.Marshal(*update.FilesRead)
updates["files_read"] = string(filesReadJSON)
}
if update.FilesModified != nil {
filesModifiedJSON, _ := json.Marshal(*update.FilesModified)
updates["files_modified"] = string(filesModifiedJSON)
}
if update.Scope != nil {
updates["scope"] = sql.NullString{String: *update.Scope, Valid: true}
}
// Add updated_at timestamp
updates["updated_at_epoch"] = sql.NullInt64{Int64: time.Now().Unix(), Valid: true}
if len(updates) == 0 {
// Nothing to update, just return existing observation
return toModelObservation(&dbObs), nil
}
// Perform the update
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("id = ?", id).Updates(updates).Error; err != nil {
return nil, fmt.Errorf("failed to update observation: %w", err)
}
// Fetch the updated observation
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
return nil, err
}
return toModelObservation(&dbObs), nil
}
// GetObservationByID retrieves an observation by its ID.
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
var dbObs Observation
@@ -134,6 +303,8 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
query = query.Order("created_at_epoch DESC")
case "importance":
query = query.Order("importance_score DESC, created_at_epoch DESC")
case "score_desc":
query = query.Order("importance_score DESC, created_at_epoch DESC")
default:
// Default: importance first, then recency
query = query.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC")
@@ -152,6 +323,60 @@ func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64
return toModelObservations(dbObservations), nil
}
// GetObservationsByIDsPreserveOrder retrieves observations by IDs, preserving the input order.
// This is useful when the caller has already sorted/ranked the IDs (e.g., by vector similarity).
func (s *ObservationStore) GetObservationsByIDsPreserveOrder(ctx context.Context, ids []int64) ([]*models.Observation, error) {
if len(ids) == 0 {
return nil, nil
}
// Fetch all observations in a single query
var dbObservations []Observation
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
if err != nil {
return nil, err
}
// Build ID -> observation map for O(1) lookups
obsMap := make(map[int64]*Observation, len(dbObservations))
for i := range dbObservations {
obsMap[int64(dbObservations[i].ID)] = &dbObservations[i]
}
// Reconstruct in original order
result := make([]*models.Observation, 0, len(ids))
for _, id := range ids {
if obs, ok := obsMap[id]; ok {
result = append(result, toModelObservation(obs))
}
}
return result, nil
}
// BatchGetObservationsWithScores retrieves observations with associated scores.
// Returns a map of ID -> observation for efficient lookup.
func (s *ObservationStore) BatchGetObservationsWithScores(ctx context.Context, ids []int64) (map[int64]*models.Observation, error) {
if len(ids) == 0 {
return make(map[int64]*models.Observation), nil
}
// Fetch all observations in a single query
var dbObservations []Observation
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
if err != nil {
return nil, err
}
// Build result map
result := make(map[int64]*models.Observation, len(dbObservations))
for i := range dbObservations {
result[int64(dbObservations[i].ID)] = toModelObservation(&dbObservations[i])
}
return result, nil
}
// GetRecentObservations retrieves recent observations for a project.
// This includes project-scoped observations for the specified project AND global observations.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
@@ -169,13 +394,13 @@ func (s *ObservationStore) GetRecentObservations(ctx context.Context, project st
return toModelObservations(dbObservations), nil
}
// GetActiveObservations retrieves recent non-superseded observations for a project.
// This excludes observations that have been marked as superseded by newer ones.
// GetActiveObservations retrieves recent non-superseded, non-archived observations for a project.
// This excludes observations that have been marked as superseded or archived.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
Scopes(projectScopeFilter(project), notSupersededFilter(), importanceOrdering()).
Scopes(projectScopeFilter(project), activeObservationFilter(), importanceOrdering()).
Limit(limit).
Find(&dbObservations).Error
@@ -245,7 +470,57 @@ func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit i
return toModelObservations(dbObservations), nil
}
// GetAllRecentObservationsPaginated retrieves recent observations with pagination.
func (s *ObservationStore) GetAllRecentObservationsPaginated(ctx context.Context, limit, offset int) ([]*models.Observation, int64, error) {
var dbObservations []Observation
var total int64
// Get total count
if err := s.db.WithContext(ctx).Model(&Observation{}).Count(&total).Error; err != nil {
return nil, 0, err
}
// Get paginated results
err := s.db.WithContext(ctx).
Scopes(importanceOrdering()).
Limit(limit).
Offset(offset).
Find(&dbObservations).Error
if err != nil {
return nil, 0, err
}
return toModelObservations(dbObservations), total, nil
}
// GetObservationsByProjectStrictPaginated retrieves observations strictly from a project with pagination.
func (s *ObservationStore) GetObservationsByProjectStrictPaginated(ctx context.Context, project string, limit, offset int) ([]*models.Observation, int64, error) {
var dbObservations []Observation
var total int64
// Get total count for project
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("project = ?", project).Count(&total).Error; err != nil {
return nil, 0, err
}
// Get paginated results
err := s.db.WithContext(ctx).
Where("project = ?", project).
Scopes(importanceOrdering()).
Limit(limit).
Offset(offset).
Find(&dbObservations).Error
if err != nil {
return nil, 0, err
}
return toModelObservations(dbObservations), total, nil
}
// GetAllObservations retrieves all observations (for vector rebuild).
// Note: For large datasets, prefer GetAllObservationsIterator to avoid memory issues.
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
@@ -259,6 +534,51 @@ func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Ob
return toModelObservations(dbObservations), nil
}
// GetAllObservationsIterator returns observations in batches to avoid loading all into memory.
// The callback is called for each batch. Return false from callback to stop iteration.
// batchSize controls how many observations are loaded at once (default 500 if <= 0).
func (s *ObservationStore) GetAllObservationsIterator(ctx context.Context, batchSize int, callback func([]*models.Observation) bool) error {
if batchSize <= 0 {
batchSize = 500
}
var lastID int64 = 0
for {
// Check context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var dbObservations []Observation
err := s.db.WithContext(ctx).
Where("id > ?", lastID).
Order("id ASC").
Limit(batchSize).
Find(&dbObservations).Error
if err != nil {
return err
}
if len(dbObservations) == 0 {
break // No more observations
}
// Update cursor for next batch
lastID = dbObservations[len(dbObservations)-1].ID
// Convert and call callback
observations := toModelObservations(dbObservations)
if !callback(observations) {
break // Callback requested stop
}
}
return nil
}
// SearchObservationsFTS performs full-text search on observations using FTS5.
// Falls back to LIKE search if FTS5 fails.
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
@@ -314,14 +634,26 @@ func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, pro
}
// searchObservationsLike performs fallback LIKE search on observations using GORM.
// Limits to 2 keywords to prevent expensive OR queries that SQLite optimizes poorly.
// This is a fallback path when FTS returns no results, so we prioritize performance.
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
if len(keywords) == 0 {
return nil, nil
}
// Limit keywords to prevent excessive OR conditions that hurt query planning.
// SQLite performs significantly better with fewer LIKE conditions.
// Using 2 instead of 5 reduces query complexity from O(15) to O(6) conditions
// (each keyword creates 3 LIKE conditions for title, subtitle, narrative).
const maxKeywords = 2
if len(keywords) > maxKeywords {
keywords = keywords[:maxKeywords]
}
// Build LIKE conditions for each keyword
var conditions []string
var args []interface{}
// Pre-allocate for efficiency: maxKeywords conditions × 3 args each + 1 project arg
conditions := make([]string, 0, len(keywords))
args := make([]any, 0, len(keywords)*3+1)
for _, kw := range keywords {
pattern := "%" + kw + "%"
@@ -358,6 +690,217 @@ func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64)
return result.RowsAffected, result.Error
}
// DeleteObservation deletes a single observation by ID.
func (s *ObservationStore) DeleteObservation(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&Observation{}, id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// MarkAsSuperseded marks an observation as superseded (stale).
func (s *ObservationStore) MarkAsSuperseded(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id = ?", id).
Update("is_superseded", 1)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// MarkAsSupersededBatch marks multiple observations as superseded in a single query.
// Returns the number of observations updated and any error.
func (s *ObservationStore) MarkAsSupersededBatch(ctx context.Context, ids []int64) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id IN ?", ids).
Update("is_superseded", 1)
return result.RowsAffected, result.Error
}
// ArchiveObservation archives a single observation with an optional reason.
func (s *ObservationStore) ArchiveObservation(ctx context.Context, id int64, reason string) error {
updates := map[string]any{
"is_archived": 1,
"archived_at_epoch": time.Now().UnixMilli(),
}
if reason != "" {
updates["archived_reason"] = reason
}
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id = ?", id).
Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// UnarchiveObservation restores an archived observation.
func (s *ObservationStore) UnarchiveObservation(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id = ?", id).
Updates(map[string]any{
"is_archived": 0,
"archived_at_epoch": nil,
"archived_reason": nil,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// ArchiveOldObservations archives observations older than the specified age.
// Returns the count of archived observations and their IDs.
func (s *ObservationStore) ArchiveOldObservations(ctx context.Context, project string, maxAgeDays int, reason string) ([]int64, error) {
if maxAgeDays <= 0 {
maxAgeDays = 90 // Default: archive observations older than 90 days
}
cutoffEpoch := time.Now().AddDate(0, 0, -maxAgeDays).UnixMilli()
if reason == "" {
reason = fmt.Sprintf("auto-archived: older than %d days", maxAgeDays)
}
var idsToArchive []int64
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Find observations to archive (not already archived, older than cutoff)
query := tx.Model(&Observation{}).
Where("created_at_epoch < ?", cutoffEpoch).
Where("COALESCE(is_archived, 0) = 0")
if project != "" {
query = query.Where("project = ?", project)
}
if err := query.Pluck("id", &idsToArchive).Error; err != nil {
return err
}
if len(idsToArchive) == 0 {
return nil
}
// Archive the observations
now := time.Now().UnixMilli()
return tx.Model(&Observation{}).
Where("id IN ?", idsToArchive).
Updates(map[string]any{
"is_archived": 1,
"archived_at_epoch": now,
"archived_reason": reason,
}).Error
})
if err != nil {
return nil, err
}
return idsToArchive, nil
}
// GetArchivedObservations retrieves archived observations for a project.
func (s *ObservationStore) GetArchivedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
query := s.db.WithContext(ctx).
Where("COALESCE(is_archived, 0) = 1")
if project != "" {
query = query.Where("project = ?", project)
}
err := query.
Order("archived_at_epoch DESC").
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetArchivalStats returns statistics about archived observations.
// Optimized to use a single query instead of 4 separate queries.
func (s *ObservationStore) GetArchivalStats(ctx context.Context, project string) (*ArchivalStats, error) {
// Use a single query with conditional aggregation to get all stats at once.
// This is much faster than 4 separate queries (saves 3 round trips).
type statsResult struct {
OldestEpoch *int64
NewestEpoch *int64
TotalCount int64
ArchivedCount int64
}
var result statsResult
query := s.db.WithContext(ctx).Model(&Observation{}).
Select(`
COUNT(*) as total_count,
SUM(CASE WHEN COALESCE(is_archived, 0) = 1 THEN 1 ELSE 0 END) as archived_count,
MIN(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as oldest_epoch,
MAX(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as newest_epoch
`)
if project != "" {
query = query.Where("project = ?", project)
}
if err := query.Scan(&result).Error; err != nil {
return nil, err
}
stats := &ArchivalStats{
TotalCount: result.TotalCount,
ArchivedCount: result.ArchivedCount,
ActiveCount: result.TotalCount - result.ArchivedCount,
}
if result.OldestEpoch != nil {
stats.OldestArchivedEpoch = *result.OldestEpoch
}
if result.NewestEpoch != nil {
stats.NewestArchivedEpoch = *result.NewestEpoch
}
return stats, nil
}
// ArchivalStats contains statistics about archived observations.
type ArchivalStats struct {
TotalCount int64 `json:"total_count"`
ActiveCount int64 `json:"active_count"`
ArchivedCount int64 `json:"archived_count"`
OldestArchivedEpoch int64 `json:"oldest_archived_epoch,omitempty"`
NewestArchivedEpoch int64 `json:"newest_archived_epoch,omitempty"`
}
// CleanupOldObservations removes observations beyond the limit for a project.
// Returns the IDs of deleted observations.
func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) {
@@ -418,10 +961,12 @@ func projectScopeFilter(project string) func(*gorm.DB) *gorm.DB {
}
}
// notSupersededFilter filters out superseded observations.
func notSupersededFilter() func(*gorm.DB) *gorm.DB {
// activeObservationFilter filters for active (non-archived, non-superseded) observations.
// This is more efficient than chaining notSupersededFilter + notArchivedFilter
// as it produces a single WHERE clause for the query optimizer.
func activeObservationFilter() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("COALESCE(is_superseded, 0) = 0")
return db.Where("COALESCE(is_archived, 0) = 0 AND COALESCE(is_superseded, 0) = 0")
}
}
@@ -437,23 +982,17 @@ func importanceOrdering() func(*gorm.DB) *gorm.DB {
// ====================
// extractKeywords extracts keywords from a search query.
// Uses package-level commonWords map for O(1) stop word filtering.
func extractKeywords(query string) []string {
words := strings.Fields(strings.ToLower(query))
var keywords []string
commonWords := map[string]bool{
"the": true, "and": true, "or": true, "but": true, "in": true,
"on": true, "at": true, "to": true, "for": true, "of": true,
"with": true, "by": true, "from": true, "as": true, "is": true,
"was": true, "are": true, "were": true, "be": true, "been": true,
"being": true, "have": true, "has": true, "had": true, "do": true,
"does": true, "did": true, "will": true, "would": true, "should": true,
"could": true, "may": true, "might": true, "must": true, "can": true,
}
keywords := make([]string, 0, len(words)) // Pre-allocate for typical case
for _, word := range words {
// Skip short words and common words
if len(word) <= 3 || commonWords[word] {
// Skip short words and common stop words
if len(word) <= 3 {
continue
}
if _, isCommon := commonWords[word]; isCommon {
continue
}
keywords = append(keywords, word)
@@ -464,7 +1003,8 @@ func extractKeywords(query string) []string {
// scanObservationRows scans multiple observations from raw SQL rows.
func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
var observations []*models.Observation
// Pre-allocate with reasonable initial capacity to avoid repeated slice growth
observations := make([]*models.Observation, 0, 64)
for rows.Next() {
obs, err := scanObservation(rows)
if err != nil {
@@ -476,7 +1016,7 @@ func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
}
// scanObservation scans a single observation from a row scanner.
func scanObservation(scanner interface{ Scan(...interface{}) error }) (*models.Observation, error) {
func scanObservation(scanner interface{ Scan(...any) error }) (*models.Observation, error) {
var obs models.Observation
var factsJSON, conceptsJSON, filesReadJSON, filesModifiedJSON, fileMtimesJSON []byte
var isSuperseded int
+44 -1
View File
@@ -3,6 +3,8 @@ package gorm
import (
"context"
"fmt"
"strings"
"time"
"gorm.io/gorm"
@@ -59,12 +61,23 @@ func (s *ObservationStore) UpdateImportanceScore(ctx context.Context, id int64,
}
// UpdateImportanceScores bulk updates importance scores for multiple observations.
// This is more efficient than individual updates for batch recalculation.
// Uses a single SQL statement with CASE/WHEN for efficient batch updates.
func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores map[int64]float64) error {
if len(scores) == 0 {
return nil
}
// For small batches, use simple individual updates
if len(scores) <= 5 {
return s.updateScoresIndividually(ctx, scores)
}
// For larger batches, use CASE/WHEN SQL for single-query update
return s.updateScoresBatch(ctx, scores)
}
// updateScoresIndividually updates scores one at a time (efficient for small batches).
func (s *ObservationStore) updateScoresIndividually(ctx context.Context, scores map[int64]float64) error {
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
now := time.Now().UnixMilli()
@@ -85,6 +98,36 @@ func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores ma
})
}
// updateScoresBatch updates multiple scores in a single SQL statement using CASE/WHEN.
// This is much more efficient for large batches (O(1) queries instead of O(n)).
func (s *ObservationStore) updateScoresBatch(ctx context.Context, scores map[int64]float64) error {
now := time.Now().UnixMilli()
// Build CASE/WHEN clause for importance_score
// UPDATE observations SET
// importance_score = CASE id WHEN 1 THEN 0.5 WHEN 2 THEN 0.8 ... END,
// score_updated_at_epoch = ?
// WHERE id IN (1, 2, ...)
ids := make([]int64, 0, len(scores))
caseBuilder := strings.Builder{}
caseBuilder.WriteString("CASE id ")
for id, score := range scores {
ids = append(ids, id)
caseBuilder.WriteString(fmt.Sprintf("WHEN %d THEN %f ", id, score))
}
caseBuilder.WriteString("END")
// Use raw SQL for the batch update
sql := fmt.Sprintf(
"UPDATE observations SET importance_score = %s, score_updated_at_epoch = ? WHERE id IN ?",
caseBuilder.String(),
)
return s.db.WithContext(ctx).Exec(sql, now, ids).Error
}
// GetObservationsNeedingScoreUpdate returns observations that need their importance score recalculated.
// Returns observations where score_updated_at_epoch is NULL or older than the threshold.
func (s *ObservationStore) GetObservationsNeedingScoreUpdate(ctx context.Context, threshold time.Duration, limit int) ([]*models.Observation, error) {
+33 -15
View File
@@ -119,26 +119,44 @@ func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID st
}
// IncrementPromptCounter increments the prompt counter and returns the new value.
// Uses a single SQL query with RETURNING clause for optimal performance.
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
// Atomic increment using GORM expression
err := s.db.WithContext(ctx).
Model(&SDKSession{}).
Where("id = ?", id).
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
// Use raw SQL with RETURNING to get updated value in single query
// SQLite supports RETURNING since version 3.35.0 (2021-03-12)
var newCounter int
err := s.db.WithContext(ctx).Raw(`
UPDATE sdk_sessions
SET prompt_counter = COALESCE(prompt_counter, 0) + 1
WHERE id = ?
RETURNING prompt_counter
`, id).Scan(&newCounter).Error
if err != nil {
// Fallback for older SQLite versions without RETURNING support
if err.Error() == "near \"RETURNING\": syntax error" || newCounter == 0 {
// Atomic increment
updateErr := s.db.WithContext(ctx).
Model(&SDKSession{}).
Where("id = ?", id).
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
if updateErr != nil {
return 0, updateErr
}
// Fetch updated value
var sess SDKSession
fetchErr := s.db.WithContext(ctx).
Select("prompt_counter").
First(&sess, id).Error
if fetchErr != nil {
return 0, fetchErr
}
return sess.PromptCounter, nil
}
return 0, err
}
// Fetch updated value
var sess SDKSession
err = s.db.WithContext(ctx).
Select("prompt_counter").
First(&sess, id).Error
if err != nil {
return 0, err
}
return sess.PromptCounter, nil
return newCounter, nil
}
// GetPromptCounter returns the current prompt counter for a session.
+426 -9
View File
@@ -2,11 +2,16 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"slices"
"sync"
"time"
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
_ "github.com/mattn/go-sqlite3" // Import SQLite driver with FTS5 support
"github.com/rs/zerolog/log"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -14,8 +19,13 @@ import (
// Store represents the GORM database connection with sqlite-vec support.
type Store struct {
DB *gorm.DB
sqlDB *sql.DB // For FTS5 and sqlite-vec operations that require raw SQL
healthCacheTime time.Time
DB *gorm.DB
sqlDB *sql.DB
metrics *PoolMetrics
cachedHealth *HealthInfo
healthCacheTTL time.Duration
healthCacheMu sync.RWMutex
}
// Config holds database configuration.
@@ -71,8 +81,10 @@ func NewStore(cfg Config) (*Store, error) {
}
store := &Store{
DB: db,
sqlDB: sqlDB,
DB: db,
sqlDB: sqlDB,
metrics: NewPoolMetrics(100), // Track last 100 latency samples
healthCacheTTL: 5 * time.Second, // Cache health checks for 5 seconds
}
// 7. Run migrations FIRST (before PRAGMA commands)
@@ -80,13 +92,20 @@ func NewStore(cfg Config) (*Store, error) {
return nil, fmt.Errorf("run migrations: %w", err)
}
// 8. CRITICAL: Set WAL mode and synchronous mode via raw SQL
// 8. CRITICAL: Set WAL mode and other performance pragmas
// Use raw sqlDB to avoid GORM transaction issues
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
return nil, fmt.Errorf("set WAL mode: %w", err)
pragmas := []string{
"PRAGMA journal_mode=WAL",
"PRAGMA synchronous=NORMAL",
"PRAGMA cache_size=-64000", // 64MB cache (negative = KB)
"PRAGMA temp_store=MEMORY", // Store temp tables in memory
"PRAGMA mmap_size=268435456", // 256MB memory-mapped I/O
"PRAGMA page_size=4096", // 4KB pages (optimal for most systems)
}
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
return nil, fmt.Errorf("set synchronous mode: %w", err)
for _, pragma := range pragmas {
if _, err := sqlDB.Exec(pragma); err != nil {
log.Warn().Str("pragma", pragma).Err(err).Msg("Failed to set pragma (non-fatal)")
}
}
// Set busy timeout to 5 seconds to handle concurrent writes
// This allows SQLite to retry when database is locked instead of failing immediately
@@ -94,9 +113,40 @@ func NewStore(cfg Config) (*Store, error) {
return nil, fmt.Errorf("set busy timeout: %w", err)
}
// 9. Warm the connection pool
store.WarmPool(maxConns)
return store, nil
}
// WarmPool pre-creates connections to avoid cold start latency.
func (s *Store) WarmPool(numConns int) {
if numConns <= 0 {
numConns = 4
}
var wg sync.WaitGroup
for i := 0; i < numConns; i++ {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
conn, err := s.sqlDB.Conn(ctx)
if err != nil {
return
}
// Execute a simple query to ensure the connection is fully initialized
_ = conn.PingContext(ctx)
// Return connection to pool (don't close it)
_ = conn.Close()
}()
}
wg.Wait()
log.Debug().Int("connections", numConns).Msg("Connection pool warmed")
}
// Close closes the database connection.
func (s *Store) Close() error {
return s.sqlDB.Close()
@@ -120,3 +170,370 @@ func (s *Store) GetRawDB() *sql.DB {
func (s *Store) GetDB() *gorm.DB {
return s.DB
}
// Stats returns database connection pool statistics.
func (s *Store) Stats() sql.DBStats {
return s.sqlDB.Stats()
}
// Optimize runs VACUUM and ANALYZE to optimize the database.
// Should be called periodically (e.g., daily) during low activity.
func (s *Store) Optimize(ctx context.Context) error {
log.Info().Msg("Starting database optimization")
start := time.Now()
// ANALYZE updates statistics for query optimizer
if _, err := s.sqlDB.ExecContext(ctx, "ANALYZE"); err != nil {
return fmt.Errorf("analyze: %w", err)
}
// PRAGMA optimize runs optimization based on query statistics
if _, err := s.sqlDB.ExecContext(ctx, "PRAGMA optimize"); err != nil {
log.Warn().Err(err).Msg("PRAGMA optimize failed (non-fatal)")
}
log.Info().Dur("duration", time.Since(start)).Msg("Database optimization complete")
return nil
}
// HealthCheck performs a comprehensive health check with latency measurement.
// Returns detailed health information including connection pool stats and query latency.
// Results are cached for healthCacheTTL (default 5 seconds) to reduce database load
// from frequent monitoring calls.
func (s *Store) HealthCheck(ctx context.Context) *HealthInfo {
// Fast path: check cache with read lock
s.healthCacheMu.RLock()
if s.cachedHealth != nil && time.Since(s.healthCacheTime) < s.healthCacheTTL {
cached := s.cachedHealth
s.healthCacheMu.RUnlock()
return cached
}
s.healthCacheMu.RUnlock()
// Slow path: perform actual health check
info := s.performHealthCheck(ctx)
// Cache the result
s.healthCacheMu.Lock()
s.cachedHealth = info
s.healthCacheTime = time.Now()
s.healthCacheMu.Unlock()
return info
}
// HealthCheckForce performs a health check bypassing the cache.
// Use this when you need real-time health data (e.g., debugging, alerting).
func (s *Store) HealthCheckForce(ctx context.Context) *HealthInfo {
info := s.performHealthCheck(ctx)
// Update the cache with fresh data
s.healthCacheMu.Lock()
s.cachedHealth = info
s.healthCacheTime = time.Now()
s.healthCacheMu.Unlock()
return info
}
// performHealthCheck does the actual health check work.
func (s *Store) performHealthCheck(ctx context.Context) *HealthInfo {
info := &HealthInfo{
Status: "healthy",
Timestamp: time.Now(),
}
// Check pool stats
stats := s.sqlDB.Stats()
info.PoolStats = PoolStats{
OpenConnections: stats.OpenConnections,
InUse: stats.InUse,
Idle: stats.Idle,
WaitCount: stats.WaitCount,
WaitDuration: stats.WaitDuration,
MaxIdleClosed: stats.MaxIdleClosed,
MaxLifetimeClosed: stats.MaxLifetimeClosed,
}
// Record pool stats for metrics tracking
if s.metrics != nil {
s.metrics.RecordPoolStats(stats)
}
// Measure query latency with a simple SELECT
start := time.Now()
var dummy int
err := s.sqlDB.QueryRowContext(ctx, "SELECT 1").Scan(&dummy)
info.QueryLatency = time.Since(start)
// Record latency for historical tracking
if s.metrics != nil {
s.metrics.RecordLatency(info.QueryLatency)
info.HistoricalMetrics = s.metrics.GetMetricsSummary()
}
if err != nil {
info.Status = "unhealthy"
info.Error = err.Error()
return info
}
// Check for connection saturation (degraded if pool is heavily used)
if stats.InUse > 0 && float64(stats.InUse)/float64(stats.OpenConnections) > 0.8 {
info.Status = "degraded"
info.Warning = "Connection pool heavily utilized"
}
// Check for wait contention
if stats.WaitCount > 100 && stats.WaitDuration > 100*time.Millisecond {
info.Status = "degraded"
info.Warning = "Connection pool contention detected"
}
// Check query latency (warn if > 10ms for simple query)
if info.QueryLatency > 10*time.Millisecond {
if info.Status == "healthy" {
info.Status = "degraded"
}
info.Warning = fmt.Sprintf("Slow query latency: %v", info.QueryLatency)
}
// Check historical latency trend (degraded if P95 is high)
if s.metrics != nil && info.HistoricalMetrics.P95Latency > 50*time.Millisecond {
if info.Status == "healthy" {
info.Status = "degraded"
}
info.Warning = fmt.Sprintf("High P95 latency: %v", info.HistoricalMetrics.P95Latency)
}
return info
}
// HealthInfo contains database health check results.
type HealthInfo struct {
Timestamp time.Time `json:"timestamp"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
Warning string `json:"warning,omitempty"`
HistoricalMetrics MetricsSummary `json:"historical_metrics,omitempty"`
PoolStats PoolStats `json:"pool_stats"`
QueryLatency time.Duration `json:"query_latency_ns"`
}
// PoolStats contains connection pool statistics.
type PoolStats struct {
OpenConnections int `json:"open_connections"`
InUse int `json:"in_use"`
Idle int `json:"idle"`
WaitCount int64 `json:"wait_count"`
WaitDuration time.Duration `json:"wait_duration_ns"`
MaxIdleClosed int64 `json:"max_idle_closed"`
MaxLifetimeClosed int64 `json:"max_lifetime_closed"`
}
// QueryTimeout constants for different query types.
const (
// DefaultQueryTimeout is the default timeout for regular queries.
DefaultQueryTimeout = 5 * time.Second
// FastQueryTimeout is for queries that should be very fast (health checks, etc).
FastQueryTimeout = 1 * time.Second
// SlowQueryTimeout is for queries that may take longer (bulk operations, rebuilds).
SlowQueryTimeout = 30 * time.Second
)
// PoolMetrics tracks historical connection pool metrics with a sliding window.
type PoolMetrics struct {
lastSampleTime time.Time
latencySamples []time.Duration
latencyIdx int
latencyCount int
totalQueries int64
totalWaitTime time.Duration
peakInUse int
peakWaitCount int64
windowSize int
mu sync.RWMutex
}
// NewPoolMetrics creates a new pool metrics collector with the given window size.
func NewPoolMetrics(windowSize int) *PoolMetrics {
if windowSize <= 0 {
windowSize = 100 // Default: track last 100 samples
}
return &PoolMetrics{
latencySamples: make([]time.Duration, windowSize),
windowSize: windowSize,
lastSampleTime: time.Now(),
}
}
// RecordLatency records a query latency sample.
func (m *PoolMetrics) RecordLatency(latency time.Duration) {
m.mu.Lock()
defer m.mu.Unlock()
m.latencySamples[m.latencyIdx] = latency
m.latencyIdx = (m.latencyIdx + 1) % m.windowSize
if m.latencyCount < m.windowSize {
m.latencyCount++
}
m.totalQueries++
m.lastSampleTime = time.Now()
}
// RecordPoolStats records pool statistics for peak tracking.
func (m *PoolMetrics) RecordPoolStats(stats sql.DBStats) {
m.mu.Lock()
defer m.mu.Unlock()
if stats.InUse > m.peakInUse {
m.peakInUse = stats.InUse
}
if stats.WaitCount > m.peakWaitCount {
m.peakWaitCount = stats.WaitCount
}
m.totalWaitTime += stats.WaitDuration
}
// GetMetricsSummary returns a summary of collected metrics.
func (m *PoolMetrics) GetMetricsSummary() MetricsSummary {
m.mu.RLock()
defer m.mu.RUnlock()
summary := MetricsSummary{
TotalQueries: m.totalQueries,
SampleCount: m.latencyCount,
PeakInUse: m.peakInUse,
PeakWaitCount: m.peakWaitCount,
TotalWaitTime: m.totalWaitTime,
LastSampleTime: m.lastSampleTime,
}
if m.latencyCount == 0 {
return summary
}
// Calculate latency statistics
var total time.Duration
var min, max time.Duration = m.latencySamples[0], m.latencySamples[0]
for i := 0; i < m.latencyCount; i++ {
sample := m.latencySamples[i]
total += sample
if sample < min {
min = sample
}
if sample > max {
max = sample
}
}
summary.AvgLatency = total / time.Duration(m.latencyCount)
summary.MinLatency = min
summary.MaxLatency = max
// Calculate P95 latency (approximate using sorted samples)
if m.latencyCount >= 20 {
// Copy samples for sorting
samples := make([]time.Duration, m.latencyCount)
copy(samples, m.latencySamples[:m.latencyCount])
// Use slices.Sort for O(n log n) instead of O(n²) insertion sort
slices.Sort(samples)
p95Idx := int(float64(len(samples)) * 0.95)
summary.P95Latency = samples[p95Idx]
}
return summary
}
// MetricsSummary contains aggregated pool metrics.
type MetricsSummary struct {
LastSampleTime time.Time `json:"last_sample_time"`
TotalQueries int64 `json:"total_queries"`
SampleCount int `json:"sample_count"`
AvgLatency time.Duration `json:"avg_latency_ns"`
MinLatency time.Duration `json:"min_latency_ns"`
MaxLatency time.Duration `json:"max_latency_ns"`
P95Latency time.Duration `json:"p95_latency_ns,omitempty"`
PeakInUse int `json:"peak_in_use"`
PeakWaitCount int64 `json:"peak_wait_count"`
TotalWaitTime time.Duration `json:"total_wait_time_ns"`
}
// GetMetrics returns the current metrics without performing a health check.
func (s *Store) GetMetrics() MetricsSummary {
if s.metrics == nil {
return MetricsSummary{}
}
return s.metrics.GetMetricsSummary()
}
// ResetMetrics resets the metrics collector (useful for testing or after major changes).
func (s *Store) ResetMetrics() {
if s.metrics != nil {
s.metrics = NewPoolMetrics(s.metrics.windowSize)
}
}
// WithTimeout wraps a context with the given timeout and logs slow queries.
// Returns the wrapped context and a cancel function that should be called when done.
func (s *Store) WithTimeout(ctx context.Context, timeout time.Duration, operation string) (context.Context, context.CancelFunc) {
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
start := time.Now()
// Return wrapped cancel that logs if query was slow
return timeoutCtx, func() {
elapsed := time.Since(start)
cancel()
// Log slow queries (> 100ms)
if elapsed > 100*time.Millisecond {
log.Warn().
Str("operation", operation).
Dur("elapsed", elapsed).
Dur("timeout", timeout).
Msg("Slow database operation")
}
}
}
// ExecWithTimeout executes a raw SQL query with timeout.
// Returns error if query takes longer than timeout.
func (s *Store) ExecWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) error {
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "exec")
defer cancel()
_, err := s.sqlDB.ExecContext(timeoutCtx, query, args...)
if err != nil {
if err == context.DeadlineExceeded {
return fmt.Errorf("query timeout after %v: %s", timeout, query)
}
return err
}
return nil
}
// QueryRowWithTimeout executes a row query with timeout.
func (s *Store) QueryRowWithTimeout(ctx context.Context, timeout time.Duration, query string, args ...any) *sql.Row {
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "query_row")
// Note: cancel will be called when row.Scan() completes or errors
_ = cancel // Caller must ensure proper cleanup
return s.sqlDB.QueryRowContext(timeoutCtx, query, args...)
}
// TransactionWithTimeout wraps a transaction function with timeout handling.
// The transaction is automatically rolled back if the context times out.
func (s *Store) TransactionWithTimeout(ctx context.Context, timeout time.Duration, fn func(*gorm.DB) error) error {
timeoutCtx, cancel := s.WithTimeout(ctx, timeout, "transaction")
defer cancel()
return s.DB.WithContext(timeoutCtx).Transaction(func(tx *gorm.DB) error {
// Check context before proceeding
select {
case <-timeoutCtx.Done():
return timeoutCtx.Err()
default:
}
return fn(tx)
})
}
+290
View File
@@ -0,0 +1,290 @@
// Package maintenance provides scheduled maintenance tasks for claude-mnemonic.
package maintenance
import (
"context"
"sync"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/rs/zerolog"
)
// Service handles scheduled maintenance tasks.
type Service struct {
log zerolog.Logger
lastRunTime time.Time
promptStore *gorm.PromptStore
store *gorm.Store
vectorCleanupFn func(ctx context.Context, deletedIDs []int64)
config *config.Config
summaryStore *gorm.SummaryStore
stopCh chan struct{}
doneCh chan struct{}
observationStore *gorm.ObservationStore
lastRunDuration time.Duration
totalCleanedObs int64
totalOptimizeRun int64
mu sync.Mutex
running bool
}
// NewService creates a new maintenance service.
func NewService(
store *gorm.Store,
observationStore *gorm.ObservationStore,
summaryStore *gorm.SummaryStore,
promptStore *gorm.PromptStore,
vectorCleanupFn func(ctx context.Context, deletedIDs []int64),
cfg *config.Config,
log zerolog.Logger,
) *Service {
return &Service{
store: store,
observationStore: observationStore,
summaryStore: summaryStore,
promptStore: promptStore,
vectorCleanupFn: vectorCleanupFn,
config: cfg,
log: log.With().Str("component", "maintenance").Logger(),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start begins the maintenance loop.
func (s *Service) Start(ctx context.Context) {
s.mu.Lock()
if s.running {
s.mu.Unlock()
return
}
s.running = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.running = false
s.mu.Unlock()
close(s.doneCh)
}()
if !s.config.MaintenanceEnabled {
s.log.Info().Msg("Maintenance disabled, not starting scheduler")
return
}
interval := max(time.Duration(s.config.MaintenanceIntervalHours)*time.Hour, time.Hour)
s.log.Info().
Dur("interval", interval).
Int("retention_days", s.config.ObservationRetentionDays).
Bool("cleanup_stale", s.config.CleanupStaleObservations).
Msg("Starting maintenance scheduler")
// Initial run after 5 minutes (allow system to stabilize)
time.Sleep(5 * time.Minute)
s.runMaintenance(ctx)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
s.log.Info().Msg("Maintenance shutting down due to context cancellation")
return
case <-s.stopCh:
s.log.Info().Msg("Maintenance shutting down due to stop signal")
return
case <-ticker.C:
s.runMaintenance(ctx)
}
}
}
// Stop signals the maintenance service to stop.
func (s *Service) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running {
return
}
close(s.stopCh)
}
// Wait waits for the maintenance service to finish.
func (s *Service) Wait() {
<-s.doneCh
}
// runMaintenance executes all maintenance tasks.
func (s *Service) runMaintenance(ctx context.Context) {
start := time.Now()
s.log.Info().Msg("Starting maintenance run")
var totalCleaned int64
// Task 1: Clean up old observations by age
if s.config.ObservationRetentionDays > 0 {
cleaned, err := s.cleanupOldObservations(ctx)
if err != nil {
s.log.Error().Err(err).Msg("Failed to cleanup old observations")
} else {
totalCleaned += cleaned
s.log.Info().Int64("cleaned", cleaned).Msg("Cleaned old observations by age")
}
}
// Task 2: Clean up stale observations
if s.config.CleanupStaleObservations {
cleaned, err := s.cleanupStaleObservations(ctx)
if err != nil {
s.log.Error().Err(err).Msg("Failed to cleanup stale observations")
} else {
totalCleaned += cleaned
s.log.Info().Int64("cleaned", cleaned).Msg("Cleaned stale observations")
}
}
// Task 3: Optimize database
if err := s.store.Optimize(ctx); err != nil {
s.log.Error().Err(err).Msg("Failed to optimize database")
} else {
s.totalOptimizeRun++
}
// Task 4: Clean up old prompts (keep last 1000 per session)
cleanedPrompts, err := s.cleanupOldPrompts(ctx)
if err != nil {
s.log.Error().Err(err).Msg("Failed to cleanup old prompts")
} else if cleanedPrompts > 0 {
s.log.Info().Int64("cleaned", cleanedPrompts).Msg("Cleaned old prompts")
}
// Update metrics
s.mu.Lock()
s.lastRunTime = time.Now()
s.lastRunDuration = time.Since(start)
s.totalCleanedObs += totalCleaned
s.mu.Unlock()
s.log.Info().
Dur("duration", time.Since(start)).
Int64("observations_cleaned", totalCleaned).
Msg("Maintenance run completed")
}
// cleanupOldObservations deletes observations older than the retention period.
func (s *Service) cleanupOldObservations(ctx context.Context) (int64, error) {
cutoffEpoch := time.Now().AddDate(0, 0, -s.config.ObservationRetentionDays).Unix()
// Get IDs of old observations
var deletedIDs []int64
err := s.store.GetDB().WithContext(ctx).
Model(&gorm.Observation{}).
Where("created_at_epoch < ?", cutoffEpoch).
Pluck("id", &deletedIDs).Error
if err != nil {
return 0, err
}
if len(deletedIDs) == 0 {
return 0, nil
}
// Delete in batches to avoid long transactions
batchSize := 100
for i := 0; i < len(deletedIDs); i += batchSize {
end := min(i+batchSize, len(deletedIDs))
batch := deletedIDs[i:end]
if err := s.store.GetDB().WithContext(ctx).
Where("id IN ?", batch).
Delete(&gorm.Observation{}).Error; err != nil {
return int64(i), err
}
// Sync vector DB deletions
if s.vectorCleanupFn != nil {
s.vectorCleanupFn(ctx, batch)
}
}
return int64(len(deletedIDs)), nil
}
// cleanupStaleObservations deletes observations marked as stale.
func (s *Service) cleanupStaleObservations(ctx context.Context) (int64, error) {
// Get IDs of stale observations (is_superseded = true)
var deletedIDs []int64
err := s.store.GetDB().WithContext(ctx).
Model(&gorm.Observation{}).
Where("is_superseded = ?", true).
Pluck("id", &deletedIDs).Error
if err != nil {
return 0, err
}
if len(deletedIDs) == 0 {
return 0, nil
}
// Delete in batches
batchSize := 100
for i := 0; i < len(deletedIDs); i += batchSize {
end := min(i+batchSize, len(deletedIDs))
batch := deletedIDs[i:end]
if err := s.store.GetDB().WithContext(ctx).
Where("id IN ?", batch).
Delete(&gorm.Observation{}).Error; err != nil {
return int64(i), err
}
// Sync vector DB deletions
if s.vectorCleanupFn != nil {
s.vectorCleanupFn(ctx, batch)
}
}
return int64(len(deletedIDs)), nil
}
// cleanupOldPrompts removes old prompts keeping only the most recent per session.
func (s *Service) cleanupOldPrompts(ctx context.Context) (int64, error) {
// Delete prompts older than 30 days that aren't the most recent in their session
cutoffEpoch := time.Now().AddDate(0, 0, -30).Unix()
result := s.store.GetDB().WithContext(ctx).
Where("created_at_epoch < ?", cutoffEpoch).
Delete(&gorm.UserPrompt{})
return result.RowsAffected, result.Error
}
// Stats returns maintenance statistics.
func (s *Service) Stats() map[string]any {
s.mu.Lock()
defer s.mu.Unlock()
return map[string]any{
"enabled": s.config.MaintenanceEnabled,
"interval_hours": s.config.MaintenanceIntervalHours,
"retention_days": s.config.ObservationRetentionDays,
"cleanup_stale": s.config.CleanupStaleObservations,
"last_run": s.lastRunTime,
"last_duration_ms": s.lastRunDuration.Milliseconds(),
"total_cleaned_obs": s.totalCleanedObs,
"total_optimizes": s.totalOptimizeRun,
"running": s.running,
}
}
// RunNow triggers an immediate maintenance run.
func (s *Service) RunNow(ctx context.Context) {
go s.runMaintenance(ctx)
}
+2648 -49
View File
File diff suppressed because it is too large Load Diff
+20 -20
View File
@@ -24,7 +24,7 @@ func TestServerSuite(t *testing.T) {
// TestNewServer tests server creation.
func (s *ServerSuite) TestNewServer() {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
s.NotNil(server)
s.Nil(server.searchMgr)
s.Equal("1.0.0", server.version)
@@ -293,7 +293,7 @@ func TestTimelineParams(t *testing.T) {
// TestHandleInitialize tests the initialize handler.
func TestHandleInitialize(t *testing.T) {
server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.2.3", nil, nil, nil, nil, nil, nil, nil, nil)
req := &Request{
JSONRPC: "2.0",
@@ -320,7 +320,7 @@ func TestHandleInitialize(t *testing.T) {
// TestHandleToolsList tests the tools/list handler.
func TestHandleToolsList(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
req := &Request{
JSONRPC: "2.0",
@@ -361,7 +361,7 @@ func TestHandleToolsList(t *testing.T) {
// TestHandleRequest tests request routing.
func TestHandleRequest(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
tests := []struct {
@@ -423,7 +423,7 @@ func TestHandleRequest(t *testing.T) {
// TestHandleToolsCall_InvalidParams tests tools/call with invalid params.
func TestHandleToolsCall_InvalidParams(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
req := &Request{
@@ -442,7 +442,7 @@ func TestHandleToolsCall_InvalidParams(t *testing.T) {
// TestCallTool_UnknownTool tests callTool with unknown tool name.
func TestCallTool_UnknownTool(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
_, err := server.callTool(ctx, "nonexistent_tool", json.RawMessage(`{}`))
@@ -452,7 +452,7 @@ func TestCallTool_UnknownTool(t *testing.T) {
// TestCallTool_InvalidArgs tests callTool with invalid arguments.
func TestCallTool_InvalidArgs(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
_, err := server.callTool(ctx, "search", json.RawMessage(`invalid json`))
@@ -574,7 +574,7 @@ func TestJSONRPCErrorCodes(t *testing.T) {
// TestToolListContainsExpectedSchemas tests that tool schemas are valid.
func TestToolListContainsExpectedSchemas(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
req := &Request{
JSONRPC: "2.0",
@@ -600,7 +600,7 @@ func TestToolListContainsExpectedSchemas(t *testing.T) {
// TestHandleToolsCall_UnknownTool tests tools/call with unknown tool name.
func TestHandleToolsCall_UnknownTool(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
req := &Request{
@@ -620,7 +620,7 @@ func TestHandleToolsCall_UnknownTool(t *testing.T) {
func TestCallTool_ToolNameRecognition(t *testing.T) {
// Note: This test verifies tool routing logic, not execution (which requires searchMgr)
// All valid tool names should be in the handleToolsList response
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
req := &Request{
JSONRPC: "2.0",
@@ -782,7 +782,7 @@ func TestResponseIDTypes(t *testing.T) {
// TestHandleTimelineByQuery_EmptyQuery tests timeline by query with empty query.
func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
// Empty query should error
@@ -793,7 +793,7 @@ func TestHandleTimelineByQuery_EmptyQuery(t *testing.T) {
// TestHandleTimeline_InvalidJSON tests timeline with invalid JSON.
func TestHandleTimeline_InvalidJSON(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
_, err := server.handleTimeline(ctx, json.RawMessage(`{invalid`))
@@ -803,7 +803,7 @@ func TestHandleTimeline_InvalidJSON(t *testing.T) {
// TestHandleTimelineByQuery_InvalidJSON tests timeline by query with invalid JSON.
func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
_, err := server.handleTimelineByQuery(ctx, json.RawMessage(`{invalid`))
@@ -813,7 +813,7 @@ func TestHandleTimelineByQuery_InvalidJSON(t *testing.T) {
// TestHandleTimeline_NoAnchorNoQuery tests timeline with no anchor and no query.
func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
// No anchor_id and no query should return empty result
@@ -825,7 +825,7 @@ func TestHandleTimeline_NoAnchorNoQuery(t *testing.T) {
// TestHandleTimeline_WithDefaults tests timeline default values are applied.
func TestHandleTimeline_WithDefaults(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
// With anchor_id but no before/after, defaults should be applied
@@ -839,7 +839,7 @@ func TestHandleTimeline_WithDefaults(t *testing.T) {
// TestServerFields tests Server struct fields.
func TestServerFields(t *testing.T) {
server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "2.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
assert.Equal(t, "2.0.0", server.version)
assert.Nil(t, server.searchMgr)
@@ -891,7 +891,7 @@ func TestErrorWithNilData(t *testing.T) {
// TestToolInputSchema tests that tool input schemas have required fields.
func TestToolInputSchema(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
req := &Request{
JSONRPC: "2.0",
@@ -960,7 +960,7 @@ func TestToolCallParamsWithComplexArgs(t *testing.T) {
// TestCallTool_UnknownToolName tests callTool with various unknown tool names.
func TestCallTool_UnknownToolName(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
unknownTools := []string{
@@ -1009,7 +1009,7 @@ func TestTimelineParams_Validation(t *testing.T) {
// TestHandleToolsCall_UnknownToolNameError tests tools/call with unknown tool returns error.
func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
req := &Request{
@@ -1031,7 +1031,7 @@ func TestHandleToolsCall_UnknownToolNameError(t *testing.T) {
// TestHandleToolsCall_EmptyParams tests tools/call with empty params.
func TestHandleToolsCall_EmptyParams(t *testing.T) {
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil)
server := NewServer(nil, "1.0.0", nil, nil, nil, nil, nil, nil, nil, nil)
ctx := context.Background()
req := &Request{
+67 -5
View File
@@ -3,6 +3,8 @@ package pattern
import (
"context"
"sort"
"strings"
"sync"
"time"
@@ -21,6 +23,8 @@ type DetectorConfig struct {
AnalysisInterval time.Duration
// MaxPatternsToTrack is the maximum number of active patterns.
MaxPatternsToTrack int
// MaxCandidates is the maximum number of candidates to track (LRU eviction).
MaxCandidates int
}
// DefaultConfig returns the default detector configuration.
@@ -30,6 +34,7 @@ func DefaultConfig() DetectorConfig {
MinFrequencyForPattern: 2, // At least 2 occurrences to form a pattern
AnalysisInterval: 5 * time.Minute,
MaxPatternsToTrack: 1000,
MaxCandidates: 500, // Prevent unbounded growth
}
}
@@ -195,7 +200,24 @@ func (d *Detector) AnalyzeObservation(ctx context.Context, obs *models.Observati
}
}
} else {
// Create new candidate
// Create new candidate - with immediate size check to prevent unbounded growth
// between periodic cleanups (which run every 5 minutes)
if d.config.MaxCandidates > 0 && len(d.candidates) >= d.config.MaxCandidates {
// Evict oldest candidate immediately rather than waiting for periodic cleanup
var oldestKey string
var oldestTime int64 = time.Now().UnixMilli()
for k, c := range d.candidates {
if c.lastSeenEpoch < oldestTime {
oldestTime = c.lastSeenEpoch
oldestKey = k
}
}
if oldestKey != "" {
delete(d.candidates, oldestKey)
log.Debug().Str("evicted_key", oldestKey).Msg("Evicted oldest candidate to make room")
}
}
patternType := models.DetectPatternType(obs.Concepts, obs.Title.String, obs.Narrative.String)
d.candidates[candidateKey] = &candidatePattern{
signature: signature,
@@ -299,17 +321,53 @@ func (d *Detector) AnalyzeRecentObservations(ctx context.Context) error {
return nil
}
// cleanupOldCandidates removes candidates that haven't been seen recently.
// cleanupOldCandidates removes candidates that haven't been seen recently
// and enforces the max candidates limit using LRU eviction.
func (d *Detector) cleanupOldCandidates() {
d.candidatesMu.Lock()
defer d.candidatesMu.Unlock()
threshold := time.Now().Add(-7 * 24 * time.Hour).UnixMilli()
// First pass: remove expired candidates
for key, candidate := range d.candidates {
if candidate.lastSeenEpoch < threshold {
delete(d.candidates, key)
}
}
// Second pass: enforce max candidates limit using LRU eviction
if d.config.MaxCandidates > 0 && len(d.candidates) > d.config.MaxCandidates {
// Find oldest candidates to evict using O(n log n) sort instead of O(n²) selection sort
type keyAge struct {
key string
age int64
}
candidates := make([]keyAge, 0, len(d.candidates))
for k, c := range d.candidates {
candidates = append(candidates, keyAge{k, c.lastSeenEpoch})
}
// Sort by age ascending (oldest first) - O(n log n)
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].age < candidates[j].age
})
// Delete oldest entries
toEvict := len(d.candidates) - d.config.MaxCandidates
for i := 0; i < toEvict; i++ {
delete(d.candidates, candidates[i].key)
}
log.Debug().Int("evicted", toEvict).Int("remaining", len(d.candidates)).Msg("Evicted old pattern candidates (LRU)")
}
}
// CandidateCount returns the current number of pattern candidates.
func (d *Detector) CandidateCount() int {
d.candidatesMu.RLock()
defer d.candidatesMu.RUnlock()
return len(d.candidates)
}
// GetPatternInsight returns a formatted insight string for a pattern.
@@ -334,11 +392,15 @@ func generateCandidateKey(signature []string) string {
if len(signature) == 0 {
return ""
}
key := ""
// Use strings.Builder to avoid O(n²) string concatenation
var b strings.Builder
// Pre-allocate: estimate average signature element is 10 chars + separator
b.Grow(len(signature) * 11)
for _, s := range signature {
key += s + "|"
b.WriteString(s)
b.WriteByte('|')
}
return key
return b.String()
}
// generatePatternName creates a human-readable name for a pattern.
+98
View File
@@ -0,0 +1,98 @@
// Package privacy provides utilities for protecting sensitive data.
package privacy
import (
"regexp"
"slices"
"strings"
)
// secretPatterns contains compiled regular expressions for detecting secrets.
// These patterns are designed to catch common secret formats with minimal false positives.
var secretPatterns = []*regexp.Regexp{
// API keys with common prefixes
regexp.MustCompile(`(?i)(api[_-]?key|apikey)\s*[:=]\s*['"]?[a-zA-Z0-9_-]{20,}['"]?`),
// Passwords in configuration
regexp.MustCompile(`(?i)(password|passwd|pwd)\s*[:=]\s*['"][^'"]{8,}['"]`),
// Secret tokens
regexp.MustCompile(`(?i)(secret[_-]?key|secret[_-]?token|auth[_-]?token)\s*[:=]\s*['"]?[a-zA-Z0-9_-]{20,}['"]?`),
// OpenAI API keys
regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}`),
// Anthropic API keys
regexp.MustCompile(`sk-ant-[a-zA-Z0-9-]{20,}`),
// GitHub tokens
regexp.MustCompile(`gh[pous]_[a-zA-Z0-9]{36,}`),
regexp.MustCompile(`github_pat_[a-zA-Z0-9_]{22,}`),
// AWS keys
regexp.MustCompile(`AKIA[0-9A-Z]{16}`),
regexp.MustCompile(`(?i)aws[_-]?secret[_-]?access[_-]?key\s*[:=]\s*['"]?[a-zA-Z0-9/+=]{40}['"]?`),
// Private keys (PEM format indicators)
regexp.MustCompile(`-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----`),
// JWT tokens (base64.base64.base64 format)
regexp.MustCompile(`eyJ[a-zA-Z0-9_-]+\.eyJ[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+`),
// Generic secret assignment patterns
regexp.MustCompile(`(?i)bearer\s+[a-zA-Z0-9_-]{20,}`),
}
// ContainsSecrets checks if the given text contains any patterns that look like secrets.
// Returns true if potential secrets are detected.
func ContainsSecrets(text string) bool {
if text == "" {
return false
}
for _, pattern := range secretPatterns {
if pattern.MatchString(text) {
return true
}
}
return false
}
// RedactSecrets replaces detected secrets with a redaction marker.
// This allows the text to be stored while protecting sensitive data.
func RedactSecrets(text string) string {
if text == "" {
return text
}
result := text
for _, pattern := range secretPatterns {
result = pattern.ReplaceAllStringFunc(result, func(match string) string {
// Preserve the key name, redact only the value
if idx := strings.Index(match, "="); idx != -1 {
return match[:idx+1] + "[REDACTED]"
}
if idx := strings.Index(match, ":"); idx != -1 {
return match[:idx+1] + "[REDACTED]"
}
// For standalone secrets, show just the prefix
if len(match) > 8 {
return match[:4] + "...[REDACTED]"
}
return "[REDACTED]"
})
}
return result
}
// SanitizeObservation checks multiple fields of an observation for secrets.
// Returns true if any secrets were found.
// This function is used as a validation gate before storing observations.
func SanitizeObservation(narrative string, facts []string) bool {
if ContainsSecrets(narrative) {
return true
}
return slices.ContainsFunc(facts, func(fact string) bool {
return ContainsSecrets(fact)
})
}
+213
View File
@@ -0,0 +1,213 @@
package privacy
import (
"testing"
)
func TestContainsSecrets(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "empty string",
input: "",
expected: false,
},
{
name: "normal text",
input: "This is just some regular text about a bug fix",
expected: false,
},
{
name: "API key pattern",
input: "api_key=abc123def456ghi789jkl012mno345pqr678",
expected: true,
},
{
name: "api-key with dash",
input: `api-key: "abc123def456ghi789jkl012mno"`,
expected: true,
},
{
name: "password in config",
input: `password="super_secret_password_123"`,
expected: true,
},
{
name: "OpenAI key format",
input: "sk-abc123def456ghi789jkl012mno345pqr678",
expected: true,
},
{
name: "Anthropic key format",
input: "sk-ant-api03-abc123def456ghi789jkl012mno345",
expected: true,
},
{
name: "GitHub PAT",
input: "ghp_1234567890abcdefghijklmnopqrstuvwxyz",
expected: true,
},
{
name: "GitHub PAT new format",
input: "github_pat_12ABCDEFGHIJ3456789abc_defghijklmno",
expected: true,
},
{
name: "AWS access key",
input: "AKIAIOSFODNN7EXAMPLE",
expected: true,
},
{
name: "Private key header",
input: "-----BEGIN RSA PRIVATE KEY-----",
expected: true,
},
{
name: "JWT token",
input: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U",
expected: true,
},
{
name: "bearer token",
input: "Bearer abc123def456ghi789jkl012mno345",
expected: true,
},
{
name: "secret_key in code",
input: `secret_key = "my_super_secret_token_here"`,
expected: true,
},
{
name: "short password is not detected",
input: `password="short"`,
expected: false, // Too short to trigger
},
{
name: "word password in sentence",
input: "The password field should be validated",
expected: false,
},
{
name: "word api in code",
input: "The API returns JSON data",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ContainsSecrets(tt.input)
if result != tt.expected {
t.Errorf("ContainsSecrets(%q) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}
func TestRedactSecrets(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string",
input: "",
expected: "",
},
{
name: "no secrets",
input: "This is safe text",
expected: "This is safe text",
},
{
name: "API key gets redacted",
input: "api_key=abc123def456ghi789jkl012mno345pqr678",
expected: "api_key=[REDACTED]",
},
{
name: "OpenAI key gets redacted",
input: "The key is sk-abc123def456ghi789jkl012mno345pqr678",
expected: "The key is sk-a...[REDACTED]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := RedactSecrets(tt.input)
if result != tt.expected {
t.Errorf("RedactSecrets(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestSanitizeObservation(t *testing.T) {
tests := []struct {
name string
narrative string
facts []string
expected bool
}{
{
name: "clean observation",
narrative: "Fixed a bug in the login flow",
facts: []string{"Users can now log in", "Session management improved"},
expected: false,
},
{
name: "secret in narrative",
narrative: "Set API key api_key=abc123def456ghi789jkl012mno345",
facts: []string{"Configuration updated"},
expected: true,
},
{
name: "secret in facts",
narrative: "Updated configuration",
facts: []string{"Added api_key=abc123def456ghi789jkl012mno345"},
expected: true,
},
{
name: "empty facts",
narrative: "Clean narrative",
facts: []string{},
expected: false,
},
{
name: "nil facts",
narrative: "Clean narrative",
facts: nil,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeObservation(tt.narrative, tt.facts)
if result != tt.expected {
t.Errorf("SanitizeObservation() = %v, want %v", result, tt.expected)
}
})
}
}
func BenchmarkContainsSecrets(b *testing.B) {
text := "This is a normal piece of text that does not contain any secrets or sensitive information"
b.ResetTimer()
for i := 0; i < b.N; i++ {
ContainsSecrets(text)
}
}
func BenchmarkContainsSecretsWithSecret(b *testing.B) {
text := "api_key=abc123def456ghi789jkl012mno345pqr678"
b.ResetTimer()
for i := 0; i < b.N; i++ {
ContainsSecrets(text)
}
}
+7 -36
View File
@@ -36,6 +36,13 @@ func NewCalculator(config *models.ScoringConfig) *Calculator {
// - ConceptContrib = sum(concept_weights) × concept_weight_factor
// - RetrievalContrib = log2(retrieval_count + 1) × 0.1 × retrieval_weight
func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 {
return c.CalculateComponents(obs, now).FinalScore
}
// CalculateComponents returns the individual components of the importance score.
// Useful for debugging and explaining scores to users.
// This is the core calculation method - Calculate() delegates to this.
func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents {
// 1. Get base type weight
typeWeight := models.TypeBaseScore(obs.Type)
@@ -75,42 +82,6 @@ func (c *Calculator) Calculate(obs *models.Observation, now time.Time) float64 {
finalScore = c.config.MinScore
}
return finalScore
}
// CalculateComponents returns the individual components of the importance score.
// Useful for debugging and explaining scores to users.
func (c *Calculator) CalculateComponents(obs *models.Observation, now time.Time) ScoreComponents {
typeWeight := models.TypeBaseScore(obs.Type)
ageDays := now.Sub(time.UnixMilli(obs.CreatedAtEpoch)).Hours() / 24.0
if ageDays < 0 {
ageDays = 0
}
recencyDecay := math.Pow(0.5, ageDays/c.config.RecencyHalfLifeDays)
coreScore := 1.0 * typeWeight * recencyDecay
feedbackContrib := float64(obs.UserFeedback) * c.config.FeedbackWeight
conceptBoost := 0.0
for _, concept := range obs.Concepts {
if weight, ok := c.config.ConceptWeights[concept]; ok {
conceptBoost += weight
}
}
conceptContrib := conceptBoost * c.config.ConceptWeight
retrievalContrib := 0.0
if obs.RetrievalCount > 0 {
retrievalBoost := math.Log2(float64(obs.RetrievalCount)+1) * 0.1
retrievalContrib = retrievalBoost * c.config.RetrievalWeight
}
finalScore := coreScore + feedbackContrib + conceptContrib + retrievalContrib
if finalScore < c.config.MinScore {
finalScore = c.config.MinScore
}
return ScoreComponents{
TypeWeight: typeWeight,
RecencyDecay: recencyDecay,
+9 -14
View File
@@ -3,7 +3,9 @@ package expansion
import (
"context"
"math"
"regexp"
"sort"
"strings"
"sync"
@@ -281,14 +283,10 @@ func (e *Expander) expandByVocabulary(ctx context.Context, query string, minSimi
return nil
}
// Sort by score (descending) using bubble sort
for i := 0; i < len(similar)-1; i++ {
for j := i + 1; j < len(similar); j++ {
if similar[j].score > similar[i].score {
similar[i], similar[j] = similar[j], similar[i]
}
}
}
// Sort by score (descending) using Go's standard sort - O(n log n)
sort.Slice(similar, func(i, j int) bool {
return similar[i].score > similar[j].score
})
// Create expansion by combining top similar terms with query
var expansions []ExpandedQuery
@@ -427,16 +425,13 @@ func cosineSimilarity(a, b []float32) float64 {
return dot / (sqrt(normA) * sqrt(normB))
}
// sqrt is a simple square root implementation.
// sqrt uses the standard math.Sqrt for better performance and accuracy.
// Returns 0 for non-positive values (original behavior for compatibility).
func sqrt(x float64) float64 {
if x <= 0 {
return 0
}
z := x
for i := 0; i < 10; i++ {
z = (z + x/z) / 2
}
return z
return math.Sqrt(x)
}
// truncate truncates a string to maxLen characters.
+680 -24
View File
@@ -3,19 +3,146 @@ package search
import (
"context"
"hash/fnv"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
"golang.org/x/sync/singleflight"
)
// multiSpaceRegex matches multiple consecutive whitespace characters.
// Pre-compiled for performance in normalizeQuery.
var multiSpaceRegex = regexp.MustCompile(`\s+`)
// Search configuration constants.
const (
// Cache configuration
defaultCacheTTL = 30 * time.Second // Short TTL for freshness
defaultCacheMaxSize = 200 // Max cached results
cacheEvictionPercent = 10 // Evict 10% when cache is full
cacheEvictionThreshold = 80 // Start eviction scan at 80% capacity
// Latency tracking
latencyHistogramCap = 1000 // Max latency samples for histogram
slowQueryThresholdNs = 100 * 1e6 // 100ms threshold for slow query logging
// Query frequency tracking
maxFrequencyEntries = 1000 // Max queries to track for warming
frequencyEvictionBatch = 100 // Remove 10% when frequency map is full
staleQueryThreshold = 24 * time.Hour // Remove queries older than 24 hours
recentQueryWindow = time.Hour // Only consider queries from last hour for warming
// Cache warming configuration
cacheWarmingInitDelay = 30 * time.Second // Delay before starting warming
cacheWarmingInterval = 20 * time.Second // Run warming cycle every 20 seconds
frequencyCleanupInterval = 5 * time.Minute // Cleanup stale entries every 5 minutes
cacheCleanupInterval = time.Minute // Cleanup expired cache every minute
warmingQueryTimeout = 5 * time.Second // Timeout for warming queries
warmingBatchSize = 5 // Warm top 5 queries per cycle
minRecencyFactor = 0.1 // Minimum recency factor for scoring
// Default query limits
defaultQueryLimit = 20
maxQueryLimit = 100
defaultOrderBy = "date_desc"
// Truncation lengths
queryLogTruncateLen = 50 // Truncate query in logs
titleTruncateLen = 100 // Truncate titles in results
warmingLogTruncateLen = 30 // Truncate query in warming logs
)
// SearchMetrics tracks search performance statistics.
type SearchMetrics struct {
latencyHistogram []int64
TotalSearches int64
VectorSearches int64
FilterSearches int64
TotalLatencyNs int64
VectorLatencyNs int64
FilterLatencyNs int64
CacheHits int64
CoalescedRequests int64
SearchErrors int64
histogramMu sync.Mutex
}
// GetStats returns the current search statistics.
func (m *SearchMetrics) GetStats() map[string]any {
totalSearches := atomic.LoadInt64(&m.TotalSearches)
totalLatency := atomic.LoadInt64(&m.TotalLatencyNs)
vectorSearches := atomic.LoadInt64(&m.VectorSearches)
vectorLatency := atomic.LoadInt64(&m.VectorLatencyNs)
filterSearches := atomic.LoadInt64(&m.FilterSearches)
filterLatency := atomic.LoadInt64(&m.FilterLatencyNs)
avgLatencyMs := float64(0)
if totalSearches > 0 {
avgLatencyMs = float64(totalLatency) / float64(totalSearches) / 1e6
}
avgVectorLatencyMs := float64(0)
if vectorSearches > 0 {
avgVectorLatencyMs = float64(vectorLatency) / float64(vectorSearches) / 1e6
}
avgFilterLatencyMs := float64(0)
if filterSearches > 0 {
avgFilterLatencyMs = float64(filterLatency) / float64(filterSearches) / 1e6
}
return map[string]any{
"total_searches": totalSearches,
"vector_searches": vectorSearches,
"filter_searches": filterSearches,
"cache_hits": atomic.LoadInt64(&m.CacheHits),
"coalesced_requests": atomic.LoadInt64(&m.CoalescedRequests),
"search_errors": atomic.LoadInt64(&m.SearchErrors),
"avg_latency_ms": avgLatencyMs,
"avg_vector_latency_ms": avgVectorLatencyMs,
"avg_filter_latency_ms": avgFilterLatencyMs,
}
}
// Manager provides unified search across SQLite and sqlite-vec.
type Manager struct {
ctx context.Context
searchGroup singleflight.Group
cancel context.CancelFunc
vectorClient *sqlitevec.Client
metrics *SearchMetrics
promptStore *gorm.PromptStore
observationStore *gorm.ObservationStore
summaryStore *gorm.SummaryStore
promptStore *gorm.PromptStore
vectorClient *sqlitevec.Client
resultCache map[string]*cachedResult
queryFrequency map[string]*queryFrequencyInfo
cacheTTL time.Duration
cacheMaxSize int
resultCacheMu sync.RWMutex
queryFrequencyMu sync.RWMutex
}
// queryFrequencyInfo tracks how often a query is used.
type queryFrequencyInfo struct {
lastUsed time.Time
lastCached time.Time
params SearchParams
count int64
}
// cachedResult stores a cached search result with expiry.
type cachedResult struct {
result *UnifiedSearchResult
expiresAt time.Time
}
// NewManager creates a new search manager.
@@ -25,12 +152,463 @@ func NewManager(
promptStore *gorm.PromptStore,
vectorClient *sqlitevec.Client,
) *Manager {
return &Manager{
ctx, cancel := context.WithCancel(context.Background())
m := &Manager{
observationStore: observationStore,
summaryStore: summaryStore,
promptStore: promptStore,
vectorClient: vectorClient,
metrics: &SearchMetrics{latencyHistogram: make([]int64, 0, latencyHistogramCap)},
ctx: ctx,
cancel: cancel,
resultCache: make(map[string]*cachedResult),
cacheTTL: defaultCacheTTL,
cacheMaxSize: defaultCacheMaxSize,
queryFrequency: make(map[string]*queryFrequencyInfo),
}
// Start cache cleanup goroutine
go m.cleanupCacheLoop()
// Start cache warming goroutine
go m.cacheWarmingLoop()
return m
}
// Close stops background goroutines and cleans up resources.
func (m *Manager) Close() {
if m.cancel != nil {
m.cancel()
}
}
// cleanupCacheLoop periodically removes expired cache entries.
func (m *Manager) cleanupCacheLoop() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanupExpiredCache()
}
}
}
// cleanupExpiredCache removes expired entries from the cache.
func (m *Manager) cleanupExpiredCache() {
m.resultCacheMu.Lock()
defer m.resultCacheMu.Unlock()
now := time.Now()
for key, cached := range m.resultCache {
if now.After(cached.expiresAt) {
delete(m.resultCache, key)
}
}
}
// cacheWarmingLoop periodically warms the cache for frequently used queries.
func (m *Manager) cacheWarmingLoop() {
// Wait a bit before starting to allow system to stabilize
select {
case <-m.ctx.Done():
return
case <-time.After(cacheWarmingInitDelay):
}
warmingTicker := time.NewTicker(cacheWarmingInterval)
cleanupTicker := time.NewTicker(frequencyCleanupInterval)
defer warmingTicker.Stop()
defer cleanupTicker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-warmingTicker.C:
m.warmFrequentQueries()
case <-cleanupTicker.C:
m.cleanupStaleFrequencyEntries()
}
}
}
// cleanupStaleFrequencyEntries removes query frequency entries older than staleQueryThreshold.
// This prevents memory bloat from queries that haven't been used in a long time.
func (m *Manager) cleanupStaleFrequencyEntries() {
m.queryFrequencyMu.Lock()
now := time.Now()
var keysToDelete []string
for k, v := range m.queryFrequency {
if now.Sub(v.lastUsed) > staleQueryThreshold {
keysToDelete = append(keysToDelete, k)
}
}
for _, k := range keysToDelete {
delete(m.queryFrequency, k)
}
m.queryFrequencyMu.Unlock()
if len(keysToDelete) > 0 {
log.Debug().Int("removed", len(keysToDelete)).Msg("Cleaned up stale query frequency entries")
}
}
// warmFrequentQueries pre-executes frequently used queries to warm the cache.
func (m *Manager) warmFrequentQueries() {
m.queryFrequencyMu.RLock()
// Find top N most frequent queries that aren't recently cached
type queryScore struct {
info *queryFrequencyInfo
key string
score float64
}
candidates := make([]queryScore, 0, len(m.queryFrequency))
now := time.Now()
for key, info := range m.queryFrequency {
// Only consider queries used recently
if now.Sub(info.lastUsed) > recentQueryWindow {
continue
}
// Only warm if not recently cached (cache about to expire or already expired)
timeSinceLastCache := now.Sub(info.lastCached)
if timeSinceLastCache < m.cacheTTL/2 {
continue
}
// Score: frequency * recency factor
recencyFactor := 1.0 - (now.Sub(info.lastUsed).Seconds() / recentQueryWindow.Seconds())
if recencyFactor < minRecencyFactor {
recencyFactor = minRecencyFactor
}
score := float64(info.count) * recencyFactor
candidates = append(candidates, queryScore{key: key, info: info, score: score})
}
m.queryFrequencyMu.RUnlock()
// Sort by score descending using O(n log n) algorithm
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].score > candidates[j].score
})
// Warm top queries
warmCount := min(warmingBatchSize, len(candidates))
for i := range warmCount {
candidate := candidates[i]
ctx, cancel := context.WithTimeout(context.Background(), warmingQueryTimeout)
result, err := m.executeSearch(ctx, candidate.info.params)
cancel()
if err == nil && result != nil {
cacheKey := m.getCacheKey(candidate.info.params)
m.putInCache(cacheKey, result)
// Update last cached time
m.queryFrequencyMu.Lock()
if info, ok := m.queryFrequency[candidate.key]; ok {
info.lastCached = time.Now()
}
m.queryFrequencyMu.Unlock()
log.Debug().
Str("query", truncate(candidate.info.params.Query, warmingLogTruncateLen)).
Float64("score", candidate.score).
Msg("Cache warmed for frequent query")
}
}
}
// trackQueryFrequency records query usage for cache warming decisions.
func (m *Manager) trackQueryFrequency(params SearchParams) {
key := m.getCacheKey(params)
m.queryFrequencyMu.Lock()
if info, ok := m.queryFrequency[key]; ok {
info.count++
info.lastUsed = time.Now()
m.queryFrequencyMu.Unlock()
return // Fast path: no eviction needed
}
m.queryFrequency[key] = &queryFrequencyInfo{
params: params,
count: 1,
lastUsed: time.Now(),
}
// Limit frequency map size to prevent memory bloat
mapLen := len(m.queryFrequency)
if mapLen <= maxFrequencyEntries {
m.queryFrequencyMu.Unlock()
return // Fast path: no eviction needed
}
// Collect keys and times for eviction (still under lock, but fast)
type entry struct {
lastUsed time.Time
key string
}
entries := make([]entry, 0, mapLen)
for k, v := range m.queryFrequency {
entries = append(entries, entry{key: k, lastUsed: v.lastUsed})
}
m.queryFrequencyMu.Unlock()
// Sort outside lock to reduce contention (O(n log n))
sort.Slice(entries, func(i, j int) bool {
return entries[i].lastUsed.Before(entries[j].lastUsed)
})
// Collect keys to delete
deleteCount := min(frequencyEvictionBatch, len(entries))
keysToDelete := make([]string, deleteCount)
for i := range deleteCount {
keysToDelete[i] = entries[i].key
}
// Re-acquire lock only for deletion (brief critical section)
m.queryFrequencyMu.Lock()
for _, k := range keysToDelete {
delete(m.queryFrequency, k)
}
m.queryFrequencyMu.Unlock()
}
// RecentQuery represents a recently executed search query.
type RecentQuery struct {
LastUsed time.Time `json:"last_used"`
Query string `json:"query"`
Project string `json:"project,omitempty"`
Type string `json:"type,omitempty"`
Count int64 `json:"count"`
}
// GetRecentQueries returns the most recent search queries, sorted by last used time.
func (m *Manager) GetRecentQueries(limit int) []RecentQuery {
if limit <= 0 {
limit = defaultQueryLimit
}
if limit > maxQueryLimit {
limit = maxQueryLimit
}
m.queryFrequencyMu.RLock()
defer m.queryFrequencyMu.RUnlock()
// Collect all queries
queries := make([]RecentQuery, 0, len(m.queryFrequency))
for _, info := range m.queryFrequency {
queries = append(queries, RecentQuery{
Query: info.params.Query,
Project: info.params.Project,
Type: info.params.Type,
Count: info.count,
LastUsed: info.lastUsed,
})
}
// Sort by last used (most recent first)
sort.Slice(queries, func(i, j int) bool {
return queries[i].LastUsed.After(queries[j].LastUsed)
})
// Limit results
if len(queries) > limit {
queries = queries[:limit]
}
return queries
}
// GetFrequentQueries returns the most frequently used search queries.
func (m *Manager) GetFrequentQueries(limit int) []RecentQuery {
if limit <= 0 {
limit = defaultQueryLimit
}
if limit > maxQueryLimit {
limit = maxQueryLimit
}
m.queryFrequencyMu.RLock()
defer m.queryFrequencyMu.RUnlock()
// Only include queries used recently
now := time.Now()
queries := make([]RecentQuery, 0, len(m.queryFrequency))
for _, info := range m.queryFrequency {
if now.Sub(info.lastUsed) > recentQueryWindow {
continue
}
queries = append(queries, RecentQuery{
Query: info.params.Query,
Project: info.params.Project,
Type: info.params.Type,
Count: info.count,
LastUsed: info.lastUsed,
})
}
// Sort by count (highest first)
sort.Slice(queries, func(i, j int) bool {
return queries[i].Count > queries[j].Count
})
// Limit results
if len(queries) > limit {
queries = queries[:limit]
}
return queries
}
// normalizeQuery normalizes a search query for consistent cache keys.
// Converts to lowercase, trims whitespace, and collapses multiple spaces.
// Uses pre-compiled regex for O(n) performance instead of O(n*m) loop.
func normalizeQuery(query string) string {
// Lowercase for case-insensitive matching
query = strings.ToLower(query)
// Collapse multiple whitespace into single space using pre-compiled regex
query = multiSpaceRegex.ReplaceAllString(query, " ")
// Trim leading/trailing whitespace (after collapsing)
return strings.TrimSpace(query)
}
// getCacheKey generates a cache key from search params.
// Uses direct string concatenation instead of JSON marshal for better performance.
// Queries are normalized for consistent cache hits across whitespace variations.
func (m *Manager) getCacheKey(params SearchParams) string {
// Normalize query for consistent cache keys
normalizedQuery := normalizeQuery(params.Query)
// Hash directly without intermediate string allocation.
// FNV-64a is fast and collision-safe for cache keys.
h := fnv.New64a()
// Write each field directly to the hasher with separators
h.Write([]byte(normalizedQuery))
h.Write([]byte{'|'})
h.Write([]byte(params.Type))
h.Write([]byte{'|'})
h.Write([]byte(params.Project))
h.Write([]byte{'|'})
h.Write([]byte(params.ObsType))
h.Write([]byte{'|'})
h.Write([]byte(params.Concepts))
h.Write([]byte{'|'})
h.Write([]byte(params.Files))
h.Write([]byte{'|'})
h.Write([]byte(strconv.FormatInt(params.DateStart, 10)))
h.Write([]byte{'|'})
h.Write([]byte(strconv.FormatInt(params.DateEnd, 10)))
h.Write([]byte{'|'})
h.Write([]byte(params.OrderBy))
h.Write([]byte{'|'})
h.Write([]byte(strconv.Itoa(params.Limit)))
h.Write([]byte{'|'})
h.Write([]byte(strconv.Itoa(params.Offset)))
h.Write([]byte{'|'})
h.Write([]byte(params.Format))
h.Write([]byte{'|'})
h.Write([]byte(params.Scope))
h.Write([]byte{'|'})
if params.IncludeGlobal {
h.Write([]byte{'1'})
} else {
h.Write([]byte{'0'})
}
h.Write([]byte{'|'})
if params.ExcludeSuperseded {
h.Write([]byte{'1'})
} else {
h.Write([]byte{'0'})
}
return strconv.FormatUint(h.Sum64(), 36) // Base36 for compact representation
}
// getFromCache retrieves a result from cache if valid.
func (m *Manager) getFromCache(key string) (*UnifiedSearchResult, bool) {
m.resultCacheMu.RLock()
defer m.resultCacheMu.RUnlock()
if cached, ok := m.resultCache[key]; ok {
if time.Now().Before(cached.expiresAt) {
atomic.AddInt64(&m.metrics.CacheHits, 1)
return cached.result, true
}
}
return nil, false
}
// putInCache stores a result in the cache.
// Optimized to skip expensive scans when cache is below capacity threshold.
func (m *Manager) putInCache(key string, result *UnifiedSearchResult) {
m.resultCacheMu.Lock()
defer m.resultCacheMu.Unlock()
now := time.Now()
cacheLen := len(m.resultCache)
// Only scan for expired entries when at threshold+ capacity (amortized cleanup)
evictionThreshold := (m.cacheMaxSize * cacheEvictionThreshold) / 100
if cacheLen >= evictionThreshold {
// Evict expired entries
for k, v := range m.resultCache {
if now.After(v.expiresAt) {
delete(m.resultCache, k)
}
}
cacheLen = len(m.resultCache) // Update after eviction
}
// If still at capacity after removing expired, use simple FIFO-style eviction
// Go map iteration order is random, which provides good cache behavior
if cacheLen >= m.cacheMaxSize {
// Evict percentage using random-order iteration (O(n) single pass)
evictCount := max(m.cacheMaxSize*cacheEvictionPercent/100, 1)
evicted := 0
for k := range m.resultCache {
delete(m.resultCache, k)
evicted++
if evicted >= evictCount {
break
}
}
}
m.resultCache[key] = &cachedResult{
result: result,
expiresAt: now.Add(m.cacheTTL),
}
}
// Metrics returns the search metrics for monitoring.
func (m *Manager) Metrics() *SearchMetrics {
return m.metrics
}
// CacheStats returns current cache statistics.
func (m *Manager) CacheStats() map[string]any {
m.resultCacheMu.RLock()
cacheSize := len(m.resultCache)
m.resultCacheMu.RUnlock()
return map[string]any{
"size": cacheSize,
"max_size": m.cacheMaxSize,
"ttl_sec": m.cacheTTL.Seconds(),
}
}
// ClearCache clears the result cache. Useful for testing or after data changes.
func (m *Manager) ClearCache() {
m.resultCacheMu.Lock()
m.resultCache = make(map[string]*cachedResult)
m.resultCacheMu.Unlock()
}
// SearchParams contains parameters for unified search.
@@ -53,16 +631,17 @@ type SearchParams struct {
}
// SearchResult represents a unified search result.
// Field order optimized for memory alignment (fieldalignment).
type SearchResult struct {
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type string `json:"type"`
Title string `json:"title,omitempty"`
Content string `json:"content,omitempty"`
Project string `json:"project"`
Scope string `json:"scope,omitempty"`
ID int64 `json:"id"`
CreatedAt int64 `json:"created_at_epoch"`
Score float64 `json:"score,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Type string `json:"type"`
Title string `json:"title,omitempty"`
Content string `json:"content,omitempty"`
Project string `json:"project"`
Scope string `json:"scope,omitempty"`
ID int64 `json:"id"`
CreatedAt int64 `json:"created_at_epoch"`
Score float64 `json:"score,omitempty"`
}
// UnifiedSearchResult contains the combined search results.
@@ -73,17 +652,69 @@ type UnifiedSearchResult struct {
}
// UnifiedSearch performs a unified search across all document types.
// Uses caching and request coalescing for optimal performance.
func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
start := time.Now()
defer func() {
latency := time.Since(start).Nanoseconds()
atomic.AddInt64(&m.metrics.TotalSearches, 1)
atomic.AddInt64(&m.metrics.TotalLatencyNs, latency)
// Sample latency for histogram (reservoir sampling)
m.metrics.histogramMu.Lock()
if len(m.metrics.latencyHistogram) < latencyHistogramCap {
m.metrics.latencyHistogram = append(m.metrics.latencyHistogram, latency)
}
m.metrics.histogramMu.Unlock()
// Log slow queries
if latency > slowQueryThresholdNs {
log.Warn().
Str("query", truncate(params.Query, queryLogTruncateLen)).
Dur("latency", time.Duration(latency)).
Str("type", params.Type).
Msg("Slow search query")
}
}()
if params.Limit <= 0 {
params.Limit = 20
params.Limit = defaultQueryLimit
}
if params.Limit > 100 {
params.Limit = 100
if params.Limit > maxQueryLimit {
params.Limit = maxQueryLimit
}
if params.OrderBy == "" {
params.OrderBy = "date_desc"
params.OrderBy = defaultOrderBy
}
// Check cache first
cacheKey := m.getCacheKey(params)
if cached, ok := m.getFromCache(cacheKey); ok {
return cached, nil
}
// Use singleflight to coalesce concurrent identical requests
result, err, _ := m.searchGroup.Do(cacheKey, func() (any, error) {
return m.executeSearch(ctx, params)
})
if err != nil {
return nil, err
}
searchResult := result.(*UnifiedSearchResult)
// Cache the result
m.putInCache(cacheKey, searchResult)
// Track query frequency for cache warming
m.trackQueryFrequency(params)
return searchResult, nil
}
// executeSearch performs the actual search without caching/coalescing.
func (m *Manager) executeSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
// If query is provided and vector client is available, use vector search
if params.Query != "" && m.vectorClient != nil && m.vectorClient.IsConnected() {
return m.vectorSearch(ctx, params)
@@ -95,6 +726,13 @@ func (m *Manager) UnifiedSearch(ctx context.Context, params SearchParams) (*Unif
// vectorSearch performs semantic search via sqlite-vec.
func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
start := time.Now()
defer func() {
latency := time.Since(start).Nanoseconds()
atomic.AddInt64(&m.metrics.VectorSearches, 1)
atomic.AddInt64(&m.metrics.VectorLatencyNs, latency)
}()
// Build where filter based on search type
var docType sqlitevec.DocType
switch params.Type {
@@ -110,6 +748,7 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
// Query sqlite-vec
vectorResults, err := m.vectorClient.Query(ctx, params.Query, params.Limit*2, where)
if err != nil {
atomic.AddInt64(&m.metrics.SearchErrors, 1)
// Fall back to filter search on error
return m.filterSearch(ctx, params)
}
@@ -125,7 +764,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
if len(obsIDs) > 0 && (params.Type == "" || params.Type == "observations") {
obs, err := m.observationStore.GetObservationsByIDs(ctx, obsIDs, params.OrderBy, 0)
if err == nil {
if err != nil {
log.Warn().Err(err).Int("count", len(obsIDs)).Msg("Failed to fetch observations by IDs in vector search")
} else {
for _, o := range obs {
// Skip superseded observations when requested
if params.ExcludeSuperseded && o.IsSuperseded {
@@ -138,7 +779,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
if len(summaryIDs) > 0 && (params.Type == "" || params.Type == "sessions") {
summaries, err := m.summaryStore.GetSummariesByIDs(ctx, summaryIDs, params.OrderBy, 0)
if err == nil {
if err != nil {
log.Warn().Err(err).Int("count", len(summaryIDs)).Msg("Failed to fetch summaries by IDs in vector search")
} else {
for _, s := range summaries {
results = append(results, m.summaryToResult(s, params.Format))
}
@@ -147,7 +790,9 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
if len(promptIDs) > 0 && (params.Type == "" || params.Type == "prompts") {
prompts, err := m.promptStore.GetPromptsByIDs(ctx, promptIDs, params.OrderBy, 0)
if err == nil {
if err != nil {
log.Warn().Err(err).Int("count", len(promptIDs)).Msg("Failed to fetch prompts by IDs in vector search")
} else {
for _, p := range prompts {
results = append(results, m.promptToResult(p, params.Format))
}
@@ -168,6 +813,13 @@ func (m *Manager) vectorSearch(ctx context.Context, params SearchParams) (*Unifi
// filterSearch performs structured filter search via SQLite.
func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*UnifiedSearchResult, error) {
start := time.Now()
defer func() {
latency := time.Since(start).Nanoseconds()
atomic.AddInt64(&m.metrics.FilterSearches, 1)
atomic.AddInt64(&m.metrics.FilterLatencyNs, latency)
}()
var results []SearchResult
// Search observations
@@ -182,7 +834,9 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi
obs, err = m.observationStore.GetRecentObservations(ctx, params.Project, params.Limit)
}
if err == nil {
if err != nil {
log.Warn().Err(err).Str("project", params.Project).Msg("Failed to fetch observations in filter search")
} else {
for _, o := range obs {
results = append(results, m.observationToResult(o, params.Format))
}
@@ -192,7 +846,9 @@ func (m *Manager) filterSearch(ctx context.Context, params SearchParams) (*Unifi
// Search summaries
if params.Type == "" || params.Type == "sessions" {
summaries, err := m.summaryStore.GetRecentSummaries(ctx, params.Project, params.Limit)
if err == nil {
if err != nil {
log.Warn().Err(err).Str("project", params.Project).Msg("Failed to fetch summaries in filter search")
} else {
for _, s := range summaries {
results = append(results, m.summaryToResult(s, params.Format))
}
@@ -249,7 +905,7 @@ func (m *Manager) observationToResult(obs *models.Observation, format string) Se
Project: obs.Project,
Scope: string(obs.Scope),
CreatedAt: obs.CreatedAtEpoch,
Metadata: map[string]interface{}{
Metadata: map[string]any{
"obs_type": string(obs.Type),
"scope": string(obs.Scope),
},
@@ -275,7 +931,7 @@ func (m *Manager) summaryToResult(summary *models.SessionSummary, format string)
}
if summary.Request.Valid {
result.Title = truncate(summary.Request.String, 100)
result.Title = truncate(summary.Request.String, titleTruncateLen)
}
if format == "full" && summary.Learned.Valid {
@@ -293,7 +949,7 @@ func (m *Manager) promptToResult(prompt *models.UserPromptWithSession, format st
CreatedAt: prompt.CreatedAtEpoch,
}
result.Title = truncate(prompt.PromptText, 100)
result.Title = truncate(prompt.PromptText, titleTruncateLen)
if format == "full" {
result.Content = prompt.PromptText
+707 -28
View File
@@ -5,19 +5,105 @@ import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/rs/zerolog/log"
"golang.org/x/sync/singleflight"
)
// embeddingCacheEntry stores a cached embedding with its timestamp.
type embeddingCacheEntry struct {
embedding []float32
timestamp int64 // Unix nano
}
// resultCacheEntry stores cached query results.
type resultCacheEntry struct {
queryHash string
results []QueryResult
timestamp int64
}
// Client provides vector operations via sqlite-vec.
type Client struct {
db *sql.DB
embedSvc *embedding.Service
mu sync.Mutex
embeddingGroup singleflight.Group
resultCache map[string]resultCacheEntry
db *sql.DB
embedSvc *embedding.Service
queryCache map[string]embeddingCacheEntry
stopCleanup chan struct{}
stats CacheStats
cleanupWg sync.WaitGroup
resultCacheTTLNano int64
cacheTTLNano int64
resultCacheMaxSize int
cacheMaxSize int
resultCacheMu sync.RWMutex
queryCacheMu sync.RWMutex
readMu sync.RWMutex
writeMu sync.Mutex
}
// CacheStats tracks cache performance metrics using atomic counters for lock-free updates.
type CacheStats struct {
embeddingHits atomic.Int64
embeddingMisses atomic.Int64
resultHits atomic.Int64
resultMisses atomic.Int64
embeddingEvictions atomic.Int64
resultEvictions atomic.Int64
}
// CacheStatsSnapshot is the exported version of CacheStats for JSON marshaling.
type CacheStatsSnapshot struct {
EmbeddingHits int64 `json:"embedding_hits"`
EmbeddingMisses int64 `json:"embedding_misses"`
ResultHits int64 `json:"result_hits"`
ResultMisses int64 `json:"result_misses"`
EmbeddingEvictions int64 `json:"embedding_evictions"`
ResultEvictions int64 `json:"result_evictions"`
}
// HitRate returns the cache hit rate as a percentage.
func (s CacheStatsSnapshot) HitRate() float64 {
total := s.EmbeddingHits + s.EmbeddingMisses + s.ResultHits + s.ResultMisses
if total == 0 {
return 0
}
hits := s.EmbeddingHits + s.ResultHits
return float64(hits) / float64(total) * 100
}
// HitRate returns the cache hit rate as a percentage.
func (s *CacheStats) HitRate() float64 {
embHits := s.embeddingHits.Load()
embMisses := s.embeddingMisses.Load()
resHits := s.resultHits.Load()
resMisses := s.resultMisses.Load()
total := embHits + embMisses + resHits + resMisses
if total == 0 {
return 0
}
hits := embHits + resHits
return float64(hits) / float64(total) * 100
}
// Snapshot returns a copy of the current stats.
func (s *CacheStats) Snapshot() CacheStatsSnapshot {
return CacheStatsSnapshot{
EmbeddingHits: s.embeddingHits.Load(),
EmbeddingMisses: s.embeddingMisses.Load(),
ResultHits: s.resultHits.Load(),
ResultMisses: s.resultMisses.Load(),
EmbeddingEvictions: s.embeddingEvictions.Load(),
ResultEvictions: s.resultEvictions.Load(),
}
}
// Config holds configuration for the client.
@@ -34,10 +120,23 @@ func NewClient(cfg Config, embedSvc *embedding.Service) (*Client, error) {
return nil, fmt.Errorf("embedding service required")
}
return &Client{
db: cfg.DB,
embedSvc: embedSvc,
}, nil
c := &Client{
db: cfg.DB,
embedSvc: embedSvc,
queryCache: make(map[string]embeddingCacheEntry),
cacheMaxSize: 500, // Cache up to 500 query embeddings
cacheTTLNano: 5 * 60 * 1e9, // 5 minute TTL for embeddings
resultCache: make(map[string]resultCacheEntry),
resultCacheMaxSize: 200, // Cache up to 200 search results
resultCacheTTLNano: 60 * 1e9, // 1 minute TTL for results (shorter since data changes)
stopCleanup: make(chan struct{}),
}
// Start background cache cleanup goroutine
c.cleanupWg.Add(1)
go c.cacheCleanupLoop()
return c, nil
}
// AddDocuments adds documents with their embeddings to the vector store.
@@ -46,8 +145,8 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Generate embeddings for all documents
texts := make([]string, len(docs))
@@ -75,7 +174,10 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
}
defer func() {
if err != nil {
_ = tx.Rollback()
if rbErr := tx.Rollback(); rbErr != nil {
// Rollback failure is serious - indicates potential data corruption risk
log.Error().Err(rbErr).Err(err).Msg("Failed to rollback transaction after error - data may be inconsistent")
}
}
}()
@@ -118,6 +220,9 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error {
return fmt.Errorf("commit transaction: %w", err)
}
// Invalidate result cache since data changed
c.InvalidateResultCache()
log.Debug().Int("count", len(docs)).Str("model", modelVersion).Msg("Added documents to sqlite-vec")
return nil
}
@@ -128,12 +233,12 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Build placeholder string
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
args := make([]any, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
@@ -148,17 +253,25 @@ func (c *Client) DeleteDocuments(ctx context.Context, ids []string) error {
return fmt.Errorf("delete documents: %w", err)
}
// Invalidate result cache since data changed
c.InvalidateResultCache()
log.Debug().Int("count", len(ids)).Msg("Deleted documents from sqlite-vec")
return nil
}
// Query performs a vector similarity search.
func (c *Client) Query(ctx context.Context, query string, limit int, where map[string]any) ([]QueryResult, error) {
c.mu.Lock()
defer c.mu.Unlock()
// Build cache key from query + filters + limit
cacheKey := c.buildResultCacheKey(query, limit, where)
// Generate query embedding
queryEmb, err := c.embedSvc.Embed(query)
// Check result cache first
if results, ok := c.getResultFromCache(cacheKey); ok {
return results, nil
}
// Generate query embedding OUTSIDE the lock for better concurrency
queryEmb, err := c.getOrComputeEmbedding(query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
@@ -169,9 +282,13 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
return nil, fmt.Errorf("serialize query embedding: %w", err)
}
// Now acquire read lock for the actual DB query
c.readMu.RLock()
defer c.readMu.RUnlock()
// Build query with filters
// vec0 supports WHERE clauses on metadata columns
args := []interface{}{queryBlob}
args := []any{queryBlob}
sqlQuery := `
SELECT
@@ -232,6 +349,9 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s
return nil, fmt.Errorf("iterate rows: %w", err)
}
// Cache the results
c.cacheResults(cacheKey, results)
log.Debug().
Str("query", truncateString(query, 50)).
Int("results", len(results)).
@@ -245,11 +365,196 @@ func (c *Client) IsConnected() bool {
return c.db != nil
}
// Close is a no-op (db managed externally).
// Close stops the background cleanup goroutine (db managed externally).
func (c *Client) Close() error {
// Signal cleanup goroutine to stop
close(c.stopCleanup)
// Wait for cleanup to finish
c.cleanupWg.Wait()
return nil
}
// cacheCleanupLoop periodically removes expired cache entries.
func (c *Client) cacheCleanupLoop() {
defer c.cleanupWg.Done()
ticker := time.NewTicker(30 * time.Second) // Cleanup every 30 seconds
defer ticker.Stop()
for {
select {
case <-c.stopCleanup:
return
case <-ticker.C:
c.cleanupExpiredCaches()
}
}
}
// cleanupExpiredCaches removes expired entries from both caches.
func (c *Client) cleanupExpiredCaches() {
now := time.Now().UnixNano()
var embeddingExpired, resultExpired int64
// Cleanup embedding cache
c.queryCacheMu.Lock()
for key, entry := range c.queryCache {
if now-entry.timestamp > c.cacheTTLNano {
delete(c.queryCache, key)
embeddingExpired++
}
}
c.queryCacheMu.Unlock()
// Cleanup result cache
c.resultCacheMu.Lock()
for key, entry := range c.resultCache {
if now-entry.timestamp > c.resultCacheTTLNano {
delete(c.resultCache, key)
resultExpired++
}
}
c.resultCacheMu.Unlock()
// Update stats atomically
if embeddingExpired > 0 || resultExpired > 0 {
c.stats.embeddingEvictions.Add(embeddingExpired)
c.stats.resultEvictions.Add(resultExpired)
log.Debug().
Int64("embedding_expired", embeddingExpired).
Int64("result_expired", resultExpired).
Msg("Cache cleanup completed")
}
}
// BatchQueryResult holds results from a batch query operation.
type BatchQueryResult struct {
Error error
Query string
Results []QueryResult
}
// QueryBatch performs multiple vector searches concurrently.
// Returns results in the same order as input queries.
// Uses a worker pool to limit concurrent queries.
func (c *Client) QueryBatch(ctx context.Context, queries []string, limit int, where map[string]any) []BatchQueryResult {
if len(queries) == 0 {
return nil
}
// Limit concurrency to avoid overwhelming the database
maxConcurrent := min(4, len(queries))
results := make([]BatchQueryResult, len(queries))
sem := make(chan struct{}, maxConcurrent)
var wg sync.WaitGroup
for i, query := range queries {
wg.Add(1)
go func(idx int, q string) {
defer wg.Done()
// Acquire semaphore
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-ctx.Done():
results[idx] = BatchQueryResult{
Query: q,
Error: ctx.Err(),
}
return
}
// Execute query
queryResults, err := c.Query(ctx, q, limit, where)
results[idx] = BatchQueryResult{
Query: q,
Results: queryResults,
Error: err,
}
}(i, query)
}
wg.Wait()
return results
}
// QueryMultiField searches across multiple fields for a single query.
// Combines results from different field types and deduplicates by document ID.
func (c *Client) QueryMultiField(ctx context.Context, query string, limit int, docType string, project string) ([]QueryResult, error) {
// Generate embedding once
queryEmb, err := c.getOrComputeEmbedding(query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
// Serialize query embedding
queryBlob, err := sqlite_vec.SerializeFloat32(queryEmb)
if err != nil {
return nil, fmt.Errorf("serialize query embedding: %w", err)
}
c.readMu.RLock()
defer c.readMu.RUnlock()
// Query with field type aggregation - get best match per document
sqlQuery := `
WITH ranked_results AS (
SELECT
doc_id,
distance,
sqlite_id,
doc_type,
field_type,
project,
scope,
ROW_NUMBER() OVER (PARTITION BY sqlite_id ORDER BY distance ASC) as rn
FROM vectors
WHERE embedding MATCH ?
AND doc_type = ?
AND (project = ? OR scope = 'global')
)
SELECT doc_id, distance, sqlite_id, doc_type, field_type, project, scope
FROM ranked_results
WHERE rn = 1
ORDER BY distance
LIMIT ?
`
rows, err := c.db.QueryContext(ctx, sqlQuery, queryBlob, docType, project, limit)
if err != nil {
return nil, fmt.Errorf("query vectors: %w", err)
}
defer rows.Close()
// Pre-allocate with limit to avoid repeated slice growth
results := make([]QueryResult, 0, limit)
for rows.Next() {
var r QueryResult
var sqliteID int64
var docTypeVal, fieldType, projectVal, scope sql.NullString
if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docTypeVal, &fieldType, &projectVal, &scope); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
r.Similarity = DistanceToSimilarity(r.Distance)
r.Metadata = map[string]any{
"sqlite_id": float64(sqliteID),
"doc_type": docTypeVal.String,
"field_type": fieldType.String,
"project": projectVal.String,
"scope": scope.String,
}
results = append(results, r)
}
return results, rows.Err()
}
// truncateString truncates a string to maxLen characters.
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
@@ -260,8 +565,8 @@ func truncateString(s string, maxLen int) string {
// Count returns the total number of vectors in the store.
func (c *Client) Count(ctx context.Context) (int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.readMu.RLock()
defer c.readMu.RUnlock()
var count int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&count)
@@ -281,8 +586,8 @@ func (c *Client) ModelVersion() string {
// - The vectors table is empty
// - Any vectors have a different model_version than the current model
func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) {
c.mu.Lock()
defer c.mu.Unlock()
c.readMu.RLock()
defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version()
@@ -329,8 +634,8 @@ type StaleVectorInfo struct {
// GetStaleVectors returns doc_ids of vectors with mismatched or null model versions.
// This enables granular rebuild - only re-embedding documents that need updating.
func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.readMu.RLock()
defer c.readMu.RUnlock()
currentModel := c.embedSvc.Version()
@@ -372,6 +677,134 @@ func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error)
return results, nil
}
// VectorHealthStats contains comprehensive health information about the vector store.
type VectorHealthStats struct {
CoverageByType map[string]int64 `json:"coverage_by_type"`
ModelVersions map[string]int64 `json:"model_versions"`
ProjectCounts map[string]int64 `json:"project_counts"`
CurrentModel string `json:"current_model"`
RebuildReason string `json:"rebuild_reason,omitempty"`
EmbeddingCache CacheStatsSnapshot `json:"embedding_cache"`
TotalVectors int64 `json:"total_vectors"`
StaleVectors int64 `json:"stale_vectors"`
NeedsRebuild bool `json:"needs_rebuild"`
}
// GetHealthStats returns comprehensive health statistics about the vector store.
func (c *Client) GetHealthStats(ctx context.Context) (*VectorHealthStats, error) {
c.readMu.RLock()
defer c.readMu.RUnlock()
stats := &VectorHealthStats{
CurrentModel: c.embedSvc.Version(),
CoverageByType: make(map[string]int64),
ModelVersions: make(map[string]int64),
ProjectCounts: make(map[string]int64),
EmbeddingCache: c.stats.Snapshot(),
}
// Get total count
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&stats.TotalVectors)
if err != nil {
return nil, fmt.Errorf("count total vectors: %w", err)
}
// Get stale count
err = c.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
stats.CurrentModel,
).Scan(&stats.StaleVectors)
if err != nil {
return nil, fmt.Errorf("count stale vectors: %w", err)
}
// Check if rebuild needed
stats.NeedsRebuild, stats.RebuildReason = c.needsRebuildUnlocked(ctx, stats.CurrentModel)
// Get coverage by doc_type
rows, err := c.db.QueryContext(ctx, "SELECT doc_type, COUNT(*) FROM vectors GROUP BY doc_type")
if err != nil {
return nil, fmt.Errorf("query doc types: %w", err)
}
defer rows.Close()
for rows.Next() {
var docType sql.NullString
var count int64
if err := rows.Scan(&docType, &count); err != nil {
return nil, fmt.Errorf("scan doc type: %w", err)
}
if docType.Valid {
stats.CoverageByType[docType.String] = count
} else {
stats.CoverageByType["unknown"] = count
}
}
// Get model version distribution
rows2, err := c.db.QueryContext(ctx, "SELECT COALESCE(model_version, 'unknown'), COUNT(*) FROM vectors GROUP BY model_version")
if err != nil {
return nil, fmt.Errorf("query model versions: %w", err)
}
defer rows2.Close()
for rows2.Next() {
var version string
var count int64
if err := rows2.Scan(&version, &count); err != nil {
return nil, fmt.Errorf("scan model version: %w", err)
}
stats.ModelVersions[version] = count
}
// Get project counts (top 10)
rows3, err := c.db.QueryContext(ctx,
"SELECT COALESCE(project, 'global'), COUNT(*) FROM vectors GROUP BY project ORDER BY COUNT(*) DESC LIMIT 10")
if err != nil {
return nil, fmt.Errorf("query projects: %w", err)
}
defer rows3.Close()
for rows3.Next() {
var project string
var count int64
if err := rows3.Scan(&project, &count); err != nil {
return nil, fmt.Errorf("scan project: %w", err)
}
stats.ProjectCounts[project] = count
}
return stats, nil
}
// needsRebuildUnlocked checks if rebuild is needed without acquiring lock (caller must hold lock).
func (c *Client) needsRebuildUnlocked(ctx context.Context, currentModel string) (bool, string) {
var totalCount int64
err := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM vectors").Scan(&totalCount)
if err != nil {
return false, ""
}
if totalCount == 0 {
return true, "empty"
}
var staleCount int64
err = c.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM vectors WHERE model_version != ? OR model_version IS NULL",
currentModel,
).Scan(&staleCount)
if err != nil {
return false, ""
}
if staleCount > 0 {
return true, fmt.Sprintf("model_mismatch:%d", staleCount)
}
return false, ""
}
// DeleteVectorsByDocIDs removes vectors by their doc_ids.
// Used for granular rebuild - delete stale vectors before re-adding.
func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) error {
@@ -379,12 +812,12 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Build placeholder string
placeholders := make([]string, len(docIDs))
args := make([]interface{}, len(docIDs))
args := make([]any, len(docIDs))
for i, id := range docIDs {
placeholders[i] = "?"
args[i] = id
@@ -402,3 +835,249 @@ func (c *Client) DeleteVectorsByDocIDs(ctx context.Context, docIDs []string) err
log.Debug().Int("count", len(docIDs)).Msg("Deleted stale vectors by doc_id")
return nil
}
// DeleteByObservationID removes all vectors associated with an observation ID.
// Vectors are stored with doc_ids that include the observation ID, e.g., "obs_123_narrative".
func (c *Client) DeleteByObservationID(ctx context.Context, obsID int64) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
// Vectors have doc_ids like "obs_123_narrative", "obs_123_facts_0", etc.
pattern := fmt.Sprintf("obs_%d_%%", obsID)
_, err := c.db.ExecContext(ctx, "DELETE FROM vectors WHERE doc_id LIKE ?", pattern)
if err != nil {
return fmt.Errorf("delete vectors for observation %d: %w", obsID, err)
}
log.Debug().Int64("observation_id", obsID).Msg("Deleted vectors for observation")
return nil
}
// getOrComputeEmbedding returns a cached embedding or computes a new one.
// Uses singleflight to prevent duplicate concurrent computations for the same query.
func (c *Client) getOrComputeEmbedding(query string) ([]float32, error) {
now := time.Now().UnixNano()
// Check cache first (read lock)
c.queryCacheMu.RLock()
if entry, ok := c.queryCache[query]; ok {
if now-entry.timestamp < c.cacheTTLNano {
c.queryCacheMu.RUnlock()
c.stats.embeddingHits.Add(1)
return entry.embedding, nil
}
}
c.queryCacheMu.RUnlock()
// Cache miss - use singleflight to deduplicate concurrent embedding requests
result, err, _ := c.embeddingGroup.Do(query, func() (any, error) {
// Double-check cache inside singleflight (another goroutine may have just cached it)
c.queryCacheMu.RLock()
if entry, ok := c.queryCache[query]; ok {
if time.Now().UnixNano()-entry.timestamp < c.cacheTTLNano {
c.queryCacheMu.RUnlock()
return entry.embedding, nil
}
}
c.queryCacheMu.RUnlock()
// Record cache miss
c.stats.embeddingMisses.Add(1)
// Compute embedding
emb, err := c.embedSvc.Embed(query)
if err != nil {
return nil, err
}
// Store in cache (write lock)
c.queryCacheMu.Lock()
nowCache := time.Now().UnixNano()
// Evict old entries if cache is full or near capacity (80% threshold)
evictionThreshold := (c.cacheMaxSize * 8) / 10
if len(c.queryCache) >= evictionThreshold {
// Phase 1: Remove ALL expired entries first (not just 10%)
evicted := int64(0)
for k, v := range c.queryCache {
if nowCache-v.timestamp > c.cacheTTLNano {
delete(c.queryCache, k)
evicted++
}
}
// Phase 2: If still at capacity, evict 10% using random iteration (O(n) instead of O(n log n))
// Go map iteration order is randomized, providing good cache behavior without sorting
if len(c.queryCache) >= c.cacheMaxSize {
evictCount := max(c.cacheMaxSize/10, 1)
for k := range c.queryCache {
delete(c.queryCache, k)
evicted++
evictCount--
if evictCount <= 0 {
break
}
}
}
if evicted > 0 {
c.stats.embeddingEvictions.Add(evicted)
}
}
c.queryCache[query] = embeddingCacheEntry{
embedding: emb,
timestamp: nowCache,
}
c.queryCacheMu.Unlock()
return emb, nil
})
if err != nil {
return nil, err
}
return result.([]float32), nil
}
// ClearCache clears the embedding cache.
func (c *Client) ClearCache() {
c.queryCacheMu.Lock()
c.queryCache = make(map[string]embeddingCacheEntry)
c.queryCacheMu.Unlock()
}
// GetCacheStats returns comprehensive cache statistics.
func (c *Client) GetCacheStats() CacheStatsSnapshot {
return c.stats.Snapshot()
}
// CacheStats returns basic cache size info for backward compatibility.
// Deprecated: Use GetCacheStats() for comprehensive statistics.
func (c *Client) CacheStats() (size int, maxSize int) {
c.queryCacheMu.RLock()
size = len(c.queryCache)
c.queryCacheMu.RUnlock()
return size, c.cacheMaxSize
}
// EmbeddingCacheSize returns the current embedding cache size.
func (c *Client) EmbeddingCacheSize() int {
c.queryCacheMu.RLock()
defer c.queryCacheMu.RUnlock()
return len(c.queryCache)
}
// ResultCacheSize returns the current result cache size.
func (c *Client) ResultCacheSize() int {
c.resultCacheMu.RLock()
defer c.resultCacheMu.RUnlock()
return len(c.resultCache)
}
// buildResultCacheKey creates a unique key for caching query results.
// Uses strings.Builder to avoid intermediate allocations.
func (c *Client) buildResultCacheKey(query string, limit int, where map[string]any) string {
// Pre-allocate with typical key size to avoid reallocation
var b strings.Builder
b.Grow(len(query) + 32) // query + typical prefix/suffix overhead
b.WriteString("q:")
b.WriteString(query)
b.WriteString(":l:")
b.WriteString(strconv.Itoa(limit))
if docType, ok := where["doc_type"].(string); ok {
b.WriteString(":dt:")
b.WriteString(docType)
}
if project, ok := where["project"].(string); ok {
b.WriteString(":p:")
b.WriteString(project)
}
return b.String()
}
// getResultFromCache retrieves cached results if available and not expired.
func (c *Client) getResultFromCache(cacheKey string) ([]QueryResult, bool) {
now := time.Now().UnixNano()
c.resultCacheMu.RLock()
entry, ok := c.resultCache[cacheKey]
c.resultCacheMu.RUnlock()
if !ok {
c.stats.resultMisses.Add(1)
return nil, false
}
// Check if entry is expired
if now-entry.timestamp > c.resultCacheTTLNano {
c.stats.resultMisses.Add(1)
return nil, false
}
c.stats.resultHits.Add(1)
// Return a copy to prevent mutation
results := make([]QueryResult, len(entry.results))
copy(results, entry.results)
return results, true
}
// cacheResults stores query results in the cache.
func (c *Client) cacheResults(cacheKey string, results []QueryResult) {
now := time.Now().UnixNano()
c.resultCacheMu.Lock()
defer c.resultCacheMu.Unlock()
// Evict old entries if cache is full
if len(c.resultCache) >= c.resultCacheMaxSize {
// Two-phase eviction: (1) TTL-expired entries, (2) random if still over capacity
evicted := 0
targetSize := c.resultCacheMaxSize * 8 / 10 // Target 80% capacity
// Phase 1: Remove all TTL-expired entries
for k, v := range c.resultCache {
if now-v.timestamp > c.resultCacheTTLNano {
delete(c.resultCache, k)
evicted++
}
}
// Phase 2: If still over target, remove random entries until at target
if len(c.resultCache) >= targetSize {
evictCount := len(c.resultCache) - targetSize + 1
for k := range c.resultCache {
delete(c.resultCache, k)
evicted++
evictCount--
if evictCount <= 0 {
break
}
}
}
if evicted > 0 {
c.stats.resultEvictions.Add(int64(evicted))
}
}
// Make a copy of results to store
resultsCopy := make([]QueryResult, len(results))
copy(resultsCopy, results)
c.resultCache[cacheKey] = resultCacheEntry{
results: resultsCopy,
timestamp: now,
queryHash: cacheKey,
}
}
// InvalidateResultCache clears the result cache.
// Should be called after write operations that modify vectors.
func (c *Client) InvalidateResultCache() {
c.resultCacheMu.Lock()
c.resultCache = make(map[string]resultCacheEntry)
c.resultCacheMu.Unlock()
}
+184
View File
@@ -338,3 +338,187 @@ func (s *Sync) DeletePatterns(ctx context.Context, patternIDs []int64) error {
return nil
}
// BatchSyncConfig configures batch synchronization behavior.
type BatchSyncConfig struct {
BatchSize int // Number of items per batch (default: 50)
ProgressLogFreq int // Log progress every N items (default: 100)
}
// DefaultBatchSyncConfig returns sensible defaults for batch sync.
func DefaultBatchSyncConfig() BatchSyncConfig {
return BatchSyncConfig{
BatchSize: 50,
ProgressLogFreq: 100,
}
}
// BatchSyncObservations syncs multiple observations efficiently in batches.
// This reduces memory pressure during large rebuilds by processing in chunks.
func (s *Sync) BatchSyncObservations(ctx context.Context, observations []*models.Observation, cfg BatchSyncConfig) (synced int, errors int) {
if len(observations) == 0 {
return 0, 0
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 50
}
if cfg.ProgressLogFreq <= 0 {
cfg.ProgressLogFreq = 100
}
for i := 0; i < len(observations); i += cfg.BatchSize {
// Check context cancellation
select {
case <-ctx.Done():
log.Warn().Int("synced", synced).Int("remaining", len(observations)-i).Msg("Batch sync cancelled")
return synced, errors
default:
}
end := min(i+cfg.BatchSize, len(observations))
batch := observations[i:end]
var docs []Document
// Collect all documents for this batch
for _, obs := range batch {
docs = append(docs, s.formatObservationDocs(obs)...)
}
// Add all documents in one call
if len(docs) > 0 {
if err := s.client.AddDocuments(ctx, docs); err != nil {
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync observation batch")
errors += len(batch)
continue
}
}
synced += len(batch)
// Log progress periodically
if synced%cfg.ProgressLogFreq == 0 || synced == len(observations) {
log.Debug().Int("synced", synced).Int("total", len(observations)).Msg("Observation batch sync progress")
}
}
return synced, errors
}
// BatchSyncSummaries syncs multiple summaries efficiently in batches.
func (s *Sync) BatchSyncSummaries(ctx context.Context, summaries []*models.SessionSummary, cfg BatchSyncConfig) (synced int, errors int) {
if len(summaries) == 0 {
return 0, 0
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 50
}
if cfg.ProgressLogFreq <= 0 {
cfg.ProgressLogFreq = 100
}
for i := 0; i < len(summaries); i += cfg.BatchSize {
// Check context cancellation
select {
case <-ctx.Done():
log.Warn().Int("synced", synced).Int("remaining", len(summaries)-i).Msg("Batch sync cancelled")
return synced, errors
default:
}
end := min(i+cfg.BatchSize, len(summaries))
batch := summaries[i:end]
var docs []Document
// Collect all documents for this batch
for _, summary := range batch {
docs = append(docs, s.formatSummaryDocs(summary)...)
}
// Add all documents in one call
if len(docs) > 0 {
if err := s.client.AddDocuments(ctx, docs); err != nil {
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync summary batch")
errors += len(batch)
continue
}
}
synced += len(batch)
// Log progress periodically
if synced%cfg.ProgressLogFreq == 0 || synced == len(summaries) {
log.Debug().Int("synced", synced).Int("total", len(summaries)).Msg("Summary batch sync progress")
}
}
return synced, errors
}
// BatchSyncPrompts syncs multiple user prompts efficiently in batches.
func (s *Sync) BatchSyncPrompts(ctx context.Context, prompts []*models.UserPromptWithSession, cfg BatchSyncConfig) (synced int, errors int) {
if len(prompts) == 0 {
return 0, 0
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 50
}
if cfg.ProgressLogFreq <= 0 {
cfg.ProgressLogFreq = 100
}
for i := 0; i < len(prompts); i += cfg.BatchSize {
// Check context cancellation
select {
case <-ctx.Done():
log.Warn().Int("synced", synced).Int("remaining", len(prompts)-i).Msg("Batch sync cancelled")
return synced, errors
default:
}
end := min(i+cfg.BatchSize, len(prompts))
batch := prompts[i:end]
docs := make([]Document, 0, len(batch))
// Collect all documents for this batch
for _, prompt := range batch {
docs = append(docs, Document{
ID: fmt.Sprintf("prompt_%d", prompt.ID),
Content: prompt.PromptText,
Metadata: map[string]any{
"sqlite_id": prompt.ID,
"doc_type": "user_prompt",
"sdk_session_id": prompt.SDKSessionID,
"project": prompt.Project,
"scope": "",
"created_at_epoch": prompt.CreatedAtEpoch,
"prompt_number": prompt.PromptNumber,
"field_type": "prompt",
},
})
}
// Add all documents in one call
if len(docs) > 0 {
if err := s.client.AddDocuments(ctx, docs); err != nil {
log.Warn().Err(err).Int("batchStart", i).Int("batchSize", len(batch)).Msg("Failed to sync prompt batch")
errors += len(batch)
continue
}
}
synced += len(batch)
// Log progress periodically
if synced%cfg.ProgressLogFreq == 0 || synced == len(prompts) {
log.Debug().Int("synced", synced).Int("total", len(prompts)).Msg("Prompt batch sync progress")
}
}
return synced, errors
}
+125 -1273
View File
File diff suppressed because it is too large Load Diff
+677
View File
@@ -0,0 +1,677 @@
// Package worker provides context and search-related HTTP handlers.
package worker
import (
"context"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// handleSearchByPrompt searches observations relevant to a user prompt.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
cwd := r.URL.Query().Get("cwd")
if project == "" || query == "" {
http.Error(w, "project and query required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
limit := gorm.ParseLimitParamWithMax(r, DefaultSearchLimit, 200)
var observations []*models.Observation
var err error
var usedVector bool
similarityScores := make(map[int64]float64) // Track similarity per observation
// Get threshold settings from config
threshold := s.config.ContextRelevanceThreshold
maxResults := s.config.ContextMaxPromptResults
// Generate expanded queries if query expander is available
// Use timeout context to prevent query expansion from blocking
var expandedQueries []expansion.ExpandedQuery
var detectedIntent string
if s.queryExpander != nil {
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
cfg := expansion.DefaultConfig()
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
expandCancel() // Cancel immediately after use (defer not needed - no panic possible between creation and here)
if len(expandedQueries) > 0 {
detectedIntent = string(expandedQueries[0].Intent)
}
}
if len(expandedQueries) == 0 {
// Fallback to just the original query
expandedQueries = []expansion.ExpandedQuery{
{Query: query, Weight: 1.0, Source: "original"},
}
}
// Try vector search first if available
var vectorSearchFailed bool
if s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
// Search with each expanded query and merge results
// Pre-allocate with estimated capacity to avoid repeated reallocation
estimatedCapacity := len(expandedQueries) * limit * 2
allVectorResults := make([]sqlitevec.QueryResult, 0, estimatedCapacity)
queryWeights := make(map[string]float64, len(expandedQueries))
var vectorErrors int
for _, eq := range expandedQueries {
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
if vecErr != nil {
vectorErrors++
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
} else if len(vectorResults) > 0 {
// Apply weight to similarity scores before merging
for i := range vectorResults {
vectorResults[i].Similarity *= eq.Weight
}
allVectorResults = append(allVectorResults, vectorResults...)
queryWeights[eq.Query] = eq.Weight
}
}
// Track if vector search had issues
if vectorErrors > 0 && vectorErrors == len(expandedQueries) {
vectorSearchFailed = true
log.Warn().Int("errors", vectorErrors).Str("project", project).Msg("All vector queries failed, falling back to FTS")
}
if len(allVectorResults) > 0 {
// Filter by relevance threshold before extracting IDs
// Use a slightly lower threshold for expanded queries
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
// Build similarity map for filtered results (keeping highest weighted score per observation)
for _, vr := range filteredResults {
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
id := int64(sqliteID)
// Keep the highest score for each observation
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
similarityScores[id] = vr.Similarity
}
}
}
// Extract observation IDs with project/scope filtering using shared helper
obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
if len(obsIDs) > 0 {
// Fetch full observations from SQLite
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
}
}
}
// Fall back to FTS if vector search not available, failed, or returned no results
if !usedVector || len(observations) == 0 {
if vectorSearchFailed {
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
}
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
if err != nil {
// FTS might fail if query has special chars, try without
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
}
// Fast staleness filter - NO verification (that's too slow for interactive use)
// Just check mtimes and exclude obviously stale observations
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 && cwd != "" {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Apply cross-encoder reranking if available
var reranked bool
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
// Build candidates from observations with their bi-encoder scores
candidates := make([]reranking.Candidate, len(freshObservations))
for i, obs := range freshObservations {
// Use strings.Builder for efficient concatenation
var content string
if obs.Narrative.Valid && obs.Narrative.String != "" {
var sb strings.Builder
sb.Grow(len(obs.Title.String) + 1 + len(obs.Narrative.String))
sb.WriteString(obs.Title.String)
sb.WriteByte(' ')
sb.WriteString(obs.Narrative.String)
content = sb.String()
} else {
content = obs.Title.String
}
candidates[i] = reranking.Candidate{
ID: strconv.FormatInt(obs.ID, 10), // Faster than fmt.Sprintf
Content: content,
Score: similarityScores[obs.ID],
Metadata: map[string]any{"obs_idx": i},
}
}
// Rerank using cross-encoder - use pure mode or combined scores
var rerankResults []reranking.RerankResult
var rerankErr error
if s.config.RerankingPureMode {
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
} else {
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
}
if rerankErr != nil {
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
} else if len(rerankResults) > 0 {
// Update similarity scores with reranked scores
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
similarityScores[id] = rr.CombinedScore
}
}
// Reorder observations based on rerank results
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
obsMap := make(map[int64]*models.Observation)
for _, obs := range freshObservations {
obsMap[obs.ID] = obs
}
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
if obs, ok := obsMap[id]; ok {
reorderedObs = append(reorderedObs, obs)
}
}
}
freshObservations = reorderedObs
reranked = true
log.Debug().
Int("candidates", len(candidates)).
Int("returned", len(rerankResults)).
Msg("Cross-encoder reranking complete")
}
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
// Sort by similarity score (highest first) if we have scores and didn't rerank
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
sort.Slice(clusteredObservations, func(i, j int) bool {
scoreI := similarityScores[clusteredObservations[i].ID]
scoreJ := similarityScores[clusteredObservations[j].ID]
return scoreI > scoreJ
})
}
// Apply max results cap if configured
if maxResults > 0 && len(clusteredObservations) > maxResults {
clusteredObservations = clusteredObservations[:maxResults]
}
// Record retrieval stats with staleness metrics
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), true)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info().
Str("project", project).
Str("query", query).
Str("intent", detectedIntent).
Int("expansions", len(expandedQueries)).
Int("found", len(clusteredObservations)).
Int("stale_excluded", staleCount).
Float64("threshold", threshold).
Msg("Prompt-based observation search")
// Build response with similarity scores
obsWithScores := make([]map[string]any, len(clusteredObservations))
for i, obs := range clusteredObservations {
obsMap := obs.ToMap()
if score, ok := similarityScores[obs.ID]; ok {
obsMap["similarity"] = score
}
obsWithScores[i] = obsMap
}
// Build expansion info for response
expansionInfo := make([]map[string]any, len(expandedQueries))
for i, eq := range expandedQueries {
expansionInfo[i] = map[string]any{
"query": eq.Query,
"weight": eq.Weight,
"source": eq.Source,
}
}
// Track this search for analytics
s.trackSearchQuery(query, project, "observations", len(clusteredObservations), usedVector)
writeJSON(w, map[string]any{
"project": project,
"query": query,
"intent": detectedIntent,
"expansions": expansionInfo,
"observations": obsWithScores,
"threshold": threshold,
"max_results": maxResults,
})
}
// handleFileContext returns observations relevant to specific files being worked on.
// Uses vector similarity search to find observations that mention or relate to the given files.
func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
filesParam := r.URL.Query().Get("files")
if filesParam == "" {
http.Error(w, "files required", http.StatusBadRequest)
return
}
// Parse comma-separated file paths
files := strings.Split(filesParam, ",")
if len(files) == 0 {
http.Error(w, "at least one file required", http.StatusBadRequest)
return
}
// Limit to reasonable number of files
maxFiles := 20
if len(files) > maxFiles {
files = files[:maxFiles]
}
// Get limit parameter (default 10 per file)
limitStr := r.URL.Query().Get("limit")
limit := 10
if limitStr != "" {
if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 && parsed <= 50 {
limit = parsed
}
}
// Search for observations related to each file in parallel
ctx := r.Context()
// Check if vector search is available
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
writeJSON(w, map[string]any{
"files": files,
"results": map[string]any{},
"count": 0,
"error": "vector search not available",
})
return
}
// Prepare for parallel execution
type fileResult struct {
file string
results []map[string]any
obsIDs []int64 // Track observation IDs for deduplication
}
resultsChan := make(chan fileResult, len(files))
sem := make(chan struct{}, 5) // Limit concurrency to 5 parallel searches
var wg sync.WaitGroup
for _, file := range files {
file = strings.TrimSpace(file)
if file == "" {
continue
}
wg.Add(1)
go func(file string) {
defer wg.Done()
sem <- struct{}{} // Acquire semaphore
defer func() { <-sem }() // Release semaphore
// Build search query from file path
query := buildFileQuery(file)
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
if vecErr != nil {
log.Warn().Err(vecErr).Str("file", file).Msg("Vector search failed for file context")
return
}
// Extract observation IDs from vector results
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
if len(obsIDs) == 0 {
return
}
// Fetch observations
observations, err := s.observationStore.GetObservationsByIDs(ctx, obsIDs, "score_desc", limit*2)
if err != nil {
log.Warn().Err(err).Str("file", file).Msg("Failed to fetch observations for file context")
return
}
// Pre-build score map from vector results (O(n) instead of O(n²))
scoreMap := make(map[int64]float64, len(vectorResults))
var avgScore float64
for _, vr := range vectorResults {
avgScore += vr.Similarity
// Parse observation ID from vector result ID (format: obs_{id}_{field})
// Use index-based parsing to avoid slice allocation from strings.Split
if len(vr.ID) > 4 && vr.ID[:4] == "obs_" {
rest := vr.ID[4:] // Skip "obs_"
underscoreIdx := strings.IndexByte(rest, '_')
var idStr string
if underscoreIdx >= 0 {
idStr = rest[:underscoreIdx]
} else {
idStr = rest
}
if id, parseErr := strconv.ParseInt(idStr, 10, 64); parseErr == nil {
// Keep highest score for each observation
if existing, exists := scoreMap[id]; !exists || vr.Similarity > existing {
scoreMap[id] = vr.Similarity
}
}
}
}
if len(vectorResults) > 0 {
avgScore /= float64(len(vectorResults))
}
fileResults := make([]map[string]any, 0, limit)
var usedIDs []int64
for _, obs := range observations {
// Check project scope
if obs.Scope == "project" && obs.Project != project {
continue
}
// O(1) score lookup instead of O(n) nested loop
score, found := scoreMap[obs.ID]
if !found {
// Use average score as fallback
score = avgScore
}
// Only include if score is above threshold
if score < 0.3 {
continue
}
fileResults = append(fileResults, map[string]any{
"id": obs.ID,
"title": obs.Title.String,
"type": obs.Type,
"narrative": obs.Narrative.String,
"facts": obs.Facts,
"score": score,
})
usedIDs = append(usedIDs, obs.ID)
if len(fileResults) >= limit {
break
}
}
if len(fileResults) > 0 {
resultsChan <- fileResult{file: file, results: fileResults, obsIDs: usedIDs}
}
}(file)
}
// Close channel when all goroutines complete
go func() {
wg.Wait()
close(resultsChan)
}()
// Collect results and deduplicate
allResults := make(map[string]any)
seenObservationIDs := make(map[int64]bool)
for res := range resultsChan {
// Filter out duplicates that were already seen in other files
dedupedResults := make([]map[string]any, 0, len(res.results))
for i, r := range res.results {
obsID := res.obsIDs[i]
if !seenObservationIDs[obsID] {
seenObservationIDs[obsID] = true
dedupedResults = append(dedupedResults, r)
}
}
if len(dedupedResults) > 0 {
allResults[res.file] = dedupedResults
}
}
writeJSON(w, map[string]any{
"files": files,
"results": allResults,
"count": len(allResults),
})
}
// buildFileQuery extracts meaningful search terms from a file path.
func buildFileQuery(filePath string) string {
// Remove common prefixes and extensions
path := strings.TrimPrefix(filePath, "/")
// Extract the filename and directory
parts := strings.Split(path, "/")
meaningful := make([]string, 0, len(parts))
for _, part := range parts {
// Skip common directory names that aren't meaningful
switch strings.ToLower(part) {
case "src", "lib", "internal", "pkg", "cmd", "api", "app", "test", "tests", "spec", "specs":
continue
default:
// Remove file extension
if idx := strings.LastIndex(part, "."); idx > 0 {
part = part[:idx]
}
// Convert camelCase/PascalCase to spaces
part = splitCamelCase(part)
// Convert snake_case to spaces
part = strings.ReplaceAll(part, "_", " ")
// Convert kebab-case to spaces
part = strings.ReplaceAll(part, "-", " ")
meaningful = append(meaningful, part)
}
}
return strings.Join(meaningful, " ")
}
// splitCamelCase splits camelCase or PascalCase into separate words.
func splitCamelCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune(' ')
}
result.WriteRune(r)
}
return result.String()
}
// handleContextInject returns context for injection at session start.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
cwd := r.URL.Query().Get("cwd")
if cwd == "" {
cwd = "/"
}
// Limit observations for fast startup (configurable, default 100)
limit := s.config.ContextObservations
if limit <= 0 {
limit = DefaultContextLimit
}
// Full count determines how many observations get full detail (configurable, default 25)
fullCount := s.config.ContextFullCount
if fullCount <= 0 {
fullCount = 25
}
// Get recent observations
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Fast staleness filter - NO verification (that's too slow for startup)
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
// Record retrieval stats with staleness metrics
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), false)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info().
Str("project", project).
Int("total", len(observations)).
Int("fresh", len(freshObservations)).
Int("clustered", len(clusteredObservations)).
Int("duplicates", duplicatesRemoved).
Int("stale_excluded", staleCount).
Msg("Context injection with clustering")
writeJSON(w, map[string]any{
"project": project,
"observations": clusteredObservations,
"full_count": fullCount,
"stale_excluded": staleCount,
"duplicates_removed": duplicatesRemoved,
})
}
// handleContextCount returns the count of observations for a project.
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
count, err := s.getCachedObservationCount(r.Context(), project)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]any{
"project": project,
"count": count,
})
}
+595
View File
@@ -0,0 +1,595 @@
// Package worker provides data retrieval HTTP handlers.
package worker
import (
"encoding/json"
"net/http"
"runtime"
"sort"
"strings"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// handleGetObservations returns recent observations.
// Supports optional query parameter for semantic search via sqlite-vec.
// Supports pagination via limit and offset query parameters.
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var observations []*models.Observation
var total int64
var err error
var usedVector bool
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
if len(obsIDs) > 0 {
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
if err == nil {
usedVector = true
total = int64(len(observations)) // Vector search doesn't have total, use returned count
}
}
}
}
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
// Strict project filtering for dashboard - only observations from this project
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
} else {
// All projects
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
}
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if observations == nil {
observations = []*models.Observation{}
}
// Track search if query was provided
if query != "" {
s.trackSearchQuery(query, project, "observations", len(observations), usedVector)
}
// Return paginated response
writeJSON(w, map[string]any{
"observations": observations,
"total": total,
"limit": pagination.Limit,
"offset": pagination.Offset,
"hasMore": int64(pagination.Offset)+int64(len(observations)) < total,
})
}
// handleGetSummaries returns recent summaries.
// Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var summaries []*models.SessionSummary
var err error
var usedVector bool
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
if len(summaryIDs) > 0 {
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
}
}
}
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
} else {
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
}
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if summaries == nil {
summaries = []*models.SessionSummary{}
}
writeJSON(w, summaries)
}
// handleGetPrompts returns recent user prompts.
// Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var prompts []*models.UserPromptWithSession
var err error
var usedVector bool
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
if len(promptIDs) > 0 {
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
}
}
}
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
} else {
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
}
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if prompts == nil {
prompts = []*models.UserPromptWithSession{}
}
writeJSON(w, prompts)
}
// handleGetProjects returns all projects.
// Response is cacheable for 5 minutes since project list changes infrequently.
func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) {
projects, err := s.sessionStore.GetAllProjects(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Cache for 5 minutes - project list changes infrequently
w.Header().Set("Cache-Control", "public, max-age=300")
writeJSON(w, projects)
}
// handleGetTypes returns the canonical list of observation and concept types.
// This provides a single source of truth for both backend and frontend.
// Response is cacheable as these values never change at runtime.
func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
// Cache for 24 hours - these values are compile-time constants
w.Header().Set("Cache-Control", "public, max-age=86400")
writeJSON(w, map[string]any{
"observation_types": ObservationTypes,
"concept_types": ConceptTypes,
})
}
// handleGetModels returns available embedding models.
// Response is cacheable as model list doesn't change without restart.
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
// Cache for 1 hour - model list is static during runtime
w.Header().Set("Cache-Control", "public, max-age=3600")
models := embedding.ListModels()
defaultModel := embedding.GetDefaultModel()
writeJSON(w, map[string]any{
"models": models,
"default": defaultModel,
"current": s.embedSvc.Version(),
})
}
// handleGetStats returns worker statistics.
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
retrievalStats := s.GetRetrievalStats(project)
sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context())
response := map[string]any{
"uptime": time.Since(s.startTime).String(),
"uptimeSeconds": time.Since(s.startTime).Seconds(),
"activeSessions": s.sessionManager.GetActiveSessionCount(),
"queueDepth": s.sessionManager.GetTotalQueueDepth(),
"isProcessing": s.sessionManager.IsAnySessionProcessing(),
"connectedClients": s.sseBroadcaster.ClientCount(),
"sessionsToday": sessionsToday,
"retrieval": retrievalStats,
"ready": s.ready.Load(),
}
// Add memory stats
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
response["memory"] = map[string]any{
"alloc_mb": float64(memStats.Alloc) / 1024 / 1024,
"total_alloc_mb": float64(memStats.TotalAlloc) / 1024 / 1024,
"sys_mb": float64(memStats.Sys) / 1024 / 1024,
"heap_alloc_mb": float64(memStats.HeapAlloc) / 1024 / 1024,
"heap_inuse_mb": float64(memStats.HeapInuse) / 1024 / 1024,
"heap_objects": memStats.HeapObjects,
"goroutines": runtime.NumGoroutine(),
"gc_cycles": memStats.NumGC,
"gc_pause_total_ms": float64(memStats.PauseTotalNs) / 1e6,
}
// Add database health if available
if s.store != nil {
dbHealth := s.store.HealthCheck(r.Context())
response["database"] = map[string]any{
"status": dbHealth.Status,
"query_latency_ms": float64(dbHealth.QueryLatency) / 1e6,
"pool": dbHealth.PoolStats,
"warning": dbHealth.Warning,
}
}
// Add embedding model info
if s.embedSvc != nil {
response["embeddingModel"] = map[string]any{
"name": s.embedSvc.Name(),
"version": s.embedSvc.Version(),
"dimensions": s.embedSvc.Dimensions(),
}
}
// Add vector cache stats
if s.vectorClient != nil {
if count, err := s.vectorClient.Count(r.Context()); err == nil {
response["vectorCount"] = count
}
cacheSize, cacheMax := s.vectorClient.CacheStats()
response["vectorCache"] = map[string]any{
"size": cacheSize,
"max_size": cacheMax,
}
}
// Include project-specific observation count if project is specified
if project != "" {
count, err := s.getCachedObservationCount(r.Context(), project)
if err == nil {
response["projectObservations"] = count
response["project"] = project
}
}
// Add rate limiter stats
if s.rateLimiter != nil {
response["rateLimiter"] = s.rateLimiter.Stats()
}
// Add circuit breaker metrics
if s.processor != nil {
response["circuitBreaker"] = s.processor.CircuitBreakerMetrics()
}
writeJSON(w, response)
}
// handleGetRetrievalStats returns detailed retrieval statistics.
func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
stats := s.GetRetrievalStats(project)
writeJSON(w, stats)
}
// handleGetRecentQueries returns recent search queries for analytics.
func (s *Service) handleGetRecentQueries(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
limit := gorm.ParseLimitParam(r, 20)
queries := s.getRecentSearchQueries(project, limit)
writeJSON(w, map[string]any{
"queries": queries,
"count": len(queries),
"project": project,
})
}
// handleGetSearchAnalytics returns comprehensive search analytics and statistics.
func (s *Service) handleGetSearchAnalytics(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
// Get all recent queries for analysis
queries := s.getRecentSearchQueries(project, maxRecentQueries)
// Calculate analytics
totalQueries := len(queries)
vectorSearches := 0
totalResults := 0
zeroResultQueries := 0
queryTypes := make(map[string]int)
topKeywords := make(map[string]int)
for _, q := range queries {
if q.UsedVector {
vectorSearches++
}
totalResults += q.Results
if q.Results == 0 {
zeroResultQueries++
}
queryTypes[q.Type]++
// Extract keywords (simple word tokenization using iterator)
for word := range strings.FieldsSeq(strings.ToLower(q.Query)) {
if len(word) > 3 { // Skip short words
topKeywords[word]++
}
}
}
// Sort keywords by frequency
type keywordCount struct {
Keyword string `json:"keyword"`
Count int `json:"count"`
}
sortedKeywords := make([]keywordCount, 0, len(topKeywords))
for kw, count := range topKeywords {
sortedKeywords = append(sortedKeywords, keywordCount{Keyword: kw, Count: count})
}
sort.Slice(sortedKeywords, func(i, j int) bool {
return sortedKeywords[i].Count > sortedKeywords[j].Count
})
if len(sortedKeywords) > 10 {
sortedKeywords = sortedKeywords[:10]
}
// Calculate averages
avgResults := float64(0)
vectorSearchRate := float64(0)
zeroResultRate := float64(0)
if totalQueries > 0 {
avgResults = float64(totalResults) / float64(totalQueries)
vectorSearchRate = float64(vectorSearches) / float64(totalQueries) * 100
zeroResultRate = float64(zeroResultQueries) / float64(totalQueries) * 100
}
writeJSON(w, map[string]any{
"total_queries": totalQueries,
"vector_search_rate": vectorSearchRate,
"avg_results": avgResults,
"zero_result_rate": zeroResultRate,
"query_types": queryTypes,
"top_keywords": sortedKeywords,
"project": project,
})
}
// handleVectorHealth returns comprehensive health information about the vector database.
func (s *Service) handleVectorHealth(w http.ResponseWriter, r *http.Request) {
if s.vectorClient == nil {
http.Error(w, "vector client not initialized", http.StatusServiceUnavailable)
return
}
stats, err := s.vectorClient.GetHealthStats(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Add additional computed metrics
healthScore := 100.0
var warnings []string
// Penalize for stale vectors
if stats.TotalVectors > 0 {
staleRatio := float64(stats.StaleVectors) / float64(stats.TotalVectors)
if staleRatio > 0 {
healthScore -= staleRatio * 50 // Up to 50 points off for stale vectors
warnings = append(warnings, formatWarning("%.1f%% vectors need rebuild", staleRatio*100))
}
}
// Check cache effectiveness
cacheHitRate := stats.EmbeddingCache.HitRate()
if cacheHitRate < 20 && (stats.EmbeddingCache.EmbeddingHits+stats.EmbeddingCache.EmbeddingMisses) > 100 {
healthScore -= 10
warnings = append(warnings, formatWarning("Low cache hit rate: %.1f%%", cacheHitRate))
}
// Penalize if rebuild is needed
if stats.NeedsRebuild {
healthScore -= 20
warnings = append(warnings, "Vector rebuild recommended: "+stats.RebuildReason)
}
if healthScore < 0 {
healthScore = 0
}
status := "healthy"
if healthScore < 50 {
status = "unhealthy"
} else if healthScore < 80 {
status = "degraded"
}
writeJSON(w, map[string]any{
"status": status,
"health_score": healthScore,
"warnings": warnings,
"stats": stats,
"cache_hit_rate": cacheHitRate,
})
}
// UpdateObservationRequest is the request body for updating an observation.
type UpdateObservationRequest struct {
Title *string `json:"title,omitempty"`
Subtitle *string `json:"subtitle,omitempty"`
Narrative *string `json:"narrative,omitempty"`
Scope *string `json:"scope,omitempty"`
Facts []string `json:"facts,omitempty"`
Concepts []string `json:"concepts,omitempty"`
FilesRead []string `json:"files_read,omitempty"`
FilesModified []string `json:"files_modified,omitempty"`
}
// handleUpdateObservation updates an existing observation.
// PUT /api/observations/{id}
func (s *Service) handleUpdateObservation(w http.ResponseWriter, r *http.Request) {
// Parse observation ID from URL
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
if !ok {
return
}
// Parse request body
var req UpdateObservationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
// Build update struct - only include fields that were provided
update := &gorm.ObservationUpdate{}
if req.Title != nil {
update.Title = req.Title
}
if req.Subtitle != nil {
update.Subtitle = req.Subtitle
}
if req.Narrative != nil {
update.Narrative = req.Narrative
}
if req.Facts != nil {
update.Facts = &req.Facts
}
if req.Concepts != nil {
update.Concepts = &req.Concepts
}
if req.FilesRead != nil {
update.FilesRead = &req.FilesRead
}
if req.FilesModified != nil {
update.FilesModified = &req.FilesModified
}
if req.Scope != nil {
// Validate scope
if *req.Scope != "project" && *req.Scope != "global" {
http.Error(w, "scope must be 'project' or 'global'", http.StatusBadRequest)
return
}
update.Scope = req.Scope
}
// Update the observation
updatedObs, err := s.observationStore.UpdateObservation(r.Context(), id, update)
if err != nil {
if strings.Contains(err.Error(), "not found") {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
http.Error(w, "failed to update observation: "+err.Error(), http.StatusInternalServerError)
return
}
// Trigger vector resync for the updated observation
if s.vectorSync != nil {
s.asyncVectorSync(func() {
if err := s.vectorSync.SyncObservation(s.ctx, updatedObs); err != nil {
log.Warn().Err(err).Int64("id", id).Msg("Failed to resync observation vectors after update")
}
})
}
// Broadcast update event
s.sseBroadcaster.Broadcast(map[string]any{
"type": "observation_updated",
"id": id,
})
writeJSON(w, map[string]any{
"observation": updatedObs,
"message": "observation updated successfully",
})
}
// handleGetObservationByID returns a single observation by ID.
// GET /api/observations/{id}
func (s *Service) handleGetObservationByID(w http.ResponseWriter, r *http.Request) {
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
if !ok {
return
}
obs, err := s.observationStore.GetObservationByID(r.Context(), id)
if err != nil {
http.Error(w, "failed to get observation: "+err.Error(), http.StatusInternalServerError)
return
}
if obs == nil {
http.Error(w, "observation not found", http.StatusNotFound)
return
}
writeJSON(w, obs)
}
+680
View File
@@ -0,0 +1,680 @@
// Package worker provides import, export, and archive HTTP handlers.
package worker
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
"github.com/rs/zerolog/log"
)
// BulkImportRequest is the request body for bulk observation import.
type BulkImportRequest struct {
Project string `json:"project"`
Observations []BulkObservationInput `json:"observations"`
}
// BulkObservationInput represents a single observation in bulk import.
type BulkObservationInput struct {
Type string `json:"type"`
Title string `json:"title"`
Subtitle string `json:"subtitle,omitempty"`
Narrative string `json:"narrative,omitempty"`
Scope string `json:"scope,omitempty"`
Facts []string `json:"facts,omitempty"`
Concepts []string `json:"concepts,omitempty"`
FilesRead []string `json:"files_read,omitempty"`
FilesModified []string `json:"files_modified,omitempty"`
}
// BulkImportResponse contains the result of a bulk import operation.
type BulkImportResponse struct {
Errors []string `json:"errors,omitempty"`
Imported int `json:"imported"`
Failed int `json:"failed"`
SkippedDuplicates int `json:"skipped_duplicates,omitempty"`
}
// handleBulkImport handles bulk import of observations.
// This is useful for migrating data or importing observations from external sources.
func (s *Service) handleBulkImport(w http.ResponseWriter, r *http.Request) {
// Rate limit bulk operations to prevent DoS
if s.bulkOpLimiter != nil && !s.bulkOpLimiter.CanExecute() {
remaining := s.bulkOpLimiter.CooldownRemaining()
http.Error(w, fmt.Sprintf("bulk import rate limited, retry in %d seconds", remaining), http.StatusTooManyRequests)
return
}
var req BulkImportRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
if req.Project == "" {
http.Error(w, "project is required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(req.Project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(req.Observations) == 0 {
http.Error(w, "at least one observation is required", http.StatusBadRequest)
return
}
// Limit batch size to prevent overwhelming the system
maxBatchSize := 100
if len(req.Observations) > maxBatchSize {
http.Error(w, fmt.Sprintf("batch size exceeds maximum of %d", maxBatchSize), http.StatusBadRequest)
return
}
// Create a synthetic session for bulk import
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), fmt.Sprintf("bulk-import-%d", time.Now().UnixMilli()), req.Project, "bulk import")
if err != nil {
http.Error(w, "failed to create import session: "+err.Error(), http.StatusInternalServerError)
return
}
var imported, failed, skippedDupes int
var errors []string
// Track imported observations for deduplication within the batch
importedObs := make([]*models.Observation, 0, len(req.Observations))
// Deduplication threshold - observations more similar than this are considered duplicates
const dedupThreshold = 0.7
for i, obsInput := range req.Observations {
// Validate observation type using O(1) map lookup
if !IsValidObservationType(obsInput.Type) {
failed++
errors = append(errors, fmt.Sprintf("observation %d: invalid type '%s'", i, obsInput.Type))
continue
}
// Build parsed observation
parsedObs := &models.ParsedObservation{
Type: models.ObservationType(obsInput.Type),
Title: obsInput.Title,
Subtitle: obsInput.Subtitle,
Facts: obsInput.Facts,
Narrative: obsInput.Narrative,
Concepts: obsInput.Concepts,
FilesRead: obsInput.FilesRead,
FilesModified: obsInput.FilesModified,
Scope: models.ObservationScope(obsInput.Scope),
}
// Convert to temporary observation for similarity check
tempObs := &models.Observation{
Title: sql.NullString{String: parsedObs.Title, Valid: parsedObs.Title != ""},
Subtitle: sql.NullString{String: parsedObs.Subtitle, Valid: parsedObs.Subtitle != ""},
Narrative: sql.NullString{String: parsedObs.Narrative, Valid: parsedObs.Narrative != ""},
}
// Check for duplicates within this import batch
if similarity.IsSimilarToAny(tempObs, importedObs, dedupThreshold) {
skippedDupes++
continue
}
// Store observation
obsID, _, err := s.observationStore.StoreObservation(
r.Context(),
fmt.Sprintf("bulk-import-%d", sessionID),
req.Project,
parsedObs,
0, // prompt number
0, // discovery tokens
)
if err != nil {
failed++
errors = append(errors, fmt.Sprintf("observation %d: %v", i, err))
continue
}
// Sync to vector DB asynchronously with rate limiting
if s.vectorSync != nil {
s.asyncVectorSync(func() {
// Use service context as parent to respect shutdown signals
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
defer cancel()
obs, err := s.observationStore.GetObservationByID(ctx, obsID)
if err == nil && obs != nil {
if syncErr := s.vectorSync.SyncObservation(ctx, obs); syncErr != nil {
if s.ctx.Err() == nil { // Don't log during shutdown
log.Debug().Err(syncErr).Int64("id", obsID).Msg("Failed to sync observation during bulk import")
}
}
}
})
}
// Track for deduplication of subsequent observations in this batch
importedObs = append(importedObs, tempObs)
imported++
}
log.Info().
Str("project", req.Project).
Int("imported", imported).
Int("failed", failed).
Int("skipped_duplicates", skippedDupes).
Msg("Bulk import completed")
// Invalidate observation count cache after import
if imported > 0 {
if req.Project != "" {
s.invalidateObsCountCache(req.Project)
} else {
s.invalidateAllObsCountCache()
}
}
// Broadcast observation event for dashboard refresh
s.sseBroadcaster.Broadcast(map[string]any{
"type": "observation",
"action": "bulk_import",
"project": req.Project,
"count": imported,
})
writeJSON(w, BulkImportResponse{
Imported: imported,
Failed: failed,
SkippedDuplicates: skippedDupes,
Errors: errors,
})
}
// ArchiveRequest is the request body for archiving observations.
type ArchiveRequest struct {
Project string `json:"project,omitempty"`
Reason string `json:"reason,omitempty"`
IDs []int64 `json:"ids,omitempty"`
MaxAgeDays int `json:"max_age_days,omitempty"`
}
// handleArchiveObservations archives observations by ID or by age.
// Supports batch archival with error tracking per observation.
func (s *Service) handleArchiveObservations(w http.ResponseWriter, r *http.Request) {
var req ArchiveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
var archivedIDs []int64
var failedIDs []int64
var errors []string
var err error
if len(req.IDs) > 0 {
// Archive specific observations with parallel processing for large batches
if len(req.IDs) > 5 {
// Use parallel archival for batches larger than 5
type archiveResult struct {
err error
id int64
}
results := make(chan archiveResult, len(req.IDs))
// Limit concurrency to avoid overwhelming the database
sem := make(chan struct{}, 5)
var wg sync.WaitGroup
for _, id := range req.IDs {
wg.Add(1)
go func(obsID int64) {
defer wg.Done()
sem <- struct{}{} // Acquire
defer func() { <-sem }() // Release
archErr := s.observationStore.ArchiveObservation(r.Context(), obsID, req.Reason)
results <- archiveResult{id: obsID, err: archErr}
}(id)
}
// Close results channel when all goroutines complete
go func() {
wg.Wait()
close(results)
}()
// Collect results
for res := range results {
if res.err != nil {
log.Warn().Err(res.err).Int64("id", res.id).Msg("Failed to archive observation")
failedIDs = append(failedIDs, res.id)
errors = append(errors, fmt.Sprintf("id %d: %v", res.id, res.err))
} else {
archivedIDs = append(archivedIDs, res.id)
}
}
} else {
// Sequential for small batches
for _, id := range req.IDs {
if archErr := s.observationStore.ArchiveObservation(r.Context(), id, req.Reason); archErr != nil {
log.Warn().Err(archErr).Int64("id", id).Msg("Failed to archive observation")
failedIDs = append(failedIDs, id)
errors = append(errors, fmt.Sprintf("id %d: %v", id, archErr))
} else {
archivedIDs = append(archivedIDs, id)
}
}
}
} else if req.Project != "" || req.MaxAgeDays > 0 {
// Archive by age
archivedIDs, err = s.observationStore.ArchiveOldObservations(r.Context(), req.Project, req.MaxAgeDays, req.Reason)
if err != nil {
http.Error(w, "failed to archive: "+err.Error(), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "either 'ids' or 'project'/'max_age_days' is required", http.StatusBadRequest)
return
}
log.Info().
Str("project", req.Project).
Int("archived", len(archivedIDs)).
Int("failed", len(failedIDs)).
Msg("Observations archived")
// Invalidate cache if any observations were archived
if len(archivedIDs) > 0 {
if req.Project != "" {
s.invalidateObsCountCache(req.Project)
} else {
s.invalidateAllObsCountCache()
}
}
response := map[string]any{
"archived_count": len(archivedIDs),
"archived_ids": archivedIDs,
}
if len(failedIDs) > 0 {
response["failed_count"] = len(failedIDs)
response["failed_ids"] = failedIDs
response["errors"] = errors
}
writeJSON(w, response)
}
// handleUnarchiveObservation restores an archived observation.
func (s *Service) handleUnarchiveObservation(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid observation id", http.StatusBadRequest)
return
}
if err := s.observationStore.UnarchiveObservation(r.Context(), id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Invalidate all caches since we don't know the project
s.invalidateAllObsCountCache()
writeJSON(w, map[string]any{
"success": true,
"id": id,
})
}
// handleGetArchivedObservations returns archived observations.
func (s *Service) handleGetArchivedObservations(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
limit := gorm.ParseLimitParam(r, DefaultObservationsLimit)
observations, err := s.observationStore.GetArchivedObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if observations == nil {
observations = []*models.Observation{}
}
writeJSON(w, observations)
}
// handleGetArchivalStats returns archival statistics.
func (s *Service) handleGetArchivalStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
stats, err := s.observationStore.GetArchivalStats(r.Context(), project)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, stats)
}
// handleExportObservations exports observations in JSON or CSV format.
// Supports query parameters: project, format (json/csv), scope, type, limit.
func (s *Service) handleExportObservations(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
format := r.URL.Query().Get("format")
if format == "" {
format = "json"
}
scope := r.URL.Query().Get("scope") // project, global, or empty for all
obsType := r.URL.Query().Get("type") // bugfix, feature, etc.
limit := gorm.ParseLimitParamWithMax(r, 1000, 5000) // Higher limit for exports, capped at 5000
// Validate format
if format != "json" && format != "csv" {
http.Error(w, "format must be 'json' or 'csv'", http.StatusBadRequest)
return
}
// Get observations with filters
ctx := r.Context()
var observations []*models.Observation
var err error
if project != "" {
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
} else {
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Apply additional filters
if scope != "" || obsType != "" {
filtered := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if scope != "" && string(obs.Scope) != scope {
continue
}
if obsType != "" && string(obs.Type) != obsType {
continue
}
filtered = append(filtered, obs)
}
observations = filtered
}
// Generate filename
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("observations-%s.%s", timestamp, format)
if project != "" {
// Sanitize project name for filename
sanitized := strings.ReplaceAll(project, "/", "_")
sanitized = strings.ReplaceAll(sanitized, "\\", "_")
if len(sanitized) > 50 {
sanitized = sanitized[:50]
}
filename = fmt.Sprintf("observations-%s-%s.%s", sanitized, timestamp, format)
}
switch format {
case "csv":
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
s.writeObservationsCSV(w, observations)
default: // json
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
writeJSON(w, map[string]any{
"exported_at": time.Now().Format(time.RFC3339),
"project": project,
"count": len(observations),
"observations": observations,
})
}
}
// writeObservationsCSV writes observations in CSV format.
// Uses fmt.Fprintf directly to avoid intermediate string allocations.
func (s *Service) writeObservationsCSV(w http.ResponseWriter, observations []*models.Observation) {
// Write CSV header
_, _ = io.WriteString(w, "id,type,scope,project,title,subtitle,narrative,concepts,facts,created_at,importance_score\n")
for _, obs := range observations {
// Write directly to avoid string allocation per row
_, _ = fmt.Fprintf(w, "%d,%s,%s,%s,%s,%s,%s,%s,%s,%s,%.2f\n",
obs.ID,
obs.Type,
obs.Scope,
escapeCsvField(obs.Project),
escapeCsvField(obs.Title.String),
escapeCsvField(obs.Subtitle.String),
escapeCsvField(obs.Narrative.String),
escapeCsvField(strings.Join(obs.Concepts, ";")),
escapeCsvField(strings.Join(obs.Facts, ";")),
obs.CreatedAt,
obs.ImportanceScore,
)
}
}
// escapeCsvField escapes a field for CSV output.
func escapeCsvField(s string) string {
// If field contains comma, quote, or newline, wrap in quotes and escape quotes
if strings.ContainsAny(s, ",\"\n\r") {
s = strings.ReplaceAll(s, "\"", "\"\"")
return "\"" + s + "\""
}
return s
}
// BulkStatusRequest represents a request to update status for multiple observations.
type BulkStatusRequest struct {
Action string `json:"action"`
Reason string `json:"reason,omitempty"`
IDs []int64 `json:"ids"`
Feedback int `json:"feedback,omitempty"`
}
// handleBulkStatusUpdate updates status for multiple observations in one request.
func (s *Service) handleBulkStatusUpdate(w http.ResponseWriter, r *http.Request) {
var req BulkStatusRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
if len(req.IDs) == 0 {
http.Error(w, "ids is required", http.StatusBadRequest)
return
}
if len(req.IDs) > 500 {
http.Error(w, "maximum 500 ids per request", http.StatusBadRequest)
return
}
ctx := r.Context()
var updated, failed int
var errors []string
switch req.Action {
case "supersede":
for _, id := range req.IDs {
if err := s.observationStore.MarkAsSuperseded(ctx, id); err != nil {
failed++
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
} else {
updated++
}
}
case "archive":
for _, id := range req.IDs {
if err := s.observationStore.ArchiveObservation(ctx, id, req.Reason); err != nil {
failed++
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
} else {
updated++
}
}
case "set_feedback":
if req.Feedback < -1 || req.Feedback > 1 {
http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest)
return
}
for _, id := range req.IDs {
if err := s.observationStore.UpdateObservationFeedback(ctx, id, req.Feedback); err != nil {
failed++
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
} else {
updated++
}
}
default:
http.Error(w, "action must be 'supersede', 'archive', or 'set_feedback'", http.StatusBadRequest)
return
}
// Invalidate cache for archive action (affects observation counts)
if req.Action == "archive" && updated > 0 {
// No project info available, invalidate all caches
s.invalidateAllObsCountCache()
}
response := map[string]any{
"action": req.Action,
"updated": updated,
"failed": failed,
}
if len(errors) > 0 {
response["errors"] = errors
}
writeJSON(w, response)
}
// handleFindDuplicates finds potential duplicate observations using similarity clustering.
// Returns groups of similar observations that may be candidates for merging or archival.
func (s *Service) handleFindDuplicates(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
thresholdStr := r.URL.Query().Get("threshold")
limit := gorm.ParseLimitParam(r, 100)
// Parse threshold (default 0.6 = 60% similarity)
threshold := 0.6
if thresholdStr != "" {
if t, err := strconv.ParseFloat(thresholdStr, 64); err == nil && t > 0 && t < 1 {
threshold = t
}
}
// Get recent observations
ctx := r.Context()
var observations []*models.Observation
var err error
if project != "" {
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
} else {
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if len(observations) < 2 {
writeJSON(w, map[string]any{
"duplicate_groups": []any{},
"total_checked": len(observations),
"threshold": threshold,
})
return
}
// Find duplicates using similarity comparison
type duplicateGroup struct {
Observations []map[string]any `json:"observations"`
Similarity float64 `json:"similarity"`
}
groups := []duplicateGroup{}
processed := make(map[int64]bool)
for i, obs1 := range observations {
if processed[obs1.ID] {
continue
}
terms1 := similarity.ExtractObservationTerms(obs1)
if len(terms1) == 0 {
continue
}
group := duplicateGroup{
Observations: []map[string]any{obs1.ToMap()},
Similarity: 1.0,
}
for j := i + 1; j < len(observations); j++ {
obs2 := observations[j]
if processed[obs2.ID] {
continue
}
terms2 := similarity.ExtractObservationTerms(obs2)
sim := similarity.JaccardSimilarity(terms1, terms2)
if sim >= threshold {
obsMap := obs2.ToMap()
obsMap["similarity_to_first"] = sim
group.Observations = append(group.Observations, obsMap)
group.Similarity = min(group.Similarity, sim)
processed[obs2.ID] = true
}
}
if len(group.Observations) > 1 {
processed[obs1.ID] = true
groups = append(groups, group)
}
}
// Sort groups by size (largest first)
sort.Slice(groups, func(i, j int) bool {
return len(groups[i].Observations) > len(groups[j].Observations)
})
writeJSON(w, map[string]any{
"duplicate_groups": groups,
"total_checked": len(observations),
"groups_found": len(groups),
"threshold": threshold,
"project": project,
})
}
+9 -5
View File
@@ -10,6 +10,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// FeedbackRequest represents a user feedback submission.
@@ -311,8 +312,7 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ
// Run recalculation in background
go func() {
if err := recalculator.RecalculateNow(r.Context()); err != nil {
// Log error but don't block response
_ = err // Explicitly ignore - background operation
log.Warn().Err(err).Msg("Background score recalculation failed")
}
}()
@@ -345,14 +345,18 @@ func (s *Service) incrementRetrievalCounts(ids []int64) {
}
// Increment in background to not block response
// Use service context to respect shutdown signals
s.wg.Add(1)
go func() {
// Create a new context with timeout for the background operation
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer s.wg.Done()
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
// Log but don't fail - this is a background operation
_ = err // Explicitly ignore - background operation
if s.ctx.Err() == nil { // Don't log during shutdown
log.Debug().Err(err).Msg("Failed to increment retrieval counts")
}
}
}()
}
+354
View File
@@ -0,0 +1,354 @@
// Package worker provides session-related HTTP handlers.
package worker
import (
"context"
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// SessionInitRequest is the request body for session initialization.
type SessionInitRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
Prompt string `json:"prompt"`
MatchedObservations int `json:"matchedObservations"`
}
// SessionInitResponse is the response for session initialization.
type SessionInitResponse struct {
Reason string `json:"reason,omitempty"`
SessionDBID int64 `json:"sessionDbId"`
PromptNumber int `json:"promptNumber"`
Skipped bool `json:"skipped,omitempty"`
}
// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions.
// If the same prompt text is seen within this window, it's considered a duplicate hook invocation.
const DuplicatePromptWindowSeconds = 10
// handleSessionInit handles session initialization from user-prompt hook.
// This handler is idempotent - duplicate requests within a short time window
// return the existing prompt data without creating duplicates.
func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
var req SessionInitRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Privacy check
if privacy.IsEntirelyPrivate(req.Prompt) {
// Create session but skip processing
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: promptNum,
Skipped: true,
Reason: "private",
})
return
}
// Clean prompt
cleanedPrompt := privacy.Clean(req.Prompt)
// DUPLICATE DETECTION: Check if this exact prompt was already saved recently.
// This prevents the bug where the hook fires multiple times for the same user action,
// creating many duplicate prompts with incrementing numbers.
if existingID, existingNum, found := s.promptStore.FindRecentPromptByText(r.Context(), req.ClaudeSessionID, cleanedPrompt, DuplicatePromptWindowSeconds); found {
// Get or create session (idempotent)
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
log.Debug().
Int64("sessionId", sessionID).
Int("promptNumber", existingNum).
Int64("promptId", existingID).
Msg("Duplicate prompt detected - returning existing")
// Return existing prompt data without incrementing or saving again
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: existingNum,
})
return
}
// Create session (idempotent)
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Increment prompt counter
promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Save user prompt with matched observation count
promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations)
if err != nil {
log.Warn().Err(err).Msg("Failed to save user prompt")
// Non-fatal: continue with session initialization
} else if s.vectorSync != nil {
// Sync to vector DB asynchronously (non-blocking)
now := time.Now()
promptWithSession := &models.UserPromptWithSession{
UserPrompt: models.UserPrompt{
ID: promptID,
ClaudeSessionID: req.ClaudeSessionID,
PromptNumber: promptNum,
PromptText: cleanedPrompt,
MatchedObservations: req.MatchedObservations,
CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: now.UnixMilli(),
},
Project: req.Project,
SDKSessionID: req.ClaudeSessionID,
}
s.asyncVectorSync(func() {
// Use service context as parent to respect shutdown signals
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
defer cancel()
if err := s.vectorSync.SyncUserPrompt(ctx, promptWithSession); err != nil {
if s.ctx.Err() == nil { // Don't log during shutdown
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to sqlite-vec")
}
}
})
}
log.Info().
Int64("sessionId", sessionID).
Int("promptNumber", promptNum).
Str("project", req.Project).
Msg("Session initialized")
// Broadcast prompt event for dashboard refresh
s.sseBroadcaster.Broadcast(map[string]any{
"type": "prompt",
"action": "created",
"project": req.Project,
})
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: promptNum,
})
}
// SessionStartRequest is the request body for starting SDK agent.
type SessionStartRequest struct {
UserPrompt string `json:"userPrompt"`
PromptNumber int `json:"promptNumber"`
}
// handleSessionStart handles SDK agent session start.
func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid session id", http.StatusBadRequest)
return
}
var req SessionStartRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Initialize session in manager
sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if sess == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
// Session is now registered. Observations will be processed
// asynchronously by the background queue processor (processQueue in service.go).
log.Info().
Int64("sessionId", id).
Int("promptNumber", req.PromptNumber).
Msg("SDK agent session initialized")
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// ObservationRequest is the request body for posting observations.
type ObservationRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
ToolName string `json:"tool_name"`
ToolInput any `json:"tool_input"`
ToolResponse any `json:"tool_response"`
CWD string `json:"cwd"`
}
// handleObservation handles observation posting from post-tool-use hook.
func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) {
var req ObservationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Find session
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if sess == nil {
// Create session on-the-fly with project from request
id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
sess, _ = s.sessionStore.GetSessionByID(r.Context(), id)
}
// Queue observation
if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{
ToolName: req.ToolName,
ToolInput: req.ToolInput,
ToolResponse: req.ToolResponse,
CWD: req.CWD,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// SubagentCompleteRequest is the request body for subagent completion.
type SubagentCompleteRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
}
// handleSubagentComplete handles subagent/Task completion notifications.
// This triggers immediate processing of any queued observations from the subagent.
func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) {
var req SubagentCompleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Find session
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
if err != nil || sess == nil {
// Session not found - subagent may have been in a different context
log.Debug().
Str("claudeSessionId", req.ClaudeSessionID).
Msg("Subagent complete - no active session found")
w.WriteHeader(http.StatusOK)
return
}
// Trigger immediate processing of queued observations
messages := s.sessionManager.DrainMessages(sess.ID)
if len(messages) > 0 && s.processor != nil {
log.Info().
Int64("sessionId", sess.ID).
Int("messages", len(messages)).
Msg("Processing queued observations from subagent")
for _, msg := range messages {
if msg.Type == session.MessageTypeObservation && msg.Observation != nil {
err := s.processor.ProcessObservation(
r.Context(),
sess.SDKSessionID.String,
sess.Project,
msg.Observation.ToolName,
msg.Observation.ToolInput,
msg.Observation.ToolResponse,
msg.Observation.PromptNumber,
msg.Observation.CWD,
)
if err != nil {
log.Error().Err(err).
Str("tool", msg.Observation.ToolName).
Msg("Failed to process subagent observation")
}
}
}
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// handleGetSessionByClaudeID looks up a session by Claude session ID.
func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) {
claudeSessionID := r.URL.Query().Get("claudeSessionId")
if claudeSessionID == "" {
http.Error(w, "claudeSessionId required", http.StatusBadRequest)
return
}
session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if session == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
writeJSON(w, session)
}
// SummarizeRequest is the request body for summarize requests.
type SummarizeRequest struct {
LastUserMessage string `json:"lastUserMessage"`
LastAssistantMessage string `json:"lastAssistantMessage"`
}
// handleSummarize handles summarize requests from stop hook.
func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid session id", http.StatusBadRequest)
return
}
var req SummarizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Queue summarize request
if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
+25 -12
View File
@@ -66,6 +66,7 @@ func testService(t *testing.T) (*Service, func()) {
cancel: cancel,
startTime: time.Now(),
retrievalStats: make(map[string]*RetrievalStats),
cachedObsCounts: make(map[string]cachedCount),
}
svc.setupRoutes()
@@ -345,11 +346,13 @@ func TestHandleGetObservations_Limit(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
// Parse as generic JSON array since the model uses custom marshaling
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
// Parse as object with observations key (API returns wrapped response)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.Len(t, observations, 10)
}
@@ -1135,10 +1138,13 @@ func TestHandleGetObservations_DefaultLimit(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
// Should return default limit (100)
assert.LessOrEqual(t, len(observations), DefaultObservationsLimit)
}
@@ -1159,10 +1165,12 @@ func TestHandleGetObservations_FilterByProject(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.Len(t, observations, 2)
}
@@ -1412,10 +1420,12 @@ func TestHandleGetObservations(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.GreaterOrEqual(t, len(observations), 2)
}
@@ -2697,10 +2707,13 @@ func TestHandleGetObservations_EmptyResult(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
// Should return empty array, not null
var obs []interface{}
err := json.Unmarshal(rec.Body.Bytes(), &obs)
// Should return empty array within observations key, not null
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
obs, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.NotNil(t, obs)
}
+243
View File
@@ -0,0 +1,243 @@
// Package worker provides update and restart HTTP handlers.
package worker
import (
"fmt"
"net/http"
"time"
"github.com/rs/zerolog/log"
)
// handleUpdateCheck checks for available updates.
func (s *Service) handleUpdateCheck(w http.ResponseWriter, r *http.Request) {
info, err := s.updater.CheckForUpdate(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, info)
}
// handleUpdateApply downloads and applies an available update.
func (s *Service) handleUpdateApply(w http.ResponseWriter, r *http.Request) {
// First check for update
info, err := s.updater.CheckForUpdate(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if !info.Available {
writeJSON(w, map[string]any{
"success": false,
"message": "No update available",
})
return
}
// Apply update in background with tracking for graceful shutdown
s.wg.Go(func() {
if err := s.updater.ApplyUpdate(s.ctx, info); err != nil {
log.Error().Err(err).Msg("Update failed")
}
})
writeJSON(w, map[string]any{
"success": true,
"message": "Update started",
"version": info.LatestVersion,
})
}
// handleUpdateStatus returns the current update status.
func (s *Service) handleUpdateStatus(w http.ResponseWriter, r *http.Request) {
status := s.updater.GetStatus()
writeJSON(w, status)
}
// ComponentHealth represents the health status of a single component.
type ComponentHealth struct {
Name string `json:"name"`
Status string `json:"status"` // "healthy", "degraded", "unhealthy"
Message string `json:"message,omitempty"`
}
// SelfCheckResponse contains the health status of all components.
type SelfCheckResponse struct {
Overall string `json:"overall"` // "healthy", "degraded", "unhealthy"
Version string `json:"version"`
Uptime string `json:"uptime"`
Components []ComponentHealth `json:"components"`
}
// handleSelfCheck returns the health status of all components.
func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
components := []ComponentHealth{}
overall := "healthy"
// Check Worker Service
workerStatus := ComponentHealth{Name: "Worker Service", Status: "healthy"}
if !s.ready.Load() {
if err := s.GetInitError(); err != nil {
workerStatus.Status = "unhealthy"
workerStatus.Message = err.Error()
overall = "unhealthy"
} else {
workerStatus.Status = "degraded"
workerStatus.Message = "Initializing"
if overall == "healthy" {
overall = "degraded"
}
}
}
components = append(components, workerStatus)
// Check SQLite Database
dbStatus := ComponentHealth{Name: "SQLite Database", Status: "healthy"}
if s.store == nil {
dbStatus.Status = "unhealthy"
dbStatus.Message = "Not initialized"
overall = "unhealthy"
} else if err := s.store.Ping(); err != nil {
dbStatus.Status = "unhealthy"
dbStatus.Message = err.Error()
overall = "unhealthy"
}
components = append(components, dbStatus)
// Check Vector DB (sqlite-vec)
vectorStatus := ComponentHealth{Name: "Vector DB", Status: "healthy"}
if s.vectorClient == nil {
vectorStatus.Status = "degraded"
vectorStatus.Message = "Not configured"
if overall == "healthy" {
overall = "degraded"
}
} else if !s.vectorClient.IsConnected() {
vectorStatus.Status = "degraded"
vectorStatus.Message = "Not connected"
if overall == "healthy" {
overall = "degraded"
}
}
components = append(components, vectorStatus)
// Check SDK Processor
sdkStatus := ComponentHealth{Name: "SDK Processor", Status: "healthy"}
if s.processor == nil {
sdkStatus.Status = "degraded"
sdkStatus.Message = "Not initialized"
if overall == "healthy" {
overall = "degraded"
}
} else if !s.processor.IsAvailable() {
sdkStatus.Status = "degraded"
sdkStatus.Message = "Claude CLI not available"
if overall == "healthy" {
overall = "degraded"
}
}
components = append(components, sdkStatus)
// Check SSE Broadcaster
sseStatus := ComponentHealth{Name: "SSE Broadcaster", Status: "healthy"}
if s.sseBroadcaster == nil {
sseStatus.Status = "unhealthy"
sseStatus.Message = "Not initialized"
overall = "unhealthy"
}
components = append(components, sseStatus)
// Check Cross-Encoder Reranker
rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"}
if !s.config.RerankingEnabled {
rerankerStatus.Status = "degraded"
rerankerStatus.Message = "Disabled in config"
if overall == "healthy" {
overall = "degraded"
}
} else if s.reranker == nil {
rerankerStatus.Status = "degraded"
rerankerStatus.Message = "Not initialized"
if overall == "healthy" {
overall = "degraded"
}
} else {
// Verify reranker is functional using Score
_, normalizedScore, err := s.reranker.Score("test query", "test document")
if err != nil {
rerankerStatus.Status = "unhealthy"
rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err)
if overall == "healthy" {
overall = "degraded"
}
} else {
rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore)
}
}
components = append(components, rerankerStatus)
// Calculate uptime
uptime := time.Since(s.startTime).Round(time.Second).String()
writeJSON(w, SelfCheckResponse{
Overall: overall,
Version: s.version,
Uptime: uptime,
Components: components,
})
}
// handleUpdateRestart restarts the worker with the new binary (after update).
func (s *Service) handleUpdateRestart(w http.ResponseWriter, r *http.Request) {
status := s.updater.GetStatus()
if status.State != "done" {
http.Error(w, "no update has been applied", http.StatusBadRequest)
return
}
// Send response before restarting
writeJSON(w, map[string]any{
"success": true,
"message": "Restarting worker...",
})
// Flush the response
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Restart in background after response is sent
go func() {
if err := s.updater.Restart(); err != nil {
log.Error().Err(err).Msg("Failed to restart worker")
}
}()
}
// handleRestart restarts the worker process (general restart, not tied to update).
func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) {
log.Info().Msg("Manual restart requested via API")
// Send response before restarting
writeJSON(w, map[string]any{
"success": true,
"message": "Restarting worker...",
"version": s.version,
})
// Flush the response
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Restart in background after response is sent
go func() {
// Small delay to ensure response is sent
time.Sleep(100 * time.Millisecond)
if err := s.updater.Restart(); err != nil {
log.Error().Err(err).Msg("Failed to restart worker")
}
}()
}
+333
View File
@@ -0,0 +1,333 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
)
// requestIDKey is the context key for request IDs.
type requestIDKey struct{}
// projectNamePattern validates project names to prevent path traversal.
var projectNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_./-]+$`)
// allowedOrigins is the whitelist of origins allowed for CORS.
// Uses exact matching to prevent bypass attacks like "evil-localhost.com".
var allowedOrigins = map[string]bool{
"http://localhost": true,
"http://localhost:3000": true,
"http://localhost:5173": true, // Vite dev server
"http://localhost:37778": true, // Dashboard UI
"http://127.0.0.1": true,
"http://127.0.0.1:3000": true,
"http://127.0.0.1:5173": true,
"http://127.0.0.1:37778": true,
}
// SecurityHeaders middleware adds essential security headers to all responses.
// These protect against common web vulnerabilities.
func SecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Enable XSS filter
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Restrict referrer information
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy - restrict to self
w.Header().Set("Content-Security-Policy", "default-src 'self'")
// Permissions Policy - disable unnecessary features
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
// CORS: Use exact match whitelist to prevent bypass attacks
origin := r.Header.Get("Origin")
if allowedOrigins[origin] {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Auth-Token, Authorization, X-Request-ID")
}
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
// MaxBodySize middleware limits the size of incoming request bodies.
// This prevents denial of service attacks via large payloads.
func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > maxBytes {
http.Error(w, "request body too large", http.StatusRequestEntityTooLarge)
return
}
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
next.ServeHTTP(w, r)
})
}
}
// TokenAuth provides simple token-based authentication for localhost services.
// The token is generated at startup and must be provided in the X-Auth-Token header.
type TokenAuth struct {
ExemptPaths map[string]bool
token string
mu sync.RWMutex
enabled bool
}
// NewTokenAuth creates a new TokenAuth with a randomly generated token.
// If enabled is false, authentication is skipped (useful for development).
func NewTokenAuth(enabled bool) (*TokenAuth, error) {
ta := &TokenAuth{
enabled: enabled,
ExemptPaths: map[string]bool{
"/health": true,
"/api/health": true,
"/api/ready": true,
},
}
if enabled {
// Generate 32-byte random token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return nil, err
}
ta.token = hex.EncodeToString(tokenBytes)
}
return ta, nil
}
// Token returns the authentication token.
// Returns empty string if authentication is disabled.
func (ta *TokenAuth) Token() string {
ta.mu.RLock()
defer ta.mu.RUnlock()
return ta.token
}
// IsEnabled returns whether token authentication is enabled.
func (ta *TokenAuth) IsEnabled() bool {
ta.mu.RLock()
defer ta.mu.RUnlock()
return ta.enabled
}
// Middleware returns HTTP middleware that enforces token authentication.
func (ta *TokenAuth) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ta.mu.RLock()
enabled := ta.enabled
token := ta.token
exempt := ta.ExemptPaths[r.URL.Path]
ta.mu.RUnlock()
// Skip auth if disabled or path is exempt
if !enabled || exempt {
next.ServeHTTP(w, r)
return
}
// Check for token in header
providedToken := r.Header.Get("X-Auth-Token")
if providedToken == "" {
// Also check Authorization header with Bearer scheme
auth := r.Header.Get("Authorization")
if bearer, found := strings.CutPrefix(auth, "Bearer "); found {
providedToken = bearer
}
}
if providedToken != token {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
// ExpensiveOperationLimiter provides stricter rate limiting for expensive operations.
// It wraps the base per-client rate limiter with additional per-operation limits.
type ExpensiveOperationLimiter struct {
// Track last execution time per operation type
lastRebuild int64 // Unix timestamp
rebuildCooldown int64 // Minimum seconds between rebuilds
mu sync.Mutex
}
// NewExpensiveOperationLimiter creates a limiter for expensive operations.
func NewExpensiveOperationLimiter() *ExpensiveOperationLimiter {
return &ExpensiveOperationLimiter{
rebuildCooldown: 300, // 5 minutes between rebuilds
}
}
// CanRebuild checks if a vector rebuild operation is allowed.
// Returns false if a rebuild was triggered too recently.
func (eol *ExpensiveOperationLimiter) CanRebuild() bool {
eol.mu.Lock()
defer eol.mu.Unlock()
now := unixNow()
if now-eol.lastRebuild < eol.rebuildCooldown {
return false
}
eol.lastRebuild = now
return true
}
// unixNow returns current Unix timestamp.
// Separated for easier testing.
func unixNow() int64 {
return time.Now().Unix()
}
// RequestID middleware adds a unique request ID to each request.
// The ID is added to the context and response headers for tracing.
func RequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check for existing request ID from client
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
// Generate new request ID
idBytes := make([]byte, 8)
if _, err := rand.Read(idBytes); err == nil {
requestID = hex.EncodeToString(idBytes)
} else {
requestID = fmt.Sprintf("%d", time.Now().UnixNano())
}
}
// Add to response header
w.Header().Set("X-Request-ID", requestID)
// Add to context
ctx := context.WithValue(r.Context(), requestIDKey{}, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetRequestID retrieves the request ID from the context.
func GetRequestID(ctx context.Context) string {
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
return id
}
return ""
}
// RequireJSONContentType middleware validates that POST/PUT/PATCH requests
// have application/json Content-Type header.
func RequireJSONContentType(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only check for methods that typically have bodies
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
ct := r.Header.Get("Content-Type")
// Allow empty Content-Type for requests without body
if ct != "" && !strings.HasPrefix(ct, "application/json") {
http.Error(w, "Content-Type must be application/json", http.StatusUnsupportedMediaType)
return
}
}
next.ServeHTTP(w, r)
})
}
// ValidateProjectName checks if a project name is safe to use.
// Returns an error if the name contains path traversal or invalid characters.
func ValidateProjectName(project string) error {
if project == "" {
return nil // Empty is allowed (means no filter)
}
// Check for path traversal
if strings.Contains(project, "..") {
return fmt.Errorf("invalid project name: path traversal detected")
}
// Check for valid characters
if !projectNamePattern.MatchString(project) {
return fmt.Errorf("invalid project name: only alphanumeric, underscore, dash, dot, and slash allowed")
}
// Max length check
if len(project) > 500 {
return fmt.Errorf("project name too long (max 500 chars)")
}
return nil
}
// BulkOperationLimiter provides rate limiting for bulk operations.
// Prevents DoS via repeated bulk requests.
type BulkOperationLimiter struct {
lastBulkOp int64 // Unix timestamp
cooldown int64 // Minimum seconds between operations
mu sync.Mutex
}
// NewBulkOperationLimiter creates a limiter for bulk operations.
func NewBulkOperationLimiter(cooldownSeconds int64) *BulkOperationLimiter {
return &BulkOperationLimiter{
cooldown: cooldownSeconds,
}
}
// CanExecute checks if a bulk operation is allowed.
// Returns false if a bulk operation was triggered too recently.
func (bol *BulkOperationLimiter) CanExecute() bool {
bol.mu.Lock()
defer bol.mu.Unlock()
now := unixNow()
if now-bol.lastBulkOp < bol.cooldown {
return false
}
bol.lastBulkOp = now
return true
}
// TimeSinceLastOp returns seconds since the last bulk operation.
func (bol *BulkOperationLimiter) TimeSinceLastOp() int64 {
bol.mu.Lock()
defer bol.mu.Unlock()
return unixNow() - bol.lastBulkOp
}
// CooldownRemaining returns seconds remaining in the cooldown period.
// Returns 0 if no cooldown is active.
func (bol *BulkOperationLimiter) CooldownRemaining() int64 {
bol.mu.Lock()
defer bol.mu.Unlock()
remaining := bol.cooldown - (unixNow() - bol.lastBulkOp)
if remaining < 0 {
return 0
}
return remaining
}
+515
View File
@@ -0,0 +1,515 @@
package worker
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestSecurityHeaders(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Check all security headers are set
tests := []struct {
header string
expected string
}{
{"X-Frame-Options", "DENY"},
{"X-Content-Type-Options", "nosniff"},
{"X-XSS-Protection", "1; mode=block"},
{"Referrer-Policy", "strict-origin-when-cross-origin"},
}
for _, tt := range tests {
if got := rr.Header().Get(tt.header); got != tt.expected {
t.Errorf("SecurityHeaders() %s = %q, want %q", tt.header, got, tt.expected)
}
}
}
func TestSecurityHeaders_CORS(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
origin string
expectedOrigin string
expectCORS bool
}{
{
name: "localhost:37778 origin allowed",
origin: "http://localhost:37778",
expectCORS: true,
expectedOrigin: "http://localhost:37778",
},
{
name: "127.0.0.1:5173 origin allowed",
origin: "http://127.0.0.1:5173",
expectCORS: true,
expectedOrigin: "http://127.0.0.1:5173",
},
{
name: "localhost without port allowed",
origin: "http://localhost",
expectCORS: true,
expectedOrigin: "http://localhost",
},
{
name: "external origin blocked",
origin: "http://evil.com",
expectCORS: false,
},
{
name: "evil-localhost.com bypass attempt blocked",
origin: "http://evil-localhost.com",
expectCORS: false,
},
{
name: "localhost subdomain bypass attempt blocked",
origin: "http://localhost.evil.com",
expectCORS: false,
},
{
name: "unknown localhost port blocked",
origin: "http://localhost:9999",
expectCORS: false,
},
{
name: "no origin header",
origin: "",
expectCORS: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
cors := rr.Header().Get("Access-Control-Allow-Origin")
if tt.expectCORS {
if cors != tt.expectedOrigin {
t.Errorf("Expected CORS origin %q, got %q", tt.expectedOrigin, cors)
}
} else {
if cors != "" {
t.Errorf("Expected no CORS header, got %q", cors)
}
}
})
}
}
func TestMaxBodySize(t *testing.T) {
maxSize := int64(100)
handler := MaxBodySize(maxSize)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
contentLength int64
expectedStatus int
}{
{
name: "within limit",
contentLength: 50,
expectedStatus: http.StatusOK,
},
{
name: "at limit",
contentLength: 100,
expectedStatus: http.StatusOK,
},
{
name: "exceeds limit",
contentLength: 150,
expectedStatus: http.StatusRequestEntityTooLarge,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/test", nil)
req.ContentLength = tt.contentLength
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("MaxBodySize() status = %d, want %d", rr.Code, tt.expectedStatus)
}
})
}
}
func TestTokenAuth(t *testing.T) {
t.Run("disabled auth allows all requests", func(t *testing.T) {
ta, err := NewTokenAuth(false)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with disabled auth, got %d", rr.Code)
}
})
t.Run("enabled auth requires token", func(t *testing.T) {
ta, err := NewTokenAuth(true)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Without token
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("Expected Unauthorized without token, got %d", rr.Code)
}
// With correct token in X-Auth-Token header
req = httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Auth-Token", ta.Token())
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with correct token, got %d", rr.Code)
}
// With correct token in Authorization header
req = httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+ta.Token())
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with Bearer token, got %d", rr.Code)
}
})
t.Run("exempt paths skip auth", func(t *testing.T) {
ta, err := NewTokenAuth(true)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
exemptPaths := []string{"/health", "/api/health", "/api/ready"}
for _, path := range exemptPaths {
req := httptest.NewRequest("GET", path, nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK for exempt path %s, got %d", path, rr.Code)
}
}
})
}
func TestExpensiveOperationLimiter(t *testing.T) {
limiter := NewExpensiveOperationLimiter()
// First rebuild should be allowed
if !limiter.CanRebuild() {
t.Error("First rebuild should be allowed")
}
// Immediate second rebuild should be blocked
if limiter.CanRebuild() {
t.Error("Immediate second rebuild should be blocked")
}
}
func TestRequestID(t *testing.T) {
handler := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request ID is in context
id := GetRequestID(r.Context())
if id == "" {
t.Error("Request ID should be set in context")
}
w.WriteHeader(http.StatusOK)
}))
t.Run("generates new request ID", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-ID") == "" {
t.Error("X-Request-ID header should be set")
}
})
t.Run("uses existing request ID", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Request-ID", "test-id-12345")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-ID") != "test-id-12345" {
t.Errorf("Expected X-Request-ID to be test-id-12345, got %s", rr.Header().Get("X-Request-ID"))
}
})
}
func TestRequireJSONContentType(t *testing.T) {
handler := RequireJSONContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
method string
contentType string
expectedStatus int
}{
{
name: "GET request without content-type",
method: "GET",
contentType: "",
expectedStatus: http.StatusOK,
},
{
name: "POST with application/json",
method: "POST",
contentType: "application/json",
expectedStatus: http.StatusOK,
},
{
name: "POST with application/json; charset=utf-8",
method: "POST",
contentType: "application/json; charset=utf-8",
expectedStatus: http.StatusOK,
},
{
name: "POST without content-type (empty body)",
method: "POST",
contentType: "",
expectedStatus: http.StatusOK,
},
{
name: "POST with text/plain rejected",
method: "POST",
contentType: "text/plain",
expectedStatus: http.StatusUnsupportedMediaType,
},
{
name: "PUT with application/xml rejected",
method: "PUT",
contentType: "application/xml",
expectedStatus: http.StatusUnsupportedMediaType,
},
{
name: "PATCH with form-urlencoded rejected",
method: "PATCH",
contentType: "application/x-www-form-urlencoded",
expectedStatus: http.StatusUnsupportedMediaType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/test", nil)
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
}
})
}
}
func TestValidateProjectName(t *testing.T) {
tests := []struct {
name string
project string
wantError bool
}{
{
name: "empty project allowed",
project: "",
wantError: false,
},
{
name: "simple project name",
project: "my-project",
wantError: false,
},
{
name: "project with path",
project: "org/my-project",
wantError: false,
},
{
name: "project with underscore",
project: "my_project_v2",
wantError: false,
},
{
name: "project with dot",
project: "my.project.name",
wantError: false,
},
{
name: "path traversal attack",
project: "../../../etc/passwd",
wantError: true,
},
{
name: "hidden path traversal",
project: "project/../../secret",
wantError: true,
},
{
name: "shell injection attempt",
project: "project; rm -rf /",
wantError: true,
},
{
name: "backtick injection",
project: "project`whoami`",
wantError: true,
},
{
name: "special characters",
project: "project$HOME",
wantError: true,
},
{
name: "too long project name",
project: string(make([]byte, 501)),
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateProjectName(tt.project)
if tt.wantError && err == nil {
t.Errorf("Expected error for project %q, got nil", tt.project)
}
if !tt.wantError && err != nil {
t.Errorf("Unexpected error for project %q: %v", tt.project, err)
}
})
}
}
func TestBulkOperationLimiter(t *testing.T) {
limiter := NewBulkOperationLimiter(1) // 1 second cooldown for testing
// First operation should be allowed
if !limiter.CanExecute() {
t.Error("First bulk operation should be allowed")
}
// Immediate second operation should be blocked
if limiter.CanExecute() {
t.Error("Immediate second bulk operation should be blocked")
}
// Check cooldown remaining
remaining := limiter.CooldownRemaining()
if remaining <= 0 || remaining > 1 {
t.Errorf("Expected cooldown remaining between 0-1 seconds, got %d", remaining)
}
// Check time since last op
since := limiter.TimeSinceLastOp()
if since < 0 || since > 1 {
t.Errorf("Expected time since last op between 0-1 seconds, got %d", since)
}
}
func TestSecurityHeaders_CSP(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Check CSP header is set
csp := rr.Header().Get("Content-Security-Policy")
if csp == "" {
t.Error("Content-Security-Policy header should be set")
}
if csp != "default-src 'self'" {
t.Errorf("Expected CSP to be \"default-src 'self'\", got %q", csp)
}
// Check Permissions-Policy header
pp := rr.Header().Get("Permissions-Policy")
if pp == "" {
t.Error("Permissions-Policy header should be set")
}
}
func TestSecurityHeaders_Preflight(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("OPTIONS", "/test", nil)
req.Header.Set("Origin", "http://localhost:3000")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// OPTIONS should return 204 No Content
if rr.Code != http.StatusNoContent {
t.Errorf("Expected status 204 for OPTIONS, got %d", rr.Code)
}
// CORS headers should be set for allowed origin
if rr.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" {
t.Errorf("CORS origin should be set for allowed origin")
}
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
t.Error("Access-Control-Allow-Methods should be set")
}
}
+226
View File
@@ -0,0 +1,226 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"net/http"
"sync"
"time"
)
// RateLimiter implements a token bucket rate limiter.
type RateLimiter struct {
lastUpdate time.Time
rate float64
burst int
tokens float64
requests int64
rejected int64
mu sync.Mutex
}
// LastUpdateTime returns the last update time.
// Thread-safe - acquires the limiter's lock.
func (rl *RateLimiter) LastUpdateTime() time.Time {
rl.mu.Lock()
defer rl.mu.Unlock()
return rl.lastUpdate
}
// lastUpdateTimeUnlocked returns the last update time without locking.
// Caller must hold rl.mu.
func (rl *RateLimiter) lastUpdateTimeUnlocked() time.Time {
return rl.lastUpdate
}
// NewRateLimiter creates a new rate limiter.
// rate is the number of requests per second to allow.
// burst is the maximum burst of requests to allow.
func NewRateLimiter(rate float64, burst int) *RateLimiter {
return &RateLimiter{
rate: rate,
burst: burst,
tokens: float64(burst),
lastUpdate: time.Now(),
}
}
// Allow checks if a request should be allowed.
// Returns true if the request is allowed, false if rate limited.
func (rl *RateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.requests++
// Calculate tokens added since last update
now := time.Now()
elapsed := now.Sub(rl.lastUpdate).Seconds()
rl.tokens += elapsed * rl.rate
if rl.tokens > float64(rl.burst) {
rl.tokens = float64(rl.burst)
}
rl.lastUpdate = now
// Check if we have a token available
if rl.tokens >= 1 {
rl.tokens--
return true
}
rl.rejected++
return false
}
// Stats returns rate limiter statistics.
func (rl *RateLimiter) Stats() map[string]any {
rl.mu.Lock()
defer rl.mu.Unlock()
return map[string]any{
"rate": rl.rate,
"burst": rl.burst,
"current_tokens": rl.tokens,
"total_requests": rl.requests,
"rejected": rl.rejected,
"rejection_rate": float64(rl.rejected) / max(float64(rl.requests), 1),
}
}
// RateLimitMiddleware creates middleware that applies rate limiting.
// Uses a shared rate limiter for all requests.
func RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// PerClientRateLimiter implements per-client rate limiting.
type PerClientRateLimiter struct {
lastCleanup time.Time
clients map[string]*RateLimiter
rate float64
burst int
cleanupInterval time.Duration
maxIdleTime time.Duration
mu sync.Mutex
}
// NewPerClientRateLimiter creates a new per-client rate limiter.
func NewPerClientRateLimiter(rate float64, burst int) *PerClientRateLimiter {
return &PerClientRateLimiter{
rate: rate,
burst: burst,
clients: make(map[string]*RateLimiter),
cleanupInterval: 5 * time.Minute,
maxIdleTime: 10 * time.Minute,
lastCleanup: time.Now(),
}
}
// getLimiter returns a rate limiter for the given client key.
func (pcrl *PerClientRateLimiter) getLimiter(key string) *RateLimiter {
pcrl.mu.Lock()
defer pcrl.mu.Unlock()
// Periodic cleanup of idle clients
if time.Since(pcrl.lastCleanup) > pcrl.cleanupInterval {
pcrl.cleanupLocked()
}
limiter, exists := pcrl.clients[key]
if !exists {
limiter = NewRateLimiter(pcrl.rate, pcrl.burst)
pcrl.clients[key] = limiter
}
return limiter
}
// cleanupLocked removes idle limiters. Must be called with lock held.
// Uses consistent lock ordering: always acquire limiter.mu while holding pcrl.mu.
// This is safe because the limiter.mu critical section is brief (just reading lastUpdate).
func (pcrl *PerClientRateLimiter) cleanupLocked() {
now := time.Now()
keysToDelete := make([]string, 0)
// Check each limiter while holding pcrl.mu
// We briefly acquire limiter.mu but the critical section is minimal
for key, limiter := range pcrl.clients {
limiter.mu.Lock()
lastUpdate := limiter.lastUpdateTimeUnlocked()
limiter.mu.Unlock()
if now.Sub(lastUpdate) > pcrl.maxIdleTime {
keysToDelete = append(keysToDelete, key)
}
}
// Delete collected keys
for _, key := range keysToDelete {
delete(pcrl.clients, key)
}
pcrl.lastCleanup = now
}
// Allow checks if a request from the given client should be allowed.
func (pcrl *PerClientRateLimiter) Allow(clientKey string) bool {
return pcrl.getLimiter(clientKey).Allow()
}
// Stats returns aggregate statistics.
// Uses two-phase approach to avoid nested lock acquisition.
func (pcrl *PerClientRateLimiter) Stats() map[string]any {
// Phase 1: Collect limiters under pcrl.mu
pcrl.mu.Lock()
rate := pcrl.rate
burst := pcrl.burst
activeClients := len(pcrl.clients)
limiters := make([]*RateLimiter, 0, activeClients)
for _, limiter := range pcrl.clients {
limiters = append(limiters, limiter)
}
pcrl.mu.Unlock()
// Phase 2: Collect stats from each limiter (only acquiring limiter.mu, not pcrl.mu)
var totalRequests, totalRejected int64
for _, limiter := range limiters {
limiter.mu.Lock()
totalRequests += limiter.requests
totalRejected += limiter.rejected
limiter.mu.Unlock()
}
return map[string]any{
"rate": rate,
"burst": burst,
"active_clients": activeClients,
"total_requests": totalRequests,
"total_rejected": totalRejected,
}
}
// PerClientRateLimitMiddleware creates middleware that applies per-client rate limiting.
// Uses X-Forwarded-For or RemoteAddr to identify clients.
func PerClientRateLimitMiddleware(limiter *PerClientRateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client identifier (prefer X-Real-IP from RealIP middleware)
clientKey := r.RemoteAddr
if xff := r.Header.Get("X-Real-IP"); xff != "" {
clientKey = xff
}
if !limiter.Allow(clientKey) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
+428 -43
View File
@@ -4,11 +4,15 @@ package sdk
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
@@ -20,8 +24,178 @@ import (
"github.com/rs/zerolog/log"
)
// CircuitBreaker implements a simple circuit breaker pattern for CLI calls.
type CircuitBreaker struct {
failures int64 // Current failure count
lastFailure int64 // Unix timestamp of last failure
threshold int64 // Number of failures before opening
resetTimeout int64 // Seconds to wait before trying again
state int32 // 0=closed, 1=open, 2=half-open
}
const (
circuitClosed int32 = 0
circuitOpen int32 = 1
circuitHalfOpen int32 = 2
)
// NewCircuitBreaker creates a new circuit breaker.
func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker {
return &CircuitBreaker{
threshold: threshold,
resetTimeout: resetTimeout,
}
}
// Allow checks if a request should be allowed through.
func (cb *CircuitBreaker) Allow() bool {
state := atomic.LoadInt32(&cb.state)
if state == circuitClosed {
return true
}
if state == circuitOpen {
// Check if reset timeout has passed
lastFail := atomic.LoadInt64(&cb.lastFailure)
if time.Now().Unix()-lastFail > cb.resetTimeout {
// Transition to half-open
atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen)
return true
}
return false
}
// Half-open: allow one request through
return true
}
// RecordSuccess records a successful call.
func (cb *CircuitBreaker) RecordSuccess() {
atomic.StoreInt64(&cb.failures, 0)
atomic.StoreInt32(&cb.state, circuitClosed)
}
// RecordFailure records a failed call.
func (cb *CircuitBreaker) RecordFailure() {
failures := atomic.AddInt64(&cb.failures, 1)
atomic.StoreInt64(&cb.lastFailure, time.Now().Unix())
if failures >= cb.threshold {
atomic.StoreInt32(&cb.state, circuitOpen)
log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled")
}
}
// State returns the current state as a string.
func (cb *CircuitBreaker) State() string {
switch atomic.LoadInt32(&cb.state) {
case circuitOpen:
return "open"
case circuitHalfOpen:
return "half-open"
default:
return "closed"
}
}
// CircuitBreakerMetrics contains metrics about the circuit breaker state.
type CircuitBreakerMetrics struct {
State string `json:"state"`
Failures int64 `json:"failures"`
Threshold int64 `json:"threshold"`
ResetTimeoutSecs int64 `json:"reset_timeout_secs"`
LastFailureUnix int64 `json:"last_failure_unix,omitempty"`
SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"`
}
// Metrics returns the current metrics of the circuit breaker.
func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics {
failures := atomic.LoadInt64(&cb.failures)
lastFail := atomic.LoadInt64(&cb.lastFailure)
state := cb.State()
metrics := CircuitBreakerMetrics{
State: state,
Failures: failures,
Threshold: cb.threshold,
ResetTimeoutSecs: cb.resetTimeout,
}
if lastFail > 0 {
metrics.LastFailureUnix = lastFail
if state == "open" {
remaining := cb.resetTimeout - (time.Now().Unix() - lastFail)
if remaining > 0 {
metrics.SecondsUntilReset = remaining
}
}
}
return metrics
}
// RequestDeduplicator tracks recent requests to prevent duplicates.
type RequestDeduplicator struct {
seen map[string]int64 // hash -> timestamp
mu sync.RWMutex
ttlSecs int64
maxSize int
}
// NewRequestDeduplicator creates a new deduplicator.
func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator {
return &RequestDeduplicator{
seen: make(map[string]int64),
ttlSecs: ttlSecs,
maxSize: maxSize,
}
}
// IsDuplicate checks if a request hash was seen recently.
func (d *RequestDeduplicator) IsDuplicate(hash string) bool {
now := time.Now().Unix()
d.mu.RLock()
ts, exists := d.seen[hash]
d.mu.RUnlock()
if exists && now-ts < d.ttlSecs {
return true
}
return false
}
// Record marks a request hash as seen.
func (d *RequestDeduplicator) Record(hash string) {
now := time.Now().Unix()
d.mu.Lock()
defer d.mu.Unlock()
// Evict old entries if at capacity
if len(d.seen) >= d.maxSize {
threshold := now - d.ttlSecs
for k, ts := range d.seen {
if ts < threshold {
delete(d.seen, k)
}
}
}
d.seen[hash] = now
}
// hashRequest creates a hash of a request for deduplication.
func hashRequest(toolName, input, output string) string {
h := sha256.New()
h.Write([]byte(toolName))
h.Write([]byte(input))
h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output
return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient
}
// BroadcastFunc is a callback for broadcasting events to SSE clients.
type BroadcastFunc func(event map[string]interface{})
type BroadcastFunc func(event map[string]any)
// SyncObservationFunc is a callback for syncing observations to vector DB.
type SyncObservationFunc func(obs *models.Observation)
@@ -29,16 +203,26 @@ type SyncObservationFunc func(obs *models.Observation)
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
type SyncSummaryFunc func(summary *models.SessionSummary)
// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations.
// This prevents unbounded goroutine spawning during high-volume observation ingestion.
const MaxVectorSyncWorkers = 8
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
// Field order optimized for memory alignment (fieldalignment).
type Processor struct {
observationStore *gorm.ObservationStore
summaryStore *gorm.SummaryStore
broadcastFunc BroadcastFunc
syncObservationFunc SyncObservationFunc
syncSummaryFunc SyncSummaryFunc
circuitBreaker *CircuitBreaker
deduplicator *RequestDeduplicator
vectorSyncChan chan *models.Observation
vectorSyncDone chan struct{}
sem chan struct{}
claudePath string
model string
vectorSyncWg sync.WaitGroup
}
// SetBroadcastFunc sets the broadcast callback for SSE events.
@@ -57,7 +241,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
}
// broadcast sends an event via the broadcast callback if set.
func (p *Processor) broadcast(event map[string]interface{}) {
func (p *Processor) broadcast(event map[string]any) {
if p.broadcastFunc != nil {
p.broadcastFunc(event)
}
@@ -93,9 +277,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su
observationStore: observationStore,
summaryStore: summaryStore,
sem: make(chan struct{}, MaxConcurrentCLICalls),
circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s
deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel
vectorSyncDone: make(chan struct{}),
}, nil
}
// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations.
// Call this after setting the sync function via SetSyncObservationFunc.
func (p *Processor) StartVectorSyncWorkers() {
for i := 0; i < MaxVectorSyncWorkers; i++ {
p.vectorSyncWg.Add(1)
go p.vectorSyncWorker()
}
log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started")
}
// StopVectorSyncWorkers gracefully stops the worker pool.
func (p *Processor) StopVectorSyncWorkers() {
close(p.vectorSyncDone)
p.vectorSyncWg.Wait()
log.Info().Msg("Vector sync worker pool stopped")
}
// vectorSyncWorker is a worker goroutine that processes vector sync requests.
func (p *Processor) vectorSyncWorker() {
defer p.vectorSyncWg.Done()
for {
select {
case <-p.vectorSyncDone:
// Drain remaining items before exiting
for {
select {
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
default:
return
}
}
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
}
}
}
// CircuitBreakerState returns the current state of the circuit breaker.
func (p *Processor) CircuitBreakerState() string {
return p.circuitBreaker.State()
}
// CircuitBreakerMetrics returns detailed metrics about the circuit breaker.
func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics {
return p.circuitBreaker.Metrics()
}
// IsAvailable checks if the Claude CLI is available for processing.
func (p *Processor) IsAvailable() bool {
_, err := os.Stat(p.claudePath)
@@ -103,7 +343,7 @@ func (p *Processor) IsAvailable() bool {
}
// ProcessObservation processes a single tool observation and extracts insights.
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error {
// Skip certain tools that aren't worth processing
if shouldSkipTool(toolName) {
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
@@ -120,11 +360,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
return nil
}
// Check for duplicate request within TTL window
reqHash := hashRequest(toolName, inputStr, outputStr)
if p.deduplicator.IsDuplicate(reqHash) {
log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)")
return nil
}
// Check circuit breaker before making CLI call
if !p.circuitBreaker.Allow() {
log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call")
return fmt.Errorf("circuit breaker open")
}
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
// Note: Removed the "file already has observations" check
// Each tool execution can produce unique insights even for the same file
// Similarity-based deduplication will handle true duplicates
// Record this request to prevent duplicates
p.deduplicator.Record(reqHash)
// Build the prompt
exec := ToolExecution{
@@ -146,9 +398,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
// Call Claude Code CLI
response, err := p.callClaudeCLI(ctx, prompt)
if err != nil {
p.circuitBreaker.RecordFailure()
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
return err
}
p.circuitBreaker.RecordSuccess()
// Parse observations from response
observations := ParseObservations(response, sdkSessionID)
@@ -199,16 +453,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
Int("trackedFiles", len(obs.FileMtimes)).
Msg("Observation stored")
// Sync to vector DB if callback is set
if p.syncObservationFunc != nil {
// Sync to vector DB via bounded worker pool (non-blocking to reduce latency)
if p.syncObservationFunc != nil && p.vectorSyncChan != nil {
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
fullObs.ID = id
fullObs.CreatedAtEpoch = createdAtEpoch
p.syncObservationFunc(fullObs)
// Non-blocking send to worker pool - drops if channel is full
select {
case p.vectorSyncChan <- fullObs:
// Sent to worker pool
default:
// Channel full, fall back to direct sync in goroutine (bounded by channel buffer)
log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine")
go func(obsToSync *models.Observation) {
p.syncObservationFunc(obsToSync)
}(fullObs)
}
}
// Broadcast new observation event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "observation",
"action": "created",
"id": id,
@@ -310,7 +574,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
}
// Broadcast new summary event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "summary",
"action": "created",
"id": id,
@@ -320,8 +584,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
return nil
}
// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI.
// This prevents resource exhaustion from extremely large prompts.
const MaxPromptSize = 100 * 1024 // 100KB
// sanitizePrompt removes null bytes and control characters from a prompt.
// Keeps newlines, tabs, and carriage returns as they're valid in prompts.
func sanitizePrompt(s string) string {
return strings.Map(func(r rune) rune {
// Keep printable ASCII, extended Unicode, and common whitespace
if r >= 32 || r == '\n' || r == '\t' || r == '\r' {
return r
}
// Remove null bytes and other control characters
return -1
}, s)
}
// callClaudeCLI calls the Claude Code CLI with the given prompt.
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
// Validate and sanitize prompt
if len(prompt) > MaxPromptSize {
return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize)
}
prompt = sanitizePrompt(prompt)
// Build the full prompt with system instructions
fullPrompt := systemPrompt + "\n\n" + prompt
@@ -418,8 +705,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
return true
}
// Skip if output indicates an error or empty result
// Pre-compute lowercase strings once to avoid repeated allocations
lowerOutput := strings.ToLower(outputStr)
lowerInput := strings.ToLower(inputStr)
// Skip if output indicates an error or empty result
trivialOutputs := []string{
"no matches found",
"file not found",
@@ -443,13 +733,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
// Skip reading config files that rarely contain project-specific insights
boringFiles := []string{
"package-lock.json", "yarn.lock", "pnpm-lock.yaml",
"go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock",
"go.sum", "cargo.lock", "gemfile.lock", "poetry.lock",
".gitignore", ".dockerignore", ".eslintignore",
"tsconfig.json", "jsconfig.json", "vite.config",
"tailwind.config", "postcss.config",
}
for _, boring := range boringFiles {
if strings.Contains(inputStr, boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -461,14 +751,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
case "Bash":
// Skip simple status commands
// Skip simple status commands (use pre-computed lowerInput)
boringCommands := []string{
"git status", "git diff", "git log", "git branch",
"ls ", "pwd", "echo ", "cat ", "which ", "type ",
"npm list", "npm outdated", "npm audit",
}
for _, boring := range boringCommands {
if strings.Contains(strings.ToLower(inputStr), boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -478,7 +768,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
// toJSONString converts an interface to a JSON string.
func toJSONString(v interface{}) string {
func toJSONString(v any) string {
if v == nil {
return ""
}
@@ -492,38 +782,132 @@ func toJSONString(v interface{}) string {
return string(b)
}
// safeResolvePath resolves a path relative to cwd and validates it doesn't escape the cwd directory.
// Returns the resolved absolute path and true if valid, or empty string and false if path traversal detected.
// This function is a security sanitizer for path traversal attacks.
func safeResolvePath(path, cwd string) (string, bool) {
// Clean the input path to normalize any .. or . components
cleanPath := filepath.Clean(path)
// Reject paths that explicitly contain parent directory traversal after cleaning
if strings.Contains(cleanPath, "..") {
return "", false
}
if filepath.IsAbs(cleanPath) {
// For absolute paths, verify they're within cwd if cwd is specified
if cwd != "" {
cleanCwd := filepath.Clean(cwd)
if !strings.HasPrefix(cleanPath, cleanCwd+string(filepath.Separator)) && cleanPath != cleanCwd {
return "", false
}
}
return cleanPath, true
}
if cwd == "" {
return cleanPath, true
}
// Clean the cwd first
cleanCwd := filepath.Clean(cwd)
// Join and clean the path
absPath := filepath.Join(cleanCwd, cleanPath)
// Use filepath.Rel to verify the path is actually within cwd
// If Rel returns a path starting with "..", it escapes the base
rel, err := filepath.Rel(cleanCwd, absPath)
if err != nil || strings.HasPrefix(rel, "..") {
return "", false
}
return absPath, true
}
// captureFileMtimes captures current modification times for tracked files.
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
// For large file lists (>10 files), uses parallel stat calls for better performance.
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
mtimes := make(map[string]int64)
// Combine all unique file paths
allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified))
for _, path := range filesRead {
allPaths[path] = struct{}{}
}
for _, path := range filesModified {
allPaths[path] = struct{}{}
}
// Helper to get mtime for a file path
getMtime := func(path string) (int64, bool) {
// Resolve relative paths against cwd
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
// For small lists, use sequential processing (goroutine overhead not worth it)
if len(allPaths) <= 10 {
return captureFileMtimesSequential(allPaths, cwd)
}
// For larger lists, parallelize with bounded concurrency
return captureFileMtimesParallel(allPaths, cwd)
}
// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists).
func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 {
mtimes := make(map[string]int64, len(paths))
for path := range paths {
absPath, ok := safeResolvePath(path, cwd)
if !ok {
// Skip paths that attempt directory traversal
continue
}
info, err := os.Stat(absPath)
if err != nil {
return 0, false
}
return info.ModTime().UnixMilli(), true
}
// Capture mtimes for all read files
for _, path := range filesRead {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
if err == nil {
mtimes[path] = info.ModTime().UnixMilli()
}
}
// Capture mtimes for all modified files
for _, path := range filesModified {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
}
return mtimes
}
// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency.
func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 {
type mtimeResult struct {
path string
mtime int64
}
results := make(chan mtimeResult, len(paths))
sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls
var wg sync.WaitGroup
for path := range paths {
wg.Add(1)
go func(p string) {
defer wg.Done()
sem <- struct{}{} // Acquire
defer func() { <-sem }() // Release
absPath, ok := safeResolvePath(p, cwd)
if !ok {
// Skip paths that attempt directory traversal
return
}
info, err := os.Stat(absPath)
if err == nil {
results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()}
}
}(path)
}
// Close results channel when all goroutines complete
go func() {
wg.Wait()
close(results)
}()
// Collect results
mtimes := make(map[string]int64, len(paths))
for res := range results {
mtimes[res.path] = res.mtime
}
return mtimes
@@ -538,12 +922,13 @@ func GetFileMtimes(paths []string, cwd string) map[string]int64 {
// GetFileContent reads file content for verification purposes.
// Returns content and ok status.
func GetFileContent(path, cwd string) (string, bool) {
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
absPath, ok := safeResolvePath(path, cwd)
if !ok {
// Reject paths that attempt directory traversal
return "", false
}
content, err := os.ReadFile(absPath) // #nosec G304 -- intentional file read for verification
content, err := os.ReadFile(absPath) // #nosec G304 -- path validated by safeResolvePath
if err != nil {
return "", false
}
+204
View File
@@ -974,3 +974,207 @@ func TestSyncSummaryFuncType(t *testing.T) {
}
assert.NotNil(t, fn)
}
// TestSanitizePrompt tests prompt sanitization for CLI safety.
func TestSanitizePrompt(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "normal text",
input: "Hello, world!",
expected: "Hello, world!",
},
{
name: "text with newlines",
input: "Line 1\nLine 2\nLine 3",
expected: "Line 1\nLine 2\nLine 3",
},
{
name: "text with tabs",
input: "Key:\tValue",
expected: "Key:\tValue",
},
{
name: "text with carriage return",
input: "Line 1\r\nLine 2",
expected: "Line 1\r\nLine 2",
},
{
name: "text with null bytes",
input: "Hello\x00World",
expected: "HelloWorld",
},
{
name: "text with control characters",
input: "Hello\x01\x02\x03World",
expected: "HelloWorld",
},
{
name: "text with bell character",
input: "Hello\x07World",
expected: "HelloWorld",
},
{
name: "text with backspace",
input: "Hello\x08World",
expected: "HelloWorld",
},
{
name: "text with form feed",
input: "Hello\x0cWorld",
expected: "HelloWorld",
},
{
name: "text with escape",
input: "Hello\x1bWorld",
expected: "HelloWorld",
},
{
name: "unicode text",
input: "Hello 世界 🌍",
expected: "Hello 世界 🌍",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "only control characters",
input: "\x00\x01\x02\x03",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizePrompt(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// TestMaxPromptSize tests that MaxPromptSize is reasonable.
func TestMaxPromptSize(t *testing.T) {
assert.Equal(t, 100*1024, MaxPromptSize)
}
// BenchmarkSanitizePrompt benchmarks the sanitize function.
func BenchmarkSanitizePrompt(b *testing.B) {
prompt := "Analyze the following code:\n```go\nfunc main() {\n\tfmt.Println(\"Hello, World!\")\n}\n```\n\nPlease identify any issues."
b.ResetTimer()
for i := 0; i < b.N; i++ {
sanitizePrompt(prompt)
}
}
// BenchmarkSanitizePromptWithControlChars benchmarks sanitization with control characters.
func BenchmarkSanitizePromptWithControlChars(b *testing.B) {
prompt := "Hello\x00World\x01Test\x02Data\x03End"
b.ResetTimer()
for i := 0; i < b.N; i++ {
sanitizePrompt(prompt)
}
}
// TestSafeResolvePath tests the path traversal protection.
func TestSafeResolvePath(t *testing.T) {
// Create a temporary directory for testing
tmpDir := t.TempDir()
tests := []struct {
name string
path string
cwd string
wantPath string
wantOk bool
}{
{
name: "simple relative path",
path: "file.txt",
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "file.txt"),
},
{
name: "nested relative path",
path: "subdir/file.txt",
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "subdir", "file.txt"),
},
{
name: "path traversal with ..",
path: "../etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "path traversal with multiple ..",
path: "../../etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "path traversal hidden in middle",
path: "subdir/../../../etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "just parent directory",
path: "..",
cwd: tmpDir,
wantOk: false,
},
{
name: "absolute path without cwd",
path: "/some/absolute/path",
cwd: "",
wantOk: true,
wantPath: "/some/absolute/path",
},
{
name: "relative path without cwd",
path: "relative/path",
cwd: "",
wantOk: true,
wantPath: "relative/path",
},
{
name: "current directory reference",
path: "./file.txt",
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "file.txt"),
},
{
name: "absolute path outside cwd",
path: "/etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "absolute path inside cwd",
path: filepath.Join(tmpDir, "inside.txt"),
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "inside.txt"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotPath, gotOk := safeResolvePath(tt.path, tt.cwd)
assert.Equal(t, tt.wantOk, gotOk, "ok status mismatch")
if tt.wantPath != "" && gotOk {
assert.Equal(t, tt.wantPath, gotPath, "path mismatch")
}
})
}
}
+678 -443
View File
File diff suppressed because it is too large Load Diff
+16 -4
View File
@@ -5,6 +5,8 @@ import (
"io/fs"
"net/http"
"strings"
"github.com/rs/zerolog/log"
)
//go:embed static/*
@@ -13,16 +15,22 @@ var staticFS embed.FS
// staticSubFS is the static subdirectory filesystem
var staticSubFS fs.FS
// staticInitErr stores any error from static filesystem initialization
var staticInitErr error
func init() {
var err error
staticSubFS, err = fs.Sub(staticFS, "static")
if err != nil {
panic("failed to create sub filesystem: " + err.Error())
staticSubFS, staticInitErr = fs.Sub(staticFS, "static")
if staticInitErr != nil {
log.Warn().Err(staticInitErr).Msg("Static filesystem initialization failed - dashboard will be unavailable")
}
}
// serveIndex serves the index.html file for the root path
func serveIndex(w http.ResponseWriter, r *http.Request) {
if staticInitErr != nil {
http.Error(w, "Dashboard unavailable: static files not initialized", http.StatusServiceUnavailable)
return
}
content, err := fs.ReadFile(staticSubFS, "index.html")
if err != nil {
http.Error(w, "Dashboard not found", http.StatusNotFound)
@@ -38,6 +46,10 @@ func serveIndex(w http.ResponseWriter, r *http.Request) {
// serveAssets serves static assets from the embedded filesystem
func serveAssets(w http.ResponseWriter, r *http.Request) {
if staticInitErr != nil {
http.Error(w, "Assets unavailable: static files not initialized", http.StatusServiceUnavailable)
return
}
// Strip the /assets/ prefix and serve the file
path := strings.TrimPrefix(r.URL.Path, "/")