Files
claude-mnemonic/internal/db/gorm/observation_store.go
T
lukaszraczylo d04b60517a Make things 'betterer' across the board (#23)
* Make things 'betterer' across the board

* fix: reorganize struct fields and config parameters for consistency

- [x] Reorder Config struct fields alphabetically and by related functionality
- [x] Reorganize Observation model fields with archival fields grouped together
- [x] Reorder ObservationStore fields to group related members
- [x] Reorder Store struct fields with health check caching grouped
- [x] Reorganize HealthInfo and PoolMetrics struct field order
- [x] Reorder maintenance Service struct fields logically
- [x] Reorganize MCP server handler parameter structs alphabetically
- [x] Reorder pattern detector candidate tracking fields
- [x] Reorganize search Manager struct fields by functionality
- [x] Reorder vector Client struct fields with mutex protections grouped
- [x] Reorganize handler request/response struct fields
- [x] Update handlers_test.go to expect wrapped response format
- [x] Reorder middleware TokenAuth and rate limiter fields
- [x] Reorganize Service struct fields with grouped functionality
- [x] Fix RateLimiter field ordering for clarity
- [x] Reorder CircuitBreaker metrics fields

* fix(security): improve JSON output safety and path traversal protection

- [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler
- [x] Remove escapeJSONString helper function in favor of standard JSON marshaling
- [x] Add safeResolvePath function to validate paths and prevent directory traversal
- [x] Apply path traversal validation in captureFileMtimes operations
- [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation

* fix(sdk): improve path traversal protection and allocation safety

- [x] Enhance safeResolvePath with stricter validation using filepath.Rel
- [x] Reject paths containing ".." after cleaning to prevent traversal
- [x] Validate absolute paths are within cwd when cwd is specified
- [x] Apply safeResolvePath validation to GetFileContent for consistency
- [x] Add comprehensive test coverage for path traversal protection
- [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
2026-01-11 01:51:20 +00:00

1104 lines
34 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package gorm provides GORM-based database operations for claude-mnemonic.
package gorm
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"gorm.io/gorm"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// MaxObservationsPerProject is the maximum number of observations to keep per project.
const MaxObservationsPerProject = 100
// cleanupQueueSize is the buffer size for the cleanup queue.
const cleanupQueueSize = 100
// commonWords is a package-level set for O(1) lookup of stop words.
// Created once at init time to avoid repeated map allocations.
var commonWords = map[string]struct{}{
"the": {}, "and": {}, "or": {}, "but": {}, "in": {},
"on": {}, "at": {}, "to": {}, "for": {}, "of": {},
"with": {}, "by": {}, "from": {}, "as": {}, "is": {},
"was": {}, "are": {}, "were": {}, "be": {}, "been": {},
"being": {}, "have": {}, "has": {}, "had": {}, "do": {},
"does": {}, "did": {}, "will": {}, "would": {}, "should": {},
"could": {}, "may": {}, "might": {}, "must": {}, "can": {},
}
// CleanupFunc is a callback for when observations are cleaned up.
// Receives the IDs of deleted observations for downstream cleanup (e.g., vector DB).
type CleanupFunc func(ctx context.Context, deletedIDs []int64)
// ObservationStore provides observation-related database operations using GORM.
type ObservationStore struct {
conflictStore any
relationStore any
db *gorm.DB
rawDB *sql.DB
cleanupFunc CleanupFunc
cleanupQueue chan string
stopCleanup chan struct{}
cleanupWg sync.WaitGroup
cleanupOnce sync.Once
cleanupStarted atomic.Bool
}
// NewObservationStore creates a new observation store.
// The conflictStore and relationStore parameters are optional (can be nil) and will be used in Phase 4.
func NewObservationStore(store *Store, cleanupFunc CleanupFunc, conflictStore, relationStore any) *ObservationStore {
s := &ObservationStore{
db: store.DB,
rawDB: store.GetRawDB(),
cleanupFunc: cleanupFunc,
conflictStore: conflictStore,
relationStore: relationStore,
cleanupQueue: make(chan string, cleanupQueueSize),
stopCleanup: make(chan struct{}),
}
// Start the cleanup worker
s.startCleanupWorker()
return s
}
// startCleanupWorker starts the background cleanup worker.
func (s *ObservationStore) startCleanupWorker() {
s.cleanupOnce.Do(func() {
s.cleanupStarted.Store(true)
s.cleanupWg.Add(1)
go s.cleanupWorker()
})
}
// cleanupWorker processes cleanup requests from the queue.
func (s *ObservationStore) cleanupWorker() {
defer s.cleanupWg.Done()
for {
select {
case <-s.stopCleanup:
// Drain remaining items in queue before exiting
for {
select {
case project := <-s.cleanupQueue:
s.processCleanup(project)
default:
return
}
}
case project := <-s.cleanupQueue:
s.processCleanup(project)
}
}
}
// processCleanup performs the actual cleanup for a project.
func (s *ObservationStore) processCleanup(project string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
deletedIDs, err := s.CleanupOldObservations(ctx, project)
if err != nil {
log.Warn().Err(err).Str("project", project).Msg("Failed to cleanup old observations")
return
}
if len(deletedIDs) > 0 && s.cleanupFunc != nil {
s.cleanupFunc(ctx, deletedIDs)
log.Debug().Str("project", project).Int("count", len(deletedIDs)).Msg("Cleaned up old observations")
}
}
// Close stops the cleanup worker and waits for it to finish.
// Safe to call even if the worker was never started.
func (s *ObservationStore) Close() {
// Only stop if worker was actually started to avoid deadlock
if s.cleanupStarted.Load() {
close(s.stopCleanup)
s.cleanupWg.Wait()
}
}
// SetCleanupFunc sets the callback for when observations are deleted during cleanup.
func (s *ObservationStore) SetCleanupFunc(fn CleanupFunc) {
s.cleanupFunc = fn
}
// StoreObservation stores a new observation.
func (s *ObservationStore) StoreObservation(ctx context.Context, sdkSessionID, project string, obs *models.ParsedObservation, promptNumber int, discoveryTokens int64) (int64, int64, error) {
now := time.Now()
nowEpoch := now.UnixMilli()
// Ensure session exists (auto-create if missing)
if err := EnsureSessionExists(ctx, s.db, sdkSessionID, project); err != nil {
return 0, 0, err
}
// Determine scope: use parsed scope if set, otherwise auto-determine from concepts
scope := obs.Scope
if scope == "" {
scope = models.DetermineScope(obs.Concepts)
}
dbObs := &Observation{
SDKSessionID: sdkSessionID,
Project: project,
Scope: scope,
Type: obs.Type,
Title: nullString(obs.Title),
Subtitle: nullString(obs.Subtitle),
Facts: models.JSONStringArray(obs.Facts),
Narrative: nullString(obs.Narrative),
Concepts: models.JSONStringArray(obs.Concepts),
FilesRead: models.JSONStringArray(obs.FilesRead),
FilesModified: models.JSONStringArray(obs.FilesModified),
FileMtimes: models.JSONInt64Map(obs.FileMtimes),
PromptNumber: nullInt64(promptNumber),
DiscoveryTokens: discoveryTokens,
CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: nowEpoch,
}
err := s.db.WithContext(ctx).Create(dbObs).Error
if err != nil {
return 0, 0, err
}
// Queue cleanup of old observations beyond the limit for this project (async to not block handler)
if project != "" {
select {
case s.cleanupQueue <- project:
// Successfully queued for cleanup
default:
// Queue is full, log a warning instead of silently dropping
log.Warn().Str("project", project).Msg("Cleanup queue full, skipping cleanup for this observation")
}
}
// Note: Conflict and relation detection intentionally omitted for now
// Will be added in Phase 4 when ConflictStore and RelationStore are implemented
return dbObs.ID, nowEpoch, nil
}
// ObservationUpdate contains fields that can be updated on an observation.
// Only non-nil fields will be updated.
type ObservationUpdate struct {
Title *string // New title
Subtitle *string // New subtitle
Narrative *string // New narrative
Facts *[]string // New facts (replaces existing)
Concepts *[]string // New concepts (replaces existing)
FilesRead *[]string // New files read (replaces existing)
FilesModified *[]string // New files modified (replaces existing)
Scope *string // New scope (project or global)
}
// UpdateObservation updates an existing observation with the provided fields.
// Only non-nil fields in the update struct will be modified.
// Returns the updated observation or an error.
func (s *ObservationStore) UpdateObservation(ctx context.Context, id int64, update *ObservationUpdate) (*models.Observation, error) {
if update == nil {
return nil, fmt.Errorf("update cannot be nil")
}
// First, verify the observation exists
var dbObs Observation
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("observation not found: %d", id)
}
return nil, err
}
// Build update map with only provided fields
updates := make(map[string]any)
if update.Title != nil {
updates["title"] = sql.NullString{String: *update.Title, Valid: true}
}
if update.Subtitle != nil {
updates["subtitle"] = sql.NullString{String: *update.Subtitle, Valid: true}
}
if update.Narrative != nil {
updates["narrative"] = sql.NullString{String: *update.Narrative, Valid: true}
}
if update.Facts != nil {
factsJSON, _ := json.Marshal(*update.Facts)
updates["facts"] = string(factsJSON)
}
if update.Concepts != nil {
conceptsJSON, _ := json.Marshal(*update.Concepts)
updates["concepts"] = string(conceptsJSON)
}
if update.FilesRead != nil {
filesReadJSON, _ := json.Marshal(*update.FilesRead)
updates["files_read"] = string(filesReadJSON)
}
if update.FilesModified != nil {
filesModifiedJSON, _ := json.Marshal(*update.FilesModified)
updates["files_modified"] = string(filesModifiedJSON)
}
if update.Scope != nil {
updates["scope"] = sql.NullString{String: *update.Scope, Valid: true}
}
// Add updated_at timestamp
updates["updated_at_epoch"] = sql.NullInt64{Int64: time.Now().Unix(), Valid: true}
if len(updates) == 0 {
// Nothing to update, just return existing observation
return toModelObservation(&dbObs), nil
}
// Perform the update
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("id = ?", id).Updates(updates).Error; err != nil {
return nil, fmt.Errorf("failed to update observation: %w", err)
}
// Fetch the updated observation
if err := s.db.WithContext(ctx).First(&dbObs, id).Error; err != nil {
return nil, err
}
return toModelObservation(&dbObs), nil
}
// GetObservationByID retrieves an observation by its ID.
func (s *ObservationStore) GetObservationByID(ctx context.Context, id int64) (*models.Observation, error) {
var dbObs Observation
err := s.db.WithContext(ctx).First(&dbObs, id).Error
if err == gorm.ErrRecordNotFound {
return nil, nil
}
if err != nil {
return nil, err
}
return toModelObservation(&dbObs), nil
}
// GetObservationsByIDs retrieves observations by a list of IDs.
func (s *ObservationStore) GetObservationsByIDs(ctx context.Context, ids []int64, orderBy string, limit int) ([]*models.Observation, error) {
if len(ids) == 0 {
return nil, nil
}
var dbObservations []Observation
query := s.db.WithContext(ctx).Where("id IN ?", ids)
// Apply ordering
switch orderBy {
case "date_asc":
query = query.Order("created_at_epoch ASC")
case "date_desc":
query = query.Order("created_at_epoch DESC")
case "importance":
query = query.Order("importance_score DESC, created_at_epoch DESC")
case "score_desc":
query = query.Order("importance_score DESC, created_at_epoch DESC")
default:
// Default: importance first, then recency
query = query.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC")
}
// Apply limit
if limit > 0 {
query = query.Limit(limit)
}
err := query.Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetObservationsByIDsPreserveOrder retrieves observations by IDs, preserving the input order.
// This is useful when the caller has already sorted/ranked the IDs (e.g., by vector similarity).
func (s *ObservationStore) GetObservationsByIDsPreserveOrder(ctx context.Context, ids []int64) ([]*models.Observation, error) {
if len(ids) == 0 {
return nil, nil
}
// Fetch all observations in a single query
var dbObservations []Observation
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
if err != nil {
return nil, err
}
// Build ID -> observation map for O(1) lookups
obsMap := make(map[int64]*Observation, len(dbObservations))
for i := range dbObservations {
obsMap[int64(dbObservations[i].ID)] = &dbObservations[i]
}
// Reconstruct in original order
result := make([]*models.Observation, 0, len(ids))
for _, id := range ids {
if obs, ok := obsMap[id]; ok {
result = append(result, toModelObservation(obs))
}
}
return result, nil
}
// BatchGetObservationsWithScores retrieves observations with associated scores.
// Returns a map of ID -> observation for efficient lookup.
func (s *ObservationStore) BatchGetObservationsWithScores(ctx context.Context, ids []int64) (map[int64]*models.Observation, error) {
if len(ids) == 0 {
return make(map[int64]*models.Observation), nil
}
// Fetch all observations in a single query
var dbObservations []Observation
err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&dbObservations).Error
if err != nil {
return nil, err
}
// Build result map
result := make(map[int64]*models.Observation, len(dbObservations))
for i := range dbObservations {
result[int64(dbObservations[i].ID)] = toModelObservation(&dbObservations[i])
}
return result, nil
}
// GetRecentObservations retrieves recent observations for a project.
// This includes project-scoped observations for the specified project AND global observations.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) GetRecentObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
Scopes(projectScopeFilter(project), importanceOrdering()).
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetActiveObservations retrieves recent non-superseded, non-archived observations for a project.
// This excludes observations that have been marked as superseded or archived.
// Results are ordered by importance_score DESC, then created_at_epoch DESC.
func (s *ObservationStore) GetActiveObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
Scopes(projectScopeFilter(project), activeObservationFilter(), importanceOrdering()).
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetSupersededObservations retrieves observations that have been superseded by newer ones.
// Results are ordered by created_at_epoch DESC.
func (s *ObservationStore) GetSupersededObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
Where("project = ? AND COALESCE(is_superseded, 0) = 1", project).
Order("created_at_epoch DESC").
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetObservationsByProjectStrict retrieves observations for a project (strict - no global observations).
func (s *ObservationStore) GetObservationsByProjectStrict(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
Where("project = ?", project).
Scopes(importanceOrdering()).
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetObservationCount returns the count of observations for a project.
func (s *ObservationStore) GetObservationCount(ctx context.Context, project string) (int, error) {
var count int64
err := s.db.WithContext(ctx).
Model(&Observation{}).
Where("project = ?", project).
Count(&count).Error
return int(count), err
}
// GetAllRecentObservations retrieves recent observations across all projects.
func (s *ObservationStore) GetAllRecentObservations(ctx context.Context, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
Scopes(importanceOrdering()).
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetAllRecentObservationsPaginated retrieves recent observations with pagination.
func (s *ObservationStore) GetAllRecentObservationsPaginated(ctx context.Context, limit, offset int) ([]*models.Observation, int64, error) {
var dbObservations []Observation
var total int64
// Get total count
if err := s.db.WithContext(ctx).Model(&Observation{}).Count(&total).Error; err != nil {
return nil, 0, err
}
// Get paginated results
err := s.db.WithContext(ctx).
Scopes(importanceOrdering()).
Limit(limit).
Offset(offset).
Find(&dbObservations).Error
if err != nil {
return nil, 0, err
}
return toModelObservations(dbObservations), total, nil
}
// GetObservationsByProjectStrictPaginated retrieves observations strictly from a project with pagination.
func (s *ObservationStore) GetObservationsByProjectStrictPaginated(ctx context.Context, project string, limit, offset int) ([]*models.Observation, int64, error) {
var dbObservations []Observation
var total int64
// Get total count for project
if err := s.db.WithContext(ctx).Model(&Observation{}).Where("project = ?", project).Count(&total).Error; err != nil {
return nil, 0, err
}
// Get paginated results
err := s.db.WithContext(ctx).
Where("project = ?", project).
Scopes(importanceOrdering()).
Limit(limit).
Offset(offset).
Find(&dbObservations).Error
if err != nil {
return nil, 0, err
}
return toModelObservations(dbObservations), total, nil
}
// GetAllObservations retrieves all observations (for vector rebuild).
// Note: For large datasets, prefer GetAllObservationsIterator to avoid memory issues.
func (s *ObservationStore) GetAllObservations(ctx context.Context) ([]*models.Observation, error) {
var dbObservations []Observation
err := s.db.WithContext(ctx).
Order("id").
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetAllObservationsIterator returns observations in batches to avoid loading all into memory.
// The callback is called for each batch. Return false from callback to stop iteration.
// batchSize controls how many observations are loaded at once (default 500 if <= 0).
func (s *ObservationStore) GetAllObservationsIterator(ctx context.Context, batchSize int, callback func([]*models.Observation) bool) error {
if batchSize <= 0 {
batchSize = 500
}
var lastID int64 = 0
for {
// Check context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var dbObservations []Observation
err := s.db.WithContext(ctx).
Where("id > ?", lastID).
Order("id ASC").
Limit(batchSize).
Find(&dbObservations).Error
if err != nil {
return err
}
if len(dbObservations) == 0 {
break // No more observations
}
// Update cursor for next batch
lastID = dbObservations[len(dbObservations)-1].ID
// Convert and call callback
observations := toModelObservations(dbObservations)
if !callback(observations) {
break // Callback requested stop
}
}
return nil
}
// SearchObservationsFTS performs full-text search on observations using FTS5.
// Falls back to LIKE search if FTS5 fails.
func (s *ObservationStore) SearchObservationsFTS(ctx context.Context, query, project string, limit int) ([]*models.Observation, error) {
if limit <= 0 {
limit = 10
}
// Extract keywords from the query
keywords := extractKeywords(query)
if len(keywords) == 0 {
return nil, nil
}
// Build FTS5 query: keyword1 OR keyword2 OR keyword3
ftsTerms := strings.Join(keywords, " OR ")
// Use FTS5 via raw SQL (GORM can't handle FTS5 MATCH operator)
ftsQuery := `
SELECT o.id, o.sdk_session_id, o.project, COALESCE(o.scope, 'project') as scope, o.type,
o.title, o.subtitle, o.facts, o.narrative, o.concepts, o.files_read, o.files_modified,
o.file_mtimes, o.prompt_number, o.discovery_tokens, o.created_at, o.created_at_epoch,
COALESCE(o.importance_score, 1.0) as importance_score,
COALESCE(o.user_feedback, 0) as user_feedback,
COALESCE(o.retrieval_count, 0) as retrieval_count,
o.last_retrieved_at_epoch, o.score_updated_at_epoch,
COALESCE(o.is_superseded, 0) as is_superseded
FROM observations o
JOIN observations_fts fts ON o.id = fts.rowid
WHERE observations_fts MATCH ?
AND (o.project = ? OR o.scope = 'global')
ORDER BY rank, COALESCE(o.importance_score, 1.0) DESC
LIMIT ?
`
rows, err := s.rawDB.QueryContext(ctx, ftsQuery, ftsTerms, project, limit)
if err != nil {
// FTS failed, try LIKE fallback
return s.searchObservationsLike(ctx, keywords, project, limit)
}
defer rows.Close()
observations, err := scanObservationRows(rows)
if err != nil {
return nil, err
}
// If FTS returned nothing, try LIKE search
if len(observations) == 0 {
return s.searchObservationsLike(ctx, keywords, project, limit)
}
return observations, nil
}
// searchObservationsLike performs fallback LIKE search on observations using GORM.
// Limits to 2 keywords to prevent expensive OR queries that SQLite optimizes poorly.
// This is a fallback path when FTS returns no results, so we prioritize performance.
func (s *ObservationStore) searchObservationsLike(ctx context.Context, keywords []string, project string, limit int) ([]*models.Observation, error) {
if len(keywords) == 0 {
return nil, nil
}
// Limit keywords to prevent excessive OR conditions that hurt query planning.
// SQLite performs significantly better with fewer LIKE conditions.
// Using 2 instead of 5 reduces query complexity from O(15) to O(6) conditions
// (each keyword creates 3 LIKE conditions for title, subtitle, narrative).
const maxKeywords = 2
if len(keywords) > maxKeywords {
keywords = keywords[:maxKeywords]
}
// Build LIKE conditions for each keyword
// Pre-allocate for efficiency: maxKeywords conditions × 3 args each + 1 project arg
conditions := make([]string, 0, len(keywords))
args := make([]any, 0, len(keywords)*3+1)
for _, kw := range keywords {
pattern := "%" + kw + "%"
conditions = append(conditions, "(title LIKE ? OR subtitle LIKE ? OR narrative LIKE ?)")
args = append(args, pattern, pattern, pattern)
}
// Build WHERE clause
whereClause := strings.Join(conditions, " OR ")
fullWhere := "(" + whereClause + ") AND (project = ? OR scope = 'global')"
args = append(args, project)
var dbObservations []Observation
err := s.db.WithContext(ctx).
Where(fullWhere, args...).
Scopes(importanceOrdering()).
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// DeleteObservations deletes observations by IDs.
func (s *ObservationStore) DeleteObservations(ctx context.Context, ids []int64) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := s.db.WithContext(ctx).Delete(&Observation{}, ids)
return result.RowsAffected, result.Error
}
// DeleteObservation deletes a single observation by ID.
func (s *ObservationStore) DeleteObservation(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).Delete(&Observation{}, id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// MarkAsSuperseded marks an observation as superseded (stale).
func (s *ObservationStore) MarkAsSuperseded(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id = ?", id).
Update("is_superseded", 1)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// MarkAsSupersededBatch marks multiple observations as superseded in a single query.
// Returns the number of observations updated and any error.
func (s *ObservationStore) MarkAsSupersededBatch(ctx context.Context, ids []int64) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id IN ?", ids).
Update("is_superseded", 1)
return result.RowsAffected, result.Error
}
// ArchiveObservation archives a single observation with an optional reason.
func (s *ObservationStore) ArchiveObservation(ctx context.Context, id int64, reason string) error {
updates := map[string]any{
"is_archived": 1,
"archived_at_epoch": time.Now().UnixMilli(),
}
if reason != "" {
updates["archived_reason"] = reason
}
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id = ?", id).
Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// UnarchiveObservation restores an archived observation.
func (s *ObservationStore) UnarchiveObservation(ctx context.Context, id int64) error {
result := s.db.WithContext(ctx).
Model(&Observation{}).
Where("id = ?", id).
Updates(map[string]any{
"is_archived": 0,
"archived_at_epoch": nil,
"archived_reason": nil,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("observation %d not found", id)
}
return nil
}
// ArchiveOldObservations archives observations older than the specified age.
// Returns the count of archived observations and their IDs.
func (s *ObservationStore) ArchiveOldObservations(ctx context.Context, project string, maxAgeDays int, reason string) ([]int64, error) {
if maxAgeDays <= 0 {
maxAgeDays = 90 // Default: archive observations older than 90 days
}
cutoffEpoch := time.Now().AddDate(0, 0, -maxAgeDays).UnixMilli()
if reason == "" {
reason = fmt.Sprintf("auto-archived: older than %d days", maxAgeDays)
}
var idsToArchive []int64
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Find observations to archive (not already archived, older than cutoff)
query := tx.Model(&Observation{}).
Where("created_at_epoch < ?", cutoffEpoch).
Where("COALESCE(is_archived, 0) = 0")
if project != "" {
query = query.Where("project = ?", project)
}
if err := query.Pluck("id", &idsToArchive).Error; err != nil {
return err
}
if len(idsToArchive) == 0 {
return nil
}
// Archive the observations
now := time.Now().UnixMilli()
return tx.Model(&Observation{}).
Where("id IN ?", idsToArchive).
Updates(map[string]any{
"is_archived": 1,
"archived_at_epoch": now,
"archived_reason": reason,
}).Error
})
if err != nil {
return nil, err
}
return idsToArchive, nil
}
// GetArchivedObservations retrieves archived observations for a project.
func (s *ObservationStore) GetArchivedObservations(ctx context.Context, project string, limit int) ([]*models.Observation, error) {
var dbObservations []Observation
query := s.db.WithContext(ctx).
Where("COALESCE(is_archived, 0) = 1")
if project != "" {
query = query.Where("project = ?", project)
}
err := query.
Order("archived_at_epoch DESC").
Limit(limit).
Find(&dbObservations).Error
if err != nil {
return nil, err
}
return toModelObservations(dbObservations), nil
}
// GetArchivalStats returns statistics about archived observations.
// Optimized to use a single query instead of 4 separate queries.
func (s *ObservationStore) GetArchivalStats(ctx context.Context, project string) (*ArchivalStats, error) {
// Use a single query with conditional aggregation to get all stats at once.
// This is much faster than 4 separate queries (saves 3 round trips).
type statsResult struct {
OldestEpoch *int64
NewestEpoch *int64
TotalCount int64
ArchivedCount int64
}
var result statsResult
query := s.db.WithContext(ctx).Model(&Observation{}).
Select(`
COUNT(*) as total_count,
SUM(CASE WHEN COALESCE(is_archived, 0) = 1 THEN 1 ELSE 0 END) as archived_count,
MIN(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as oldest_epoch,
MAX(CASE WHEN COALESCE(is_archived, 0) = 1 THEN archived_at_epoch END) as newest_epoch
`)
if project != "" {
query = query.Where("project = ?", project)
}
if err := query.Scan(&result).Error; err != nil {
return nil, err
}
stats := &ArchivalStats{
TotalCount: result.TotalCount,
ArchivedCount: result.ArchivedCount,
ActiveCount: result.TotalCount - result.ArchivedCount,
}
if result.OldestEpoch != nil {
stats.OldestArchivedEpoch = *result.OldestEpoch
}
if result.NewestEpoch != nil {
stats.NewestArchivedEpoch = *result.NewestEpoch
}
return stats, nil
}
// ArchivalStats contains statistics about archived observations.
type ArchivalStats struct {
TotalCount int64 `json:"total_count"`
ActiveCount int64 `json:"active_count"`
ArchivedCount int64 `json:"archived_count"`
OldestArchivedEpoch int64 `json:"oldest_archived_epoch,omitempty"`
NewestArchivedEpoch int64 `json:"newest_archived_epoch,omitempty"`
}
// CleanupOldObservations removes observations beyond the limit for a project.
// Returns the IDs of deleted observations.
func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project string) ([]int64, error) {
// Use a transaction to prevent TOCTOU race condition
var idsToDelete []int64
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Find IDs to keep (most recent MaxObservationsPerProject)
var idsToKeep []int64
err := tx.Model(&Observation{}).
Where("project = ?", project).
Order("created_at_epoch DESC").
Limit(MaxObservationsPerProject).
Pluck("id", &idsToKeep).Error
if err != nil {
return err
}
if len(idsToKeep) == 0 {
return nil
}
// Find IDs to delete (all IDs not in the keep list)
// This happens in the same transaction to prevent race conditions
err = tx.Model(&Observation{}).
Where("project = ? AND id NOT IN ?", project, idsToKeep).
Pluck("id", &idsToDelete).Error
if err != nil {
return err
}
if len(idsToDelete) == 0 {
return nil
}
// Delete the observations
return tx.Delete(&Observation{}, idsToDelete).Error
})
if err != nil {
return nil, err
}
return idsToDelete, nil
}
// ====================
// GORM Scopes (Reusable Query Filters)
// ====================
// projectScopeFilter filters observations by project scope.
// Includes project-scoped observations for the specified project AND global observations.
func projectScopeFilter(project string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("(project = ? AND (scope IS NULL OR scope = 'project')) OR scope = 'global'", project)
}
}
// activeObservationFilter filters for active (non-archived, non-superseded) observations.
// This is more efficient than chaining notSupersededFilter + notArchivedFilter
// as it produces a single WHERE clause for the query optimizer.
func activeObservationFilter() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Where("COALESCE(is_archived, 0) = 0 AND COALESCE(is_superseded, 0) = 0")
}
}
// importanceOrdering orders by importance score DESC, then created_at_epoch DESC.
func importanceOrdering() func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Order("COALESCE(importance_score, 1.0) DESC, created_at_epoch DESC")
}
}
// ====================
// Helper Functions
// ====================
// extractKeywords extracts keywords from a search query.
// Uses package-level commonWords map for O(1) stop word filtering.
func extractKeywords(query string) []string {
words := strings.Fields(strings.ToLower(query))
keywords := make([]string, 0, len(words)) // Pre-allocate for typical case
for _, word := range words {
// Skip short words and common stop words
if len(word) <= 3 {
continue
}
if _, isCommon := commonWords[word]; isCommon {
continue
}
keywords = append(keywords, word)
}
return keywords
}
// scanObservationRows scans multiple observations from raw SQL rows.
func scanObservationRows(rows *sql.Rows) ([]*models.Observation, error) {
// Pre-allocate with reasonable initial capacity to avoid repeated slice growth
observations := make([]*models.Observation, 0, 64)
for rows.Next() {
obs, err := scanObservation(rows)
if err != nil {
return nil, err
}
observations = append(observations, obs)
}
return observations, rows.Err()
}
// scanObservation scans a single observation from a row scanner.
func scanObservation(scanner interface{ Scan(...any) error }) (*models.Observation, error) {
var obs models.Observation
var factsJSON, conceptsJSON, filesReadJSON, filesModifiedJSON, fileMtimesJSON []byte
var isSuperseded int
err := scanner.Scan(
&obs.ID, &obs.SDKSessionID, &obs.Project, &obs.Scope, &obs.Type,
&obs.Title, &obs.Subtitle, &factsJSON, &obs.Narrative, &conceptsJSON,
&filesReadJSON, &filesModifiedJSON, &fileMtimesJSON,
&obs.PromptNumber, &obs.DiscoveryTokens, &obs.CreatedAt, &obs.CreatedAtEpoch,
&obs.ImportanceScore, &obs.UserFeedback, &obs.RetrievalCount,
&obs.LastRetrievedAt, &obs.ScoreUpdatedAt, &isSuperseded,
)
if err != nil {
return nil, err
}
// Unmarshal JSON fields (data comes from DB, should always be valid)
if len(factsJSON) > 0 {
_ = json.Unmarshal(factsJSON, &obs.Facts)
}
if len(conceptsJSON) > 0 {
_ = json.Unmarshal(conceptsJSON, &obs.Concepts)
}
if len(filesReadJSON) > 0 {
_ = json.Unmarshal(filesReadJSON, &obs.FilesRead)
}
if len(filesModifiedJSON) > 0 {
_ = json.Unmarshal(filesModifiedJSON, &obs.FilesModified)
}
if len(fileMtimesJSON) > 0 {
_ = json.Unmarshal(fileMtimesJSON, &obs.FileMtimes)
}
// Convert int to bool for IsSuperseded
obs.IsSuperseded = isSuperseded != 0
return &obs, nil
}
// toModelObservation converts a GORM Observation to pkg/models.Observation.
func toModelObservation(o *Observation) *models.Observation {
return &models.Observation{
ID: o.ID,
SDKSessionID: o.SDKSessionID,
Project: o.Project,
Scope: o.Scope,
Type: o.Type,
Title: o.Title,
Subtitle: o.Subtitle,
Facts: o.Facts,
Narrative: o.Narrative,
Concepts: o.Concepts,
FilesRead: o.FilesRead,
FilesModified: o.FilesModified,
FileMtimes: o.FileMtimes,
PromptNumber: o.PromptNumber,
DiscoveryTokens: o.DiscoveryTokens,
CreatedAt: o.CreatedAt,
CreatedAtEpoch: o.CreatedAtEpoch,
ImportanceScore: o.ImportanceScore,
UserFeedback: o.UserFeedback,
RetrievalCount: o.RetrievalCount,
LastRetrievedAt: o.LastRetrievedAt,
ScoreUpdatedAt: o.ScoreUpdatedAt,
IsSuperseded: o.IsSuperseded != 0, // Convert int to bool
}
}
// toModelObservations converts a slice of GORM Observation to pkg/models.Observation.
func toModelObservations(observations []Observation) []*models.Observation {
result := make([]*models.Observation, len(observations))
for i := range observations {
result[i] = toModelObservation(&observations[i])
}
return result
}
// nullInt64 converts an int to sql.NullInt64.
func nullInt64(val int) sql.NullInt64 {
if val == 0 {
return sql.NullInt64{Valid: false}
}
return sql.NullInt64{Int64: int64(val), Valid: true}
}